Vowpal Wabbit
Classes | Functions
EXPLORE_EVAL Namespace Reference

Classes

struct  explore_eval
 

Functions

void finish (explore_eval &data)
 
void output_example (vw &all, explore_eval &c, example &ec, multi_ex *ec_seq)
 
void output_example_seq (vw &all, explore_eval &data, multi_ex &ec_seq)
 
void finish_multiline_example (vw &all, explore_eval &data, multi_ex &ec_seq)
 
template<bool is_learn>
void do_actual_learning (explore_eval &data, multi_learner &base, multi_ex &ec_seq)
 

Function Documentation

◆ do_actual_learning()

template<bool is_learn>
void EXPLORE_EVAL::do_actual_learning ( explore_eval data,
multi_learner base,
multi_ex ec_seq 
)

Definition at line 124 of file explore_eval.cc.

References EXPLORE_EVAL::explore_eval::_random_state, CB::cb_class::action, EXPLORE_EVAL::explore_eval::action_label, polylabel::cb, CB::label::costs, EXPLORE_EVAL::explore_eval::empty_label, EXPLORE_EVAL::explore_eval::fixed_multiplier, CB_ADF::get_observed_cost(), EXPLORE_EVAL::explore_eval::known_cost, example::l, EXPLORE_EVAL::explore_eval::multiplier, EXPLORE_EVAL::explore_eval::offset, CB::cb_class::probability, v_array< T >::size(), CB_ADF::test_adf_sequence(), EXPLORE_EVAL::explore_eval::update_count, EXPLORE_EVAL::explore_eval::violations, and example::weight.

125 {
126  example* label_example = CB_ADF::test_adf_sequence(ec_seq);
127 
128  if (label_example != nullptr) // extract label
129  {
130  data.action_label = label_example->l.cb;
131  label_example->l.cb = data.empty_label;
132  }
133  multiline_learn_or_predict<false>(base, ec_seq, data.offset);
134 
135  if (label_example != nullptr) // restore label
136  label_example->l.cb = data.action_label;
137 
138  data.known_cost = CB_ADF::get_observed_cost(ec_seq);
139  if (label_example != nullptr && is_learn)
140  {
141  ACTION_SCORE::action_scores& a_s = ec_seq[0]->pred.a_s;
142 
143  float action_probability = 0;
144  for (size_t i = 0; i < a_s.size(); i++)
145  if (data.known_cost.action == a_s[i].action)
146  action_probability = a_s[i].score;
147 
148  float threshold = action_probability / data.known_cost.probability;
149 
150  if (!data.fixed_multiplier)
151  data.multiplier = std::min(data.multiplier, 1 / threshold);
152  else
153  threshold *= data.multiplier;
154 
155  if (threshold > 1. + 1e-6)
156  data.violations++;
157 
158  if (data._random_state->get_and_update_random() < threshold)
159  {
160  example* ec_found = nullptr;
161  for (example*& ec : ec_seq)
162  {
163  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX && ec->l.cb.costs[0].probability > 0)
164  ec_found = ec;
165  if (threshold > 1)
166  ec->weight *= threshold;
167  }
168  ec_found->l.cb.costs[0].probability = action_probability;
169 
170  multiline_learn_or_predict<true>(base, ec_seq, data.offset);
171 
172  if (threshold > 1)
173  {
174  float inv_threshold = 1.f / threshold;
175  for (auto& ec : ec_seq) ec->weight *= inv_threshold;
176  }
177  ec_found->l.cb.costs[0].probability = data.known_cost.probability;
178  data.update_count++;
179  }
180  }
181 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
example * test_adf_sequence(multi_ex &ec_seq)
Definition: cb_adf.cc:268
size_t size() const
Definition: v_array.h:68
std::shared_ptr< rand_state > _random_state
Definition: explore_eval.cc:23
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
uint32_t action
Definition: cb.h:18
float probability
Definition: cb.h:19
polylabel l
Definition: example.h:57
float weight
Definition: example.h:62

◆ finish()

void EXPLORE_EVAL::finish ( explore_eval data)

Definition at line 36 of file explore_eval.cc.

References EXPLORE_EVAL::explore_eval::all, EXPLORE_EVAL::explore_eval::fixed_multiplier, EXPLORE_EVAL::explore_eval::multiplier, vw::quiet, vw::trace_message, EXPLORE_EVAL::explore_eval::update_count, and EXPLORE_EVAL::explore_eval::violations.

Referenced by explore_eval_setup().

