Vowpal Wabbit
Classes | Namespaces | Functions
cb_sample.cc File Reference
#include "reductions.h"
#include "cb_sample.h"
#include "explore.h"
#include "rand48.h"

Go to the source code of this file.

Classes

struct  VW::cb_sample_data
 

Namespaces

 VW
 

Functions

template<bool is_learn>
void learn_or_predict (cb_sample_data &data, multi_learner &base, multi_ex &examples)
 
base_learnercb_sample_setup (options_i &options, vw &all)
 

Function Documentation

◆ cb_sample_setup()

base_learner* cb_sample_setup ( options_i options,
vw all 
)

Definition at line 97 of file cb_sample.cc.

References prediction_type::action_probs, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), vw::get_random_state(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), setup_base(), THROW, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

98 {
99  bool cb_sample_option = false;
100 
101  option_group_definition new_options("CB Sample");
102  new_options.add(make_option("cb_sample", cb_sample_option).keep().help("Sample from CB pdf and swap top action."));
103  options.add_and_parse(new_options);
104 
105  if (!cb_sample_option)
106  return nullptr;
107 
108  if (options.was_supplied("no_predict"))
109  {
110  THROW("cb_sample cannot be used with no_predict, as there would be no predictions to sample.");
111  }
112 
113  auto data = scoped_calloc_or_throw<cb_sample_data>(all.get_random_state());
114  return make_base(init_learner(data, as_multiline(setup_base(options, all)), learn_or_predict<true>,
115  learn_or_predict<false>, 1 /* weights */, prediction_type::action_probs));
116 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468

◆ learn_or_predict()

template<bool is_learn>
void learn_or_predict ( cb_sample_data data,
multi_learner base,
multi_ex examples 
)

Definition at line 92 of file cb_sample.cc.

References VW::cb_sample_data::learn_or_predict().

93 {
94  data.learn_or_predict<is_learn>(base, examples);
95 }
void learn_or_predict(multi_learner &base, multi_ex &examples)
Definition: cb_sample.cc:20