Vowpal Wabbit
Classes | Functions
VW::cb_explore_adf::cover Namespace Reference

Classes

struct  cb_explore_adf_cover
 

Functions

LEARNER::base_learnersetup (config::options_i &options, vw &all)
 

Function Documentation

◆ setup()

LEARNER::base_learner * VW::cb_explore_adf::cover::setup ( config::options_i options,
vw all 
)

Definition at line 174 of file cb_explore_adf_cover.cc.

References prediction_type::action_probs, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), label_type::cb, CB::cb_label, CB_TYPE_DR, CB_TYPE_IPS, CB_TYPE_MTR, vw::cost_sensitive, ACTION_SCORE::delete_action_scores(), vw::delete_prediction, f, finish_multiline_example(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, learn(), parser::lp, LEARNER::make_base(), VW::config::make_option(), vw::p, predict(), VW::config::options_i::replace(), vw::scorer, setup_base(), vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

175 {
176  using config::make_option;
177 
178  bool cb_explore_adf_option = false;
179  std::string type_string = "mtr";
180  size_t cover_size = 0;
181  float psi = 0.;
182  bool nounif = false;
183  bool first_only = false;
184 
185  config::option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features");
186  new_options
187  .add(make_option("cb_explore_adf", cb_explore_adf_option)
188  .keep()
189  .help("Online explore-exploit for a contextual bandit problem with multiline action dependent features"))
190  .add(make_option("cover", cover_size).keep().help("Online cover based exploration"))
191  .add(make_option("psi", psi).keep().default_value(1.0f).help("disagreement parameter for cover"))
192  .add(make_option("nounif", nounif).keep().help("do not explore uniformly on zero-probability actions in cover"))
193  .add(make_option("first_only", first_only).keep().help("Only explore the first action in a tie-breaking event"))
194  .add(make_option("cb_type", type_string)
195  .keep()
196  .help("contextual bandit method to use in {ips,dr,mtr}. Default: mtr"));
197  options.add_and_parse(new_options);
198 
199  if (!cb_explore_adf_option || !options.was_supplied("cover"))
200  return nullptr;
201 
202  // Ensure serialization of cb_type in all cases.
203  if (!options.was_supplied("cb_type"))
204  {
205  options.insert("cb_type", type_string);
206  options.add_and_parse(new_options);
207  }
208 
209  // Ensure serialization of cb_adf in all cases.
210  if (!options.was_supplied("cb_adf"))
211  {
212  options.insert("cb_adf", "");
213  }
214 
216 
217  // Set cb_type
218  size_t cb_type_enum;
219  if (type_string.compare("dr") == 0)
220  cb_type_enum = CB_TYPE_DR;
221  else if (type_string.compare("ips") == 0)
222  cb_type_enum = CB_TYPE_IPS;
223  else if (type_string.compare("mtr") == 0)
224  {
225  all.trace_message << "warning: currently, mtr is only used for the first policy in cover, other policies use dr"
226  << std::endl;
227  cb_type_enum = CB_TYPE_MTR;
228  }
229  else
230  {
231  all.trace_message << "warning: cb_type must be in {'ips','dr','mtr'}; resetting to mtr." << std::endl;
232  options.replace("cb_type", "mtr");
233  cb_type_enum = CB_TYPE_MTR;
234  }
235 
236  // Set explore_type
237  size_t problem_multiplier = cover_size + 1;
238 
240  all.p->lp = CB::cb_label;
242 
243  using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
244  auto data = scoped_calloc_or_throw<explore_type>(
245  cover_size, psi, nounif, first_only, as_multiline(all.cost_sensitive), all.scorer, cb_type_enum);
246 
249 
250  l.set_finish_example(explore_type::finish_multiline_example);
251  return make_base(l);
252 }
#define CB_TYPE_IPS
Definition: cb_algs.h:15
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
void(* delete_prediction)(void *)
Definition: global_data.h:485
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
#define CB_TYPE_DR
Definition: cb_algs.h:13
parser * p
Definition: global_data.h:377
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
vw_ostream trace_message
Definition: global_data.h:424
LEARNER::single_learner * scorer
Definition: global_data.h:384
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
label_parser lp
Definition: parser.h:102
#define CB_TYPE_MTR
Definition: cb_algs.h:16