Vowpal Wabbit
Classes | Functions
ExpReplay Namespace Reference

Classes

struct  expreplay
 

Functions

template<bool is_learn, label_parser & lp>
void predict_or_learn (expreplay< lp > &er, LEARNER::single_learner &base, example &ec)
 
template<label_parser & lp>
void multipredict (expreplay< lp > &, LEARNER::single_learner &base, example &ec, size_t count, size_t step, polyprediction *pred, bool finalize_predictions)
 
template<label_parser & lp>
void end_pass (expreplay< lp > &er)
 
template<char er_level, label_parser & lp>
LEARNER::base_learnerexpreplay_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ end_pass()

template<label_parser & lp>
void ExpReplay::end_pass ( expreplay< lp > &  er)

Definition at line 69 of file expreplay.h.

References ExpReplay::expreplay< lp >::base, ExpReplay::expreplay< lp >::buf, ExpReplay::expreplay< lp >::filled, LEARNER::learner< T, E >::learn(), and ExpReplay::expreplay< lp >::N.

70 { // we need to go through and learn on everyone who remains
71  // also need to clean up remaining examples
72  for (size_t n = 0; n < er.N; n++)
73  if (er.filled[n])
74  { // TODO: if er.replay_count > 1 do we need to play these more?
75  er.base->learn(er.buf[n]);
76  er.filled[n] = false;
77  }
78 }

◆ expreplay_setup()

template<char er_level, label_parser & lp>
LEARNER::base_learner* ExpReplay::expreplay_setup ( VW::config::options_i options,
vw all 
)

Definition at line 81 of file expreplay.h.

References add(), VW::config::options_i::add_and_parse(), ExpReplay::expreplay< lp >::all, VW::alloc_examples(), LEARNER::as_singleline(), vw::get_random_state(), LEARNER::init_learner(), vw::interactions, LEARNER::make_base(), VW::config::make_option(), vw::quiet, LEARNER::learner< T, E >::set_end_pass(), setup_base(), and VW::config::options_i::was_supplied().

82 {
83  std::string replay_string = "replay_";
84  replay_string += er_level;
85  std::string replay_count_string = replay_string;
86  replay_count_string += "_count";
87 
88  auto er = scoped_calloc_or_throw<expreplay<lp>>();
89  VW::config::option_group_definition new_options("Experience Replay");
90  new_options
91  .add(VW::config::make_option(replay_string, er->N)
92  .keep()
93  .help("use experience replay at a specified level [b=classification/regression, m=multiclass, c=cost "
94  "sensitive] with specified buffer size"))
95  .add(VW::config::make_option(replay_count_string, er->replay_count)
96  .default_value(1)
97  .help("how many times (in expectation) should each example be played (default: 1 = permuting)"));
98  options.add_and_parse(new_options);
99 
100  if (!options.was_supplied(replay_string) || er->N == 0)
101  return nullptr;
102 
103  er->all = &all;
104  er->_random_state = all.get_random_state();
105  er->buf = VW::alloc_examples(1, er->N);
106  er->buf->interactions = &all.interactions;
107 
108  if (er_level == 'c')
109  for (size_t n = 0; n < er->N; n++) er->buf[n].l.cs.costs = v_init<COST_SENSITIVE::wclass>();
110 
111  er->filled = calloc_or_throw<bool>(er->N);
112 
113  if (!all.quiet)
114  std::cerr << "experience replay level=" << er_level << ", buffer=" << er->N << ", replay count=" << er->replay_count
115  << std::endl;
116 
117  er->base = LEARNER::as_singleline(setup_base(options, all));
119  &init_learner(er, er->base, predict_or_learn<true, lp>, predict_or_learn<false, lp>);
120  l->set_end_pass(end_pass<lp>);
121 
122  return make_base(*l);
123 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
std::vector< std::string > interactions
Definition: global_data.h:457
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ multipredict()

template<label_parser & lp>
void ExpReplay::multipredict ( expreplay< lp > &  ,
LEARNER::single_learner base,
example ec,
size_t  count,
size_t  step,
polyprediction pred,
bool  finalize_predictions 
)

Definition at line 62 of file expreplay.h.

References LEARNER::learner< T, E >::multipredict().

64 {
65  base.multipredict(ec, count, step, pred, finalize_predictions);
66 }
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178

◆ predict_or_learn()

template<bool is_learn, label_parser & lp>
void ExpReplay::predict_or_learn ( expreplay< lp > &  er,
LEARNER::single_learner base,
example ec 
)

Definition at line 35 of file expreplay.h.

References ExpReplay::expreplay< lp >::_random_state, ExpReplay::expreplay< lp >::all, vw::audit, ExpReplay::expreplay< lp >::buf, VW::copy_example_data(), ExpReplay::expreplay< lp >::filled, example::l, LEARNER::learner< T, E >::learn(), ExpReplay::expreplay< lp >::N, LEARNER::learner< T, E >::predict(), and ExpReplay::expreplay< lp >::replay_count.

36 { // regardless of what happens, we must predict
37  base.predict(ec);
38  // if we're not learning, that's all that has to happen
39  if (!is_learn || lp.get_weight(&ec.l) == 0.)
40  return;
41 
42  for (size_t replay = 1; replay < er.replay_count; replay++)
43  {
44  size_t n = (size_t)(er._random_state->get_and_update_random() * (float)er.N);
45  if (er.filled[n])
46  base.learn(er.buf[n]);
47  }
48 
49  size_t n = (size_t)(er._random_state->get_and_update_random() * (float)er.N);
50  if (er.filled[n])
51  base.learn(er.buf[n]);
52 
53  er.filled[n] = true;
54  VW::copy_example_data(er.all->audit, &er.buf[n], &ec); // don't copy the label
55  if (lp.copy_label)
56  lp.copy_label(&er.buf[n].l, &ec.l);
57  else
58  er.buf[n].l = ec.l;
59 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
polylabel l
Definition: example.h:57
void learn(E &ec, size_t i=0)
Definition: learner.h:160