Vowpal Wabbit
Classes | Functions
CB_ADF Namespace Reference

Classes

struct  cb_adf
 

Functions

CB::cb_class get_observed_cost (multi_ex &examples)
 
exampletest_adf_sequence (multi_ex &ec_seq)
 
void global_print_newline (const v_array< int > &final_prediction_sink)
 
void output_example (vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
 
void output_rank_example (vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
 
void output_example_seq (vw &all, cb_adf &data, multi_ex &ec_seq)
 
void finish_multiline_example (vw &all, cb_adf &data, multi_ex &ec_seq)
 
void save_load (cb_adf &c, io_buf &model_file, bool read, bool text)
 
void learn (cb_adf &c, multi_learner &base, multi_ex &ec_seq)
 
void predict (cb_adf &c, multi_learner &base, multi_ex &ec_seq)
 

Function Documentation

◆ finish_multiline_example()

void CB_ADF::finish_multiline_example ( vw all,
cb_adf data,
multi_ex ec_seq 
)

Definition at line 451 of file cb_adf.cc.

References vw::final_prediction_sink, VW::finish_example(), global_print_newline(), and output_example_seq().

Referenced by cb_adf_setup().

452 {
453  if (!ec_seq.empty())
454  {
455  output_example_seq(all, data, ec_seq);
457  }
458  VW::finish_example(all, ec_seq);
459 }
void output_example_seq(vw &all, multi_ex &ec_seq)
Definition: cbify.cc:356
v_array< int > final_prediction_sink
Definition: global_data.h:518
void global_print_newline(const v_array< int > &final_prediction_sink)
Definition: cb_adf.cc:342
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ get_observed_cost()

CB::cb_class CB_ADF::get_observed_cost ( multi_ex examples)

Definition at line 99 of file cb_adf.cc.

References CB::cb_class::action, CB::label::costs, and CB::cb_class::probability.

Referenced by CB_ADF::cb_adf::do_actual_learning(), EXPLORE_EVAL::do_actual_learning(), VW::cb_explore_adf::cb_explore_adf_base< ExploreType >::learn(), CB_ALGS::learn_eval(), CB_EXPLORE::output_example(), VW::cb_explore_adf::cb_explore_adf_base< ExploreType >::predict(), CB_ALGS::predict_or_learn(), CB_EXPLORE::predict_or_learn_cover(), and VW::cb_explore_adf::cover::cb_explore_adf_cover::predict_or_learn_impl().

100 {
101  CB::label ld;
102  ld.costs = v_init<cb_class>();
103  int index = -1;
104  CB::cb_class known_cost;
105 
106  size_t i = 0;
107  for (example*& ec : examples)
108  {
109  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX && ec->l.cb.costs[0].probability > 0)
110  {
111  ld = ec->l.cb;
112  index = (int)i;
113  }
114  ++i;
115  }
116 
117  // handle -1 case.
118  if (index == -1)
119  {
120  known_cost.probability = -1;
121  return known_cost;
122  // std::cerr << "None of the examples has known cost. Exiting." << std::endl;
123  // throw exception();
124  }
125 
126  known_cost = ld.costs[0];
127  known_cost.action = index;
128  return known_cost;
129 }
v_array< cb_class > costs
Definition: cb.h:27
uint32_t action
Definition: cb.h:18
float probability
Definition: cb.h:19
Definition: cb.h:25

◆ global_print_newline()

void CB_ADF::global_print_newline ( const v_array< int > &  final_prediction_sink)

Definition at line 342 of file cb_adf.cc.

References f, and io_buf::write_file_or_socket().

Referenced by VW::cb_explore_adf::cb_explore_adf_base< ExploreType >::finish_multiline_example(), EXPLORE_EVAL::finish_multiline_example(), finish_multiline_example(), and CCB::finish_multiline_example().

343 {
344  char temp[1];
345  temp[0] = '\n';
346  for (auto f : final_prediction_sink)
347  {
348  ssize_t t;
349  t = io_buf::write_file_or_socket(f, temp, 1);
350  if (t != 1)
351  std::cerr << "write error: " << strerror(errno) << std::endl;
352  }
353 }
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
float f
Definition: cache.cc:40

◆ learn()

void CB_ADF::learn ( cb_adf c,
multi_learner base,
multi_ex ec_seq 
)

Definition at line 475 of file cb_adf.cc.

References CB_ADF::cb_adf::do_actual_learning().

Referenced by cb_adf_setup().

475 { c.do_actual_learning<true>(base, ec_seq); }
void do_actual_learning(LEARNER::multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:295

◆ output_example()

void CB_ADF::output_example ( vw all,
cb_adf c,
example ec,
multi_ex ec_seq 
)

Definition at line 379 of file cb_adf.cc.

References polyprediction::a_s, polylabel::cb, CB::label::costs, LEARNER::example_is_newline_not_header(), vw::final_prediction_sink, example::l, example::pred, vw::print, vw::print_text, CB::print_update(), vw::raw_prediction, v_array< T >::size(), example::tag, and CB_ADF::cb_adf::update_statistics().

Referenced by output_example_seq().

380 {
382  return;
383 
384  bool labeled_example = c.update_statistics(ec, ec_seq);
385 
386  uint32_t action = ec.pred.a_s[0].action;
387  for (int sink : all.final_prediction_sink) all.print(sink, (float)action, 0, ec.tag);
388 
389  if (all.raw_prediction > 0)
390  {
391  std::string outputString;
392  std::stringstream outputStringStream(outputString);
393  v_array<CB::cb_class> costs = ec.l.cb.costs;
394 
395  for (size_t i = 0; i < costs.size(); i++)
396  {
397  if (i > 0)
398  outputStringStream << ' ';
399  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
400  }
401  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
402  }
403 
404  CB::print_update(all, !labeled_example, ec, ec_seq, true);
405 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
ACTION_SCORE::action_scores a_s
Definition: example.h:47
CB::label cb
Definition: example.h:31
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
uint32_t action
Definition: search.h:19
size_t size() const
Definition: v_array.h:68
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
polylabel l
Definition: example.h:57
bool update_statistics(example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:357
polyprediction pred
Definition: example.h:60
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521

◆ output_example_seq()

void CB_ADF::output_example_seq ( vw all,
cb_adf data,
multi_ex ec_seq 
)

Definition at line 435 of file cb_adf.cc.

References CB_ADF::cb_adf::get_rank_all(), output_example(), output_rank_example(), vw::print_text, and vw::raw_prediction.

436 {
437  if (!ec_seq.empty())
438  {
439  if (data.get_rank_all())
440  output_rank_example(all, data, **(ec_seq.begin()), &(ec_seq));
441  else
442  {
443  output_example(all, data, **(ec_seq.begin()), &(ec_seq));
444 
445  if (all.raw_prediction > 0)
446  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
447  }
448  }
449 }
int raw_prediction
Definition: global_data.h:519
void output_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:379
bool get_rank_all() const
Definition: cb_adf.cc:68
void output_rank_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:407
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522

◆ output_rank_example()

void CB_ADF::output_rank_example ( vw all,
cb_adf c,
example ec,
multi_ex ec_seq 
)

Definition at line 407 of file cb_adf.cc.

References polyprediction::a_s, polylabel::cb, CB::label::costs, LEARNER::example_is_newline_not_header(), vw::final_prediction_sink, example::l, example::pred, ACTION_SCORE::print_action_score(), vw::print_text, CB::print_update(), vw::raw_prediction, v_array< T >::size(), example::tag, and CB_ADF::cb_adf::update_statistics().

Referenced by output_example_seq().

408 {
409  label& ld = ec.l.cb;
410  v_array<CB::cb_class> costs = ld.costs;
411 
413  return;
414 
415  bool labeled_example = c.update_statistics(ec, ec_seq);
416 
417  for (int sink : all.final_prediction_sink) print_action_score(sink, ec.pred.a_s, ec.tag);
418 
419  if (all.raw_prediction > 0)
420  {
421  std::string outputString;
422  std::stringstream outputStringStream(outputString);
423  for (size_t i = 0; i < costs.size(); i++)
424  {
425  if (i > 0)
426  outputStringStream << ' ';
427  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
428  }
429  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
430  }
431 
432  CB::print_update(all, !labeled_example, ec, ec_seq, true);
433 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
ACTION_SCORE::action_scores a_s
Definition: example.h:47
CB::label cb
Definition: example.h:31
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
size_t size() const
Definition: v_array.h:68
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
polylabel l
Definition: example.h:57
Definition: cb.h:25
bool update_statistics(example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:357
polyprediction pred
Definition: example.h:60
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
Definition: action_score.cc:8

◆ predict()

void CB_ADF::predict ( cb_adf c,
multi_learner base,
multi_ex ec_seq 
)

Definition at line 477 of file cb_adf.cc.

References CB_ADF::cb_adf::do_actual_learning().

Referenced by cb_adf_setup(), and CB_ADF::cb_adf::learn_MTR().

477 { c.do_actual_learning<false>(base, ec_seq); }
void do_actual_learning(LEARNER::multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:295

◆ save_load()

void CB_ADF::save_load ( cb_adf c,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 461 of file cb_adf.cc.

References GEN_CS::cb_to_cs_adf::action_sum, bin_text_read_write_fixed(), GEN_CS::cb_to_cs_adf::event_sum, CB_ADF::cb_adf::get_gen_cs(), CB_ADF::cb_adf::get_model_file_ver(), and VERSION_FILE_WITH_CB_ADF_SAVE.

Referenced by cb_adf_setup().

462 {
464  return;
465  std::stringstream msg;
466  msg << "event_sum " << c.get_gen_cs().event_sum << "\n";
468  model_file, (char*)&c.get_gen_cs().event_sum, sizeof(c.get_gen_cs().event_sum), "", read, msg, text);
469 
470  msg << "action_sum " << c.get_gen_cs().action_sum << "\n";
472  model_file, (char*)&c.get_gen_cs().action_sum, sizeof(c.get_gen_cs().action_sum), "", read, msg, text);
473 }
#define VERSION_FILE_WITH_CB_ADF_SAVE
Definition: vw_versions.h:20
const cb_to_cs_adf & get_gen_cs() const
Definition: cb_adf.cc:70
const VW::version_struct * get_model_file_ver() const
Definition: cb_adf.cc:72
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326

◆ test_adf_sequence()

example * CB_ADF::test_adf_sequence ( multi_ex ec_seq)

Definition at line 268 of file cb_adf.cc.

References THROW.

Referenced by CB_ADF::cb_adf::do_actual_learning(), EXPLORE_EVAL::do_actual_learning(), VW::cb_explore_adf::cb_explore_adf_base< ExploreType >::learn(), and VW::cb_explore_adf::cb_explore_adf_base< ExploreType >::predict().

269 {
270  if (ec_seq.empty())
271  THROW("cb_adf: At least one action must be provided for an example to be valid.");
272 
273  uint32_t count = 0;
274  example* ret = nullptr;
275  for (auto* ec : ec_seq)
276  {
277  // Check if there is more than one cost for this example.
278  if (ec->l.cb.costs.size() > 1)
279  THROW("cb_adf: badly formatted example, only one cost can be known.");
280 
281  // Check whether the cost was initialized to a value.
282  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX)
283  {
284  ret = ec;
285  count += 1;
286  if (count > 1)
287  THROW("cb_adf: badly formatted example, only one line can have a cost");
288  }
289  }
290 
291  return ret;
292 }
#define THROW(args)
Definition: vw_exception.h:181