Vowpal Wabbit
Classes | Functions
cbify.cc File Reference
#include "reductions.h"
#include "cb_algs.h"
#include "vw.h"
#include "hash.h"
#include "explore.h"
#include <vector>

Go to the source code of this file.

Classes

struct  cbify_adf_data
 
struct  cbify
 

Functions

float loss (cbify &data, uint32_t label, uint32_t final_prediction)
 
float loss_cs (cbify &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
 
float loss_csldf (cbify &data, std::vector< v_array< COST_SENSITIVE::wclass >> &cs_costs, uint32_t final_prediction)
 
void copy_example_to_adf (cbify &data, example &ec)
 
template<bool is_learn, bool use_cs>
void predict_or_learn (cbify &data, single_learner &base, example &ec)
 
template<bool is_learn, bool use_cs>
void predict_or_learn_adf (cbify &data, multi_learner &base, example &ec)
 
void init_adf_data (cbify &data, const size_t num_actions)
 
template<bool is_learn>
void do_actual_learning_ldf (cbify &data, multi_learner &base, multi_ex &ec_seq)
 
void output_example (vw &all, example &ec, bool &hit_loss, multi_ex *ec_seq)
 
void output_example_seq (vw &all, multi_ex &ec_seq)
 
void finish_multiline_example (vw &all, cbify &, multi_ex &ec_seq)
 
base_learnercbify_setup (options_i &options, vw &all)
 
base_learnercbifyldf_setup (options_i &options, vw &all)
 

Function Documentation

◆ cbify_setup()

base_learner* cbify_setup ( options_i options,
vw all 
)

Definition at line 383 of file cbify.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), LEARNER::as_singleline(), vw::delete_prediction, f, init_adf_data(), LEARNER::init_cost_sensitive_learner(), LEARNER::init_multiclass_learner(), VW::config::options_i::insert(), LEARNER::make_base(), VW::config::make_option(), vw::p, setup_base(), prediction_type::to_string(), uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

384 {
385  uint32_t num_actions = 0;
386  auto data = scoped_calloc_or_throw<cbify>();
387  bool use_cs;
388 
389  option_group_definition new_options("Make Multiclass into Contextual Bandit");
390  new_options
391  .add(make_option("cbify", num_actions)
392  .keep()
393  .help("Convert multiclass on <k> classes into a contextual bandit problem"))
394  .add(make_option("cbify_cs", use_cs).help("consume cost-sensitive classification examples instead of multiclass"))
395  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
396  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"));
397  options.add_and_parse(new_options);
398 
399  if (!options.was_supplied("cbify"))
400  return nullptr;
401 
402  data->use_adf = options.was_supplied("cb_explore_adf");
403  data->app_seed = uniform_hash("vw", 2, 0);
404  data->a_s = v_init<action_score>();
405  data->all = &all;
406 
407  if (data->use_adf)
408  init_adf_data(*data, num_actions);
409 
410  if (!options.was_supplied("cb_explore") && !data->use_adf)
411  {
412  std::stringstream ss;
413  ss << num_actions;
414  options.insert("cb_explore", ss.str());
415  }
416 
417  if (data->use_adf)
418  {
419  options.insert("cb_min_cost", std::to_string(data->loss0));
420  options.insert("cb_max_cost", std::to_string(data->loss1));
421  }
422 
423  if (options.was_supplied("baseline"))
424  {
425  std::stringstream ss;
426  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
427  options.insert("lr_multiplier", ss.str());
428  }
429 
431 
432  if (data->use_adf)
433  {
434  multi_learner* base = as_multiline(setup_base(options, all));
435  if (use_cs)
437  data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.p, 1);
438  else
440  data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.p, 1);
441  }
442  else
443  {
444  single_learner* base = as_singleline(setup_base(options, all));
445  if (use_cs)
447  data, base, predict_or_learn<true, true>, predict_or_learn<false, true>, all.p, 1);
448  else
449  l = &init_multiclass_learner(data, base, predict_or_learn<true, false>, predict_or_learn<false, false>, all.p, 1);
450  }
451  all.delete_prediction = nullptr;
452 
453  return make_base(*l);
454 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
void init_adf_data(cbify &data, const size_t num_actions)
Definition: cbify.cc:226
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12

◆ cbifyldf_setup()

base_learner* cbifyldf_setup ( options_i options,
vw all 
)

