#include "constants.h"
#include "factory_resolver.h"
#include "future_compat.h"
#include "multistep.h"
#include "person.h"
#include "rand48.h"
#include "rl_sim_cpp.h"
#include "simulation_stats.h"
#include "trace_logger.h"
 
#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <chrono>
#include <cmath>
#include <functional>
#include <thread>
 
using namespace std;
 
std::string get_dist_str(const reinforcement_learning::slot_response& response);
std::string get_dist_str(const reinforcement_learning::slot_entry& response);
 
int rl_sim::loop()
{
  if (!init()) { return -1; }
 
  switch (_loop_kind)
  {
    case CB:
      return cb_loop();
    case CA:
      return ca_loop();
    case CCB:
      return ccb_loop();
    case Slates:
      return slates_loop();
    case Multistep:
      return multistep_loop();
    default:
      std::cout << "Invalid loop kind:" << _loop_kind << std::endl;
      return -1;
  }
}
 
int rl_sim::cb_loop()
{
  r::ranking_response response;
  simulation_stats<size_t> stats;
 
  while (_run_loop)
  {
    auto& p = pick_a_random_person();
    const auto context_features = p.get_features();
    const auto action_features = get_action_features();
    const auto context_json = create_context_json(context_features, action_features);
    const auto req_id = create_event_id();
    r::api_status status;
 
    
    if (_rl->choose_rank(req_id.c_str(), context_json.c_str(), response, &status) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
 
    
    size_t chosen_action = 0;
    if (response.get_chosen_action_id(chosen_action) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
 
    
    const auto outcome = p.get_outcome(_topics[chosen_action], _random_seed);
 
    
    if (_rl->report_outcome(req_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
 
    stats.record(p.id(), chosen_action, outcome);
 
    if (!_quiet)
    {
      std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
                << outcome << ", dist, " << get_dist_str(response) << ", " << stats.get_stats(p.id(), chosen_action)
                << std::endl;
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
  }
 
  return 0;
}
 
std::string create_episode_id(size_t episode_id)
{
  std::ostringstream oss;
  oss << "episode"
      << "-" << episode_id;
  return oss.str();
}
 
int rl_sim::multistep_loop()
{
  r::ranking_response response;
  simulation_stats<size_t> stats;
 
  const size_t episode_length = 2;
  size_t episode_indx = 0;
 
  
  _num_events *= episode_length;
 
  while (_run_loop)
  {
    const std::string episode_id = create_episode_id(episode_indx++);
    r::api_status status;
    r::episode_state episode(episode_id.c_str());
 
    std::string previous_id;
    float episodic_outcome = 0;
 
    for (size_t i = 0; i < episode_length; ++i)
    {
      auto& p = pick_a_random_person();
      const auto context_features = p.get_features();
      const auto action_features = get_action_features();
      const auto context_json = create_context_json(context_features, action_features);
      const auto req_id = create_event_id();
 
      r::ranking_response response1;
      if (_rl->request_episodic_decision(req_id.c_str(), i == 0 ? nullptr : previous_id.c_str(), context_json.c_str(),
              response, episode, &status) != err::success)
      {
        std::cout << status.get_error_msg() << std::endl;
        return -1;
      }
 
      size_t chosen_action = 0;
      if (response.get_chosen_action_id(chosen_action) != err::success)
      {
        std::cout << status.get_error_msg() << std::endl;
        continue;
      }
 
      const auto outcome_per_step = p.get_outcome(_topics[chosen_action], _random_seed);
      stats.record(p.id(), chosen_action, outcome_per_step);
 
      if (!_quiet)
      {
        std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
                  << outcome_per_step << ", dist, " << get_dist_str(response) << ", "
                  << stats.get_stats(p.id(), chosen_action) << std::endl;
      }
 
      episodic_outcome += outcome_per_step;
      previous_id = req_id;
    }
 
    if (_rl->report_outcome(episode.get_episode_id(), episodic_outcome, &status) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
 
    std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
  }
  return 0;
}
 
int rl_sim::ca_loop()
{
  r::continuous_action_response response;
  simulation_stats<float> stats;
  while (_run_loop)
  {
    auto& joint = pick_a_random_joint();
    const auto context_features = joint.get_features();
    const auto context_json = create_context_json(context_features);
    const auto req_id = create_event_id();
    r::api_status status;
 
    RL_IGNORE_DEPRECATED_USAGE_START
    if (_rl->request_continuous_action(req_id.c_str(), context_json, response, &status) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
    RL_IGNORE_DEPRECATED_USAGE_END
    const auto chosen_action = response.get_chosen_action();
    const auto outcome = joint.get_outcome(chosen_action, _random_seed);
    if (_rl->report_outcome(req_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
    {
      std::cout << status.get_error_msg() << std::endl;
    }
 
    stats.record(joint.id(), chosen_action, outcome);
 
    if (!_quiet)
    {
      std::cout << " " << stats.count() << " - ctxt: " << joint.id() << ", action: " << chosen_action
                << ", outcome: " << outcome << ", dist: " << response.get_chosen_action_pdf_value() << ", "
                << stats.get_stats(joint.id(), chosen_action) << std::endl;
    }
 
    std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
  }
  return 0;
}
 
int rl_sim::ccb_loop()
{
  r::multi_slot_response decision;
  simulation_stats<size_t> stats;
 
  while (_run_loop)
  {
    auto& p = pick_a_random_person();
    const auto context_features = p.get_features();
    const auto action_features = get_action_features();
    const auto event_id = create_event_id();
    const auto slot_json = get_slot_features();
    const auto context_json = create_context_json(context_features, action_features, slot_json);
    std::cout << context_json << std::endl;
    r::api_status status;
 
    
    RL_IGNORE_DEPRECATED_USAGE_START
    if (_rl->request_multi_slot_decision(event_id.c_str(), context_json, decision, &status) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
    RL_IGNORE_DEPRECATED_USAGE_END
 
    auto index = 0;
    for (auto& response : decision)
    {
      const auto chosen_action = response.get_action_id();
      const auto outcome = p.get_outcome(_topics[chosen_action], _random_seed);
 
      
      if (_rl->report_outcome(event_id.c_str(), index, outcome, &status) != err::success && outcome > 0.00001f)
      {
        std::cout << status.get_error_msg() << std::endl;
        continue;
      }
 
      stats.record(p.id(), chosen_action, outcome);
 
      if (!_quiet)
      {
        std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", slot, "
                  << index << ", outcome, " << outcome << ", dist, " << get_dist_str(response) << ", "
                  << stats.get_stats(p.id(), chosen_action) << std::endl;
      }
      index++;
    }
 
    std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
  }
 
  return 0;
}
 
std::string get_slates_slot_features(size_t slot_count)
{
  std::ostringstream oss;
  
  
  oss << R"("_slots": [ )";
  for (auto idx = 0; idx < slot_count - 1; ++idx) { oss << R"({ "slot_id":"__)" << idx << R"("}, )"; }
  oss << R"({ "slot_id":"__)" << slot_count << R"("}] )";
  return oss.str();
}
 
int rl_sim::slates_loop()
{
  r::multi_slot_response decision;
  simulation_stats<size_t> stats;
 
  while (_run_loop)
  {
    auto& p = pick_a_random_person();
    const auto context_features = p.get_features();
    const auto action_features = get_slates_action_features();
    const auto event_id = create_event_id();
 
    const auto slot_json = get_slates_slot_features(NUM_SLATES_SLOTS);
    const auto context_json = create_context_json(context_features, action_features, slot_json);
    std::cout << context_json << std::endl;
    r::api_status status;
 
    
    RL_IGNORE_DEPRECATED_USAGE_START
    if (_rl->request_multi_slot_decision(event_id.c_str(), context_json.c_str(), decision, &status) != err::success)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
    RL_IGNORE_DEPRECATED_USAGE_END
 
    float outcome = 0;
    int index = 0;
    auto actions_per_slot = _topics.size() / NUM_SLATES_SLOTS;
 
    for (auto& response : decision)
    {
      const auto chosen_action = response.get_action_id() + index * actions_per_slot;
      const auto slot_outcome = p.get_outcome(_topics[chosen_action], _random_seed);  
      stats.record(event_id, chosen_action, slot_outcome);
      outcome += slot_outcome;
 
      if (!_quiet)
      {
        std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", slot, "
                  << index << ", outcome, " << outcome << ", dist, " << get_dist_str(response) << ", "
                  << stats.get_stats(p.id(), chosen_action) << std::endl;
      }
      index++;
    }
 
    
    if (_rl->report_outcome(event_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
    {
      std::cout << status.get_error_msg() << std::endl;
      continue;
    }
 
    std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
  }
 
  return 0;
}
 
person& rl_sim::pick_a_random_person()
{
  size_t idx = static_cast<size_t>(rand48(_random_seed) * (_people.size() + 1));
  return _people[std::min(idx, _people.size() - 1)];
}
 
joint& rl_sim::pick_a_random_joint()
{
  size_t idx = static_cast<size_t>(rand48(_random_seed) * (_robot_joints.size() + 1));
  return _robot_joints[std::min(idx, _robot_joints.size() - 1)];
}
 
int rl_sim::load_config_from_json(const std::string& file_name, u::configuration& config, r::api_status* status)
{
  std::string config_str;
 
  
 
  
  return cfg::create_from_json(config_str, config, nullptr, status);
}
 
int rl_sim::load_file(const std::string& file_name, std::string& config_str, r::api_status* status)
{
  std::ifstream fs;
  fs.open(file_name);
  if (!fs.good()) { 
RETURN_ERROR_LS(
nullptr, status, invalid_argument) << 
"Cannot open file: " << file_name; }
 
  std::stringstream buffer;
  buffer << fs.rdbuf();
  config_str = buffer.str();
  return err::success;
}
 
 
struct throughput_tracking_sender : public reinforcement_learning::i_sender
{
  throughput_tracking_sender(std::unique_ptr<reinforcement_learning::i_sender> inner_sender, std::string name,
      reinforcement_learning::i_trace* trace, int print_interval)
      : _inner_sender(std::move(inner_sender)), _name(std::move(name)), _trace(trace), _print_interval(print_interval)
  {
  }
 
  int init(
  {
    _begin_time = std::chrono::steady_clock::now();
    return _inner_sender->init(config, status);
  };
 
protected:
  {
    _bytes_sent += data->buffer_filled_size();
    _messages_sent++;
    _print_counter++;
    if (_print_counter >= _print_interval)
    {
      auto now = std::chrono::steady_clock::now();
      auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - _begin_time).count();
      std::ostringstream oss;
      oss << _name << " throughput: " << _bytes_sent / elapsed << " bytes/s, "
          << static_cast<float>(_messages_sent) / elapsed << " messages/s, " << _bytes_sent << " total bytes, "
          << _messages_sent << " total messages." << std::endl;
      TRACE_INFO(_trace, oss.str());
      _print_counter = 0;
    }
 
    return _inner_sender->send(data, status);
  };
 
private:
  std::unique_ptr<reinforcement_learning::i_sender> _inner_sender;
  std::string _name;
  size_t _bytes_sent = 0;
  size_t _messages_sent = 0;
  std::chrono::steady_clock::time_point _begin_time;
  reinforcement_learning::i_trace* _trace;
  int _print_interval;
  int _print_counter = 0;
};
 
reinforcement_learning::sender_factory_t::create_fn wrap_sender_generate_for_throughput_sender(const std::string& name)
{
  return [=](std::unique_ptr<reinforcement_learning::i_sender>& retval, const u::configuration& cfg,
             reinforcement_learning::error_callback_fn* error_cb, reinforcement_learning::i_trace* trace_logger,
  {
    std::unique_ptr<reinforcement_learning::i_sender> sender;
    auto res =
        reinforcement_learning::sender_factory.create(sender, name, cfg, std::move(error_cb), trace_logger, status);
    auto print_interval = cfg.get_int("thoughputsender.printinterval", 1);
    retval.reset(new throughput_tracking_sender(std::move(sender), name, trace_logger, print_interval));
    return 0;
  };
}
 
int rl_sim::init_rl()
{
  r::api_status status;
  u::configuration config;
  
  const auto config_file = _options["json_config"].as<std::string>();
  if (load_config_from_json(config_file, config, &status) != err::success)
  {
    std::cout << status.get_error_msg() << std::endl;
    return -1;
  }
 
  if (_options["log_to_file"].as<bool>())
  {
    config.
set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER);
 
    config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
  }
 
  if (!_options["get_model"].as<bool>())
  {
    
    config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
  }
 
  if (_options["log_timestamp"].as<bool>())
  {
    
    config.set(r::name::TIME_PROVIDER_IMPLEMENTATION, r::value::CLOCK_TIME_PROVIDER);
  }
 
  
  if (!_quiet) { config.set(r::name::TRACE_LOG_IMPLEMENTATION, r::value::CONSOLE_TRACE_LOGGER); }
 
  reinforcement_learning::sender_factory_t* sender_factory = &reinforcement_learning::sender_factory;
  reinforcement_learning::sender_factory_t factory;
  if (_options.count("throughput") != 0u)
  {
    factory.register_type(reinforcement_learning::value::OBSERVATION_EH_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::OBSERVATION_EH_SENDER));
    factory.register_type(reinforcement_learning::value::INTERACTION_EH_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::INTERACTION_EH_SENDER));
    factory.register_type(reinforcement_learning::value::EPISODE_EH_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_EH_SENDER));
    factory.register_type(reinforcement_learning::value::OBSERVATION_HTTP_API_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::OBSERVATION_HTTP_API_SENDER));
    factory.register_type(reinforcement_learning::value::INTERACTION_HTTP_API_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::INTERACTION_HTTP_API_SENDER));
    factory.register_type(reinforcement_learning::value::EPISODE_HTTP_API_SENDER,
        wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_HTTP_API_SENDER));
    sender_factory = &factory;
  }
  
#ifdef LINK_AZURE_LIBS
  
  
  
  
  
  
  using namespace std::placeholders;
  reinforcement_learning::oauth_callback_t callback =
      std::bind(&azure_credentials_provider_t::get_token, &_creds, _1, _2, _3, _4);
  reinforcement_learning::register_default_factories_callback(callback);
#endif
 
  
  _rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
      &reinforcement_learning::trace_logger_factory, &reinforcement_learning::data_transport_factory,
      &reinforcement_learning::model_factory, sender_factory, &reinforcement_learning::time_provider_factory));
 
  if (_rl->init(&status) != err::success)
  {
    std::cout << status.get_error_msg() << std::endl;
    return -1;
  }
 
  if (!_quiet) { std::cout << " API Config " << config; }
 
  return err::success;
}
 
bool rl_sim::init_sim_world()
{
  
  _topics = {"SkiConditions-VT", "HerbGarden", "BeyBlades", "NYCLiving", "MachineLearning"};
 
  _slot_sizes = {"large", "medium", "small"};
 
  
  person::topic_prob tp = {
      {_topics[0], 0.08f}, {_topics[1], 0.03f}, {_topics[2], 0.05f}, {_topics[3], 0.03f}, {_topics[4], 0.25f}};
  _people.emplace_back("rnc", "engineering", "hiking", "spock", tp);
 
  
  tp = {{_topics[0], 0.08f}, {_topics[1], 0.30f}, {_topics[2], 0.02f}, {_topics[3], 0.02f}, {_topics[4], 0.10f}};
  _people.emplace_back("mk", "psychology", "kids", "7of9", tp);
 
  return true;
}
 
bool rl_sim::init_continuous_sim_world()
{
  
 
  _friction = {25.4f, 41.2f, 66.5f, 81.9f, 104.4f};
 
  
  
  
 
  
  joint::friction_prob fb = {{_friction[0], 0.08f}, {_friction[1], 0.03f}, {_friction[2], 0.05f}, {_friction[3], 0.03f},
      {_friction[4], 0.25f}};
 
  _robot_joints.emplace_back("j1", 20.3, 102.4, -10.2, fb);
 
  
  fb = {{_friction[0], 0.08f}, {_friction[1], 0.30f}, {_friction[2], 0.02f}, {_friction[3], 0.02f},
      {_friction[4], 0.10f}};
 
  _robot_joints.emplace_back("j2", 40.6, 30.8, 98.5, fb);
 
  return true;
}
 
bool rl_sim::init()
{
  if (init_rl() != err::success) { return false; }
  if (!init_sim_world()) { return false; }
  if (!init_continuous_sim_world()) { return false; }
  return true;
}
 
std::string rl_sim::get_action_features()
{
  std::ostringstream oss;
  
  
  oss << R"("_multi": [ )";
  for (auto idx = 0; idx < _topics.size() - 1; ++idx)
  {
    oss << R"({ "TAction":{"topic":")" << _topics[idx] << R"("} }, )";
  }
  oss << R"({ "TAction":{"topic":")" << _topics.back() << R"("} } ])";
  return oss.str();
}
 
std::string rl_sim::get_slates_action_features()
{
  std::ostringstream oss;
  
  
  
  oss << R"("_multi": [ )";
  for (auto idx = 0; idx < _topics.size() - 1; ++idx)
  {
    oss << R"({ "_slot_id":)" << (idx / NUM_SLATES_SLOTS);
    oss << R"(, "TAction":{"topic":")" << _topics[idx] << R"("} }, )";
  }
  oss << R"({ "_slot_id":)" << (NUM_SLATES_SLOTS - 1);
  oss << R"(, "TAction":{"topic":")" << _topics.back() << R"("} } ])";
  return oss.str();
}
 
std::string rl_sim::get_slot_features()
{
  std::ostringstream oss;
  
  
  oss << R"("_slots": [ )";
  for (auto idx = 0; idx < NUM_SLOTS - 1; ++idx) { oss << R"({ "_size":")" << _slot_sizes[idx] << R"("}, )"; }
  oss << R"({ "_size":")" << _slot_sizes.back() << R"("}] )";
  return oss.str();
}
 
{
  std::cout << 
"Background error in Inference API: " << status.
get_error_msg() << std::endl;
 
  std::cout << "Exiting simulation loop." << std::endl;
  _run_loop = false;
}
 
std::string rl_sim::create_context_json(const std::string& cntxt)
{
  std::ostringstream oss;
  oss << "{ " << cntxt << " }";
  return oss.str();
}
 
std::string rl_sim::create_context_json(const std::string& cntxt, const std::string& action)
{
  std::ostringstream oss;
  oss << "{ " << cntxt << ", " << action << " }";
  return oss.str();
}
 
std::string rl_sim::create_context_json(const std::string& cntxt, const std::string& action, const std::string& slots)
{
  std::ostringstream oss;
  oss << "{ " << cntxt << ", " << action << ", " << slots << " }";
  return oss.str();
}
 
std::string rl_sim::create_event_id()
{
  if (_num_events > 0 && ++_current_events >= _num_events) { _run_loop = false; }
 
  if (_random_ids) { return boost::uuids::to_string(boost::uuids::random_generator()()); }
 
  std::ostringstream oss;
  oss << "event_" << _current_events;
  return oss.str();
}
 
rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB)
{
  if (_options["ccb"].as<bool>()) { _loop_kind = CCB; }
  else if (_options["slates"].as<bool>()) { _loop_kind = Slates; }
  else if (_options["ca"].as<bool>()) { _loop_kind = CA; }
  else if (_options["multistep"].as<bool>()) { _loop_kind = Multistep; }
 
  _num_events = _options["num_events"].as<int>();
  _random_seed = _options["random_seed"].as<uint64_t>();
  _delay = _options["delay"].as<int64_t>();
  _quiet = _options["quiet"].as<bool>();
  _random_ids = _options["random_ids"].as<bool>();
}
 
{
  std::string ret;
  ret += "(";
  for (const auto& ap_pair : response)
  {
    ret += "[" + to_string(ap_pair.action_id) + ",";
    ret += to_string(ap_pair.probability) + "]";
    ret += " ,";
  }
  ret += ")";
  return ret;
}
 
std::string get_dist_str(const reinforcement_learning::slot_response& response)
{
  std::string ret;
  ret += "(";
  ret += "[" + std::string(response.get_slot_id()) + ",";
  ret += to_string(response.get_action_id()) + ",";
  ret += to_string(response.get_probability()) + "]";
  ret += " ,";
  ret += ")";
  return ret;
}
 
std::string get_dist_str(const reinforcement_learning::slot_entry& response)
{
  std::string ret;
  ret += "(";
  ret += "[" + string(response.get_id()) + ",";
  ret += to_string(response.get_action_id()) + ",";
  ret += to_string(response.get_probability()) + "]";
  ret += " ,";
  ret += ")";
  return ret;
}
std::string get_dist_str(const reinforcement_learning::decision_response& response)
{
  std::string ret;
  ret += "(";
  for (const auto& resp : response)
  {
    ret += get_dist_str(resp);
    ret += " ,";
  }
  ret += ")";
  return ret;
}
api_status definition. api_status is used to return error information to the caller....
 
#define RETURN_IF_FAIL(x)
Error reporting macro to test and return on error.
Definition: api_status.h:244
 
#define RETURN_ERROR_LS(trace, status, code)
Error reporting macro used with left shift operator.
Definition: api_status.h:237
 
Report status of all API calls.
Definition: api_status.h:22
 
const char * get_error_msg() const
(API Error Codes) Get the error msg string All API calls will return a status code....
 
choose_rank() returns the action choice using ranking_response. ranking_response contains all the act...
Definition: ranking_response.h:29
 
Configuration class to initialize the API Represents a collection of (name,value) pairs used to confi...
Definition: configuration.h:31
 
void set(const char *name, const char *value)
Sets the value for a given name. It overrides any existing values for that name.
 
RL Inference API definition.