Vowpal Wabbit
cb_adf.cc
Go to the documentation of this file.
1 /*
2  Copyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5 */
6 #include <cfloat>
7 #include <cerrno>
8 #include <algorithm>
9 
10 #include "reductions.h"
11 #include "v_hashmap.h"
12 #include "label_dictionary.h"
13 #include "vw.h"
14 #include "cb_algs.h"
15 #include "vw_exception.h"
16 #include "gen_cs_example.h"
17 #include "vw_versions.h"
18 #include "explore.h"
19 
20 using namespace LEARNER;
21 using namespace CB;
22 using namespace ACTION_SCORE;
23 using namespace GEN_CS;
24 using namespace CB_ALGS;
25 using namespace VW::config;
26 using namespace exploration;
27 
28 namespace CB_ADF
29 {
30 struct cb_adf
31 {
32  private:
34  // model_file_ver is only used to conditionally run save_load(). In the setup function
35  // model_file_ver is not always set.
37 
42 
43  action_scores _a_s; // temporary storage for mtr and sm
44  action_scores _a_s_mtr_cs; // temporary storage for mtr cost sensitive example
45  action_scores _prob_s; // temporary storage for sm; stores softmax values
46  v_array<uint32_t> _backup_nf; // temporary storage for sm; backup for numFeatures in examples
47  v_array<float> _backup_weights; // temporary storage for sm; backup for weights in examples
48 
49  uint64_t _offset;
50  const bool _no_predict;
51  const bool _rank_all;
52  const float _clip_p;
53 
54  public:
55  template <bool is_learn>
57  bool update_statistics(example& ec, multi_ex* ec_seq);
58 
60  shared_data* sd, size_t cb_type, VW::version_struct* model_file_ver, bool rank_all, float clip_p, bool no_predict)
61  : _sd(sd), _model_file_ver(model_file_ver), _no_predict(no_predict), _rank_all(rank_all), _clip_p(clip_p)
62  {
63  _gen_cs.cb_type = cb_type;
64  }
65 
66  void set_scorer(LEARNER::single_learner* scorer) { _gen_cs.scorer = scorer; }
67 
68  bool get_rank_all() const { return _rank_all; }
69 
70  const cb_to_cs_adf& get_gen_cs() const { return _gen_cs; }
71 
72  const VW::version_struct* get_model_file_ver() const { return _model_file_ver; }
73 
75  {
76  _cb_labels.delete_v();
77  for (auto& prepped_cs_label : _prepped_cs_labels) prepped_cs_label.costs.delete_v();
78  _prepped_cs_labels.delete_v();
79  _cs_labels.costs.delete_v();
80  _backup_weights.delete_v();
81  _backup_nf.delete_v();
82  _prob_s.delete_v();
83 
84  _a_s.delete_v();
85  _a_s_mtr_cs.delete_v();
86  _gen_cs.pred_scores.costs.delete_v();
87  }
88 
89  private:
90  void learn_IPS(multi_learner& base, multi_ex& examples);
91  void learn_DR(multi_learner& base, multi_ex& examples);
92  void learn_DM(multi_learner& base, multi_ex& examples);
93  void learn_SM(multi_learner& base, multi_ex& examples);
94 
95  template <bool predict>
96  void learn_MTR(multi_learner& base, multi_ex& examples);
97 };
98 
100 {
101  CB::label ld;
102  ld.costs = v_init<cb_class>();
103  int index = -1;
104  CB::cb_class known_cost;
105 
106  size_t i = 0;
107  for (example*& ec : examples)
108  {
109  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX && ec->l.cb.costs[0].probability > 0)
110  {
111  ld = ec->l.cb;
112  index = (int)i;
113  }
114  ++i;
115  }
116 
117  // handle -1 case.
118  if (index == -1)
119  {
120  known_cost.probability = -1;
121  return known_cost;
122  // std::cerr << "None of the examples has known cost. Exiting." << std::endl;
123  // throw exception();
124  }
125 
126  known_cost = ld.costs[0];
127  known_cost.action = index;
128  return known_cost;
129 }
130 
131 void cb_adf::learn_IPS(multi_learner& base, multi_ex& examples)
132 {
133  gen_cs_example_ips(examples, _cs_labels, _clip_p);
134  call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
135 }
136 
137 void cb_adf::learn_SM(multi_learner& base, multi_ex& examples)
138 {
139  gen_cs_test_example(examples, _cs_labels); // create test labels.
140  call_cs_ldf<false>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
141 
142  // Can probably do this more efficiently than 6 loops over the examples...
143  //[1: initialize temporary storage;
144  // 2: find chosen action;
145  // 3: create cs_labels (gen_cs_example_sm);
146  // 4: get probability of chosen action;
147  // 5: backup example wts;
148  // 6: restore example wts]
149  _a_s.clear();
150  _prob_s.clear();
151  // TODO: Check that predicted scores are always stored with the first example
152  for (uint32_t i = 0; i < examples[0]->pred.a_s.size(); i++)
153  {
154  _a_s.push_back({examples[0]->pred.a_s[i].action, examples[0]->pred.a_s[i].score});
155  _prob_s.push_back({examples[0]->pred.a_s[i].action, 0.0});
156  }
157 
158  float sign_offset = 1.0; // To account for negative rewards/costs
159  uint32_t chosen_action = 0;
160  float example_weight = 1.0;
161 
162  for (uint32_t i = 0; i < examples.size(); i++)
163  {
164  CB::label ld = examples[i]->l.cb;
165  if (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX)
166  {
167  chosen_action = i;
168  example_weight = ld.costs[0].cost / safe_probability(ld.costs[0].probability);
169 
170  // Importance weights of examples cannot be negative.
171  // So we use a trick: set |w| as weight, and use sign(w) as an offset in the regression target.
172  if (ld.costs[0].cost < 0.0)
173  {
174  sign_offset = -1.0;
175  example_weight = -example_weight;
176  }
177  break;
178  }
179  }
180 
181  gen_cs_example_sm(examples, chosen_action, sign_offset, _a_s, _cs_labels);
182 
183  // Lambda is -1 in the call to generate_softmax because in vw, lower score is better; for softmax higher score is
184  // better.
185  generate_softmax(-1.0, begin_scores(_a_s), end_scores(_a_s), begin_scores(_prob_s), end_scores(_prob_s));
186 
187  // TODO: Check Marco's example that causes VW to report prob > 1.
188 
189  for (auto const& action_score : _prob_s) // Scale example_wt by prob of chosen action
190  {
191  if (action_score.action == chosen_action)
192  {
193  example_weight *= action_score.score;
194  break;
195  }
196  }
197 
198  _backup_weights.clear();
199  _backup_nf.clear();
200  for (auto const& action_score : _prob_s)
201  {
202  uint32_t current_action = action_score.action;
203  _backup_weights.push_back(examples[current_action]->weight);
204  _backup_nf.push_back((uint32_t)examples[current_action]->num_features);
205 
206  if (current_action == chosen_action)
207  examples[current_action]->weight = example_weight * (1.0f - action_score.score);
208  else
209  examples[current_action]->weight = example_weight * action_score.score;
210 
211  if (examples[current_action]->weight <= 1e-15)
212  examples[current_action]->weight = 0;
213  }
214 
215  // Do actual training
216  call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
217 
218  // Restore example weights and numFeatures
219  for (size_t i = 0; i < _prob_s.size(); i++)
220  {
221  uint32_t current_action = _prob_s[i].action;
222  examples[current_action]->weight = _backup_weights[i];
223  examples[current_action]->num_features = _backup_nf[i];
224  }
225 }
226 
227 void cb_adf::learn_DR(multi_learner& base, multi_ex& examples)
228 {
229  gen_cs_example_dr<true>(_gen_cs, examples, _cs_labels, _clip_p);
230  call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
231 }
232 
233 void cb_adf::learn_DM(multi_learner& base, multi_ex& examples)
234 {
235  gen_cs_example_dm(examples, _cs_labels);
236  call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
237 }
238 
239 template <bool predict>
240 void cb_adf::learn_MTR(multi_learner& base, multi_ex& examples)
241 {
242  // uint32_t action = 0;
243  if (predict) // first get the prediction to return
244  {
245  gen_cs_example_ips(examples, _cs_labels);
246  call_cs_ldf<false>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
247  std::swap(examples[0]->pred.a_s, _a_s);
248  }
249  // second train on _one_ action (which requires up to 3 examples).
250  // We must go through the cost sensitive classifier layer to get
251  // proper feature handling.
252  gen_cs_example_mtr(_gen_cs, examples, _cs_labels);
253  uint32_t nf = (uint32_t)examples[_gen_cs.mtr_example]->num_features;
254  float old_weight = examples[_gen_cs.mtr_example]->weight;
255  const float clipped_p = std::max(examples[_gen_cs.mtr_example]->l.cb.costs[0].probability, _clip_p);
256  examples[_gen_cs.mtr_example]->weight *= 1.f / clipped_p * ((float)_gen_cs.event_sum / (float)_gen_cs.action_sum);
257 
258  std::swap(_gen_cs.mtr_ec_seq[0]->pred.a_s, _a_s_mtr_cs);
259  // TODO!!! cb_labels are not getting properly restored (empty costs are dropped)
260  GEN_CS::call_cs_ldf<true>(base, _gen_cs.mtr_ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
261  examples[_gen_cs.mtr_example]->num_features = nf;
262  examples[_gen_cs.mtr_example]->weight = old_weight;
263  std::swap(_gen_cs.mtr_ec_seq[0]->pred.a_s, _a_s_mtr_cs);
264  std::swap(examples[0]->pred.a_s, _a_s);
265 }
266 
267 // Validates a multiline example collection as a valid sequence for action dependent features format.
269 {
270  if (ec_seq.empty())
271  THROW("cb_adf: At least one action must be provided for an example to be valid.");
272 
273  uint32_t count = 0;
274  example* ret = nullptr;
275  for (auto* ec : ec_seq)
276  {
277  // Check if there is more than one cost for this example.
278  if (ec->l.cb.costs.size() > 1)
279  THROW("cb_adf: badly formatted example, only one cost can be known.");
280 
281  // Check whether the cost was initialized to a value.
282  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX)
283  {
284  ret = ec;
285  count += 1;
286  if (count > 1)
287  THROW("cb_adf: badly formatted example, only one line can have a cost");
288  }
289  }
290 
291  return ret;
292 }
293 
294 template <bool is_learn>
296 {
297  _offset = ec_seq[0]->ft_offset;
298  _gen_cs.known_cost = get_observed_cost(ec_seq); // need to set for test case
299  if (is_learn && test_adf_sequence(ec_seq) != nullptr)
300  {
301  /* v_array<float> temp_scores;
302  temp_scores = v_init<float>();
303  do_actual_learning<false>(data,base);
304  for (size_t i = 0; i < data.ec_seq[0]->pred.a_s.size(); i++)
305  temp_scores.push_back(data.ec_seq[0]->pred.a_s[i].score);*/
306  switch (_gen_cs.cb_type)
307  {
308  case CB_TYPE_IPS:
309  learn_IPS(base, ec_seq);
310  break;
311  case CB_TYPE_DR:
312  learn_DR(base, ec_seq);
313  break;
314  case CB_TYPE_DM:
315  learn_DM(base, ec_seq);
316  break;
317  case CB_TYPE_MTR:
318  if (_no_predict)
319  learn_MTR<false>(base, ec_seq);
320  else
321  learn_MTR<true>(base, ec_seq);
322  break;
323  case CB_TYPE_SM:
324  learn_SM(base, ec_seq);
325  break;
326  default:
327  THROW("Unknown cb_type specified for contextual bandit learning: " << _gen_cs.cb_type);
328  }
329 
330  /* for (size_t i = 0; i < temp_scores.size(); i++)
331  if (temp_scores[i] != data.ec_seq[0]->pred.a_s[i].score)
332  std::cout << "problem! " << temp_scores[i] << " != " << data.ec_seq[0]->pred.a_s[i].score << " for " <<
333  data.ec_seq[0]->pred.a_s[i].action << std::endl; temp_scores.delete_v();*/
334  }
335  else
336  {
337  gen_cs_test_example(ec_seq, _cs_labels); // create test labels.
338  call_cs_ldf<false>(base, ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
339  }
340 }
341 
342 void global_print_newline(const v_array<int>& final_prediction_sink)
343 {
344  char temp[1];
345  temp[0] = '\n';
346  for (auto f : final_prediction_sink)
347  {
348  ssize_t t;
349  t = io_buf::write_file_or_socket(f, temp, 1);
350  if (t != 1)
351  std::cerr << "write error: " << strerror(errno) << std::endl;
352  }
353 }
354 
355 // how to
356 
357 bool cb_adf::update_statistics(example& ec, multi_ex* ec_seq)
358 {
359  size_t num_features = 0;
360 
361  uint32_t action = ec.pred.a_s[0].action;
362  for (const auto& example : *ec_seq) num_features += example->num_features;
363 
364  float loss = 0.;
365 
366  bool labeled_example = true;
367  if (_gen_cs.known_cost.probability > 0)
368  loss = get_cost_estimate(&(_gen_cs.known_cost), _gen_cs.pred_scores, action);
369  else
370  labeled_example = false;
371 
372  bool holdout_example = labeled_example;
373  for (auto const& i : *ec_seq) holdout_example &= i->test_only;
374 
375  _sd->update(holdout_example, labeled_example, loss, ec.weight, num_features);
376  return labeled_example;
377 }
378 
379 void output_example(vw& all, cb_adf& c, example& ec, multi_ex* ec_seq)
380 {
382  return;
383 
384  bool labeled_example = c.update_statistics(ec, ec_seq);
385 
386  uint32_t action = ec.pred.a_s[0].action;
387  for (int sink : all.final_prediction_sink) all.print(sink, (float)action, 0, ec.tag);
388 
389  if (all.raw_prediction > 0)
390  {
391  std::string outputString;
392  std::stringstream outputStringStream(outputString);
393  v_array<CB::cb_class> costs = ec.l.cb.costs;
394 
395  for (size_t i = 0; i < costs.size(); i++)
396  {
397  if (i > 0)
398  outputStringStream << ' ';
399  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
400  }
401  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
402  }
403 
404  CB::print_update(all, !labeled_example, ec, ec_seq, true);
405 }
406 
407 void output_rank_example(vw& all, cb_adf& c, example& ec, multi_ex* ec_seq)
408 {
409  label& ld = ec.l.cb;
410  v_array<CB::cb_class> costs = ld.costs;
411 
413  return;
414 
415  bool labeled_example = c.update_statistics(ec, ec_seq);
416 
417  for (int sink : all.final_prediction_sink) print_action_score(sink, ec.pred.a_s, ec.tag);
418 
419  if (all.raw_prediction > 0)
420  {
421  std::string outputString;
422  std::stringstream outputStringStream(outputString);
423  for (size_t i = 0; i < costs.size(); i++)
424  {
425  if (i > 0)
426  outputStringStream << ' ';
427  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
428  }
429  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
430  }
431 
432  CB::print_update(all, !labeled_example, ec, ec_seq, true);
433 }
434 
435 void output_example_seq(vw& all, cb_adf& data, multi_ex& ec_seq)
436 {
437  if (!ec_seq.empty())
438  {
439  if (data.get_rank_all())
440  output_rank_example(all, data, **(ec_seq.begin()), &(ec_seq));
441  else
442  {
443  output_example(all, data, **(ec_seq.begin()), &(ec_seq));
444 
445  if (all.raw_prediction > 0)
446  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
447  }
448  }
449 }
450 
451 void finish_multiline_example(vw& all, cb_adf& data, multi_ex& ec_seq)
452 {
453  if (!ec_seq.empty())
454  {
455  output_example_seq(all, data, ec_seq);
457  }
458  VW::finish_example(all, ec_seq);
459 }
460 
461 void save_load(cb_adf& c, io_buf& model_file, bool read, bool text)
462 {
464  return;
465  std::stringstream msg;
466  msg << "event_sum " << c.get_gen_cs().event_sum << "\n";
468  model_file, (char*)&c.get_gen_cs().event_sum, sizeof(c.get_gen_cs().event_sum), "", read, msg, text);
469 
470  msg << "action_sum " << c.get_gen_cs().action_sum << "\n";
472  model_file, (char*)&c.get_gen_cs().action_sum, sizeof(c.get_gen_cs().action_sum), "", read, msg, text);
473 }
474 
475 void learn(cb_adf& c, multi_learner& base, multi_ex& ec_seq) { c.do_actual_learning<true>(base, ec_seq); }
476 
477 void predict(cb_adf& c, multi_learner& base, multi_ex& ec_seq) { c.do_actual_learning<false>(base, ec_seq); }
478 
479 } // namespace CB_ADF
480 using namespace CB_ADF;
482 {
483  bool cb_adf_option = false;
484  std::string type_string = "mtr";
485 
486  size_t cb_type;
487  bool rank_all;
488  float clip_p;
489  bool no_predict;
490 
491  option_group_definition new_options("Contextual Bandit with Action Dependent Features");
492  new_options
493  .add(make_option("cb_adf", cb_adf_option)
494  .keep()
495  .help("Do Contextual Bandit learning with multiline action dependent features."))
496  .add(make_option("rank_all", rank_all).keep().help("Return actions sorted by score order"))
497  .add(make_option("no_predict", no_predict).help("Do not do a prediction when training"))
498  .add(make_option("clip_p", clip_p)
499  .keep()
500  .default_value(0.f)
501  .help("Clipping probability in importance weight. Default: 0.f (no clipping)."))
502  .add(make_option("cb_type", type_string)
503  .keep()
504  .help("contextual bandit method to use in {ips, dm, dr, mtr, sm}. Default: mtr"));
505  options.add_and_parse(new_options);
506 
507  if (!cb_adf_option)
508  return nullptr;
509 
510  // Ensure serialization of this option in all cases.
511  if (!options.was_supplied("cb_type"))
512  {
513  options.insert("cb_type", type_string);
514  options.add_and_parse(new_options);
515  }
516 
517  // number of weight vectors needed
518  size_t problem_multiplier = 1; // default for IPS
519  bool check_baseline_enabled = false;
520 
521  if (type_string == "dr")
522  {
523  cb_type = CB_TYPE_DR;
524  problem_multiplier = 2;
525  // only use baseline when manually enabled for loss estimation
526  check_baseline_enabled = true;
527  }
528  else if (type_string == "ips")
529  cb_type = CB_TYPE_IPS;
530  else if (type_string == "mtr")
531  cb_type = CB_TYPE_MTR;
532  else if (type_string == "dm")
533  cb_type = CB_TYPE_DM;
534  else if (type_string == "sm")
535  cb_type = CB_TYPE_SM;
536  else
537  {
538  all.trace_message << "warning: cb_type must be in {'ips','dr','mtr','dm','sm'}; resetting to mtr." << std::endl;
539  cb_type = CB_TYPE_MTR;
540  }
541 
542  if (clip_p > 0.f && cb_type == CB_TYPE_SM)
543  all.trace_message << "warning: clipping probability not yet implemented for cb_type sm; p will not be clipped."
544  << std::endl;
545 
547 
548  // Push necessary flags.
549  if ((!options.was_supplied("csoaa_ldf") && !options.was_supplied("wap_ldf")) || rank_all ||
550  !options.was_supplied("csoaa_rank"))
551  {
552  if (!options.was_supplied("csoaa_ldf"))
553  {
554  options.insert("csoaa_ldf", "multiline");
555  }
556 
557  if (!options.was_supplied("csoaa_rank"))
558  {
559  options.insert("csoaa_rank", "");
560  }
561  }
562 
563  if (options.was_supplied("baseline") && check_baseline_enabled)
564  {
565  options.insert("check_enabled", "");
566  }
567 
568  auto ld = scoped_calloc_or_throw<cb_adf>(all.sd, cb_type, &all.model_file_ver, rank_all, clip_p, no_predict);
569 
570  auto base = as_multiline(setup_base(options, all));
571  all.p->lp = CB::cb_label;
573 
574  cb_adf* bare = ld.get();
576  init_learner(ld, base, learn, predict, problem_multiplier, prediction_type::action_scores);
578 
579  bare->set_scorer(all.scorer);
580 
582  return make_base(l);
583 }
VW::version_struct * _model_file_ver
Definition: cb_adf.cc:36
void gen_cs_example_sm(multi_ex &, uint32_t chosen_action, float sign_offset, ACTION_SCORE::action_scores action_vals, COST_SENSITIVE::label &cs_labels)
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
v_array< float > _backup_weights
Definition: cb_adf.cc:47
ACTION_SCORE::action_scores a_s
Definition: example.h:47
#define VERSION_FILE_WITH_CB_ADF_SAVE
Definition: vw_versions.h:20
Definition: scorer.cc:8
#define CB_TYPE_IPS
Definition: cb_algs.h:15
action_scores _a_s
Definition: cb_adf.cc:43
void(* delete_prediction)(void *)
Definition: global_data.h:485
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
const float _clip_p
Definition: cb_adf.cc:52
void output_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:379
void output_example_seq(vw &all, multi_ex &ec_seq)
Definition: cbify.cc:356
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
CB::label cb
Definition: example.h:31
COST_SENSITIVE::label _cs_labels
Definition: cb_adf.cc:40
v_array< uint32_t > _backup_nf
Definition: cb_adf.cc:46
label_type::label_type_t label_type
Definition: global_data.h:550
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
example * test_adf_sequence(multi_ex &ec_seq)
Definition: cb_adf.cc:268
uint32_t action
Definition: search.h:19
void finish_multiline_example(vw &all, cb_adf &data, multi_ex &ec_seq)
Definition: cb_adf.cc:451
#define CB_TYPE_DM
Definition: cb_algs.h:14
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
int generate_softmax(float lambda, InputIt scores_first, InputIt scores_last, OutputIt pdf_first, OutputIt pdf_last)
Generates softmax style exploration distribution.
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
bool get_rank_all() const
Definition: cb_adf.cc:68
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
shared_data * _sd
Definition: cb_adf.cc:33
void gen_cs_example_dm(multi_ex &examples, COST_SENSITIVE::label &cs_labels)
#define CB_TYPE_DR
Definition: cb_algs.h:13
size_t size() const
Definition: v_array.h:68
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
const cb_to_cs_adf & get_gen_cs() const
Definition: cb_adf.cc:70
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
parser * p
Definition: global_data.h:377
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
void do_actual_learning(ldf &data, single_learner &base, multi_ex &ec_seq_all)
Definition: csoaa.cc:420
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void set_scorer(LEARNER::single_learner *scorer)
Definition: cb_adf.cc:66
base_learner * cb_adf_setup(options_i &options, vw &all)
Definition: cb_adf.cc:481
Definition: cb.cc:15
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
uint32_t action
Definition: cb.h:18
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
shared_data * sd
Definition: global_data.h:375
float probability
Definition: cb.h:19
VW::version_struct model_file_ver
Definition: global_data.h:419
action_scores _prob_s
Definition: cb_adf.cc:45
vw_ostream trace_message
Definition: global_data.h:424
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
void global_print_newline(const v_array< int > &final_prediction_sink)
Definition: cb_adf.cc:342
const bool _rank_all
Definition: cb_adf.cc:51
void output_rank_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:407
v_array< CB::label > _cb_labels
Definition: cb_adf.cc:39
uint64_t _offset
Definition: cb_adf.cc:49
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
Definition: io_buf.h:54
void gen_cs_example_ips(multi_ex &examples, COST_SENSITIVE::label &cs_labels, float clip_p)
void finish_example(vw &, example &)
Definition: parser.cc:881
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
action_scores _a_s_mtr_cs
Definition: cb_adf.cc:44
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
float safe_probability(float prob)
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
bool example_is_newline_not_header(example &ec, vw &all)
Definition: learner.cc:68
Definition: cb.h:25
void gen_cs_example_mtr(cb_to_cs_adf &c, multi_ex &ec_seq, COST_SENSITIVE::label &cs_labels)
const VW::version_struct * get_model_file_ver() const
Definition: cb_adf.cc:72
bool update_statistics(example &ec, multi_ex *ec_seq)
Definition: cb_adf.cc:357
void save_load(cb_adf &c, io_buf &model_file, bool read, bool text)
Definition: cb_adf.cc:461
void predict(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:477
LEARNER::single_learner * scorer
const bool _no_predict
Definition: cb_adf.cc:50
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
v_array< COST_SENSITIVE::label > _prepped_cs_labels
Definition: cb_adf.cc:41
cb_to_cs_adf _gen_cs
Definition: cb_adf.cc:38
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
cb_adf(shared_data *sd, size_t cb_type, VW::version_struct *model_file_ver, bool rank_all, float clip_p, bool no_predict)
Definition: cb_adf.cc:59
Definition: cb_adf.cc:28
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
Definition: action_score.cc:8
v_array< wclass > costs
float weight
Definition: example.h:62
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
void gen_cs_test_example(multi_ex &examples, COST_SENSITIVE::label &cs_labels)
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
#define CB_TYPE_SM
Definition: cb_algs.h:17
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
COST_SENSITIVE::label pred_scores
void learn(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:475
label_parser lp
Definition: parser.h:102
#define CB_TYPE_MTR
Definition: cb_algs.h:16
void do_actual_learning(LEARNER::multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:295