Vowpal Wabbit
warm_cb.cc
Go to the documentation of this file.
1 #include <float.h>
2 #include "reductions.h"
3 #include "cb_algs.h"
4 #include "rand48.h"
5 #include "bs.h"
6 #include "vw.h"
7 #include "hash.h"
8 #include "explore.h"
9 #include "vw_exception.h"
10 
11 #include <vector>
12 #include <memory>
13 
14 using namespace LEARNER;
15 using namespace exploration;
16 using namespace ACTION_SCORE;
17 using namespace VW::config;
18 
19 #define WARM_START 1
20 #define INTERACTION 2
21 #define SKIP 3
22 
23 #define SUPERVISED_WS 1
24 #define BANDIT_WS 2
25 
26 #define UAR 1
27 #define CIRCULAR 2
28 #define OVERWRITE 3
29 
30 #define ABS_CENTRAL 1
31 #define ABS_CENTRAL_ZEROONE 2
32 #define MINIMAX_CENTRAL 3
33 #define MINIMAX_CENTRAL_ZEROONE 4
34 
35 struct warm_cb
36 {
38  uint64_t app_seed;
40  // used as the seed
42  vw* all;
43  std::shared_ptr<rand_state> _random_state;
45  float loss0;
46  float loss1;
47 
48  // warm start parameters
49  uint32_t ws_period;
50  uint32_t inter_period;
51  uint32_t choices_lambda;
52  bool upd_ws;
53  bool upd_inter;
55  float cor_prob_ws;
57  int wt_scheme;
59  uint32_t overwrite_label;
60  int ws_type;
61  bool sim_bandit;
62 
63  // auxiliary variables
64  uint32_t num_actions;
65  float epsilon;
66  std::vector<float> lambdas;
68  std::vector<float> cumulative_costs;
70  uint32_t ws_train_size;
71  uint32_t ws_vali_size;
72  std::vector<example*> ws_vali;
73  float cumu_var;
74  uint32_t ws_iter;
75  uint32_t inter_iter;
80  bool use_cs;
81 
83  {
84  CB::cb_label.delete_label(&cb_label);
85  a_s.delete_v();
86 
87  for (size_t a = 0; a < num_actions; ++a)
88  {
90  }
91  free(csls);
92  free(cbls);
93 
94  for (size_t a = 0; a < num_actions; ++a)
95  {
96  ecs[a]->pred.a_s.delete_v();
98  free_it(ecs[a]);
99  }
100 
101  a_s_adf.delete_v();
102  for (size_t i = 0; i < ws_vali.size(); ++i)
103  {
104  if (use_cs)
106  else
107  VW::dealloc_example(MULTICLASS::mc_label.delete_label, *ws_vali[i]);
108  free(ws_vali[i]);
109  }
110  }
111 };
112 
113 float loss(warm_cb& data, uint32_t label, uint32_t final_prediction)
114 {
115  if (label != final_prediction)
116  return data.loss1;
117  else
118  return data.loss0;
119 }
120 
121 float loss_cs(warm_cb& data, v_array<COST_SENSITIVE::wclass>& costs, uint32_t final_prediction)
122 {
123  float cost = 0.;
124  for (auto wc : costs)
125  {
126  if (wc.class_index == final_prediction)
127  {
128  cost = wc.x;
129  break;
130  }
131  }
132  return data.loss0 + (data.loss1 - data.loss0) * cost;
133 }
134 
135 template <class T>
136 uint32_t find_min(std::vector<T> arr)
137 {
138  T min_val = FLT_MAX;
139  uint32_t argmin = 0;
140 
141  for (uint32_t i = 0; i < arr.size(); i++)
142  {
143  if (arr[i] < min_val)
144  {
145  min_val = arr[i];
146  argmin = i;
147  }
148  }
149  return argmin;
150 }
151 
152 void finish(warm_cb& data)
153 {
154  uint32_t argmin = find_min(data.cumulative_costs);
155 
156  if (!data.all->quiet)
157  {
158  std::cerr << "average variance estimate = " << data.cumu_var / data.inter_iter << std::endl;
159  std::cerr << "theoretical average variance = " << data.num_actions / data.epsilon << std::endl;
160  std::cerr << "last lambda chosen = " << data.lambdas[argmin] << " among lambdas ranging from " << data.lambdas[0]
161  << " to " << data.lambdas[data.choices_lambda - 1] << std::endl;
162  }
163 }
164 
166 {
167  const uint64_t ss = data.all->weights.stride_shift();
168  const uint64_t mask = data.all->weights.mask();
169 
170  for (size_t a = 0; a < data.num_actions; ++a)
171  {
172  auto& eca = *data.ecs[a];
173  // clear label
174  auto& lab = eca.l.cb;
176 
177  // copy data
178  VW::copy_example_data(false, &eca, &ec);
179 
180  // offset indicies for given action
181  for (features& fs : eca)
182  {
183  for (feature_index& idx : fs.indicies)
184  {
185  idx = ((((idx >> ss) * 28904713) + 4832917 * (uint64_t)a) << ss) & mask;
186  }
187  }
188 
189  // avoid empty example by adding a tag (hacky)
191  {
192  eca.tag.push_back('n');
193  }
194  }
195 }
196 
197 // Changing the minimax value from eps/(K+eps)
198 // to eps/(1+eps) to accomodate for
199 // weight scaling of bandit examples by factor 1/K in mtr reduction
200 float minimax_lambda(float epsilon) { return epsilon / (1.0f + epsilon); }
201 
203 {
204  // The lambdas are arranged in ascending order
205  std::vector<float>& lambdas = data.lambdas;
206  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas.push_back(0.f);
207 
208  // interaction only: set all lambda's to be identically 1
209  if (!data.upd_ws && data.upd_inter)
210  {
211  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas[i] = 1.0;
212  return;
213  }
214 
215  // warm start only: set all lambda's to be identically 0
216  if (!data.upd_inter && data.upd_ws)
217  {
218  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas[i] = 0.0;
219  return;
220  }
221 
222  uint32_t mid = data.choices_lambda / 2;
223 
225  lambdas[mid] = 0.5;
226  else
227  lambdas[mid] = minimax_lambda(data.epsilon);
228 
229  for (uint32_t i = mid; i > 0; i--) lambdas[i - 1] = lambdas[i] / 2.0f;
230 
231  for (uint32_t i = mid + 1; i < data.choices_lambda; i++) lambdas[i] = 1.f - (1.f - lambdas[i - 1]) / 2.0f;
232 
234  {
235  lambdas[0] = 0.0;
236  lambdas[data.choices_lambda - 1] = 1.0;
237  }
238 }
239 
241 {
242  float randf = data._random_state->get_and_update_random();
243 
244  for (uint32_t i = 1; i <= data.num_actions; i++)
245  {
246  if (randf <= float(i) / data.num_actions)
247  return i;
248  }
249  return data.num_actions;
250 }
251 
252 uint32_t corrupt_action(warm_cb& data, uint32_t action, int ec_type)
253 {
254  float cor_prob = 0.;
255  uint32_t cor_type = UAR;
256  uint32_t cor_action;
257 
258  if (ec_type == WARM_START)
259  {
260  cor_prob = data.cor_prob_ws;
261  cor_type = data.cor_type_ws;
262  }
263 
264  float randf = data._random_state->get_and_update_random();
265  if (randf < cor_prob)
266  {
267  if (cor_type == UAR)
268  cor_action = generate_uar_action(data);
269  else if (cor_type == OVERWRITE)
270  cor_action = data.overwrite_label;
271  else
272  cor_action = (action % data.num_actions) + 1;
273  }
274  else
275  cor_action = action;
276  return cor_action;
277 }
278 
279 bool ind_update(warm_cb& data, int ec_type)
280 {
281  if (ec_type == WARM_START)
282  return data.upd_ws;
283  else
284  return data.upd_inter;
285 }
286 
287 float compute_weight_multiplier(warm_cb& data, size_t i, int ec_type)
288 {
289  float weight_multiplier;
290  float ws_train_size = (float)data.ws_train_size;
291  float inter_train_size = (float)data.inter_period;
292  float total_train_size = ws_train_size + inter_train_size;
293  float total_weight = (1 - data.lambdas[i]) * ws_train_size + data.lambdas[i] * inter_train_size;
294 
295  if (ec_type == WARM_START)
296  weight_multiplier = (1 - data.lambdas[i]) * total_train_size / (total_weight + FLT_MIN);
297  else
298  weight_multiplier = data.lambdas[i] * total_train_size / (total_weight + FLT_MIN);
299 
300  return weight_multiplier;
301 }
302 
303 uint32_t predict_sublearner_adf(warm_cb& data, multi_learner& base, example& ec, uint32_t i)
304 {
305  copy_example_to_adf(data, ec);
306  base.predict(data.ecs, i);
307  return data.ecs[0]->pred.a_s[0].action + 1;
308 }
309 
311 {
312  CB::cb_class& cl = data.cl_adf;
313  // IPS for approximating the cumulative costs for all lambdas
314  for (uint32_t i = 0; i < data.choices_lambda; i++)
315  {
316  uint32_t action = predict_sublearner_adf(data, base, ec, i);
317 
318  if (action == cl.action)
319  data.cumulative_costs[i] += cl.cost / cl.probability;
320  }
321 }
322 
323 template <bool use_cs>
324 void add_to_vali(warm_cb& data, example& ec)
325 {
326  // TODO: set the first parameter properly
327  example* ec_copy = VW::alloc_examples(sizeof(polylabel), 1);
328 
329  if (use_cs)
331  else
332  VW::copy_example_data(false, ec_copy, &ec, 0, MULTICLASS::mc_label.copy_label);
333 
334  data.ws_vali.push_back(ec_copy);
335 }
336 
337 uint32_t predict_sup_adf(warm_cb& data, multi_learner& base, example& ec)
338 {
339  uint32_t argmin = find_min(data.cumulative_costs);
340  return predict_sublearner_adf(data, base, ec, argmin);
341 }
342 
343 template <bool use_cs>
344 void learn_sup_adf(warm_cb& data, example& ec, int ec_type)
345 {
346  copy_example_to_adf(data, ec);
347  // generate cost-sensitive label (for cost-sensitive learner's temporary use)
348  auto& csls = data.csls;
349  auto& cbls = data.cbls;
350  for (uint32_t a = 0; a < data.num_actions; ++a)
351  {
352  csls[a].costs[0].class_index = a + 1;
353  if (use_cs)
354  csls[a].costs[0].x = loss_cs(data, ec.l.cs.costs, a + 1);
355  else
356  csls[a].costs[0].x = loss(data, ec.l.multi.label, a + 1);
357  }
358  for (size_t a = 0; a < data.num_actions; ++a)
359  {
360  cbls[a] = data.ecs[a]->l.cb;
361  data.ecs[a]->l.cs = csls[a];
362  }
363 
364  std::vector<float> old_weights;
365  for (size_t a = 0; a < data.num_actions; ++a) old_weights.push_back(data.ecs[a]->weight);
366 
367  for (uint32_t i = 0; i < data.choices_lambda; i++)
368  {
369  float weight_multiplier = compute_weight_multiplier(data, i, ec_type);
370  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a] * weight_multiplier;
371  multi_learner* cs_learner = as_multiline(data.all->cost_sensitive);
372  cs_learner->learn(data.ecs, i);
373  }
374 
375  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a];
376 
377  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->l.cb = cbls[a];
378 }
379 
380 template <bool use_cs>
381 void predict_or_learn_sup_adf(warm_cb& data, multi_learner& base, example& ec, int ec_type)
382 {
383  uint32_t action = predict_sup_adf(data, base, ec);
384 
385  if (ind_update(data, ec_type))
386  learn_sup_adf<use_cs>(data, ec, ec_type);
387 
388  ec.pred.multiclass = action;
389 }
390 
391 uint32_t predict_bandit_adf(warm_cb& data, multi_learner& base, example& ec)
392 {
393  uint32_t argmin = find_min(data.cumulative_costs);
394 
395  copy_example_to_adf(data, ec);
396  base.predict(data.ecs, argmin);
397 
398  auto& out_ec = *data.ecs[0];
399  uint32_t chosen_action;
400  if (sample_after_normalizing(data.app_seed + data.example_counter++, begin_scores(out_ec.pred.a_s),
401  end_scores(out_ec.pred.a_s), chosen_action))
402  THROW("Failed to sample from pdf");
403 
404  auto& a_s = data.a_s_adf;
405  copy_array<action_score>(a_s, out_ec.pred.a_s);
406 
407  return chosen_action;
408 }
409 
410 void learn_bandit_adf(warm_cb& data, multi_learner& base, example& ec, int ec_type)
411 {
412  copy_example_to_adf(data, ec);
413 
414  // add cb label to chosen action
415  auto& cl = data.cl_adf;
416  auto& lab = data.ecs[cl.action - 1]->l.cb;
417  lab.costs.push_back(cl);
418 
419  std::vector<float> old_weights;
420  for (size_t a = 0; a < data.num_actions; ++a) old_weights.push_back(data.ecs[a]->weight);
421 
422  for (uint32_t i = 0; i < data.choices_lambda; i++)
423  {
424  float weight_multiplier = compute_weight_multiplier(data, i, ec_type);
425  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a] * weight_multiplier;
426  base.learn(data.ecs, i);
427  }
428 
429  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a];
430 }
431 
432 template <bool use_cs>
433 void predict_or_learn_bandit_adf(warm_cb& data, multi_learner& base, example& ec, int ec_type)
434 {
435  uint32_t chosen_action = predict_bandit_adf(data, base, ec);
436 
437  auto& cl = data.cl_adf;
438  auto& a_s = data.a_s_adf;
439  cl.action = a_s[chosen_action].action + 1;
440  cl.probability = a_s[chosen_action].score;
441 
442  if (!cl.action)
443  THROW("No action with non-zero probability found!");
444 
445  if (use_cs)
446  cl.cost = loss_cs(data, ec.l.cs.costs, cl.action);
447  else
448  cl.cost = loss(data, ec.l.multi.label, cl.action);
449 
450  if (ec_type == INTERACTION)
451  accumu_costs_iv_adf(data, base, ec);
452 
453  if (ind_update(data, ec_type))
454  learn_bandit_adf(data, base, ec, ec_type);
455 
456  ec.pred.multiclass = cl.action;
457 }
458 
460 {
461  size_t pred_best_approx = predict_sup_adf(data, base, ec);
462  float temp_var = 0.f;
463 
464  for (size_t a = 0; a < data.num_actions; ++a)
465  if (pred_best_approx == data.a_s_adf[a].action + 1)
466  temp_var = 1.0f / data.a_s_adf[a].score;
467 
468  data.cumu_var += temp_var;
469 }
470 
471 template <bool is_learn, bool use_cs>
473 {
474  // Corrupt labels (only corrupting multiclass labels as of now)
475  if (use_cs)
476  data.cs_label = ec.l.cs;
477  else
478  {
479  data.mc_label = ec.l.multi;
480  if (data.ws_iter < data.ws_period)
481  ec.l.multi.label = corrupt_action(data, data.mc_label.label, WARM_START);
482  }
483 
484  // Warm start phase
485  if (data.ws_iter < data.ws_period)
486  {
487  if (data.ws_type == SUPERVISED_WS)
488  predict_or_learn_sup_adf<use_cs>(data, base, ec, WARM_START);
489  else if (data.ws_type == BANDIT_WS)
490  predict_or_learn_bandit_adf<use_cs>(data, base, ec, WARM_START);
491 
492  ec.weight = 0;
493  data.ws_iter++;
494  }
495  // Interaction phase
496  else if (data.inter_iter < data.inter_period)
497  {
498  predict_or_learn_bandit_adf<use_cs>(data, base, ec, INTERACTION);
499  accumu_var_adf(data, base, ec);
500  data.a_s_adf.clear();
501  data.inter_iter++;
502  }
503  // Skipping the rest of the examples
504  else
505  {
506  ec.weight = 0;
507  ec.pred.multiclass = 1;
508  }
509 
510  // Restore the original labels
511  if (use_cs)
512  ec.l.cs = data.cs_label;
513  else
514  ec.l.multi = data.mc_label;
515 }
516 
517 void init_adf_data(warm_cb& data, const uint32_t num_actions)
518 {
519  data.num_actions = num_actions;
520  if (data.sim_bandit)
521  data.ws_type = BANDIT_WS;
522  else
523  data.ws_type = SUPERVISED_WS;
524  data.ecs.resize(num_actions);
525  for (size_t a = 0; a < num_actions; ++a)
526  {
527  data.ecs[a] = VW::alloc_examples(CB::cb_label.label_size, 1);
528  auto& lab = data.ecs[a]->l.cb;
530  }
531 
532  // The rest of the initialization is for warm start CB
533  data.csls = calloc_or_throw<COST_SENSITIVE::label>(num_actions);
534  for (uint32_t a = 0; a < num_actions; ++a)
535  {
537  data.csls[a].costs.push_back({0, a + 1, 0, 0});
538  }
539  data.cbls = calloc_or_throw<CB::label>(num_actions);
540 
541  data.ws_train_size = data.ws_period;
542  data.ws_vali_size = 0;
543 
544  data.ws_iter = 0;
545  data.inter_iter = 0;
546 
547  setup_lambdas(data);
548  for (uint32_t i = 0; i < data.choices_lambda; i++) data.cumulative_costs.push_back(0.f);
549  data.cumu_var = 0.f;
550 }
551 
553 {
554  uint32_t num_actions = 0;
555  auto data = scoped_calloc_or_throw<warm_cb>();
556  bool use_cs;
557 
558  option_group_definition new_options("Make Multiclass into Warm-starting Contextual Bandit");
559 
560  new_options
561  .add(make_option("warm_cb", num_actions)
562  .keep()
563  .help("Convert multiclass on <k> classes into a contextual bandit problem"))
564  .add(make_option("warm_cb_cs", use_cs)
565  .help("consume cost-sensitive classification examples instead of multiclass"))
566  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
567  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"))
568  .add(make_option("warm_start", data->ws_period)
569  .default_value(0U)
570  .help("number of training examples for warm start phase"))
571  .add(make_option("epsilon", data->epsilon).keep().help("epsilon-greedy exploration"))
572  .add(make_option("interaction", data->inter_period)
573  .default_value(UINT32_MAX)
574  .help("number of examples for the interactive contextual bandit learning phase"))
575  .add(make_option("warm_start_update", data->upd_ws).help("indicator of warm start updates"))
576  .add(make_option("interaction_update", data->upd_inter).help("indicator of interaction updates"))
577  .add(make_option("corrupt_type_warm_start", data->cor_type_ws)
578  .default_value(UAR)
579  .help("type of label corruption in the warm start phase (1: uniformly at random, 2: circular, 3: "
580  "replacing with overwriting label)"))
581  .add(make_option("corrupt_prob_warm_start", data->cor_prob_ws)
582  .default_value(0.f)
583  .help("probability of label corruption in the warm start phase"))
584  .add(make_option("choices_lambda", data->choices_lambda)
585  .default_value(1U)
586  .help("the number of candidate lambdas to aggregate (lambda is the importance weight parameter between "
587  "the two sources)"))
588  .add(make_option("lambda_scheme", data->lambda_scheme)
589  .default_value(ABS_CENTRAL)
590  .help("The scheme for generating candidate lambda set (1: center lambda=0.5, 2: center lambda=0.5, min "
591  "lambda=0, max lambda=1, 3: center lambda=epsilon/(1+epsilon), 4: center "
592  "lambda=epsilon/(1+epsilon), min lambda=0, max lambda=1); the rest of candidate lambda values are "
593  "generated using a doubling scheme"))
594  .add(make_option("overwrite_label", data->overwrite_label)
595  .default_value(1U)
596  .help("the label used by type 3 corruptions (overwriting)"))
597  .add(make_option("sim_bandit", data->sim_bandit)
598  .help("simulate contextual bandit updates on warm start examples"));
599 
600  options.add_and_parse(new_options);
601 
602  if (use_cs && (options.was_supplied("corrupt_type_warm_start") || options.was_supplied("corrupt_prob_warm_start")))
603  {
604  THROW("label corruption on cost-sensitive examples not currently supported");
605  }
606 
607  if (!options.was_supplied("warm_cb"))
608  {
609  return nullptr;
610  }
611 
612  data->app_seed = uniform_hash("vw", 2, 0);
613  data->a_s = v_init<action_score>();
614  data->all = &all;
615  data->_random_state = all.get_random_state();
616  data->use_cs = use_cs;
617 
618  init_adf_data(*data.get(), num_actions);
619 
620  options.insert("cb_min_cost", std::to_string(data->loss0));
621  options.insert("cb_max_cost", std::to_string(data->loss1));
622 
623  if (options.was_supplied("baseline"))
624  {
625  std::stringstream ss;
626  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
627  options.insert("lr_multiplier", ss.str());
628  }
629 
631 
632  multi_learner* base = as_multiline(setup_base(options, all));
633  // Note: the current version of warm start CB can only support epsilon-greedy exploration
634  // We need to wait for the epsilon value to be passed from the base
635  // cb_explore learner, if there is one
636 
637  if (!options.was_supplied("epsilon"))
638  {
639  std::cerr << "Warning: no epsilon (greedy parameter) specified; resetting to 0.05" << std::endl;
640  data->epsilon = 0.05f;
641  }
642 
643  if (use_cs)
645  data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.p, data->choices_lambda);
646  else
648  data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.p, data->choices_lambda);
649 
650  l->set_finish(finish);
651  all.delete_prediction = nullptr;
652 
653  return make_base(*l);
654 }
float loss1
Definition: warm_cb.cc:46
void copy_label(void *dst, void *src)
Definition: cb.cc:104
uint32_t choices_lambda
Definition: warm_cb.cc:51
#define WARM_START
Definition: warm_cb.cc:19
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
void predict_or_learn_sup_adf(warm_cb &data, multi_learner &base, example &ec, int ec_type)
Definition: warm_cb.cc:381
uint32_t multiclass
Definition: example.h:49
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
uint32_t predict_sublearner_adf(warm_cb &data, multi_learner &base, example &ec, uint32_t i)
Definition: warm_cb.cc:303
parameters weights
Definition: global_data.h:537
void predict(E &ec, size_t i=0)
Definition: learner.h:169
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
void(* delete_prediction)(void *)
Definition: global_data.h:485
void learn_sup_adf(warm_cb &data, example &ec, int ec_type)
Definition: warm_cb.cc:344
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
label_parser cs_label
uint32_t num_actions
Definition: warm_cb.cc:64
void(* delete_label)(void *)
Definition: label_parser.h:16
uint32_t inter_iter
Definition: warm_cb.cc:75
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
float epsilon
Definition: warm_cb.cc:65
void predict_or_learn_bandit_adf(warm_cb &data, multi_learner &base, example &ec, int ec_type)
Definition: warm_cb.cc:433
void(* default_label)(void *)
Definition: label_parser.h:12
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
bool(* test_label)(void *)
Definition: label_parser.h:22
#define ABS_CENTRAL_ZEROONE
Definition: warm_cb.cc:31
uint32_t inter_period
Definition: warm_cb.cc:50
#define UAR
Definition: warm_cb.cc:26
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
the core definition of a set of features.
v_array< cb_class > costs
Definition: cb.h:27
void accumu_costs_iv_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:310
void learn_bandit_adf(warm_cb &data, multi_learner &base, example &ec, int ec_type)
Definition: warm_cb.cc:410
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void delete_label(void *v)
Definition: cb.cc:98
void finish(warm_cb &data)
Definition: warm_cb.cc:152
uint32_t action
Definition: search.h:19
void setup_lambdas(warm_cb &data)
Definition: warm_cb.cc:202
bool quiet
Definition: global_data.h:487
uint32_t overwrite_label
Definition: warm_cb.cc:59
virtual void add_and_parse(const option_group_definition &group)=0
void predict_or_learn_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:472
int wt_scheme
Definition: warm_cb.cc:57
std::vector< float > lambdas
Definition: warm_cb.cc:66
float minimax_lambda(float epsilon)
Definition: warm_cb.cc:200
std::shared_ptr< rand_state > _random_state
Definition: warm_cb.cc:43
CB::label cb_label
Definition: warm_cb.cc:37
#define MINIMAX_CENTRAL_ZEROONE
Definition: warm_cb.cc:33
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
action_scores a_s_adf
Definition: warm_cb.cc:67
void free_it(void *ptr)
Definition: memory.h:94
label_parser mc_label
Definition: multiclass.cc:93
parser * p
Definition: global_data.h:377
uint64_t app_seed
Definition: warm_cb.cc:38
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t predict_sup_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:337
bool upd_ws
Definition: warm_cb.cc:52
float compute_weight_multiplier(warm_cb &data, size_t i, int ec_type)
Definition: warm_cb.cc:287
~warm_cb()
Definition: warm_cb.cc:82
bool upd_inter
Definition: warm_cb.cc:53
float loss0
Definition: warm_cb.cc:45
uint32_t action
Definition: cb.h:18
#define ABS_CENTRAL
Definition: warm_cb.cc:30
uint32_t predict_bandit_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:391
COST_SENSITIVE::label cs
Definition: example.h:30
#define INTERACTION
Definition: warm_cb.cc:20
float probability
Definition: cb.h:19
uint32_t ws_period
Definition: warm_cb.cc:49
virtual bool was_supplied(const std::string &key)=0
void add_to_vali(warm_cb &data, example &ec)
Definition: warm_cb.cc:324
uint64_t feature_index
Definition: feature_group.h:21
std::vector< example * > ws_vali
Definition: warm_cb.cc:72
CB::cb_class cl_adf
Definition: warm_cb.cc:69
action_scores a_s
Definition: warm_cb.cc:39
COST_SENSITIVE::label * csls
Definition: warm_cb.cc:78
bool sim_bandit
Definition: warm_cb.cc:61
void copy_example_to_adf(warm_cb &data, example &ec)
Definition: warm_cb.cc:165
CB::label * cbls
Definition: warm_cb.cc:79
bool use_cs
Definition: warm_cb.cc:80
virtual void insert(const std::string &key, const std::string &value)=0
base_learner * warm_cb_setup(options_i &options, vw &all)
Definition: warm_cb.cc:552
int ws_type
Definition: warm_cb.cc:60
void accumu_var_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:459
option_group_definition & add(T &&op)
Definition: options.h:90
uint32_t ws_train_size
Definition: warm_cb.cc:70
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
uint32_t find_min(std::vector< T > arr)
Definition: warm_cb.cc:136
float cor_prob_ws
Definition: warm_cb.cc:55
#define SUPERVISED_WS
Definition: warm_cb.cc:23
int lambda_scheme
Definition: warm_cb.cc:58
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float loss(warm_cb &data, uint32_t label, uint32_t final_prediction)
Definition: warm_cb.cc:113
Definition: cb.h:25
float cost
Definition: cb.h:17
void set_finish(void(*f)(T &))
Definition: learner.h:265
COST_SENSITIVE::label cs_label
Definition: warm_cb.cc:77
MULTICLASS::label_t mc_label
Definition: warm_cb.cc:76
#define OVERWRITE
Definition: warm_cb.cc:28
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t stride_shift()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
int cor_type_ws
Definition: warm_cb.cc:54
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
#define BANDIT_WS
Definition: warm_cb.cc:24
float loss_cs(warm_cb &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
Definition: warm_cb.cc:121
void learn(E &ec, size_t i=0)
Definition: learner.h:160
multi_ex ecs
Definition: warm_cb.cc:44
v_array< wclass > costs
float weight
Definition: example.h:62
bool ind_update(warm_cb &data, int ec_type)
Definition: warm_cb.cc:279
uint32_t corrupt_action(warm_cb &data, uint32_t action, int ec_type)
Definition: warm_cb.cc:252
uint64_t mask()
#define THROW(args)
Definition: vw_exception.h:181
void init_adf_data(warm_cb &data, const uint32_t num_actions)
Definition: warm_cb.cc:517
int vali_method
Definition: warm_cb.cc:56
vw * all
Definition: warm_cb.cc:42
float f
Definition: cache.cc:40
float cumu_var
Definition: warm_cb.cc:73
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
uint32_t ws_vali_size
Definition: warm_cb.cc:71
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
uint32_t ws_iter
Definition: warm_cb.cc:74
uint32_t generate_uar_action(warm_cb &data)
Definition: warm_cb.cc:240
size_t example_counter
Definition: warm_cb.cc:41