Vowpal Wabbit
explore_eval.cc
Go to the documentation of this file.
1 #include "reductions.h"
2 #include "cb_algs.h"
3 #include "vw.h"
4 #include "cb_adf.h"
5 #include "rand48.h"
6 #include "gen_cs_example.h"
7 #include <memory>
8 
9 // Do evaluation of nonstationary policies.
10 // input = contextual bandit label
11 // output = chosen ranking
12 
13 using namespace LEARNER;
14 using namespace CB_ALGS;
15 using namespace VW::config;
16 
17 namespace EXPLORE_EVAL
18 {
20 {
22  vw* all;
23  std::shared_ptr<rand_state> _random_state;
24  uint64_t offset;
28 
29  size_t update_count;
30  size_t violations;
31  float multiplier;
32 
34 };
35 
36 void finish(explore_eval& data)
37 {
38  if (!data.all->quiet)
39  {
40  data.all->trace_message << "update count = " << data.update_count << std::endl;
41  if (data.violations > 0)
42  data.all->trace_message << "violation count = " << data.violations << std::endl;
43  if (!data.fixed_multiplier)
44  data.all->trace_message << "final multiplier = " << data.multiplier << std::endl;
45  }
46 }
47 
48 // Semantics: Currently we compute the IPS loss no matter what flags
49 // are specified. We print the first action and probability, based on
50 // ordering by scores in the final output.
51 
52 void output_example(vw& all, explore_eval& c, example& ec, multi_ex* ec_seq)
53 {
55  return;
56 
57  size_t num_features = 0;
58 
59  float loss = 0.;
60  ACTION_SCORE::action_scores preds = (*ec_seq)[0]->pred.a_s;
61 
62  for (size_t i = 0; i < (*ec_seq).size(); i++)
63  if (!CB::ec_is_example_header(*(*ec_seq)[i]))
64  num_features += (*ec_seq)[i]->num_features;
65 
66  bool labeled_example = true;
67  if (c.known_cost.probability > 0)
68  {
69  for (uint32_t i = 0; i < preds.size(); i++)
70  {
71  float l = get_cost_estimate(&c.known_cost, preds[i].action);
72  loss += l * preds[i].score;
73  }
74  }
75  else
76  labeled_example = false;
77 
78  bool holdout_example = labeled_example;
79  for (size_t i = 0; i < ec_seq->size(); i++) holdout_example &= (*ec_seq)[i]->test_only;
80 
81  all.sd->update(holdout_example, labeled_example, loss, ec.weight, num_features);
82 
83  for (int sink : all.final_prediction_sink) print_action_score(sink, ec.pred.a_s, ec.tag);
84 
85  if (all.raw_prediction > 0)
86  {
87  std::string outputString;
88  std::stringstream outputStringStream(outputString);
89  v_array<CB::cb_class> costs = ec.l.cb.costs;
90 
91  for (size_t i = 0; i < costs.size(); i++)
92  {
93  if (i > 0)
94  outputStringStream << ' ';
95  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
96  }
97  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
98  }
99 
100  CB::print_update(all, !labeled_example, ec, ec_seq, true);
101 }
102 
103 void output_example_seq(vw& all, explore_eval& data, multi_ex& ec_seq)
104 {
105  if (ec_seq.size() > 0)
106  {
107  output_example(all, data, **(ec_seq.begin()), &(ec_seq));
108  if (all.raw_prediction > 0)
109  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
110  }
111 }
112 
114 {
115  if (ec_seq.size() > 0)
116  {
117  output_example_seq(all, data, ec_seq);
119  }
120  VW::finish_example(all, ec_seq);
121 }
122 
123 template <bool is_learn>
125 {
126  example* label_example = CB_ADF::test_adf_sequence(ec_seq);
127 
128  if (label_example != nullptr) // extract label
129  {
130  data.action_label = label_example->l.cb;
131  label_example->l.cb = data.empty_label;
132  }
133  multiline_learn_or_predict<false>(base, ec_seq, data.offset);
134 
135  if (label_example != nullptr) // restore label
136  label_example->l.cb = data.action_label;
137 
138  data.known_cost = CB_ADF::get_observed_cost(ec_seq);
139  if (label_example != nullptr && is_learn)
140  {
141  ACTION_SCORE::action_scores& a_s = ec_seq[0]->pred.a_s;
142 
143  float action_probability = 0;
144  for (size_t i = 0; i < a_s.size(); i++)
145  if (data.known_cost.action == a_s[i].action)
146  action_probability = a_s[i].score;
147 
148  float threshold = action_probability / data.known_cost.probability;
149 
150  if (!data.fixed_multiplier)
151  data.multiplier = std::min(data.multiplier, 1 / threshold);
152  else
153  threshold *= data.multiplier;
154 
155  if (threshold > 1. + 1e-6)
156  data.violations++;
157 
158  if (data._random_state->get_and_update_random() < threshold)
159  {
160  example* ec_found = nullptr;
161  for (example*& ec : ec_seq)
162  {
163  if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX && ec->l.cb.costs[0].probability > 0)
164  ec_found = ec;
165  if (threshold > 1)
166  ec->weight *= threshold;
167  }
168  ec_found->l.cb.costs[0].probability = action_probability;
169 
170  multiline_learn_or_predict<true>(base, ec_seq, data.offset);
171 
172  if (threshold > 1)
173  {
174  float inv_threshold = 1.f / threshold;
175  for (auto& ec : ec_seq) ec->weight *= inv_threshold;
176  }
177  ec_found->l.cb.costs[0].probability = data.known_cost.probability;
178  data.update_count++;
179  }
180  }
181 }
182 } // namespace EXPLORE_EVAL
183 
184 using namespace EXPLORE_EVAL;
185 
187 {
188  auto data = scoped_calloc_or_throw<explore_eval>();
189  bool explore_eval_option = false;
190  option_group_definition new_options("Explore evaluation");
191  new_options.add(make_option("explore_eval", explore_eval_option).keep().help("Evaluate explore_eval adf policies"))
192  .add(make_option("multiplier", data->multiplier)
193  .help("Multiplier used to make all rejection sample probabilities <= 1"));
194  options.add_and_parse(new_options);
195 
196  if (!explore_eval_option)
197  return nullptr;
198 
199  data->all = &all;
200  data->_random_state = all.get_random_state();
201 
202  if (options.was_supplied("multiplier"))
203  data->fixed_multiplier = true;
204  else
205  data->multiplier = 1;
206 
207  if (!options.was_supplied("cb_explore_adf"))
208  options.insert("cb_explore_adf", "");
209 
210  all.delete_prediction = nullptr;
211 
212  multi_learner* base = as_multiline(setup_base(options, all));
213  all.p->lp = CB::cb_label;
215 
217  init_learner(data, base, do_actual_learning<true>, do_actual_learning<false>, 1, prediction_type::action_probs);
218 
220  l.set_finish(finish);
221  return make_base(l);
222 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
ACTION_SCORE::action_scores a_s
Definition: example.h:47
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
base_learner * explore_eval_setup(options_i &options, vw &all)
void(* delete_prediction)(void *)
Definition: global_data.h:485
void output_example_seq(vw &all, multi_ex &ec_seq)
Definition: cbify.cc:356
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
CB::label cb
Definition: example.h:31
void output_example(vw &all, explore_eval &c, example &ec, multi_ex *ec_seq)
Definition: explore_eval.cc:52
label_type::label_type_t label_type
Definition: global_data.h:550
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
example * test_adf_sequence(multi_ex &ec_seq)
Definition: cb_adf.cc:268
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
size_t size() const
Definition: v_array.h:68
std::shared_ptr< rand_state > _random_state
Definition: explore_eval.cc:23
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
uint32_t action
Definition: cb.h:18
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
float probability
Definition: cb.h:19
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
void global_print_newline(const v_array< int > &final_prediction_sink)
Definition: cb_adf.cc:342
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void finish_example(vw &, example &)
Definition: parser.cc:881
virtual void insert(const std::string &key, const std::string &value)=0
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
void do_actual_learning(explore_eval &data, multi_learner &base, multi_ex &ec_seq)
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
bool example_is_newline_not_header(example &ec, vw &all)
Definition: learner.cc:68
Definition: cb.h:25
void set_finish(void(*f)(T &))
Definition: learner.h:265
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void finish(explore_eval &data)
Definition: explore_eval.cc:36
polyprediction pred
Definition: example.h:60
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
Definition: action_score.cc:8
float weight
Definition: example.h:62
constexpr uint64_t c
Definition: rand48.cc:12
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
label_parser lp
Definition: parser.h:102