Vowpal Wabbit
Classes | Namespaces | Functions
cb_explore.h File Reference

Go to the source code of this file.

Classes

struct  LEARNER::learner< T, E >
 

Namespaces

 LEARNER
 

Functions

LEARNER::base_learnercb_explore_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ cb_explore_setup()

LEARNER::base_learner* cb_explore_setup ( VW::config::options_i options,
vw all 
)

Definition at line 274 of file cb_explore.cc.

References prediction_type::action_probs, add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), CB_TYPE_DR, vw::cost_sensitive, ACTION_SCORE::delete_action_scores(), vw::delete_prediction, f, CB_EXPLORE::finish_example(), vw::get_random_state(), LEARNER::init_learner(), VW::config::options_i::insert(), LEARNER::make_base(), VW::config::make_option(), vw::scorer, LEARNER::learner< T, E >::set_finish_example(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

275 {
276  auto data = scoped_calloc_or_throw<cb_explore>();
277  option_group_definition new_options("Contextual Bandit Exploration");
278  new_options
279  .add(make_option("cb_explore", data->cbcs.num_actions)
280  .keep()
281  .help("Online explore-exploit for a <k> action contextual bandit problem"))
282  .add(make_option("first", data->tau).keep().help("tau-first exploration"))
283  .add(make_option("epsilon", data->epsilon).keep().default_value(0.05f).help("epsilon-greedy exploration"))
284  .add(make_option("bag", data->bag_size).keep().help("bagging-based exploration"))
285  .add(make_option("cover", data->cover_size).keep().help("Online cover based exploration"))
286  .add(make_option("psi", data->psi).keep().default_value(1.0f).help("disagreement parameter for cover"));
287  options.add_and_parse(new_options);
288 
289  if (!options.was_supplied("cb_explore"))
290  return nullptr;
291 
292  data->_random_state = all.get_random_state();
293  uint32_t num_actions = data->cbcs.num_actions;
294 
295  if (!options.was_supplied("cb"))
296  {
297  std::stringstream ss;
298  ss << data->cbcs.num_actions;
299  options.insert("cb", ss.str());
300  }
301 
303  data->cbcs.cb_type = CB_TYPE_DR;
304 
305  single_learner* base = as_singleline(setup_base(options, all));
306  data->cbcs.scorer = all.scorer;
307 
309  if (options.was_supplied("cover"))
310  {
312  data->second_cs_label.costs.resize(num_actions);
313  data->second_cs_label.costs.end() = data->second_cs_label.costs.begin() + num_actions;
314  data->cover_probs = v_init<float>();
315  data->cover_probs.resize(num_actions);
316  data->preds = v_init<uint32_t>();
317  data->preds.resize(data->cover_size);
318  l = &init_learner(data, base, predict_or_learn_cover<true>, predict_or_learn_cover<false>, data->cover_size + 1,
320  }
321  else if (options.was_supplied("bag"))
322  l = &init_learner(data, base, predict_or_learn_bag<true>, predict_or_learn_bag<false>, data->bag_size,
324  else if (options.was_supplied("first"))
325  l = &init_learner(
326  data, base, predict_or_learn_first<true>, predict_or_learn_first<false>, 1, prediction_type::action_probs);
327  else // greedy
328  l = &init_learner(
329  data, base, predict_or_learn_greedy<true>, predict_or_learn_greedy<false>, 1, prediction_type::action_probs);
330 
332  return make_base(*l);
333 }
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
void(* delete_prediction)(void *)
Definition: global_data.h:485
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
#define CB_TYPE_DR
Definition: cb_algs.h:13
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
virtual bool was_supplied(const std::string &key)=0
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void finish_example(vw &all, cb_explore &c, example &ec)
Definition: cb_explore.cc:266
float f
Definition: cache.cc:40