Vowpal Wabbit
Functions
cbify.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnercbify_setup (VW::config::options_i &options, vw &all)
 
LEARNER::base_learnercbifyldf_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ cbify_setup()

LEARNER::base_learner* cbify_setup ( VW::config::options_i options,
vw all 
)

Definition at line 383 of file cbify.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), LEARNER::as_singleline(), vw::delete_prediction, f, init_adf_data(), LEARNER::init_cost_sensitive_learner(), LEARNER::init_multiclass_learner(), VW::config::options_i::insert(), LEARNER::make_base(), VW::config::make_option(), vw::p, setup_base(), prediction_type::to_string(), uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

384 {
385  uint32_t num_actions = 0;
386  auto data = scoped_calloc_or_throw<cbify>();
387  bool use_cs;
388 
389  option_group_definition new_options("Make Multiclass into Contextual Bandit");
390  new_options
391  .add(make_option("cbify", num_actions)
392  .keep()
393  .help("Convert multiclass on <k> classes into a contextual bandit problem"))
394  .add(make_option("cbify_cs", use_cs).help("consume cost-sensitive classification examples instead of multiclass"))
395  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
396  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"));
397  options.add_and_parse(new_options);
398 
399  if (!options.was_supplied("cbify"))
400  return nullptr;
401 
402  data->use_adf = options.was_supplied("cb_explore_adf");
403  data->app_seed = uniform_hash("vw", 2, 0);
404  data->a_s = v_init<action_score>();
405  data->all = &all;
406 
407  if (data->use_adf)
408  init_adf_data(*data, num_actions);
409 
410  if (!options.was_supplied("cb_explore") && !data->use_adf)
411  {
412  std::stringstream ss;
413  ss << num_actions;
414  options.insert("cb_explore", ss.str());
415  }
416 
417  if (data->use_adf)
418  {
419  options.insert("cb_min_cost", std::to_string(data->loss0));
420  options.insert("cb_max_cost", std::to_string(data->loss1));
421  }
422 
423  if (options.was_supplied("baseline"))
424  {
425  std::stringstream ss;
426  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
427  options.insert("lr_multiplier", ss.str());
428  }
429 
431 
432  if (data->use_adf)
433  {
434  multi_learner* base = as_multiline(setup_base(options, all));
435  if (use_cs)
437  data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.p, 1);
438  else
440  data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.p, 1);
441  }
442  else
443  {
444  single_learner* base = as_singleline(setup_base(options, all));
445  if (use_cs)
447  data, base, predict_or_learn<true, true>, predict_or_learn<false, true>, all.p, 1);
448  else
449  l = &init_multiclass_learner(data, base, predict_or_learn<true, false>, predict_or_learn<false, false>, all.p, 1);
450  }
451  all.delete_prediction = nullptr;
452 
453  return make_base(*l);
454 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
void init_adf_data(cbify &data, const size_t num_actions)
Definition: cbify.cc:226
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12

◆ cbifyldf_setup()

LEARNER::base_learner* cbifyldf_setup ( VW::config::options_i options,
vw all 
)

Definition at line 456 of file cbify.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), COST_SENSITIVE::cs_label, vw::delete_prediction, f, finish_multiline_example(), LEARNER::init_learner(), VW::config::options_i::insert(), parser::lp, LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, vw::p, LEARNER::learner< T, E >::set_finish_example(), setup_base(), prediction_type::to_string(), uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

457 {
458  auto data = scoped_calloc_or_throw<cbify>();
459  bool cbify_ldf_option = false;
460 
461  option_group_definition new_options("Make csoaa_ldf into Contextual Bandit");
462  new_options
463  .add(make_option("cbify_ldf", cbify_ldf_option).keep().help("Convert csoaa_ldf into a contextual bandit problem"))
464  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
465  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"));
466  options.add_and_parse(new_options);
467 
468  if (!options.was_supplied("cbify_ldf"))
469  return nullptr;
470 
471  data->app_seed = uniform_hash("vw", 2, 0);
472  data->all = &all;
473  data->use_adf = true;
474 
475  if (!options.was_supplied("cb_explore_adf"))
476  {
477  options.insert("cb_explore_adf", "");
478  }
479  options.insert("cb_min_cost", std::to_string(data->loss0));
480  options.insert("cb_max_cost", std::to_string(data->loss1));
481 
482  if (options.was_supplied("baseline"))
483  {
484  std::stringstream ss;
485  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
486  options.insert("lr_multiplier", ss.str());
487  }
488 
489  multi_learner* base = as_multiline(setup_base(options, all));
491  data, base, do_actual_learning_ldf<true>, do_actual_learning_ldf<false>, 1, prediction_type::multiclass);
492 
495  all.delete_prediction = nullptr;
496 
497  return make_base(l);
498 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
label_parser cs_label
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
label_parser lp
Definition: parser.h:102