Vowpal Wabbit
Classes | Namespaces | Functions
mwt.cc File Reference
#include "vw.h"
#include "reductions.h"
#include "gd.h"
#include "cb_algs.h"
#include "io_buf.h"

Go to the source code of this file.

Classes

struct  MWT::policy_data
 
struct  MWT::mwt
 

Namespaces

 MWT
 

Functions

bool MWT::observed_cost (CB::cb_class *cl)
 
CB::cb_classMWT::get_observed_cost (CB::label &ld)
 
void MWT::value_policy (mwt &c, float val, uint64_t index)
 
template<bool learn, bool exclude, bool is_learn>
void MWT::predict_or_learn (mwt &c, single_learner &base, example &ec)
 
void MWT::print_scalars (int f, v_array< float > &scalars, v_array< char > &tag)
 
void MWT::finish_example (vw &all, mwt &c, example &ec)
 
void MWT::save_load (mwt &c, io_buf &model_file, bool read, bool text)
 
base_learnermwt_setup (options_i &options, vw &all)
 

Function Documentation

◆ mwt_setup()

base_learner* mwt_setup ( options_i options,
vw all 
)

Definition at line 236 of file mwt.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), c, calloc_reserve(), label_type::cb, CB::cb_label, vw::delete_prediction, delete_scalars(), MWT::finish_example(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, vw::length(), parser::lp, LEARNER::make_base(), VW::config::make_option(), vw::p, MWT::save_load(), prediction_type::scalars, LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_save_load(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

237 {
238  auto c = scoped_calloc_or_throw<mwt>();
239  std::string s;
240  bool exclude_eval = false;
241  option_group_definition new_options("Multiworld Testing Options");
242  new_options.add(make_option("multiworld_test", s).keep().help("Evaluate features as a policies"))
243  .add(make_option("learn", c->num_classes).help("Do Contextual Bandit learning on <n> classes."))
244  .add(make_option("exclude_eval", exclude_eval).help("Discard mwt policy features before learning"));
245  options.add_and_parse(new_options);
246 
247  if (!options.was_supplied("multiworld_test"))
248  return nullptr;
249 
250  for (char i : s) c->namespaces[(unsigned char)i] = true;
251  c->all = &all;
252 
253  calloc_reserve(c->evals, all.length());
254  c->evals.end() = c->evals.begin() + all.length();
255 
257  all.p->lp = CB::cb_label;
259 
260  if (c->num_classes > 0)
261  {
262  c->learn = true;
263 
264  if (!options.was_supplied("cb"))
265  {
266  std::stringstream ss;
267  ss << c->num_classes;
268  options.insert("cb", ss.str());
269  }
270  }
271 
273  if (c->learn)
274  if (exclude_eval)
275  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<true, true, true>,
276  predict_or_learn<true, true, false>, 1, prediction_type::scalars);
277  else
278  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<true, false, true>,
279  predict_or_learn<true, false, false>, 1, prediction_type::scalars);
280  else
281  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<false, false, true>,
282  predict_or_learn<false, false, false>, 1, prediction_type::scalars);
283 
286  return make_base(*l);
287 }
size_t length()
Definition: global_data.h:513
void(* delete_prediction)(void *)
Definition: global_data.h:485
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_scalars(void *v)
Definition: example.h:37
virtual bool was_supplied(const std::string &key)=0
void finish_example(vw &all, mwt &c, example &ec)
Definition: mwt.cc:175
void calloc_reserve(v_array< T > &v, size_t length)
Definition: v_array.h:220
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
label_parser cb_label
Definition: cb.cc:167
void save_load(mwt &c, io_buf &model_file, bool read, bool text)
Definition: mwt.cc:195
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
constexpr uint64_t c
Definition: rand48.cc:12
label_parser lp
Definition: parser.h:102