Definition at line 456 of file cbify.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), COST_SENSITIVE::cs_label, vw::delete_prediction, f, finish_multiline_example(), LEARNER::init_learner(), VW::config::options_i::insert(), parser::lp, LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, vw::p, LEARNER::learner< T, E >::set_finish_example(), setup_base(), prediction_type::to_string(), uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

457 {
458  auto data = scoped_calloc_or_throw<cbify>();
459  bool cbify_ldf_option = false;
460 
461  option_group_definition new_options("Make csoaa_ldf into Contextual Bandit");
462  new_options
463  .add(make_option("cbify_ldf", cbify_ldf_option).keep().help("Convert csoaa_ldf into a contextual bandit problem"))
464  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
465  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"));
466  options.add_and_parse(new_options);
467 
468  if (!options.was_supplied("cbify_ldf"))
469  return nullptr;
470 
471  data->app_seed = uniform_hash("vw", 2, 0);
472  data->all = &all;
473  data->use_adf = true;
474 
475  if (!options.was_supplied("cb_explore_adf"))
476  {
477  options.insert("cb_explore_adf", "");
478  }
479  options.insert("cb_min_cost", std::to_string(data->loss0));
480  options.insert("cb_max_cost", std::to_string(data->loss1));
481 
482  if (options.was_supplied("baseline"))
483  {
484  std::stringstream ss;
485  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
486  options.insert("lr_multiplier", ss.str());
487  }
488 
489  multi_learner* base = as_multiline(setup_base(options, all));
491  data, base, do_actual_learning_ldf<true>, do_actual_learning_ldf<false>, 1, prediction_type::multiclass);
492 
495  all.delete_prediction = nullptr;
496 
497  return make_base(l);
498 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
label_parser cs_label
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
label_parser lp
Definition: parser.h:102

◆ copy_example_to_adf()

void copy_example_to_adf ( cbify data,
example ec 
)

Definition at line 96 of file cbify.cc.

References a, cbify::adf_data, cbify::all, CB::cb_label, VW::copy_example_data(), label_parser::default_label, CB_ALGS::example_is_newline_not_header(), parameters::mask(), parameters::stride_shift(), label_parser::test_label, and vw::weights.

Referenced by predict_or_learn_adf().

97 {
98  auto& adf_data = data.adf_data;
99  const uint64_t ss = data.all->weights.stride_shift();
100  const uint64_t mask = data.all->weights.mask();
101 
102  for (size_t a = 0; a < adf_data.num_actions; ++a)
103  {
104  auto& eca = *adf_data.ecs[a];
105  // clear label
106  auto& lab = eca.l.cb;
108 
109  // copy data
110  VW::copy_example_data(false, &eca, &ec);
111 
112  // offset indices for given action
113  for (features& fs : eca)
114  {
115  for (feature_index& idx : fs.indicies)
116  {
117  idx = ((((idx >> ss) * 28904713) + 4832917 * (uint64_t)a) << ss) & mask;
118  }
119  }
120 
121  // avoid empty example by adding a tag (hacky)
123  {
124  eca.tag.push_back('n');
125  }
126  }
127 }
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
parameters weights
Definition: global_data.h:537
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
void(* default_label)(void *)
Definition: label_parser.h:12
bool(* test_label)(void *)
Definition: label_parser.h:22
the core definition of a set of features.
cbify_adf_data adf_data
Definition: cbify.cc:33
uint64_t feature_index
Definition: feature_group.h:21
label_parser cb_label
Definition: cb.cc:167
constexpr uint64_t a
Definition: rand48.cc:11
uint32_t stride_shift()
uint64_t mask()
vw * all
Definition: cbify.cc:31

◆ do_actual_learning_ldf()

template<bool is_learn>
void do_actual_learning_ldf ( cbify data,
multi_learner base,
multi_ex ec_seq 
)

Definition at line 242 of file cbify.cc.

References CB::cb_class::action, cbify::app_seed, ACTION_SCORE::begin_scores(), cbify::cb_as, cbify::cb_costs, cbify::cb_label, CB::cb_class::cost, CB::label::costs, cbify::cs_costs, ACTION_SCORE::end_scores(), cbify::example_counter, LEARNER::learner< T, E >::learn(), loss_csldf(), LEARNER::learner< T, E >::predict(), CB::cb_class::probability, exploration::sample_after_normalizing(), and THROW.

