Vowpal Wabbit
Classes | Namespaces | Functions
cb_algs.cc File Reference
#include <float.h>
#include "vw.h"
#include "reductions.h"
#include "cb_algs.h"
#include "vw_exception.h"
#include "gen_cs_example.h"

Go to the source code of this file.

Classes

struct  CB_ALGS::cb
 

Namespaces

 CB_ALGS
 

Functions

bool CB_ALGS::know_all_cost_example (CB::label &ld)
 
template<bool is_learn>
void CB_ALGS::predict_or_learn (cb &data, single_learner &base, example &ec)
 
void CB_ALGS::predict_eval (cb &, single_learner &, example &)
 
void CB_ALGS::learn_eval (cb &data, single_learner &, example &ec)
 
void CB_ALGS::output_example (vw &all, cb &data, example &ec, CB::label &ld)
 
void CB_ALGS::finish_example (vw &all, cb &c, example &ec)
 
void CB_ALGS::eval_finish_example (vw &all, cb &c, example &ec)
 
base_learnercb_algs_setup (options_i &options, vw &all)
 

Function Documentation

◆ cb_algs_setup()

base_learner* cb_algs_setup ( options_i options,
vw all 
)

Definition at line 132 of file cb_algs.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), c, label_type::cb, CB_EVAL::cb_eval, label_type::cb_eval, CB::cb_label, GEN_CS::cb_to_cs::cb_type, CB_TYPE_DM, CB_TYPE_DR, CB_TYPE_IPS, CB_ALGS::eval_finish_example(), CB_ALGS::finish_example(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, CB_ALGS::learn_eval(), parser::lp, LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, vw::p, CB_ALGS::predict_eval(), GEN_CS::cb_to_cs::scorer, vw::scorer, LEARNER::learner< T, E >::set_finish_example(), setup_base(), THROW, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

133 {
134  auto data = scoped_calloc_or_throw<cb>();
135  std::string type_string = "dr";
136  bool eval = false;
137 
138  option_group_definition new_options("Contextual Bandit Options");
139  new_options
140  .add(make_option("cb", data->cbcs.num_actions).keep().help("Use contextual bandit learning with <k> costs"))
141  .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dm,dr}"))
142  .add(make_option("eval", eval).help("Evaluate a policy rather than optimizing."));
143  options.add_and_parse(new_options);
144 
145  if (!options.was_supplied("cb"))
146  return nullptr;
147 
148  // Ensure serialization of this option in all cases.
149  if (!options.was_supplied("cb_type"))
150  {
151  options.insert("cb_type", type_string);
152  options.add_and_parse(new_options);
153  }
154 
155  cb_to_cs& c = data->cbcs;
156 
157  size_t problem_multiplier = 2; // default for DR
158  if (type_string.compare("dr") == 0)
159  c.cb_type = CB_TYPE_DR;
160  else if (type_string.compare("dm") == 0)
161  {
162  if (eval)
163  THROW("direct method can not be used for evaluation --- it is biased.");
164  c.cb_type = CB_TYPE_DM;
165  problem_multiplier = 1;
166  }
167  else if (type_string.compare("ips") == 0)
168  {
169  c.cb_type = CB_TYPE_IPS;
170  problem_multiplier = 1;
171  }
172  else
173  {
174  std::cerr << "warning: cb_type must be in {'ips','dm','dr'}; resetting to dr." << std::endl;
175  c.cb_type = CB_TYPE_DR;
176  }
177 
178  if (!options.was_supplied("csoaa"))
179  {
180  std::stringstream ss;
181  ss << data->cbcs.num_actions;
182  options.insert("csoaa", ss.str());
183  }
184 
185  auto base = as_singleline(setup_base(options, all));
186  if (eval)
187  {
188  all.p->lp = CB_EVAL::cb_eval;
190  }
191  else
192  {
193  all.p->lp = CB::cb_label;
195  }
196 
198  if (eval)
199  {
200  l = &init_learner(data, base, learn_eval, predict_eval, problem_multiplier, prediction_type::multiclass);
202  }
203  else
204  {
205  l = &init_learner(
206  data, base, predict_or_learn<true>, predict_or_learn<false>, problem_multiplier, prediction_type::multiclass);
208  }
209  c.scorer = all.scorer;
210 
211  return make_base(*l);
212 }
label_parser cb_eval
Definition: cb.cc:292
#define CB_TYPE_IPS
Definition: cb_algs.h:15
void finish_example(vw &all, cb &c, example &ec)
Definition: cb_algs.cc:119
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
#define CB_TYPE_DM
Definition: cb_algs.h:14
virtual void add_and_parse(const option_group_definition &group)=0
void learn_eval(cb &data, single_learner &, example &ec)
Definition: cb_algs.cc:77
#define CB_TYPE_DR
Definition: cb_algs.h:13
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
void predict_eval(cb &, single_learner &, example &)
Definition: cb_algs.cc:75
LEARNER::single_learner * scorer
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void eval_finish_example(vw &all, cb &c, example &ec)
Definition: cb_algs.cc:125
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
label_parser lp
Definition: parser.h:102