Vowpal Wabbit
Functions
warm_cb.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerwarm_cb_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ warm_cb_setup()

LEARNER::base_learner* warm_cb_setup ( VW::config::options_i options,
vw all 
)

Definition at line 552 of file warm_cb.cc.

References ABS_CENTRAL, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), vw::delete_prediction, f, finish(), vw::get_random_state(), init_adf_data(), LEARNER::init_cost_sensitive_learner(), LEARNER::init_multiclass_learner(), VW::config::options_i::insert(), LEARNER::make_base(), VW::config::make_option(), vw::p, LEARNER::learner< T, E >::set_finish(), setup_base(), THROW, prediction_type::to_string(), UAR, uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

553 {
554  uint32_t num_actions = 0;
555  auto data = scoped_calloc_or_throw<warm_cb>();
556  bool use_cs;
557 
558  option_group_definition new_options("Make Multiclass into Warm-starting Contextual Bandit");
559 
560  new_options
561  .add(make_option("warm_cb", num_actions)
562  .keep()
563  .help("Convert multiclass on <k> classes into a contextual bandit problem"))
564  .add(make_option("warm_cb_cs", use_cs)
565  .help("consume cost-sensitive classification examples instead of multiclass"))
566  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
567  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"))
568  .add(make_option("warm_start", data->ws_period)
569  .default_value(0U)
570  .help("number of training examples for warm start phase"))
571  .add(make_option("epsilon", data->epsilon).keep().help("epsilon-greedy exploration"))
572  .add(make_option("interaction", data->inter_period)
573  .default_value(UINT32_MAX)
574  .help("number of examples for the interactive contextual bandit learning phase"))
575  .add(make_option("warm_start_update", data->upd_ws).help("indicator of warm start updates"))
576  .add(make_option("interaction_update", data->upd_inter).help("indicator of interaction updates"))
577  .add(make_option("corrupt_type_warm_start", data->cor_type_ws)
578  .default_value(UAR)
579  .help("type of label corruption in the warm start phase (1: uniformly at random, 2: circular, 3: "
580  "replacing with overwriting label)"))
581  .add(make_option("corrupt_prob_warm_start", data->cor_prob_ws)
582  .default_value(0.f)
583  .help("probability of label corruption in the warm start phase"))
584  .add(make_option("choices_lambda", data->choices_lambda)
585  .default_value(1U)
586  .help("the number of candidate lambdas to aggregate (lambda is the importance weight parameter between "
587  "the two sources)"))
588  .add(make_option("lambda_scheme", data->lambda_scheme)
589  .default_value(ABS_CENTRAL)
590  .help("The scheme for generating candidate lambda set (1: center lambda=0.5, 2: center lambda=0.5, min "
591  "lambda=0, max lambda=1, 3: center lambda=epsilon/(1+epsilon), 4: center "
592  "lambda=epsilon/(1+epsilon), min lambda=0, max lambda=1); the rest of candidate lambda values are "
593  "generated using a doubling scheme"))
594  .add(make_option("overwrite_label", data->overwrite_label)
595  .default_value(1U)
596  .help("the label used by type 3 corruptions (overwriting)"))
597  .add(make_option("sim_bandit", data->sim_bandit)
598  .help("simulate contextual bandit updates on warm start examples"));
599 
600  options.add_and_parse(new_options);
601 
602  if (use_cs && (options.was_supplied("corrupt_type_warm_start") || options.was_supplied("corrupt_prob_warm_start")))
603  {
604  THROW("label corruption on cost-sensitive examples not currently supported");
605  }
606 
607  if (!options.was_supplied("warm_cb"))
608  {
609  return nullptr;
610  }
611 
612  data->app_seed = uniform_hash("vw", 2, 0);
613  data->a_s = v_init<action_score>();
614  data->all = &all;
615  data->_random_state = all.get_random_state();
616  data->use_cs = use_cs;
617 
618  init_adf_data(*data.get(), num_actions);
619 
620  options.insert("cb_min_cost", std::to_string(data->loss0));
621  options.insert("cb_max_cost", std::to_string(data->loss1));
622 
623  if (options.was_supplied("baseline"))
624  {
625  std::stringstream ss;
626  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
627  options.insert("lr_multiplier", ss.str());
628  }
629 
631 
632  multi_learner* base = as_multiline(setup_base(options, all));
633  // Note: the current version of warm start CB can only support epsilon-greedy exploration
634  // We need to wait for the epsilon value to be passed from the base
635  // cb_explore learner, if there is one
636 
637  if (!options.was_supplied("epsilon"))
638  {
639  std::cerr << "Warning: no epsilon (greedy parameter) specified; resetting to 0.05" << std::endl;
640  data->epsilon = 0.05f;
641  }
642 
643  if (use_cs)
645  data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.p, data->choices_lambda);
646  else
648  data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.p, data->choices_lambda);
649 
650  l->set_finish(finish);
651  all.delete_prediction = nullptr;
652 
653  return make_base(*l);
654 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
#define UAR
Definition: warm_cb.cc:26
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void finish(warm_cb &data)
Definition: warm_cb.cc:152
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
#define ABS_CENTRAL
Definition: warm_cb.cc:30
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_finish(void(*f)(T &))
Definition: learner.h:265
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
void init_adf_data(warm_cb &data, const uint32_t num_actions)
Definition: warm_cb.cc:517
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12