243 {
244  // change label and pred data for cb
245  if (data.cs_costs.size() < ec_seq.size())
246  data.cs_costs.resize(ec_seq.size());
247  if (data.cb_costs.size() < ec_seq.size())
248  data.cb_costs.resize(ec_seq.size());
249  if (data.cb_as.size() < ec_seq.size())
250  data.cb_as.resize(ec_seq.size());
251  for (size_t i = 0; i < ec_seq.size(); ++i)
252  {
253  auto& ec = *ec_seq[i];
254  data.cs_costs[i] = ec.l.cs.costs;
255  data.cb_costs[i].clear();
256  data.cb_as[i].clear();
257  ec.l.cb.costs = data.cb_costs[i];
258  ec.pred.a_s = data.cb_as[i];
259  }
260 
261  base.predict(ec_seq);
262 
263  auto& out_ec = *ec_seq[0];
264 
265  uint32_t chosen_action;
266  if (sample_after_normalizing(data.app_seed + data.example_counter++, begin_scores(out_ec.pred.a_s),
267  end_scores(out_ec.pred.a_s), chosen_action))
268  THROW("Failed to sample from pdf");
269 
270  CB::cb_class cl;
271  cl.action = out_ec.pred.a_s[chosen_action].action + 1;
272  cl.probability = out_ec.pred.a_s[chosen_action].score;
273 
274  if (!cl.action)
275  THROW("No action with non-zero probability found!");
276 
277  cl.cost = loss_csldf(data, data.cs_costs, cl.action);
278 
279  // add cb label to chosen action
280  data.cb_label.costs.clear();
281  data.cb_label.costs.push_back(cl);
282  data.cb_costs[cl.action - 1] = ec_seq[cl.action - 1]->l.cb.costs;
283  ec_seq[cl.action - 1]->l.cb = data.cb_label;
284 
285  base.learn(ec_seq);
286 
287  // set cs prediction and reset cs costs
288  for (size_t i = 0; i < ec_seq.size(); ++i)
289  {
290  auto& ec = *ec_seq[i];
291  data.cb_as[i] = ec.pred.a_s; // store action_score vector for later reuse.
292  if (i == cl.action - 1)
293  data.cb_label = ec.l.cb;
294  else
295  data.cb_costs[i] = ec.l.cb.costs;
296  ec.l.cs.costs = data.cs_costs[i];
297  if (i == cl.action - 1)
298  ec.pred.multiclass = cl.action;
299  else
300  ec.pred.multiclass = 0;
301  }
302 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t app_seed
Definition: cbify.cc:27
float loss_csldf(cbify &data, std::vector< v_array< COST_SENSITIVE::wclass >> &cs_costs, uint32_t final_prediction)
Definition: cbify.cc:82
std::vector< ACTION_SCORE::action_scores > cb_as
Definition: cbify.cc:40
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
v_array< cb_class > costs
Definition: cb.h:27
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
std::vector< v_array< COST_SENSITIVE::wclass > > cs_costs
Definition: cbify.cc:38
uint32_t action
Definition: cb.h:18
float probability
Definition: cb.h:19
CB::label cb_label
Definition: cbify.cc:26
size_t example_counter
Definition: cbify.cc:30
float cost
Definition: cb.h:17
std::vector< v_array< CB::cb_class > > cb_costs
Definition: cbify.cc:39
void learn(E &ec, size_t i=0)
Definition: learner.h:160
#define THROW(args)
Definition: vw_exception.h:181

◆ finish_multiline_example()

void finish_multiline_example ( vw all,
cbify ,
multi_ex ec_seq 
)

Definition at line 373 of file cbify.cc.

References VW::finish_example(), and output_example_seq().

Referenced by cbifyldf_setup(), CCB::ccb_explore_adf_setup(), CSOAA::csldf_setup(), explore_eval_setup(), VW::cb_explore_adf::softmax::setup(), VW::cb_explore_adf::greedy::setup(), VW::cb_explore_adf::first::setup(), VW::cb_explore_adf::bag::setup(), VW::cb_explore_adf::cover::setup(), VW::cb_explore_adf::regcb::setup(), and Search::setup().

374 {
375  if (!ec_seq.empty())
376  {
377  output_example_seq(all, ec_seq);
378  // global_print_newline(all);
379  }
380  VW::finish_example(all, ec_seq);
381 }
void output_example_seq(vw &all, multi_ex &ec_seq)
Definition: cbify.cc:356
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ init_adf_data()

void init_adf_data ( cbify data,
const size_t  num_actions 
)

Definition at line 226 of file cbify.cc.

References a, cbify::adf_data, cbify::all, VW::alloc_examples(), CB::cb_label, label_parser::default_label, vw::interactions, and cbify_adf_data::num_actions.

Referenced by cbify_setup().

227 {
228  auto& adf_data = data.adf_data;
229  adf_data.num_actions = num_actions;
230 
231  adf_data.ecs.resize(num_actions);
232  for (size_t a = 0; a < num_actions; ++a)
233  {
234  adf_data.ecs[a] = VW::alloc_examples(CB::cb_label.label_size, 1);
235  auto& lab = adf_data.ecs[a]->l.cb;
237  adf_data.ecs[a]->interactions = &data.all->interactions;
238  }
239 }
void(* default_label)(void *)
Definition: label_parser.h:12
cbify_adf_data adf_data
Definition: cbify.cc:33
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
label_parser cb_label
Definition: cb.cc:167
constexpr uint64_t a
Definition: rand48.cc:11
std::vector< std::string > interactions
Definition: global_data.h:457
size_t num_actions
Definition: cbify.cc:21
vw * all
Definition: cbify.cc:31

◆ loss()

float loss ( cbify data,
uint32_t  label,
uint32_t  final_prediction 
)

◆ loss_cs()

float loss_cs ( cbify data,
v_array< COST_SENSITIVE::wclass > &  costs,
uint32_t  final_prediction 
)

Definition at line 68 of file cbify.cc.

References cbify::loss0, and cbify::loss1.

Referenced by predict_or_learn(), and predict_or_learn_adf().

69 {
70  float cost = 0.;
71  for (auto wc : costs)
72  {
73  if (wc.class_index == final_prediction)
74  {
75  cost = wc.x;
76  break;
77  }
78  }
79  return data.loss0 + (data.loss1 - data.loss0) * cost;
80 }
float loss1
Definition: cbify.cc:35
float loss0
Definition: cbify.cc:34

◆ loss_csldf()

float loss_csldf ( cbify data,
std::vector< v_array< COST_SENSITIVE::wclass >> &  cs_costs,
uint32_t  final_prediction 
)

Definition at line 82 of file cbify.cc.

References cbify::loss0, and cbify::loss1.

Referenced by do_actual_learning_ldf().

83 {
84  float cost = 0.;
85  for (auto costs : cs_costs)
86  {
87  if (costs[0].class_index == final_prediction)
88  {
89  cost = costs[0].x;
90  break;
91  }
92  }
93  return data.loss0 + (data.loss1 - data.loss0) * cost;
94 }
float loss1
Definition: cbify.cc:35
float loss0
Definition: cbify.cc:34

◆ output_example()

void output_example ( vw all,
example ec,
bool &  hit_loss,
multi_ex ec_seq 
)

Definition at line 304 of file cbify.cc.

References COST_SENSITIVE::label::costs, polylabel::cs, COST_SENSITIVE::cs_label, COST_SENSITIVE::ec_is_example_header(), example_is_newline(), vw::final_prediction_sink, example::l, loss(), polyprediction::multiclass, example::num_features, example::pred, vw::print, vw::print_text, COST_SENSITIVE::print_update(), vw::raw_prediction, vw::sd, v_array< T >::size(), shared_data::sum_loss, shared_data::sum_loss_since_last_dump, example::tag, test_label(), and shared_data::total_features.

Referenced by output_example_seq().

305 {
306  COST_SENSITIVE::label& ld = ec.l.cs;
308 
309  if (example_is_newline(ec))
310  return;
312  return;
313 
314  all.sd->total_features += ec.num_features;
315 
316  float loss = 0.;
317 
318  uint32_t predicted_class = ec.pred.multiclass;
319 
321  {
322  for (auto const& cost : costs)
323  {
324  if (hit_loss)
325  break;
326  if (predicted_class == cost.class_index)
327  {
328  loss = cost.x;
329  hit_loss = true;
330  }
331  }
332 
333  all.sd->sum_loss += loss;
335  }
336 
337  for (int sink : all.final_prediction_sink) all.print(sink, (float)ec.pred.multiclass, 0, ec.tag);
338 
339  if (all.raw_prediction > 0)
340  {
341  std::string outputString;
342  std::stringstream outputStringStream(outputString);
343  for (size_t i = 0; i < costs.size(); i++)
344  {
345  if (i > 0)
346  outputStringStream << ' ';
347  outputStringStream << costs[i].class_index << ':' << costs[i].partial_prediction;
348  }
349  // outputStringStream << std::endl;
350  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
351  }
352 
353  COST_SENSITIVE::print_update(all, COST_SENSITIVE::cs_label.test_label(&ec.l), ec, ec_seq, false, predicted_class);
354 }
double sum_loss
Definition: global_data.h:145
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
uint32_t multiclass
Definition: example.h:49
label_parser cs_label
v_array< int > final_prediction_sink
Definition: global_data.h:518
bool ec_is_example_header(example const &ec)
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
int example_is_newline(example const &ec)
Definition: example.h:104
double sum_loss_since_last_dump
Definition: global_data.h:146
COST_SENSITIVE::label cs
Definition: example.h:30
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
polylabel l
Definition: example.h:57
bool test_label(void *v)
Definition: simple_label.cc:70
polyprediction pred
Definition: example.h:60
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores, uint32_t prediction)
v_array< wclass > costs
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
uint64_t total_features
Definition: global_data.h:138