37 {
38  if (!data.all->quiet)
39  {
40  data.all->trace_message << "update count = " << data.update_count << std::endl;
41  if (data.violations > 0)
42  data.all->trace_message << "violation count = " << data.violations << std::endl;
43  if (!data.fixed_multiplier)
44  data.all->trace_message << "final multiplier = " << data.multiplier << std::endl;
45  }
46 }
bool quiet
Definition: global_data.h:487
vw_ostream trace_message
Definition: global_data.h:424

◆ finish_multiline_example()

void EXPLORE_EVAL::finish_multiline_example ( vw all,
explore_eval data,
multi_ex ec_seq 
)

Definition at line 113 of file explore_eval.cc.

References vw::final_prediction_sink, VW::finish_example(), CB_ADF::global_print_newline(), and output_example_seq().

114 {
115  if (ec_seq.size() > 0)
116  {
117  output_example_seq(all, data, ec_seq);
119  }
120  VW::finish_example(all, ec_seq);
121 }
void output_example_seq(vw &all, multi_ex &ec_seq)
Definition: cbify.cc:356
v_array< int > final_prediction_sink
Definition: global_data.h:518
void global_print_newline(const v_array< int > &final_prediction_sink)
Definition: cb_adf.cc:342
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ output_example()

void EXPLORE_EVAL::output_example ( vw all,
explore_eval c,
example ec,
multi_ex ec_seq 
)

Definition at line 52 of file explore_eval.cc.

References polyprediction::a_s, polylabel::cb, CB::label::costs, CB::ec_is_example_header(), LEARNER::example_is_newline_not_header(), vw::final_prediction_sink, CB_ALGS::get_cost_estimate(), EXPLORE_EVAL::explore_eval::known_cost, example::l, loss(), example::pred, ACTION_SCORE::print_action_score(), vw::print_text, CB::print_update(), CB::cb_class::probability, vw::raw_prediction, vw::sd, v_array< T >::size(), example::tag, shared_data::update(), and example::weight.

Referenced by output_example_seq().

53 {
55  return;
56 
57  size_t num_features = 0;
58 
59  float loss = 0.;
60  ACTION_SCORE::action_scores preds = (*ec_seq)[0]->pred.a_s;
61 
62  for (size_t i = 0; i < (*ec_seq).size(); i++)
63  if (!CB::ec_is_example_header(*(*ec_seq)[i]))
64  num_features += (*ec_seq)[i]->num_features;
65 
66  bool labeled_example = true;
67  if (c.known_cost.probability > 0)
68  {
69  for (uint32_t i = 0; i < preds.size(); i++)
70  {
71  float l = get_cost_estimate(&c.known_cost, preds[i].action);
72  loss += l * preds[i].score;
73  }
74  }
75  else
76  labeled_example = false;
77 
78  bool holdout_example = labeled_example;
79  for (size_t i = 0; i < ec_seq->size(); i++) holdout_example &= (*ec_seq)[i]->test_only;
80 
81  all.sd->update(holdout_example, labeled_example, loss, ec.weight, num_features);
82 
83  for (int sink : all.final_prediction_sink) print_action_score(sink, ec.pred.a_s, ec.tag);
84 
85  if (all.raw_prediction > 0)
86  {
87  std::string outputString;
88  std::stringstream outputStringStream(outputString);
89  v_array<CB::cb_class> costs = ec.l.cb.costs;
90 
91  for (size_t i = 0; i < costs.size(); i++)
92  {
93  if (i > 0)
94  outputStringStream << ' ';
95  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
96  }
97  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
98  }
99 
100  CB::print_update(all, !labeled_example, ec, ec_seq, true);
101 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
ACTION_SCORE::action_scores a_s
Definition: example.h:47
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
CB::label cb
Definition: example.h:31
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
size_t size() const
Definition: v_array.h:68
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
shared_data * sd
Definition: global_data.h:375
float probability
Definition: cb.h:19
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
Definition: action_score.cc:8
float weight
Definition: example.h:62

◆ output_example_seq()

void EXPLORE_EVAL::output_example_seq ( vw all,
explore_eval data,
multi_ex ec_seq 
)

Definition at line 103 of file explore_eval.cc.

References output_example(), vw::print_text, and vw::raw_prediction.

104 {
105  if (ec_seq.size() > 0)
106  {
107  output_example(all, data, **(ec_seq.begin()), &(ec_seq));
108  if (all.raw_prediction > 0)
109  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
110  }
111 }
int raw_prediction
Definition: global_data.h:519
void output_example(vw &all, explore_eval &c, example &ec, multi_ex *ec_seq)
Definition: explore_eval.cc:52
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522