Vowpal Wabbit
cb_explore_adf_common.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5 */
6 #pragma once
7 #include <stdint.h>
8 #include <algorithm>
9 
10 // Most of these includes are required because templated functions are using the objects defined in them
11 // A few options to get rid of them:
12 // - Use virtual function calls in predict/learn to get rid of the templates entirely (con: virtual function calls)
13 // - Cut out the portions of code that actually use the objects and put them into new functions
14 // defined in the cc file (con: can't inline those functions)
15 // - templatize all input parameters (con: no type safety)
16 #include "v_array.h" // required by action_score.h
17 #include "action_score.h" // used in sort_action_probs
18 #include "cb.h" // required for CB::label
19 #include "cb_adf.h" // used for function call in predict/learn
20 #include "example.h" // used in predict
21 #include "gen_cs_example.h" // required for GEN_CS::cb_to_cs_adf
22 #include "reductions_fwd.h"
23 
24 namespace VW
25 {
26 namespace cb_explore_adf
27 {
28 // Free functions
29 inline void sort_action_probs(v_array<ACTION_SCORE::action_score>& probs, const std::vector<float>& scores)
30 {
31  // We want to preserve the score order in the returned action_probs if possible. To do this,
32  // sort top_actions and action_probs by the order induced in scores.
33  std::sort(probs.begin(), probs.end(),
34  [&scores](const ACTION_SCORE::action_score& as1, const ACTION_SCORE::action_score& as2) {
35  if (as1.score > as2.score)
36  return true;
37  else if (as1.score < as2.score)
38  return false;
39  // equal probabilities
40  if (scores[as1.action] < scores[as2.action])
41  return true;
42  else if (scores[as1.action] > scores[as2.action])
43  return false;
44  // equal probabilities and equal cost estimates
45  return as1.action < as2.action;
46  });
47 }
49 {
50  if (preds.size() == 0)
51  return 0;
52  size_t ret = 1;
53  for (size_t i = 1; i < preds.size(); ++i)
54  if (preds[i].score == preds[0].score)
55  ++ret;
56  else
57  return ret;
58  return ret;
59 }
60 
61 // Object
62 template <typename ExploreType>
63 // data common to all cb_explore_adf reductions
65 {
66  private:
68  // used in output_example
71  ExploreType explore;
72 
73  public:
74  template <typename... Args>
75  cb_explore_adf_base(Args&&... args) : explore(std::forward<Args>(args)...)
76  {
77  }
78  static void finish_multiline_example(vw& all, cb_explore_adf_base<ExploreType>& data, multi_ex& ec_seq);
79  static void predict(cb_explore_adf_base<ExploreType>& data, LEARNER::multi_learner& base, multi_ex& examples);
80  static void learn(cb_explore_adf_base<ExploreType>& data, LEARNER::multi_learner& base, multi_ex& examples);
81 
82  private:
83  void output_example_seq(vw& all, multi_ex& ec_seq);
84  void output_example(vw& all, multi_ex& ec_seq);
85 };
86 
87 template <typename ExploreType>
90 {
91  example* label_example = CB_ADF::test_adf_sequence(examples);
92  data._known_cost = CB_ADF::get_observed_cost(examples);
93 
94  if (label_example != nullptr)
95  {
96  // predict path, replace the label example with an empty one
97  data._action_label = label_example->l.cb;
98  label_example->l.cb = data._empty_label;
99  }
100 
101  data.explore.predict(base, examples);
102 
103  if (label_example != nullptr)
104  {
105  // predict path, restore label
106  label_example->l.cb = data._action_label;
107  }
108 }
109 
110 template <typename ExploreType>
113 {
114  example* label_example = CB_ADF::test_adf_sequence(examples);
115  if (label_example != nullptr)
116  {
117  data._known_cost = CB_ADF::get_observed_cost(examples);
118  // learn iff label_example != nullptr
119  data.explore.learn(base, examples);
120  }
121  else
122  {
123  predict(data, base, examples);
124  }
125 }
126 
127 template <typename ExploreType>
129 {
130  if (ec_seq.size() <= 0)
131  return;
132 
133  size_t num_features = 0;
134 
135  float loss = 0.;
136 
137  auto& ec = *ec_seq[0];
138  ACTION_SCORE::action_scores preds = ec.pred.a_s;
139 
140  for (const auto& example : ec_seq)
141  {
142  num_features += example->num_features;
143  }
144 
145  bool labeled_example = true;
146  if (_known_cost.probability > 0)
147  {
148  for (uint32_t i = 0; i < preds.size(); i++)
149  {
150  float l = CB_ALGS::get_cost_estimate(&_known_cost, preds[i].action);
151  loss += l * preds[i].score;
152  }
153  }
154  else
155  labeled_example = false;
156 
157  bool holdout_example = labeled_example;
158  for (size_t i = 0; i < ec_seq.size(); i++) holdout_example &= ec_seq[i]->test_only;
159 
160  all.sd->update(holdout_example, labeled_example, loss, ec.weight, num_features);
161 
162  for (auto sink : all.final_prediction_sink) ACTION_SCORE::print_action_score(sink, ec.pred.a_s, ec.tag);
163 
164  if (all.raw_prediction > 0)
165  {
166  std::string outputString;
167  std::stringstream outputStringStream(outputString);
168  v_array<CB::cb_class> costs = ec.l.cb.costs;
169 
170  for (size_t i = 0; i < costs.size(); i++)
171  {
172  if (i > 0)
173  outputStringStream << ' ';
174  outputStringStream << costs[i].action << ':' << costs[i].partial_prediction;
175  }
176  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
177  }
178 
179  CB::print_update(all, !labeled_example, ec, &ec_seq, true);
180 }
181 
182 template <typename ExploreType>
184 {
185  if (ec_seq.size() > 0)
186  {
187  output_example(all, ec_seq);
188  if (all.raw_prediction > 0)
189  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
190  }
191 }
192 
193 template <typename ExploreType>
195  vw& all, cb_explore_adf_base<ExploreType>& data, multi_ex& ec_seq)
196 {
197  if (ec_seq.size() > 0)
198  {
199  data.output_example_seq(all, ec_seq);
201  }
202 
203  VW::finish_example(all, ec_seq);
204 }
205 } // namespace cb_explore_adf
206 } // namespace VW
int raw_prediction
Definition: global_data.h:519
static void learn(cb_explore_adf_base< ExploreType > &data, LEARNER::multi_learner &base, multi_ex &examples)
CB::label cb
Definition: example.h:31
v_array< int > final_prediction_sink
Definition: global_data.h:518
example * test_adf_sequence(multi_ex &ec_seq)
Definition: cb_adf.cc:268
uint32_t action
Definition: search.h:19
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
size_t fill_tied(v_array< ACTION_SCORE::action_score > &preds)
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
shared_data * sd
Definition: global_data.h:375
float probability
Definition: cb.h:19
size_t num_features
Definition: example.h:67
void global_print_newline(const v_array< int > &final_prediction_sink)
Definition: cb_adf.cc:342
void output_example(vw &all, multi_ex &ec_seq)
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void finish_example(vw &, example &)
Definition: parser.cc:881
T *& end()
Definition: v_array.h:43
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
Definition: cb.h:25
Definition: autolink.cc:11
static void finish_multiline_example(vw &all, cb_explore_adf_base< ExploreType > &data, multi_ex &ec_seq)
static void predict(cb_explore_adf_base< ExploreType > &data, LEARNER::multi_learner &base, multi_ex &examples)
void sort_action_probs(v_array< ACTION_SCORE::action_score > &probs, const std::vector< float > &scores)
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
Definition: action_score.cc:8
void output_example_seq(vw &all, multi_ex &ec_seq)