◆ output_example_seq()

void output_example_seq ( vw all,
multi_ex ec_seq 
)

Definition at line 356 of file cbify.cc.

References shared_data::example_number, output_example(), vw::print_text, vw::raw_prediction, vw::sd, and shared_data::weighted_labeled_examples.

Referenced by EXPLORE_EVAL::finish_multiline_example(), finish_multiline_example(), CB_ADF::finish_multiline_example(), and CSOAA::finish_multiline_example().

357 {
358  if (ec_seq.empty())
359  return;
360  all.sd->weighted_labeled_examples += ec_seq[0]->weight;
361  all.sd->example_number++;
362 
363  bool hit_loss = false;
364  for (example* ec : ec_seq) output_example(all, *ec, hit_loss, &(ec_seq));
365 
366  if (all.raw_prediction > 0)
367  {
368  v_array<char> empty = {nullptr, nullptr, nullptr, 0};
369  all.print_text(all.raw_prediction, "", empty);
370  }
371 }
int raw_prediction
Definition: global_data.h:519
shared_data * sd
Definition: global_data.h:375
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
uint64_t example_number
Definition: global_data.h:137
double weighted_labeled_examples
Definition: global_data.h:141
void output_example(vw &all, example &ec, bool &hit_loss, multi_ex *ec_seq)
Definition: cbify.cc:304

