Vowpal Wabbit
cb_explore_adf_regcb.cc
Go to the documentation of this file.
1 #include "cb_explore_adf_regcb.h"
2 #include "reductions.h"
3 #include "cb_adf.h"
4 #include "rand48.h"
5 #include "bs.h"
6 #include "gen_cs_example.h"
7 #include "cb_explore.h"
8 #include "explore.h"
9 #include "action_score.h"
10 #include "cb.h"
11 #include <vector>
12 #include <algorithm>
13 #include <cmath>
14 
15 // All exploration algorithms return a vector of id, probability tuples, sorted in order of scores. The probabilities
16 // are the probability with which each action should be replaced to the top of the list.
17 
18 #define B_SEARCH_MAX_ITER 20
19 
20 namespace VW
21 {
22 namespace cb_explore_adf
23 {
24 namespace regcb
25 {
27 {
28  private:
29  size_t _counter;
30  bool _regcbopt; // use optimistic variant of RegCB
31  float _c0; // mellowness parameter for RegCB
33  float _min_cb_cost;
34  float _max_cb_cost;
35 
36  std::vector<float> _min_costs;
37  std::vector<float> _max_costs;
38 
39  // for backing up cb example data when computing sensitivities
40  std::vector<ACTION_SCORE::action_scores> _ex_as;
41  std::vector<v_array<CB::cb_class>> _ex_costs;
42 
43  public:
44  cb_explore_adf_regcb(bool regcbopt, float c0, bool first_only, float min_cb_cost, float max_cb_cost);
45  ~cb_explore_adf_regcb() = default;
46 
47  // Should be called through cb_explore_adf_base for pre/post-processing
48  void predict(LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<false>(base, examples); }
49  void learn(LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<true>(base, examples); }
50 
51  private:
52  template <bool is_learn>
54 
55  void get_cost_ranges(float delta, LEARNER::multi_learner& base, multi_ex& examples, bool min_only);
56  float binary_search(float fhat, float delta, float sens, float tol = 1e-6);
57 };
58 
60  bool regcbopt, float c0, bool first_only, float min_cb_cost, float max_cb_cost)
61  : _regcbopt(regcbopt), _c0(c0), _first_only(first_only), _min_cb_cost(min_cb_cost), _max_cb_cost(max_cb_cost)
62 {
63 }
64 
65 // TODO: same as cs_active.cc, move to shared place
66 float cb_explore_adf_regcb::binary_search(float fhat, float delta, float sens, float tol)
67 {
68  const float maxw = (std::min)(fhat / sens, FLT_MAX);
69 
70  if (maxw * fhat * fhat <= delta)
71  return maxw;
72 
73  float l = 0;
74  float u = maxw;
75  float w, v;
76 
77  for (int iter = 0; iter < B_SEARCH_MAX_ITER; iter++)
78  {
79  w = (u + l) / 2.f;
80  v = w * (fhat * fhat - (fhat - sens * w) * (fhat - sens * w)) - delta;
81  if (v > 0)
82  u = w;
83  else
84  l = w;
85  if (fabs(v) <= tol || u - l <= tol)
86  break;
87  }
88 
89  return l;
90 }
91 
92 void cb_explore_adf_regcb::get_cost_ranges(float delta, LEARNER::multi_learner& base, multi_ex& examples, bool min_only)
93 {
94  const size_t num_actions = examples[0]->pred.a_s.size();
95  _min_costs.resize(num_actions);
96  _max_costs.resize(num_actions);
97 
98  _ex_as.clear();
99  _ex_costs.clear();
100 
101  // backup cb example data
102  for (const auto& ex : examples)
103  {
104  _ex_as.push_back(ex->pred.a_s);
105  _ex_costs.push_back(ex->l.cb.costs);
106  }
107 
108  // set regressor predictions
109  for (const auto& as : _ex_as[0])
110  {
111  examples[as.action]->pred.scalar = as.score;
112  }
113 
114  const float cmin = _min_cb_cost;
115  const float cmax = _max_cb_cost;
116 
117  for (size_t a = 0; a < num_actions; ++a)
118  {
119  example* ec = examples[a];
120  ec->l.simple.label = cmin - 1;
121  float sens = base.sensitivity(*ec);
122  float w = 0; // importance weight
123 
124  if (ec->pred.scalar < cmin || std::isnan(sens) || std::isinf(sens))
125  _min_costs[a] = cmin;
126  else
127  {
128  w = binary_search(ec->pred.scalar - cmin + 1, delta, sens);
129  _min_costs[a] = (std::max)(ec->pred.scalar - sens * w, cmin);
130  if (_min_costs[a] > cmax)
131  _min_costs[a] = cmax;
132  }
133 
134  if (!min_only)
135  {
136  ec->l.simple.label = cmax + 1;
137  sens = base.sensitivity(*ec);
138  if (ec->pred.scalar > cmax || std::isnan(sens) || std::isinf(sens))
139  {
140  _max_costs[a] = cmax;
141  }
142  else
143  {
144  w = binary_search(cmax + 1 - ec->pred.scalar, delta, sens);
145  _max_costs[a] = (std::min)(ec->pred.scalar + sens * w, cmax);
146  if (_max_costs[a] < cmin)
147  _max_costs[a] = cmin;
148  }
149  }
150  }
151 
152  // reset cb example data
153  for (size_t i = 0; i < examples.size(); ++i)
154  {
155  examples[i]->pred.a_s = _ex_as[i];
156  examples[i]->l.cb.costs = _ex_costs[i];
157  }
158 }
159 
160 template <bool is_learn>
162 {
163  if (is_learn)
164  {
165  for (size_t i = 0; i < examples.size() - 1; ++i)
166  {
167  CB::label& ld = examples[i]->l.cb;
168  if (ld.costs.size() == 1)
169  ld.costs[0].probability = 1.f; // no importance weighting
170  }
171 
172  LEARNER::multiline_learn_or_predict<true>(base, examples, examples[0]->ft_offset);
173  ++_counter;
174  }
175  else
176  LEARNER::multiline_learn_or_predict<false>(base, examples, examples[0]->ft_offset);
177 
178  v_array<ACTION_SCORE::action_score>& preds = examples[0]->pred.a_s;
179  uint32_t num_actions = (uint32_t)preds.size();
180 
181  const float max_range = _max_cb_cost - _min_cb_cost;
182  // threshold on empirical loss difference
183  const float delta = _c0 * log((float)(num_actions * _counter)) * pow(max_range, 2);
184 
185  if (!is_learn)
186  {
187  get_cost_ranges(delta, base, examples, /*min_only=*/_regcbopt);
188 
189  if (_regcbopt) // optimistic variant
190  {
191  float min_cost = FLT_MAX;
192  size_t a_opt = 0; // optimistic action
193  for (size_t a = 0; a < num_actions; ++a)
194  {
195  if (_min_costs[a] < min_cost)
196  {
197  min_cost = _min_costs[a];
198  a_opt = a;
199  }
200  }
201  for (size_t i = 0; i < preds.size(); ++i)
202  {
203  if (preds[i].action == a_opt || (!_first_only && _min_costs[preds[i].action] == min_cost))
204  preds[i].score = 1;
205  else
206  preds[i].score = 0;
207  }
208  }
209  else // elimination variant
210  {
211  float min_max_cost = FLT_MAX;
212  for (size_t a = 0; a < num_actions; ++a)
213  if (_max_costs[a] < min_max_cost)
214  min_max_cost = _max_costs[a];
215  for (size_t i = 0; i < preds.size(); ++i)
216  {
217  if (_min_costs[preds[i].action] <= min_max_cost)
218  preds[i].score = 1;
219  else
220  preds[i].score = 0;
221  // explore uniformly on support
223  1.0, /*update_zero_elements=*/false, begin_scores(preds), end_scores(preds));
224  }
225  }
226  }
227 }
228 
230 {
231  using config::make_option;
232  bool cb_explore_adf_option = false;
233  bool regcb = false;
234  const std::string mtr = "mtr";
235  std::string type_string(mtr);
236  bool regcbopt = false;
237  float c0 = 0.;
238  bool first_only = false;
239  float min_cb_cost = 0.;
240  float max_cb_cost = 0.;
241  config::option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features");
242  new_options
243  .add(make_option("cb_explore_adf", cb_explore_adf_option)
244  .keep()
245  .help("Online explore-exploit for a contextual bandit problem with multiline action dependent features"))
246  .add(make_option("regcb", regcb).keep().help("RegCB-elim exploration"))
247  .add(make_option("regcbopt", regcbopt).keep().help("RegCB optimistic exploration"))
248  .add(make_option("mellowness", c0).keep().default_value(0.1f).help("RegCB mellowness parameter c_0. Default 0.1"))
249  .add(make_option("cb_min_cost", min_cb_cost).keep().default_value(0.f).help("lower bound on cost"))
250  .add(make_option("cb_max_cost", max_cb_cost).keep().default_value(1.f).help("upper bound on cost"))
251  .add(make_option("first_only", first_only).keep().help("Only explore the first action in a tie-breaking event"))
252  .add(make_option("cb_type", type_string)
253  .keep()
254  .help("contextual bandit method to use in {ips,dr,mtr}. Default: mtr"));
255  options.add_and_parse(new_options);
256 
257  if (!cb_explore_adf_option || !(options.was_supplied("regcb") || options.was_supplied("regcbopt")))
258  return nullptr;
259 
260  // Ensure serialization of cb_adf in all cases.
261  if (!options.was_supplied("cb_adf"))
262  {
263  options.insert("cb_adf", "");
264  }
265  if (type_string != mtr)
266  {
267  all.trace_message << "warning: bad cb_type, RegCB only supports mtr; resetting to mtr." << std::endl;
268  options.replace("cb_type", mtr);
269  }
270 
272 
273  // Set explore_type
274  size_t problem_multiplier = 1;
275 
276  LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
277  all.p->lp = CB::cb_label;
279 
280  using explore_type = cb_explore_adf_base<cb_explore_adf_regcb>;
281  auto data = scoped_calloc_or_throw<explore_type>(regcbopt, c0, first_only, min_cb_cost, max_cb_cost);
284 
285  l.set_finish_example(explore_type::finish_multiline_example);
286  return make_base(l);
287 }
288 
289 } // namespace regcb
290 } // namespace cb_explore_adf
291 } // namespace VW
void predict_or_learn_impl(LEARNER::multi_learner &base, multi_ex &examples)
float binary_search(float fhat, float delta, float sens, float tol=1e-6)
void(* delete_prediction)(void *)
Definition: global_data.h:485
float scalar
Definition: example.h:45
void predict(LEARNER::multi_learner &base, multi_ex &examples)
void get_cost_ranges(float delta, LEARNER::multi_learner &base, multi_ex &examples, bool min_only)
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
virtual void replace(const std::string &key, const std::string &value)=0
label_type::label_type_t label_type
Definition: global_data.h:550
v_array< cb_class > costs
Definition: cb.h:27
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
uint32_t action
Definition: search.h:19
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
LEARNER::base_learner * setup(VW::config::options_i &options, vw &all)
size_t size() const
Definition: v_array.h:68
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
parser * p
Definition: global_data.h:377
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last)
Updates the pdf to ensure each action is explored with at least minimum_uniform/num_actions.
std::vector< v_array< CB::cb_class > > _ex_costs
virtual void insert(const std::string &key, const std::string &value)=0
cb_explore_adf_regcb(bool regcbopt, float c0, bool first_only, float min_cb_cost, float max_cb_cost)
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: cb.h:25
Definition: autolink.cc:11
void learn(LEARNER::multi_learner &base, multi_ex &examples)
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
polyprediction pred
Definition: example.h:60
#define B_SEARCH_MAX_ITER
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
std::vector< ACTION_SCORE::action_scores > _ex_as
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
label_parser lp
Definition: parser.h:102