Vowpal Wabbit
cb_explore.cc
Go to the documentation of this file.
1 #include "reductions.h"
2 #include "cb_algs.h"
3 #include "rand48.h"
4 #include "bs.h"
5 #include "gen_cs_example.h"
6 #include "explore.h"
7 #include <memory>
8 
9 using namespace LEARNER;
10 using namespace ACTION_SCORE;
11 using namespace GEN_CS;
12 using namespace CB_ALGS;
13 using namespace exploration;
14 using namespace VW::config;
15 // All exploration algorithms return a vector of probabilities, to be used by GenericExplorer downstream
16 
17 namespace CB_EXPLORE
18 {
19 struct cb_explore
20 {
21  std::shared_ptr<rand_state> _random_state;
25 
29 
31 
32  size_t tau;
33  float epsilon;
34  size_t bag_size;
35  size_t cover_size;
36  float psi;
37 
38  size_t counter;
39 
41  {
42  preds.delete_v();
43  cover_probs.delete_v();
46  COST_SENSITIVE::cs_label.delete_label(&second_cs_label);
47  }
48 };
49 
50 template <bool is_learn>
52 {
53  // Explore tau times, then act according to optimal.
54  action_scores probs = ec.pred.a_s;
55 
56  if (is_learn && ec.l.cb.costs[0].probability < 1)
57  base.learn(ec);
58  else
59  base.predict(ec);
60 
61  probs.clear();
62  if (data.tau > 0)
63  {
64  float prob = 1.f / (float)data.cbcs.num_actions;
65  for (uint32_t i = 0; i < data.cbcs.num_actions; i++) probs.push_back({i, prob});
66  data.tau--;
67  }
68  else
69  {
70  uint32_t chosen = ec.pred.multiclass - 1;
71  for (uint32_t i = 0; i < data.cbcs.num_actions; i++) probs.push_back({i, 0.});
72  probs[chosen].score = 1.0;
73  }
74 
75  ec.pred.a_s = probs;
76 }
77 
78 template <bool is_learn>
80 {
81  // Explore uniform random an epsilon fraction of the time.
82  // TODO: pointers are copied here. What happens if base.learn/base.predict re-allocs?
83  // ec.pred.a_s = probs; will restore the than free'd memory
84  action_scores probs = ec.pred.a_s;
85  probs.clear();
86 
87  if (is_learn)
88  base.learn(ec);
89  else
90  base.predict(ec);
91 
92  // pre-allocate pdf
93  probs.resize(data.cbcs.num_actions);
94  for (uint32_t i = 0; i < data.cbcs.num_actions; i++) probs.push_back({i, 0});
96 
97  ec.pred.a_s = probs;
98 }
99 
100 template <bool is_learn>
102 {
103  // Randomize over predictions from a base set of predictors
104  action_scores probs = ec.pred.a_s;
105  probs.clear();
106 
107  for (uint32_t i = 0; i < data.cbcs.num_actions; i++) probs.push_back({i, 0.});
108  float prob = 1.f / (float)data.bag_size;
109  for (size_t i = 0; i < data.bag_size; i++)
110  {
111  uint32_t count = BS::weight_gen(data._random_state);
112  if (is_learn && count > 0)
113  base.learn(ec, i);
114  else
115  base.predict(ec, i);
116  uint32_t chosen = ec.pred.multiclass - 1;
117  probs[chosen].score += prob;
118  if (is_learn)
119  for (uint32_t j = 1; j < count; j++) base.learn(ec, i);
120  }
121 
122  ec.pred.a_s = probs;
123 }
124 
126 {
127  float additive_probability = 1.f / (float)data.cover_size;
128  data.preds.clear();
129 
130  for (uint32_t i = 0; i < data.cbcs.num_actions; i++) probs.push_back({i, 0.});
131 
132  for (size_t i = 0; i < data.cover_size; i++)
133  {
134  // get predicted cost-sensitive predictions
135  if (i == 0)
136  data.cs->predict(ec, i);
137  else
138  data.cs->predict(ec, i + 1);
139  uint32_t pred = ec.pred.multiclass;
140  probs[pred - 1].score += additive_probability;
141  data.preds.push_back((uint32_t)pred);
142  }
143  uint32_t num_actions = data.cbcs.num_actions;
144 
145  float min_prob = std::min(1.f / num_actions, 1.f / (float)std::sqrt(data.counter * num_actions));
146 
147  enforce_minimum_probability(min_prob * num_actions, false, begin_scores(probs), end_scores(probs));
148 
149  data.counter++;
150 }
151 
152 template <bool is_learn>
154 {
155  // Randomize over predictions from a base set of predictors
156  // Use cost sensitive oracle to cover actions to form distribution.
157 
158  uint32_t num_actions = data.cbcs.num_actions;
159 
160  action_scores probs = ec.pred.a_s;
161  probs.clear();
162  data.cs_label.costs.clear();
163 
164  for (uint32_t j = 0; j < num_actions; j++) data.cs_label.costs.push_back({FLT_MAX, j + 1, 0., 0.});
165 
166  size_t cover_size = data.cover_size;
167  size_t counter = data.counter;
168  v_array<float>& probabilities = data.cover_probs;
169  v_array<uint32_t>& predictions = data.preds;
170 
171  float additive_probability = 1.f / (float)cover_size;
172 
173  float min_prob = std::min(1.f / num_actions, 1.f / (float)std::sqrt(counter * num_actions));
174 
175  data.cb_label = ec.l.cb;
176 
177  ec.l.cs = data.cs_label;
178  get_cover_probabilities(data, base, ec, probs);
179 
180  if (is_learn)
181  {
182  ec.l.cb = data.cb_label;
183  base.learn(ec);
184 
185  // Now update oracles
186 
187  // 1. Compute loss vector
188  data.cs_label.costs.clear();
189  float norm = min_prob * num_actions;
190  ec.l.cb = data.cb_label;
192  gen_cs_example<false>(data.cbcs, ec, data.cb_label, data.cs_label);
193  for (uint32_t i = 0; i < num_actions; i++) probabilities[i] = 0;
194 
195  ec.l.cs = data.second_cs_label;
196  // 2. Update functions
197  for (size_t i = 0; i < cover_size; i++)
198  {
199  // Create costs of each action based on online cover
200  for (uint32_t j = 0; j < num_actions; j++)
201  {
202  float pseudo_cost =
203  data.cs_label.costs[j].x - data.psi * min_prob / (std::max(probabilities[j], min_prob) / norm) + 1;
204  data.second_cs_label.costs[j].class_index = j + 1;
205  data.second_cs_label.costs[j].x = pseudo_cost;
206  }
207  if (i != 0)
208  data.cs->learn(ec, i + 1);
209  if (probabilities[predictions[i] - 1] < min_prob)
210  norm += std::max(0.f, additive_probability - (min_prob - probabilities[predictions[i] - 1]));
211  else
212  norm += additive_probability;
213  probabilities[predictions[i] - 1] += additive_probability;
214  }
215  }
216 
217  ec.l.cb = data.cb_label;
218  ec.pred.a_s = probs;
219 }
220 
221 void print_update_cb_explore(vw& all, bool is_test, example& ec, std::stringstream& pred_string)
222 {
223  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
224  {
225  std::stringstream label_string;
226  if (is_test)
227  label_string << " unknown";
228  else
229  label_string << ec.l.cb.costs[0].action;
230  all.sd->print_update(all.holdout_set_off, all.current_pass, label_string.str(), pred_string.str(), ec.num_features,
231  all.progress_add, all.progress_arg);
232  }
233 }
234 
235 void output_example(vw& all, cb_explore& data, example& ec, CB::label& ld)
236 {
237  float loss = 0.;
238 
239  cb_to_cs& c = data.cbcs;
240 
241  if ((c.known_cost = get_observed_cost(ld)) != nullptr)
242  for (uint32_t i = 0; i < ec.pred.a_s.size(); i++)
243  loss += get_cost_estimate(c.known_cost, c.pred_scores, i + 1) * ec.pred.a_s[i].score;
244 
245  all.sd->update(ec.test_only, get_observed_cost(ld) != nullptr, loss, 1.f, ec.num_features);
246 
247  std::stringstream ss;
248  float maxprob = 0.;
249  uint32_t maxid = 0;
250  for (uint32_t i = 0; i < ec.pred.a_s.size(); i++)
251  {
252  ss << std::fixed << ec.pred.a_s[i].score << " ";
253  if (ec.pred.a_s[i].score > maxprob)
254  {
255  maxprob = ec.pred.a_s[i].score;
256  maxid = i + 1;
257  }
258  }
259  for (int sink : all.final_prediction_sink) all.print_text(sink, ss.str(), ec.tag);
260 
261  std::stringstream sso;
262  sso << maxid << ":" << std::fixed << maxprob;
263  print_update_cb_explore(all, CB::cb_label.test_label(&ld), ec, sso);
264 }
265 
267 {
268  output_example(all, c, ec, ec.l.cb);
269  VW::finish_example(all, ec);
270 }
271 } // namespace CB_EXPLORE
272 using namespace CB_EXPLORE;
273 
275 {
276  auto data = scoped_calloc_or_throw<cb_explore>();
277  option_group_definition new_options("Contextual Bandit Exploration");
278  new_options
279  .add(make_option("cb_explore", data->cbcs.num_actions)
280  .keep()
281  .help("Online explore-exploit for a <k> action contextual bandit problem"))
282  .add(make_option("first", data->tau).keep().help("tau-first exploration"))
283  .add(make_option("epsilon", data->epsilon).keep().default_value(0.05f).help("epsilon-greedy exploration"))
284  .add(make_option("bag", data->bag_size).keep().help("bagging-based exploration"))
285  .add(make_option("cover", data->cover_size).keep().help("Online cover based exploration"))
286  .add(make_option("psi", data->psi).keep().default_value(1.0f).help("disagreement parameter for cover"));
287  options.add_and_parse(new_options);
288 
289  if (!options.was_supplied("cb_explore"))
290  return nullptr;
291 
292  data->_random_state = all.get_random_state();
293  uint32_t num_actions = data->cbcs.num_actions;
294 
295  if (!options.was_supplied("cb"))
296  {
297  std::stringstream ss;
298  ss << data->cbcs.num_actions;
299  options.insert("cb", ss.str());
300  }
301 
303  data->cbcs.cb_type = CB_TYPE_DR;
304 
305  single_learner* base = as_singleline(setup_base(options, all));
306  data->cbcs.scorer = all.scorer;
307 
309  if (options.was_supplied("cover"))
310  {
312  data->second_cs_label.costs.resize(num_actions);
313  data->second_cs_label.costs.end() = data->second_cs_label.costs.begin() + num_actions;
314  data->cover_probs = v_init<float>();
315  data->cover_probs.resize(num_actions);
316  data->preds = v_init<uint32_t>();
317  data->preds.resize(data->cover_size);
318  l = &init_learner(data, base, predict_or_learn_cover<true>, predict_or_learn_cover<false>, data->cover_size + 1,
320  }
321  else if (options.was_supplied("bag"))
322  l = &init_learner(data, base, predict_or_learn_bag<true>, predict_or_learn_bag<false>, data->bag_size,
324  else if (options.was_supplied("first"))
325  l = &init_learner(
326  data, base, predict_or_learn_first<true>, predict_or_learn_first<false>, 1, prediction_type::action_probs);
327  else // greedy
328  l = &init_learner(
329  data, base, predict_or_learn_greedy<true>, predict_or_learn_greedy<false>, 1, prediction_type::action_probs);
330 
332  return make_base(*l);
333 }
void resize(size_t length)
Definition: v_array.h:69
v_array< char > tag
Definition: example.h:63
uint32_t multiclass
Definition: example.h:49
ACTION_SCORE::action_scores a_s
Definition: example.h:47
COST_SENSITIVE::label pred_scores
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint32_t weight_gen(std::shared_ptr< rand_state > &state)
Definition: bs.h:17
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
void(* delete_prediction)(void *)
Definition: global_data.h:485
label_parser cs_label
void(* delete_label)(void *)
Definition: label_parser.h:16
COST_SENSITIVE::label second_cs_label
Definition: cb_explore.cc:28
COST_SENSITIVE::label cs_label
Definition: cb_explore.cc:27
CB::label cb
Definition: example.h:31
void predict_or_learn_greedy(cb_explore &data, single_learner &base, example &ec)
Definition: cb_explore.cc:79
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
bool holdout_set_off
Definition: global_data.h:499
void print_update_cb_explore(vw &all, bool is_test, example &ec, std::stringstream &pred_string)
Definition: cb_explore.cc:221
bool progress_add
Definition: global_data.h:545
#define CB_TYPE_DR
Definition: cb_algs.h:13
size_t size() const
Definition: v_array.h:68
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
int generate_epsilon_greedy(float epsilon, uint32_t top_action, It pdf_first, It pdf_last)
Generates epsilon-greedy style exploration distribution.
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
base_learner * cb_explore_setup(options_i &options, vw &all)
Definition: cb_explore.cc:274
CB::cb_class * known_cost
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void clear()
Definition: v_array.h:88
std::shared_ptr< rand_state > _random_state
Definition: cb_explore.cc:21
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
void predict_or_learn_cover(cb_explore &data, single_learner &base, example &ec)
Definition: cb_explore.cc:153
uint32_t num_actions
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
learner< cb_explore, example > * cs
Definition: cb_explore.cc:30
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last)
Updates the pdf to ensure each action is explored with at least minimum_uniform/num_actions.
uint64_t current_pass
Definition: global_data.h:396
void output_example(vw &all, cb_explore &data, example &ec, CB::label &ld)
Definition: cb_explore.cc:235
v_array< float > cover_probs
Definition: cb_explore.cc:24
void finish_example(vw &, example &)
Definition: parser.cc:881
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: cb.h:25
void get_cover_probabilities(cb_explore &data, single_learner &, example &ec, v_array< action_score > &probs)
Definition: cb_explore.cc:125
v_array< uint32_t > preds
Definition: cb_explore.cc:23
bool test_label(void *v)
Definition: simple_label.cc:70
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void predict_or_learn_first(cb_explore &data, single_learner &base, example &ec)
Definition: cb_explore.cc:51
void delete_v()
Definition: v_array.h:98
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< wclass > costs
void finish_example(vw &all, cb_explore &c, example &ec)
Definition: cb_explore.cc:266
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
constexpr uint64_t c
Definition: rand48.cc:12
void predict_or_learn_bag(cb_explore &data, single_learner &base, example &ec)
Definition: cb_explore.cc:101
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76