◆ predict_or_learn()

template<bool is_learn, bool use_cs>
void predict_or_learn ( cbify data,
single_learner base,
example ec 
)

Definition at line 130 of file cbify.cc.

References cbify::a_s, polyprediction::a_s, CB::cb_class::action, cbify::app_seed, ACTION_SCORE::begin_scores(), polylabel::cb, cbify::cb_label, v_array< T >::clear(), CB::cb_class::cost, CB::label::costs, COST_SENSITIVE::label::costs, polylabel::cs, ACTION_SCORE::end_scores(), cbify::example_counter, example::l, MULTICLASS::label_t::label, LEARNER::learner< T, E >::learn(), loss(), loss_cs(), polylabel::multi, polyprediction::multiclass, example::pred, LEARNER::learner< T, E >::predict(), CB::cb_class::probability, exploration::sample_after_normalizing(), and THROW.

131 {
132  // Store the multiclass or cost-sensitive input label
135  if (use_cs)
136  csl = ec.l.cs;
137  else
138  ld = ec.l.multi;
139 
140  data.cb_label.costs.clear();
141  ec.l.cb = data.cb_label;
142  ec.pred.a_s = data.a_s;
143 
144  // Call the cb_explore algorithm. It returns a vector of probabilities for each action
145  base.predict(ec);
146  // data.probs = ec.pred.scalars;
147 
148  uint32_t chosen_action;
150  data.app_seed + data.example_counter++, begin_scores(ec.pred.a_s), end_scores(ec.pred.a_s), chosen_action))
151  THROW("Failed to sample from pdf");
152 
153  CB::cb_class cl;
154  cl.action = chosen_action + 1;
155  cl.probability = ec.pred.a_s[chosen_action].score;
156 
157  if (!cl.action)
158  THROW("No action with non-zero probability found!");
159  if (use_cs)
160  cl.cost = loss_cs(data, csl.costs, cl.action);
161  else
162  cl.cost = loss(data, ld.label, cl.action);
163 
164  // Create a new cb label
165  data.cb_label.costs.push_back(cl);
166  ec.l.cb = data.cb_label;
167 
168  if (is_learn)
169  base.learn(ec);
170 
171  data.a_s.clear();
172  data.a_s = ec.pred.a_s;
173 
174  if (use_cs)
175  ec.l.cs = csl;
176  else
177  ec.l.multi = ld;
178 
179  ec.pred.multiclass = cl.action;
180 }
uint32_t multiclass
Definition: example.h:49
ACTION_SCORE::action_scores a_s
Definition: example.h:47
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t app_seed
Definition: cbify.cc:27
action_scores a_s
Definition: cbify.cc:28
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t action
Definition: cb.h:18
COST_SENSITIVE::label cs
Definition: example.h:30
float probability
Definition: cb.h:19
void clear()
Definition: v_array.h:88
float loss_cs(cbify &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
Definition: cbify.cc:68
CB::label cb_label
Definition: cbify.cc:26
size_t example_counter
Definition: cbify.cc:30
polylabel l
Definition: example.h:57
float cost
Definition: cb.h:17
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< wclass > costs
#define THROW(args)
Definition: vw_exception.h:181

◆ predict_or_learn_adf()

template<bool is_learn, bool use_cs>
void predict_or_learn_adf ( cbify data,
multi_learner base,
example ec 
)

Definition at line 183 of file cbify.cc.

References CB::cb_class::action, cbify::adf_data, cbify::app_seed, ACTION_SCORE::begin_scores(), copy_example_to_adf(), CB::cb_class::cost, COST_SENSITIVE::label::costs, polylabel::cs, cbify_adf_data::ecs, ACTION_SCORE::end_scores(), cbify::example_counter, example::l, MULTICLASS::label_t::label, LEARNER::learner< T, E >::learn(), loss(), loss_cs(), polylabel::multi, polyprediction::multiclass, example::pred, LEARNER::learner< T, E >::predict(), CB::cb_class::probability, exploration::sample_after_normalizing(), and THROW.

184 {
185  // Store the multiclass or cost-sensitive input label
188  if (use_cs)
189  csl = ec.l.cs;
190  else
191  ld = ec.l.multi;
192 
193  copy_example_to_adf(data, ec);
194  base.predict(data.adf_data.ecs);
195 
196  auto& out_ec = *data.adf_data.ecs[0];
197 
198  uint32_t chosen_action;
199  if (sample_after_normalizing(data.app_seed + data.example_counter++, begin_scores(out_ec.pred.a_s),
200  end_scores(out_ec.pred.a_s), chosen_action))
201  THROW("Failed to sample from pdf");
202 
203  CB::cb_class cl;
204  cl.action = out_ec.pred.a_s[chosen_action].action + 1;
205  cl.probability = out_ec.pred.a_s[chosen_action].score;
206 
207  if (!cl.action)
208  THROW("No action with non-zero probability found!");
209 
210  if (use_cs)
211  cl.cost = loss_cs(data, csl.costs, cl.action);
212  else
213  cl.cost = loss(data, ld.label, cl.action);
214 
215  // add cb label to chosen action
216  auto& lab = data.adf_data.ecs[cl.action - 1]->l.cb;
217  lab.costs.clear();
218  lab.costs.push_back(cl);
219 
220  if (is_learn)
221  base.learn(data.adf_data.ecs);
222 
223  ec.pred.multiclass = cl.action;
224 }
multi_ex ecs
Definition: cbify.cc:20
uint32_t multiclass
Definition: example.h:49
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t app_seed
Definition: cbify.cc:27
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
cbify_adf_data adf_data
Definition: cbify.cc:33
void copy_example_to_adf(cbify &data, example &ec)
Definition: cbify.cc:96
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t action
Definition: cb.h:18
COST_SENSITIVE::label cs
Definition: example.h:30
float probability
Definition: cb.h:19
float loss_cs(cbify &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
Definition: cbify.cc:68
size_t example_counter
Definition: cbify.cc:30
polylabel l
Definition: example.h:57
float cost
Definition: cb.h:17
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< wclass > costs
#define THROW(args)
Definition: vw_exception.h:181