#include "constants.h"
#include "factory_resolver.h"
#include "future_compat.h"
#include "multistep.h"
#include "person.h"
#include "rand48.h"
#include "rl_sim_cpp.h"
#include "simulation_stats.h"
#include "trace_logger.h"
#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <chrono>
#include <cmath>
#include <functional>
#include <thread>
using namespace std;
std::string get_dist_str(const reinforcement_learning::slot_response& response);
std::string get_dist_str(const reinforcement_learning::slot_entry& response);
int rl_sim::loop()
{
if (!init()) { return -1; }
switch (_loop_kind)
{
case CB:
return cb_loop();
case CA:
return ca_loop();
case CCB:
return ccb_loop();
case Slates:
return slates_loop();
case Multistep:
return multistep_loop();
default:
std::cout << "Invalid loop kind:" << _loop_kind << std::endl;
return -1;
}
}
int rl_sim::cb_loop()
{
r::ranking_response response;
simulation_stats<size_t> stats;
while (_run_loop)
{
auto& p = pick_a_random_person();
const auto context_features = p.get_features();
const auto action_features = get_action_features();
const auto context_json = create_context_json(context_features, action_features);
const auto req_id = create_event_id();
r::api_status status;
if (_rl->choose_rank(req_id.c_str(), context_json.c_str(), response, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
size_t chosen_action = 0;
if (response.get_chosen_action_id(chosen_action) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
const auto outcome = p.get_outcome(_topics[chosen_action], _random_seed);
if (_rl->report_outcome(req_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
stats.record(p.id(), chosen_action, outcome);
if (!_quiet)
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
<< outcome << ", dist, " << get_dist_str(response) << ", " << stats.get_stats(p.id(), chosen_action)
<< std::endl;
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
}
std::string create_episode_id(size_t episode_id)
{
std::ostringstream oss;
oss << "episode"
<< "-" << episode_id;
return oss.str();
}
int rl_sim::multistep_loop()
{
r::ranking_response response;
simulation_stats<size_t> stats;
const size_t episode_length = 2;
size_t episode_indx = 0;
_num_events *= episode_length;
while (_run_loop)
{
const std::string episode_id = create_episode_id(episode_indx++);
r::api_status status;
r::episode_state episode(episode_id.c_str());
std::string previous_id;
float episodic_outcome = 0;
for (size_t i = 0; i < episode_length; ++i)
{
auto& p = pick_a_random_person();
const auto context_features = p.get_features();
const auto action_features = get_action_features();
const auto context_json = create_context_json(context_features, action_features);
const auto req_id = create_event_id();
r::ranking_response response1;
if (_rl->request_episodic_decision(req_id.c_str(), i == 0 ? nullptr : previous_id.c_str(), context_json.c_str(),
response, episode, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
return -1;
}
size_t chosen_action = 0;
if (response.get_chosen_action_id(chosen_action) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
const auto outcome_per_step = p.get_outcome(_topics[chosen_action], _random_seed);
stats.record(p.id(), chosen_action, outcome_per_step);
if (!_quiet)
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
<< outcome_per_step << ", dist, " << get_dist_str(response) << ", "
<< stats.get_stats(p.id(), chosen_action) << std::endl;
}
episodic_outcome += outcome_per_step;
previous_id = req_id;
}
if (_rl->report_outcome(episode.get_episode_id(), episodic_outcome, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
}
int rl_sim::ca_loop()
{
r::continuous_action_response response;
simulation_stats<float> stats;
while (_run_loop)
{
auto& joint = pick_a_random_joint();
const auto context_features = joint.get_features();
const auto context_json = create_context_json(context_features);
const auto req_id = create_event_id();
r::api_status status;
RL_IGNORE_DEPRECATED_USAGE_START
if (_rl->request_continuous_action(req_id.c_str(), context_json, response, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
RL_IGNORE_DEPRECATED_USAGE_END
const auto chosen_action = response.get_chosen_action();
const auto outcome = joint.get_outcome(chosen_action, _random_seed);
if (_rl->report_outcome(req_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
{
std::cout << status.get_error_msg() << std::endl;
}
stats.record(joint.id(), chosen_action, outcome);
if (!_quiet)
{
std::cout << " " << stats.count() << " - ctxt: " << joint.id() << ", action: " << chosen_action
<< ", outcome: " << outcome << ", dist: " << response.get_chosen_action_pdf_value() << ", "
<< stats.get_stats(joint.id(), chosen_action) << std::endl;
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
}
int rl_sim::ccb_loop()
{
r::multi_slot_response decision;
simulation_stats<size_t> stats;
while (_run_loop)
{
auto& p = pick_a_random_person();
const auto context_features = p.get_features();
const auto action_features = get_action_features();
const auto event_id = create_event_id();
const auto slot_json = get_slot_features();
const auto context_json = create_context_json(context_features, action_features, slot_json);
std::cout << context_json << std::endl;
r::api_status status;
RL_IGNORE_DEPRECATED_USAGE_START
if (_rl->request_multi_slot_decision(event_id.c_str(), context_json, decision, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
RL_IGNORE_DEPRECATED_USAGE_END
auto index = 0;
for (auto& response : decision)
{
const auto chosen_action = response.get_action_id();
const auto outcome = p.get_outcome(_topics[chosen_action], _random_seed);
if (_rl->report_outcome(event_id.c_str(), index, outcome, &status) != err::success && outcome > 0.00001f)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
stats.record(p.id(), chosen_action, outcome);
if (!_quiet)
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", slot, "
<< index << ", outcome, " << outcome << ", dist, " << get_dist_str(response) << ", "
<< stats.get_stats(p.id(), chosen_action) << std::endl;
}
index++;
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
}
std::string get_slates_slot_features(size_t slot_count)
{
std::ostringstream oss;
oss << R"("_slots": [ )";
for (auto idx = 0; idx < slot_count - 1; ++idx) { oss << R"({ "slot_id":"__)" << idx << R"("}, )"; }
oss << R"({ "slot_id":"__)" << slot_count << R"("}] )";
return oss.str();
}
int rl_sim::slates_loop()
{
r::multi_slot_response decision;
simulation_stats<size_t> stats;
while (_run_loop)
{
auto& p = pick_a_random_person();
const auto context_features = p.get_features();
const auto action_features = get_slates_action_features();
const auto event_id = create_event_id();
const auto slot_json = get_slates_slot_features(NUM_SLATES_SLOTS);
const auto context_json = create_context_json(context_features, action_features, slot_json);
std::cout << context_json << std::endl;
r::api_status status;
RL_IGNORE_DEPRECATED_USAGE_START
if (_rl->request_multi_slot_decision(event_id.c_str(), context_json.c_str(), decision, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
RL_IGNORE_DEPRECATED_USAGE_END
float outcome = 0;
int index = 0;
auto actions_per_slot = _topics.size() / NUM_SLATES_SLOTS;
for (auto& response : decision)
{
const auto chosen_action = response.get_action_id() + index * actions_per_slot;
const auto slot_outcome = p.get_outcome(_topics[chosen_action], _random_seed);
stats.record(event_id, chosen_action, slot_outcome);
outcome += slot_outcome;
if (!_quiet)
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", slot, "
<< index << ", outcome, " << outcome << ", dist, " << get_dist_str(response) << ", "
<< stats.get_stats(p.id(), chosen_action) << std::endl;
}
index++;
}
if (_rl->report_outcome(event_id.c_str(), outcome, &status) != err::success && outcome > 0.00001f)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
}
person& rl_sim::pick_a_random_person()
{
size_t idx = static_cast<size_t>(rand48(_random_seed) * (_people.size() + 1));
return _people[std::min(idx, _people.size() - 1)];
}
joint& rl_sim::pick_a_random_joint()
{
size_t idx = static_cast<size_t>(rand48(_random_seed) * (_robot_joints.size() + 1));
return _robot_joints[std::min(idx, _robot_joints.size() - 1)];
}
int rl_sim::load_config_from_json(const std::string& file_name, u::configuration& config, r::api_status* status)
{
std::string config_str;
return cfg::create_from_json(config_str, config, nullptr, status);
}
int rl_sim::load_file(const std::string& file_name, std::string& config_str, r::api_status* status)
{
std::ifstream fs;
fs.open(file_name);
if (!fs.good()) {
RETURN_ERROR_LS(
nullptr, status, invalid_argument) <<
"Cannot open file: " << file_name; }
std::stringstream buffer;
buffer << fs.rdbuf();
config_str = buffer.str();
return err::success;
}
struct throughput_tracking_sender : public reinforcement_learning::i_sender
{
throughput_tracking_sender(std::unique_ptr<reinforcement_learning::i_sender> inner_sender, std::string name,
reinforcement_learning::i_trace* trace, int print_interval)
: _inner_sender(std::move(inner_sender)), _name(std::move(name)), _trace(trace), _print_interval(print_interval)
{
}
int init(
{
_begin_time = std::chrono::steady_clock::now();
return _inner_sender->init(config, status);
};
protected:
{
_bytes_sent += data->buffer_filled_size();
_messages_sent++;
_print_counter++;
if (_print_counter >= _print_interval)
{
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - _begin_time).count();
std::ostringstream oss;
oss << _name << " throughput: " << _bytes_sent / elapsed << " bytes/s, "
<< static_cast<float>(_messages_sent) / elapsed << " messages/s, " << _bytes_sent << " total bytes, "
<< _messages_sent << " total messages." << std::endl;
TRACE_INFO(_trace, oss.str());
_print_counter = 0;
}
return _inner_sender->send(data, status);
};
private:
std::unique_ptr<reinforcement_learning::i_sender> _inner_sender;
std::string _name;
size_t _bytes_sent = 0;
size_t _messages_sent = 0;
std::chrono::steady_clock::time_point _begin_time;
reinforcement_learning::i_trace* _trace;
int _print_interval;
int _print_counter = 0;
};
reinforcement_learning::sender_factory_t::create_fn wrap_sender_generate_for_throughput_sender(const std::string& name)
{
return [=](std::unique_ptr<reinforcement_learning::i_sender>& retval, const u::configuration& cfg,
reinforcement_learning::error_callback_fn* error_cb, reinforcement_learning::i_trace* trace_logger,
{
std::unique_ptr<reinforcement_learning::i_sender> sender;
auto res =
reinforcement_learning::sender_factory.create(sender, name, cfg, std::move(error_cb), trace_logger, status);
auto print_interval = cfg.get_int("thoughputsender.printinterval", 1);
retval.reset(new throughput_tracking_sender(std::move(sender), name, trace_logger, print_interval));
return 0;
};
}
int rl_sim::init_rl()
{
r::api_status status;
u::configuration config;
const auto config_file = _options["json_config"].as<std::string>();
if (load_config_from_json(config_file, config, &status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
return -1;
}
if (_options["log_to_file"].as<bool>())
{
config.
set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER);
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
}
if (!_options["get_model"].as<bool>())
{
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
}
if (_options["log_timestamp"].as<bool>())
{
config.set(r::name::TIME_PROVIDER_IMPLEMENTATION, r::value::CLOCK_TIME_PROVIDER);
}
if (!_quiet) { config.set(r::name::TRACE_LOG_IMPLEMENTATION, r::value::CONSOLE_TRACE_LOGGER); }
reinforcement_learning::sender_factory_t* sender_factory = &reinforcement_learning::sender_factory;
reinforcement_learning::sender_factory_t factory;
if (_options.count("throughput") != 0u)
{
factory.register_type(reinforcement_learning::value::OBSERVATION_EH_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::OBSERVATION_EH_SENDER));
factory.register_type(reinforcement_learning::value::INTERACTION_EH_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::INTERACTION_EH_SENDER));
factory.register_type(reinforcement_learning::value::EPISODE_EH_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_EH_SENDER));
factory.register_type(reinforcement_learning::value::OBSERVATION_HTTP_API_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::OBSERVATION_HTTP_API_SENDER));
factory.register_type(reinforcement_learning::value::INTERACTION_HTTP_API_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::INTERACTION_HTTP_API_SENDER));
factory.register_type(reinforcement_learning::value::EPISODE_HTTP_API_SENDER,
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_HTTP_API_SENDER));
sender_factory = &factory;
}
#ifdef LINK_AZURE_LIBS
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_token, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
#endif
_rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
&reinforcement_learning::trace_logger_factory, &reinforcement_learning::data_transport_factory,
&reinforcement_learning::model_factory, sender_factory, &reinforcement_learning::time_provider_factory));
if (_rl->init(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
return -1;
}
if (!_quiet) { std::cout << " API Config " << config; }
return err::success;
}
bool rl_sim::init_sim_world()
{
_topics = {"SkiConditions-VT", "HerbGarden", "BeyBlades", "NYCLiving", "MachineLearning"};
_slot_sizes = {"large", "medium", "small"};
person::topic_prob tp = {
{_topics[0], 0.08f}, {_topics[1], 0.03f}, {_topics[2], 0.05f}, {_topics[3], 0.03f}, {_topics[4], 0.25f}};
_people.emplace_back("rnc", "engineering", "hiking", "spock", tp);
tp = {{_topics[0], 0.08f}, {_topics[1], 0.30f}, {_topics[2], 0.02f}, {_topics[3], 0.02f}, {_topics[4], 0.10f}};
_people.emplace_back("mk", "psychology", "kids", "7of9", tp);
return true;
}
bool rl_sim::init_continuous_sim_world()
{
_friction = {25.4f, 41.2f, 66.5f, 81.9f, 104.4f};
joint::friction_prob fb = {{_friction[0], 0.08f}, {_friction[1], 0.03f}, {_friction[2], 0.05f}, {_friction[3], 0.03f},
{_friction[4], 0.25f}};
_robot_joints.emplace_back("j1", 20.3, 102.4, -10.2, fb);
fb = {{_friction[0], 0.08f}, {_friction[1], 0.30f}, {_friction[2], 0.02f}, {_friction[3], 0.02f},
{_friction[4], 0.10f}};
_robot_joints.emplace_back("j2", 40.6, 30.8, 98.5, fb);
return true;
}
bool rl_sim::init()
{
if (init_rl() != err::success) { return false; }
if (!init_sim_world()) { return false; }
if (!init_continuous_sim_world()) { return false; }
return true;
}
std::string rl_sim::get_action_features()
{
std::ostringstream oss;
oss << R"("_multi": [ )";
for (auto idx = 0; idx < _topics.size() - 1; ++idx)
{
oss << R"({ "TAction":{"topic":")" << _topics[idx] << R"("} }, )";
}
oss << R"({ "TAction":{"topic":")" << _topics.back() << R"("} } ])";
return oss.str();
}
std::string rl_sim::get_slates_action_features()
{
std::ostringstream oss;
oss << R"("_multi": [ )";
for (auto idx = 0; idx < _topics.size() - 1; ++idx)
{
oss << R"({ "_slot_id":)" << (idx / NUM_SLATES_SLOTS);
oss << R"(, "TAction":{"topic":")" << _topics[idx] << R"("} }, )";
}
oss << R"({ "_slot_id":)" << (NUM_SLATES_SLOTS - 1);
oss << R"(, "TAction":{"topic":")" << _topics.back() << R"("} } ])";
return oss.str();
}
std::string rl_sim::get_slot_features()
{
std::ostringstream oss;
oss << R"("_slots": [ )";
for (auto idx = 0; idx < NUM_SLOTS - 1; ++idx) { oss << R"({ "_size":")" << _slot_sizes[idx] << R"("}, )"; }
oss << R"({ "_size":")" << _slot_sizes.back() << R"("}] )";
return oss.str();
}
{
std::cout <<
"Background error in Inference API: " << status.
get_error_msg() << std::endl;
std::cout << "Exiting simulation loop." << std::endl;
_run_loop = false;
}
std::string rl_sim::create_context_json(const std::string& cntxt)
{
std::ostringstream oss;
oss << "{ " << cntxt << " }";
return oss.str();
}
std::string rl_sim::create_context_json(const std::string& cntxt, const std::string& action)
{
std::ostringstream oss;
oss << "{ " << cntxt << ", " << action << " }";
return oss.str();
}
std::string rl_sim::create_context_json(const std::string& cntxt, const std::string& action, const std::string& slots)
{
std::ostringstream oss;
oss << "{ " << cntxt << ", " << action << ", " << slots << " }";
return oss.str();
}
std::string rl_sim::create_event_id()
{
if (_num_events > 0 && ++_current_events >= _num_events) { _run_loop = false; }
if (_random_ids) { return boost::uuids::to_string(boost::uuids::random_generator()()); }
std::ostringstream oss;
oss << "event_" << _current_events;
return oss.str();
}
rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB)
{
if (_options["ccb"].as<bool>()) { _loop_kind = CCB; }
else if (_options["slates"].as<bool>()) { _loop_kind = Slates; }
else if (_options["ca"].as<bool>()) { _loop_kind = CA; }
else if (_options["multistep"].as<bool>()) { _loop_kind = Multistep; }
_num_events = _options["num_events"].as<int>();
_random_seed = _options["random_seed"].as<uint64_t>();
_delay = _options["delay"].as<int64_t>();
_quiet = _options["quiet"].as<bool>();
_random_ids = _options["random_ids"].as<bool>();
}
{
std::string ret;
ret += "(";
for (const auto& ap_pair : response)
{
ret += "[" + to_string(ap_pair.action_id) + ",";
ret += to_string(ap_pair.probability) + "]";
ret += " ,";
}
ret += ")";
return ret;
}
std::string get_dist_str(const reinforcement_learning::slot_response& response)
{
std::string ret;
ret += "(";
ret += "[" + std::string(response.get_slot_id()) + ",";
ret += to_string(response.get_action_id()) + ",";
ret += to_string(response.get_probability()) + "]";
ret += " ,";
ret += ")";
return ret;
}
std::string get_dist_str(const reinforcement_learning::slot_entry& response)
{
std::string ret;
ret += "(";
ret += "[" + string(response.get_id()) + ",";
ret += to_string(response.get_action_id()) + ",";
ret += to_string(response.get_probability()) + "]";
ret += " ,";
ret += ")";
return ret;
}
std::string get_dist_str(const reinforcement_learning::decision_response& response)
{
std::string ret;
ret += "(";
for (const auto& resp : response)
{
ret += get_dist_str(resp);
ret += " ,";
}
ret += ")";
return ret;
}
api_status definition. api_status is used to return error information to the caller....
#define RETURN_IF_FAIL(x)
Error reporting macro to test and return on error.
Definition: api_status.h:244
#define RETURN_ERROR_LS(trace, status, code)
Error reporting macro used with left shift operator.
Definition: api_status.h:237
Report status of all API calls.
Definition: api_status.h:22
const char * get_error_msg() const
(API Error Codes) Get the error msg string All API calls will return a status code....
choose_rank() returns the action choice using ranking_response. ranking_response contains all the act...
Definition: ranking_response.h:29
Configuration class to initialize the API Represents a collection of (name,value) pairs used to confi...
Definition: configuration.h:31
void set(const char *name, const char *value)
Sets the value for a given name. It overrides any existing values for that name.
RL Inference API definition.