Vowpal Wabbit
Classes | Functions
VW::cb_explore_adf::regcb Namespace Reference

Classes

struct  cb_explore_adf_regcb
 

Functions

LEARNER::base_learnersetup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ setup()

LEARNER::base_learner * VW::cb_explore_adf::regcb::setup ( VW::config::options_i options,
vw all 
)

Definition at line 229 of file cb_explore_adf_regcb.cc.

References prediction_type::action_probs, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), label_type::cb, CB::cb_label, ACTION_SCORE::delete_action_scores(), vw::delete_prediction, f, finish_multiline_example(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, learn(), parser::lp, LEARNER::make_base(), VW::config::make_option(), vw::p, predict(), VW::config::options_i::replace(), setup_base(), vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

230 {
231  using config::make_option;
232  bool cb_explore_adf_option = false;
233  bool regcb = false;
234  const std::string mtr = "mtr";
235  std::string type_string(mtr);
236  bool regcbopt = false;
237  float c0 = 0.;
238  bool first_only = false;
239  float min_cb_cost = 0.;
240  float max_cb_cost = 0.;
241  config::option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features");
242  new_options
243  .add(make_option("cb_explore_adf", cb_explore_adf_option)
244  .keep()
245  .help("Online explore-exploit for a contextual bandit problem with multiline action dependent features"))
246  .add(make_option("regcb", regcb).keep().help("RegCB-elim exploration"))
247  .add(make_option("regcbopt", regcbopt).keep().help("RegCB optimistic exploration"))
248  .add(make_option("mellowness", c0).keep().default_value(0.1f).help("RegCB mellowness parameter c_0. Default 0.1"))
249  .add(make_option("cb_min_cost", min_cb_cost).keep().default_value(0.f).help("lower bound on cost"))
250  .add(make_option("cb_max_cost", max_cb_cost).keep().default_value(1.f).help("upper bound on cost"))
251  .add(make_option("first_only", first_only).keep().help("Only explore the first action in a tie-breaking event"))
252  .add(make_option("cb_type", type_string)
253  .keep()
254  .help("contextual bandit method to use in {ips,dr,mtr}. Default: mtr"));
255  options.add_and_parse(new_options);
256 
257  if (!cb_explore_adf_option || !(options.was_supplied("regcb") || options.was_supplied("regcbopt")))
258  return nullptr;
259 
260  // Ensure serialization of cb_adf in all cases.
261  if (!options.was_supplied("cb_adf"))
262  {
263  options.insert("cb_adf", "");
264  }
265  if (type_string != mtr)
266  {
267  all.trace_message << "warning: bad cb_type, RegCB only supports mtr; resetting to mtr." << std::endl;
268  options.replace("cb_type", mtr);
269  }
270 
272 
273  // Set explore_type
274  size_t problem_multiplier = 1;
275 
276  LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
277  all.p->lp = CB::cb_label;
279 
280  using explore_type = cb_explore_adf_base<cb_explore_adf_regcb>;
281  auto data = scoped_calloc_or_throw<explore_type>(regcbopt, c0, first_only, min_cb_cost, max_cb_cost);
284 
285  l.set_finish_example(explore_type::finish_multiline_example);
286  return make_base(l);
287 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
virtual void replace(const std::string &key, const std::string &value)=0
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
label_parser lp
Definition: parser.h:102