Vowpal Wabbit
Classes | Macros | Functions
warm_cb.cc File Reference
#include <float.h>
#include "reductions.h"
#include "cb_algs.h"
#include "rand48.h"
#include "bs.h"
#include "vw.h"
#include "hash.h"
#include "explore.h"
#include "vw_exception.h"
#include <vector>
#include <memory>

Go to the source code of this file.

Classes

struct  warm_cb
 

Macros

#define WARM_START   1
 
#define INTERACTION   2
 
#define SKIP   3
 
#define SUPERVISED_WS   1
 
#define BANDIT_WS   2
 
#define UAR   1
 
#define CIRCULAR   2
 
#define OVERWRITE   3
 
#define ABS_CENTRAL   1
 
#define ABS_CENTRAL_ZEROONE   2
 
#define MINIMAX_CENTRAL   3
 
#define MINIMAX_CENTRAL_ZEROONE   4
 

Functions

float loss (warm_cb &data, uint32_t label, uint32_t final_prediction)
 
float loss_cs (warm_cb &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
 
template<class T >
uint32_t find_min (std::vector< T > arr)
 
void finish (warm_cb &data)
 
void copy_example_to_adf (warm_cb &data, example &ec)
 
float minimax_lambda (float epsilon)
 
void setup_lambdas (warm_cb &data)
 
uint32_t generate_uar_action (warm_cb &data)
 
uint32_t corrupt_action (warm_cb &data, uint32_t action, int ec_type)
 
bool ind_update (warm_cb &data, int ec_type)
 
float compute_weight_multiplier (warm_cb &data, size_t i, int ec_type)
 
uint32_t predict_sublearner_adf (warm_cb &data, multi_learner &base, example &ec, uint32_t i)
 
void accumu_costs_iv_adf (warm_cb &data, multi_learner &base, example &ec)
 
template<bool use_cs>
void add_to_vali (warm_cb &data, example &ec)
 
uint32_t predict_sup_adf (warm_cb &data, multi_learner &base, example &ec)
 
template<bool use_cs>
void learn_sup_adf (warm_cb &data, example &ec, int ec_type)
 
template<bool use_cs>
void predict_or_learn_sup_adf (warm_cb &data, multi_learner &base, example &ec, int ec_type)
 
uint32_t predict_bandit_adf (warm_cb &data, multi_learner &base, example &ec)
 
void learn_bandit_adf (warm_cb &data, multi_learner &base, example &ec, int ec_type)
 
template<bool use_cs>
void predict_or_learn_bandit_adf (warm_cb &data, multi_learner &base, example &ec, int ec_type)
 
void accumu_var_adf (warm_cb &data, multi_learner &base, example &ec)
 
template<bool is_learn, bool use_cs>
void predict_or_learn_adf (warm_cb &data, multi_learner &base, example &ec)
 
void init_adf_data (warm_cb &data, const uint32_t num_actions)
 
base_learnerwarm_cb_setup (options_i &options, vw &all)
 

Macro Definition Documentation

◆ ABS_CENTRAL

#define ABS_CENTRAL   1

Definition at line 30 of file warm_cb.cc.

Referenced by setup_lambdas(), and warm_cb_setup().

◆ ABS_CENTRAL_ZEROONE

#define ABS_CENTRAL_ZEROONE   2

Definition at line 31 of file warm_cb.cc.

Referenced by setup_lambdas().

◆ BANDIT_WS

#define BANDIT_WS   2

Definition at line 24 of file warm_cb.cc.

Referenced by init_adf_data(), and predict_or_learn_adf().

◆ CIRCULAR

#define CIRCULAR   2

Definition at line 27 of file warm_cb.cc.

◆ INTERACTION

#define INTERACTION   2

Definition at line 20 of file warm_cb.cc.

Referenced by predict_or_learn_adf(), and predict_or_learn_bandit_adf().

◆ MINIMAX_CENTRAL

#define MINIMAX_CENTRAL   3

Definition at line 32 of file warm_cb.cc.

◆ MINIMAX_CENTRAL_ZEROONE

#define MINIMAX_CENTRAL_ZEROONE   4

Definition at line 33 of file warm_cb.cc.

Referenced by setup_lambdas().

◆ OVERWRITE

#define OVERWRITE   3

Definition at line 28 of file warm_cb.cc.

Referenced by corrupt_action().

◆ SKIP

#define SKIP   3

Definition at line 21 of file warm_cb.cc.

◆ SUPERVISED_WS

#define SUPERVISED_WS   1

Definition at line 23 of file warm_cb.cc.

Referenced by init_adf_data(), and predict_or_learn_adf().

◆ UAR

#define UAR   1

Definition at line 26 of file warm_cb.cc.

Referenced by corrupt_action(), and warm_cb_setup().

◆ WARM_START

#define WARM_START   1

Function Documentation

◆ accumu_costs_iv_adf()

void accumu_costs_iv_adf ( warm_cb data,
multi_learner base,
example ec 
)

Definition at line 310 of file warm_cb.cc.

References CB::cb_class::action, warm_cb::choices_lambda, warm_cb::cl_adf, CB::cb_class::cost, warm_cb::cumulative_costs, predict_sublearner_adf(), and CB::cb_class::probability.

Referenced by predict_or_learn_bandit_adf().

311 {
312  CB::cb_class& cl = data.cl_adf;
313  // IPS for approximating the cumulative costs for all lambdas
314  for (uint32_t i = 0; i < data.choices_lambda; i++)
315  {
316  uint32_t action = predict_sublearner_adf(data, base, ec, i);
317 
318  if (action == cl.action)
319  data.cumulative_costs[i] += cl.cost / cl.probability;
320  }
321 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
uint32_t predict_sublearner_adf(warm_cb &data, multi_learner &base, example &ec, uint32_t i)
Definition: warm_cb.cc:303
uint32_t action
Definition: search.h:19
uint32_t action
Definition: cb.h:18
float probability
Definition: cb.h:19
CB::cb_class cl_adf
Definition: warm_cb.cc:69
float cost
Definition: cb.h:17

◆ accumu_var_adf()

void accumu_var_adf ( warm_cb data,
multi_learner base,
example ec 
)

Definition at line 459 of file warm_cb.cc.

References a, warm_cb::a_s_adf, warm_cb::cumu_var, warm_cb::num_actions, and predict_sup_adf().

Referenced by predict_or_learn_adf().

460 {
461  size_t pred_best_approx = predict_sup_adf(data, base, ec);
462  float temp_var = 0.f;
463 
464  for (size_t a = 0; a < data.num_actions; ++a)
465  if (pred_best_approx == data.a_s_adf[a].action + 1)
466  temp_var = 1.0f / data.a_s_adf[a].score;
467 
468  data.cumu_var += temp_var;
469 }
uint32_t num_actions
Definition: warm_cb.cc:64
action_scores a_s_adf
Definition: warm_cb.cc:67
uint32_t predict_sup_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:337
constexpr uint64_t a
Definition: rand48.cc:11
float cumu_var
Definition: warm_cb.cc:73

◆ add_to_vali()

template<bool use_cs>
void add_to_vali ( warm_cb data,
example ec 
)

Definition at line 324 of file warm_cb.cc.

References VW::alloc_examples(), VW::copy_example_data(), CB::copy_label(), COST_SENSITIVE::cs_label, MULTICLASS::mc_label, and warm_cb::ws_vali.

325 {
326  // TODO: set the first parameter properly
327  example* ec_copy = VW::alloc_examples(sizeof(polylabel), 1);
328 
329  if (use_cs)
331  else
332  VW::copy_example_data(false, ec_copy, &ec, 0, MULTICLASS::mc_label.copy_label);
333 
334  data.ws_vali.push_back(ec_copy);
335 }
void copy_label(void *dst, void *src)
Definition: cb.cc:104
label_parser cs_label
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
label_parser mc_label
Definition: multiclass.cc:93
std::vector< example * > ws_vali
Definition: warm_cb.cc:72

◆ compute_weight_multiplier()

float compute_weight_multiplier ( warm_cb data,
size_t  i,
int  ec_type 
)

Definition at line 287 of file warm_cb.cc.

References warm_cb::inter_period, warm_cb::lambdas, WARM_START, and warm_cb::ws_train_size.

Referenced by learn_bandit_adf(), and learn_sup_adf().

288 {
289  float weight_multiplier;
290  float ws_train_size = (float)data.ws_train_size;
291  float inter_train_size = (float)data.inter_period;
292  float total_train_size = ws_train_size + inter_train_size;
293  float total_weight = (1 - data.lambdas[i]) * ws_train_size + data.lambdas[i] * inter_train_size;
294 
295  if (ec_type == WARM_START)
296  weight_multiplier = (1 - data.lambdas[i]) * total_train_size / (total_weight + FLT_MIN);
297  else
298  weight_multiplier = data.lambdas[i] * total_train_size / (total_weight + FLT_MIN);
299 
300  return weight_multiplier;
301 }
#define WARM_START
Definition: warm_cb.cc:19
uint32_t inter_period
Definition: warm_cb.cc:50
std::vector< float > lambdas
Definition: warm_cb.cc:66
uint32_t ws_train_size
Definition: warm_cb.cc:70

◆ copy_example_to_adf()

void copy_example_to_adf ( warm_cb data,
example ec 
)

Definition at line 165 of file warm_cb.cc.

References a, warm_cb::all, CB::cb_label, VW::copy_example_data(), label_parser::default_label, warm_cb::ecs, CB_ALGS::example_is_newline_not_header(), parameters::mask(), warm_cb::num_actions, parameters::stride_shift(), label_parser::test_label, and vw::weights.

Referenced by learn_bandit_adf(), learn_sup_adf(), predict_bandit_adf(), and predict_sublearner_adf().

166 {
167  const uint64_t ss = data.all->weights.stride_shift();
168  const uint64_t mask = data.all->weights.mask();
169 
170  for (size_t a = 0; a < data.num_actions; ++a)
171  {
172  auto& eca = *data.ecs[a];
173  // clear label
174  auto& lab = eca.l.cb;
176 
177  // copy data
178  VW::copy_example_data(false, &eca, &ec);
179 
180  // offset indicies for given action
181  for (features& fs : eca)
182  {
183  for (feature_index& idx : fs.indicies)
184  {
185  idx = ((((idx >> ss) * 28904713) + 4832917 * (uint64_t)a) << ss) & mask;
186  }
187  }
188 
189  // avoid empty example by adding a tag (hacky)
191  {
192  eca.tag.push_back('n');
193  }
194  }
195 }
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
parameters weights
Definition: global_data.h:537
uint32_t num_actions
Definition: warm_cb.cc:64
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
void(* default_label)(void *)
Definition: label_parser.h:12
bool(* test_label)(void *)
Definition: label_parser.h:22
the core definition of a set of features.
uint64_t feature_index
Definition: feature_group.h:21
label_parser cb_label
Definition: cb.cc:167
constexpr uint64_t a
Definition: rand48.cc:11
uint32_t stride_shift()
multi_ex ecs
Definition: warm_cb.cc:44
uint64_t mask()
vw * all
Definition: warm_cb.cc:42

◆ corrupt_action()

uint32_t corrupt_action ( warm_cb data,
uint32_t  action,
int  ec_type 
)

Definition at line 252 of file warm_cb.cc.

References warm_cb::_random_state, warm_cb::cor_prob_ws, warm_cb::cor_type_ws, generate_uar_action(), warm_cb::num_actions, OVERWRITE, warm_cb::overwrite_label, UAR, and WARM_START.

Referenced by predict_or_learn_adf().

253 {
254  float cor_prob = 0.;
255  uint32_t cor_type = UAR;
256  uint32_t cor_action;
257 
258  if (ec_type == WARM_START)
259  {
260  cor_prob = data.cor_prob_ws;
261  cor_type = data.cor_type_ws;
262  }
263 
264  float randf = data._random_state->get_and_update_random();
265  if (randf < cor_prob)
266  {
267  if (cor_type == UAR)
268  cor_action = generate_uar_action(data);
269  else if (cor_type == OVERWRITE)
270  cor_action = data.overwrite_label;
271  else
272  cor_action = (action % data.num_actions) + 1;
273  }
274  else
275  cor_action = action;
276  return cor_action;
277 }
#define WARM_START
Definition: warm_cb.cc:19
uint32_t num_actions
Definition: warm_cb.cc:64
#define UAR
Definition: warm_cb.cc:26
uint32_t action
Definition: search.h:19
uint32_t overwrite_label
Definition: warm_cb.cc:59
std::shared_ptr< rand_state > _random_state
Definition: warm_cb.cc:43
float cor_prob_ws
Definition: warm_cb.cc:55
#define OVERWRITE
Definition: warm_cb.cc:28
int cor_type_ws
Definition: warm_cb.cc:54
uint32_t generate_uar_action(warm_cb &data)
Definition: warm_cb.cc:240

◆ find_min()

template<class T >
uint32_t find_min ( std::vector< T >  arr)

Definition at line 136 of file warm_cb.cc.

Referenced by finish(), predict_bandit_adf(), and predict_sup_adf().

137 {
138  T min_val = FLT_MAX;
139  uint32_t argmin = 0;
140 
141  for (uint32_t i = 0; i < arr.size(); i++)
142  {
143  if (arr[i] < min_val)
144  {
145  min_val = arr[i];
146  argmin = i;
147  }
148  }
149  return argmin;
150 }

◆ finish()

void finish ( warm_cb data)

Definition at line 152 of file warm_cb.cc.

References warm_cb::all, warm_cb::choices_lambda, warm_cb::cumu_var, warm_cb::cumulative_costs, warm_cb::epsilon, find_min(), warm_cb::inter_iter, warm_cb::lambdas, warm_cb::num_actions, and vw::quiet.

Referenced by warm_cb_setup().

153 {
154  uint32_t argmin = find_min(data.cumulative_costs);
155 
156  if (!data.all->quiet)
157  {
158  std::cerr << "average variance estimate = " << data.cumu_var / data.inter_iter << std::endl;
159  std::cerr << "theoretical average variance = " << data.num_actions / data.epsilon << std::endl;
160  std::cerr << "last lambda chosen = " << data.lambdas[argmin] << " among lambdas ranging from " << data.lambdas[0]
161  << " to " << data.lambdas[data.choices_lambda - 1] << std::endl;
162  }
163 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
uint32_t num_actions
Definition: warm_cb.cc:64
uint32_t inter_iter
Definition: warm_cb.cc:75
float epsilon
Definition: warm_cb.cc:65
bool quiet
Definition: global_data.h:487
std::vector< float > lambdas
Definition: warm_cb.cc:66
uint32_t find_min(std::vector< T > arr)
Definition: warm_cb.cc:136
vw * all
Definition: warm_cb.cc:42
float cumu_var
Definition: warm_cb.cc:73

◆ generate_uar_action()

uint32_t generate_uar_action ( warm_cb data)

Definition at line 240 of file warm_cb.cc.

References warm_cb::_random_state, and warm_cb::num_actions.

Referenced by corrupt_action().

241 {
242  float randf = data._random_state->get_and_update_random();
243 
244  for (uint32_t i = 1; i <= data.num_actions; i++)
245  {
246  if (randf <= float(i) / data.num_actions)
247  return i;
248  }
249  return data.num_actions;
250 }
uint32_t num_actions
Definition: warm_cb.cc:64
std::shared_ptr< rand_state > _random_state
Definition: warm_cb.cc:43

◆ ind_update()

bool ind_update ( warm_cb data,
int  ec_type 
)

Definition at line 279 of file warm_cb.cc.

References warm_cb::upd_inter, warm_cb::upd_ws, and WARM_START.

Referenced by predict_or_learn_bandit_adf(), and predict_or_learn_sup_adf().

280 {
281  if (ec_type == WARM_START)
282  return data.upd_ws;
283  else
284  return data.upd_inter;
285 }
#define WARM_START
Definition: warm_cb.cc:19
bool upd_ws
Definition: warm_cb.cc:52
bool upd_inter
Definition: warm_cb.cc:53

◆ init_adf_data()

void init_adf_data ( warm_cb data,
const uint32_t  num_actions 
)

Definition at line 517 of file warm_cb.cc.

References a, VW::alloc_examples(), BANDIT_WS, CB::cb_label, warm_cb::cbls, warm_cb::choices_lambda, COST_SENSITIVE::label::costs, COST_SENSITIVE::cs_label, warm_cb::csls, warm_cb::cumu_var, warm_cb::cumulative_costs, label_parser::default_label, warm_cb::ecs, warm_cb::inter_iter, warm_cb::num_actions, setup_lambdas(), warm_cb::sim_bandit, SUPERVISED_WS, warm_cb::ws_iter, warm_cb::ws_period, warm_cb::ws_train_size, warm_cb::ws_type, and warm_cb::ws_vali_size.

Referenced by warm_cb_setup().

518 {
519  data.num_actions = num_actions;
520  if (data.sim_bandit)
521  data.ws_type = BANDIT_WS;
522  else
523  data.ws_type = SUPERVISED_WS;
524  data.ecs.resize(num_actions);
525  for (size_t a = 0; a < num_actions; ++a)
526  {
527  data.ecs[a] = VW::alloc_examples(CB::cb_label.label_size, 1);
528  auto& lab = data.ecs[a]->l.cb;
530  }
531 
532  // The rest of the initialization is for warm start CB
533  data.csls = calloc_or_throw<COST_SENSITIVE::label>(num_actions);
534  for (uint32_t a = 0; a < num_actions; ++a)
535  {
537  data.csls[a].costs.push_back({0, a + 1, 0, 0});
538  }
539  data.cbls = calloc_or_throw<CB::label>(num_actions);
540 
541  data.ws_train_size = data.ws_period;
542  data.ws_vali_size = 0;
543 
544  data.ws_iter = 0;
545  data.inter_iter = 0;
546 
547  setup_lambdas(data);
548  for (uint32_t i = 0; i < data.choices_lambda; i++) data.cumulative_costs.push_back(0.f);
549  data.cumu_var = 0.f;
550 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
label_parser cs_label
uint32_t num_actions
Definition: warm_cb.cc:64
uint32_t inter_iter
Definition: warm_cb.cc:75
void(* default_label)(void *)
Definition: label_parser.h:12
void setup_lambdas(warm_cb &data)
Definition: warm_cb.cc:202
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
uint32_t ws_period
Definition: warm_cb.cc:49
COST_SENSITIVE::label * csls
Definition: warm_cb.cc:78
bool sim_bandit
Definition: warm_cb.cc:61
CB::label * cbls
Definition: warm_cb.cc:79
int ws_type
Definition: warm_cb.cc:60
uint32_t ws_train_size
Definition: warm_cb.cc:70
label_parser cb_label
Definition: cb.cc:167
constexpr uint64_t a
Definition: rand48.cc:11
#define SUPERVISED_WS
Definition: warm_cb.cc:23
#define BANDIT_WS
Definition: warm_cb.cc:24
multi_ex ecs
Definition: warm_cb.cc:44
v_array< wclass > costs
float cumu_var
Definition: warm_cb.cc:73
uint32_t ws_vali_size
Definition: warm_cb.cc:71
uint32_t ws_iter
Definition: warm_cb.cc:74

◆ learn_bandit_adf()

void learn_bandit_adf ( warm_cb data,
multi_learner base,
example ec,
int  ec_type 
)

Definition at line 410 of file warm_cb.cc.

References a, warm_cb::choices_lambda, warm_cb::cl_adf, compute_weight_multiplier(), copy_example_to_adf(), warm_cb::ecs, LEARNER::learner< T, E >::learn(), and warm_cb::num_actions.

Referenced by predict_or_learn_bandit_adf().

411 {
412  copy_example_to_adf(data, ec);
413 
414  // add cb label to chosen action
415  auto& cl = data.cl_adf;
416  auto& lab = data.ecs[cl.action - 1]->l.cb;
417  lab.costs.push_back(cl);
418 
419  std::vector<float> old_weights;
420  for (size_t a = 0; a < data.num_actions; ++a) old_weights.push_back(data.ecs[a]->weight);
421 
422  for (uint32_t i = 0; i < data.choices_lambda; i++)
423  {
424  float weight_multiplier = compute_weight_multiplier(data, i, ec_type);
425  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a] * weight_multiplier;
426  base.learn(data.ecs, i);
427  }
428 
429  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a];
430 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
uint32_t num_actions
Definition: warm_cb.cc:64
float compute_weight_multiplier(warm_cb &data, size_t i, int ec_type)
Definition: warm_cb.cc:287
CB::cb_class cl_adf
Definition: warm_cb.cc:69
void copy_example_to_adf(warm_cb &data, example &ec)
Definition: warm_cb.cc:165
constexpr uint64_t a
Definition: rand48.cc:11
void learn(E &ec, size_t i=0)
Definition: learner.h:160
multi_ex ecs
Definition: warm_cb.cc:44

◆ learn_sup_adf()

template<bool use_cs>
void learn_sup_adf ( warm_cb data,
example ec,
int  ec_type 
)

Definition at line 344 of file warm_cb.cc.

References a, warm_cb::all, LEARNER::as_multiline(), warm_cb::cbls, warm_cb::choices_lambda, compute_weight_multiplier(), copy_example_to_adf(), vw::cost_sensitive, CB::label::costs, COST_SENSITIVE::label::costs, polylabel::cs, warm_cb::csls, warm_cb::ecs, example::l, MULTICLASS::label_t::label, LEARNER::learner< T, E >::learn(), loss(), loss_cs(), polylabel::multi, and warm_cb::num_actions.

345 {
346  copy_example_to_adf(data, ec);
347  // generate cost-sensitive label (for cost-sensitive learner's temporary use)
348  auto& csls = data.csls;
349  auto& cbls = data.cbls;
350  for (uint32_t a = 0; a < data.num_actions; ++a)
351  {
352  csls[a].costs[0].class_index = a + 1;
353  if (use_cs)
354  csls[a].costs[0].x = loss_cs(data, ec.l.cs.costs, a + 1);
355  else
356  csls[a].costs[0].x = loss(data, ec.l.multi.label, a + 1);
357  }
358  for (size_t a = 0; a < data.num_actions; ++a)
359  {
360  cbls[a] = data.ecs[a]->l.cb;
361  data.ecs[a]->l.cs = csls[a];
362  }
363 
364  std::vector<float> old_weights;
365  for (size_t a = 0; a < data.num_actions; ++a) old_weights.push_back(data.ecs[a]->weight);
366 
367  for (uint32_t i = 0; i < data.choices_lambda; i++)
368  {
369  float weight_multiplier = compute_weight_multiplier(data, i, ec_type);
370  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a] * weight_multiplier;
371  multi_learner* cs_learner = as_multiline(data.all->cost_sensitive);
372  cs_learner->learn(data.ecs, i);
373  }
374 
375  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->weight = old_weights[a];
376 
377  for (size_t a = 0; a < data.num_actions; ++a) data.ecs[a]->l.cb = cbls[a];
378 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
uint32_t num_actions
Definition: warm_cb.cc:64
v_array< cb_class > costs
Definition: cb.h:27
MULTICLASS::label_t multi
Definition: example.h:29
float compute_weight_multiplier(warm_cb &data, size_t i, int ec_type)
Definition: warm_cb.cc:287
COST_SENSITIVE::label cs
Definition: example.h:30
COST_SENSITIVE::label * csls
Definition: warm_cb.cc:78
void copy_example_to_adf(warm_cb &data, example &ec)
Definition: warm_cb.cc:165
CB::label * cbls
Definition: warm_cb.cc:79
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
float loss(warm_cb &data, uint32_t label, uint32_t final_prediction)
Definition: warm_cb.cc:113
float loss_cs(warm_cb &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
Definition: warm_cb.cc:121
void learn(E &ec, size_t i=0)
Definition: learner.h:160
multi_ex ecs
Definition: warm_cb.cc:44
v_array< wclass > costs
vw * all
Definition: warm_cb.cc:42
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468

◆ loss()

float loss ( warm_cb data,
uint32_t  label,
uint32_t  final_prediction 
)

Definition at line 113 of file warm_cb.cc.

References warm_cb::loss0, and warm_cb::loss1.

Referenced by learn_sup_adf(), and predict_or_learn_bandit_adf().

114 {
115  if (label != final_prediction)
116  return data.loss1;
117  else
118  return data.loss0;
119 }
float loss1
Definition: warm_cb.cc:46
float loss0
Definition: warm_cb.cc:45

◆ loss_cs()

float loss_cs ( warm_cb data,
v_array< COST_SENSITIVE::wclass > &  costs,
uint32_t  final_prediction 
)

Definition at line 121 of file warm_cb.cc.

References warm_cb::loss0, and warm_cb::loss1.

Referenced by learn_sup_adf(), and predict_or_learn_bandit_adf().

122 {
123  float cost = 0.;
124  for (auto wc : costs)
125  {
126  if (wc.class_index == final_prediction)
127  {
128  cost = wc.x;
129  break;
130  }
131  }
132  return data.loss0 + (data.loss1 - data.loss0) * cost;
133 }
float loss1
Definition: warm_cb.cc:46
float loss0
Definition: warm_cb.cc:45

◆ minimax_lambda()

float minimax_lambda ( float  epsilon)

Definition at line 200 of file warm_cb.cc.

Referenced by setup_lambdas().

200 { return epsilon / (1.0f + epsilon); }

◆ predict_bandit_adf()

uint32_t predict_bandit_adf ( warm_cb data,
multi_learner base,
example ec 
)

Definition at line 391 of file warm_cb.cc.

References warm_cb::a_s_adf, warm_cb::app_seed, ACTION_SCORE::begin_scores(), copy_example_to_adf(), warm_cb::cumulative_costs, warm_cb::ecs, ACTION_SCORE::end_scores(), warm_cb::example_counter, find_min(), LEARNER::learner< T, E >::predict(), exploration::sample_after_normalizing(), and THROW.

Referenced by predict_or_learn_bandit_adf().

392 {
393  uint32_t argmin = find_min(data.cumulative_costs);
394 
395  copy_example_to_adf(data, ec);
396  base.predict(data.ecs, argmin);
397 
398  auto& out_ec = *data.ecs[0];
399  uint32_t chosen_action;
400  if (sample_after_normalizing(data.app_seed + data.example_counter++, begin_scores(out_ec.pred.a_s),
401  end_scores(out_ec.pred.a_s), chosen_action))
402  THROW("Failed to sample from pdf");
403 
404  auto& a_s = data.a_s_adf;
405  copy_array<action_score>(a_s, out_ec.pred.a_s);
406 
407  return chosen_action;
408 }
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
void predict(E &ec, size_t i=0)
Definition: learner.h:169
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
action_scores a_s_adf
Definition: warm_cb.cc:67
uint64_t app_seed
Definition: warm_cb.cc:38
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
void copy_example_to_adf(warm_cb &data, example &ec)
Definition: warm_cb.cc:165
uint32_t find_min(std::vector< T > arr)
Definition: warm_cb.cc:136
multi_ex ecs
Definition: warm_cb.cc:44
#define THROW(args)
Definition: vw_exception.h:181
size_t example_counter
Definition: warm_cb.cc:41

◆ predict_or_learn_adf()

template<bool is_learn, bool use_cs>
void predict_or_learn_adf ( warm_cb data,
multi_learner base,
example ec 
)

Definition at line 472 of file warm_cb.cc.

References accumu_var_adf(), BANDIT_WS, corrupt_action(), polylabel::cs, warm_cb::cs_label, warm_cb::inter_iter, warm_cb::inter_period, INTERACTION, example::l, MULTICLASS::label_t::label, warm_cb::mc_label, polylabel::multi, polyprediction::multiclass, example::pred, SUPERVISED_WS, WARM_START, example::weight, warm_cb::ws_iter, warm_cb::ws_period, and warm_cb::ws_type.

473 {
474  // Corrupt labels (only corrupting multiclass labels as of now)
475  if (use_cs)
476  data.cs_label = ec.l.cs;
477  else
478  {
479  data.mc_label = ec.l.multi;
480  if (data.ws_iter < data.ws_period)
481  ec.l.multi.label = corrupt_action(data, data.mc_label.label, WARM_START);
482  }
483 
484  // Warm start phase
485  if (data.ws_iter < data.ws_period)
486  {
487  if (data.ws_type == SUPERVISED_WS)
488  predict_or_learn_sup_adf<use_cs>(data, base, ec, WARM_START);
489  else if (data.ws_type == BANDIT_WS)
490  predict_or_learn_bandit_adf<use_cs>(data, base, ec, WARM_START);
491 
492  ec.weight = 0;
493  data.ws_iter++;
494  }
495  // Interaction phase
496  else if (data.inter_iter < data.inter_period)
497  {
498  predict_or_learn_bandit_adf<use_cs>(data, base, ec, INTERACTION);
499  accumu_var_adf(data, base, ec);
500  data.a_s_adf.clear();
501  data.inter_iter++;
502  }
503  // Skipping the rest of the examples
504  else
505  {
506  ec.weight = 0;
507  ec.pred.multiclass = 1;
508  }
509 
510  // Restore the original labels
511  if (use_cs)
512  ec.l.cs = data.cs_label;
513  else
514  ec.l.multi = data.mc_label;
515 }
#define WARM_START
Definition: warm_cb.cc:19
MULTICLASS::label_t multi
Definition: example.h:29
COST_SENSITIVE::label cs
Definition: example.h:30
#define INTERACTION
Definition: warm_cb.cc:20
uint32_t ws_period
Definition: warm_cb.cc:49
int ws_type
Definition: warm_cb.cc:60
void accumu_var_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:459
polylabel l
Definition: example.h:57
#define SUPERVISED_WS
Definition: warm_cb.cc:23
COST_SENSITIVE::label cs_label
Definition: warm_cb.cc:77
MULTICLASS::label_t mc_label
Definition: warm_cb.cc:76
#define BANDIT_WS
Definition: warm_cb.cc:24
uint32_t corrupt_action(warm_cb &data, uint32_t action, int ec_type)
Definition: warm_cb.cc:252
uint32_t ws_iter
Definition: warm_cb.cc:74

◆ predict_or_learn_bandit_adf()

template<bool use_cs>
void predict_or_learn_bandit_adf ( warm_cb data,
multi_learner base,
example ec,
int  ec_type 
)

Definition at line 433 of file warm_cb.cc.

References warm_cb::a_s_adf, accumu_costs_iv_adf(), warm_cb::cl_adf, COST_SENSITIVE::label::costs, polylabel::cs, ind_update(), INTERACTION, example::l, MULTICLASS::label_t::label, learn_bandit_adf(), loss(), loss_cs(), polylabel::multi, polyprediction::multiclass, example::pred, predict_bandit_adf(), and THROW.

434 {
435  uint32_t chosen_action = predict_bandit_adf(data, base, ec);
436 
437  auto& cl = data.cl_adf;
438  auto& a_s = data.a_s_adf;
439  cl.action = a_s[chosen_action].action + 1;
440  cl.probability = a_s[chosen_action].score;
441 
442  if (!cl.action)
443  THROW("No action with non-zero probability found!");
444 
445  if (use_cs)
446  cl.cost = loss_cs(data, ec.l.cs.costs, cl.action);
447  else
448  cl.cost = loss(data, ec.l.multi.label, cl.action);
449 
450  if (ec_type == INTERACTION)
451  accumu_costs_iv_adf(data, base, ec);
452 
453  if (ind_update(data, ec_type))
454  learn_bandit_adf(data, base, ec, ec_type);
455 
456  ec.pred.multiclass = cl.action;
457 }
uint32_t multiclass
Definition: example.h:49
void accumu_costs_iv_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:310
void learn_bandit_adf(warm_cb &data, multi_learner &base, example &ec, int ec_type)
Definition: warm_cb.cc:410
action_scores a_s_adf
Definition: warm_cb.cc:67
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t predict_bandit_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:391
COST_SENSITIVE::label cs
Definition: example.h:30
#define INTERACTION
Definition: warm_cb.cc:20
CB::cb_class cl_adf
Definition: warm_cb.cc:69
polylabel l
Definition: example.h:57
float loss(warm_cb &data, uint32_t label, uint32_t final_prediction)
Definition: warm_cb.cc:113
polyprediction pred
Definition: example.h:60
float loss_cs(warm_cb &data, v_array< COST_SENSITIVE::wclass > &costs, uint32_t final_prediction)
Definition: warm_cb.cc:121
v_array< wclass > costs
bool ind_update(warm_cb &data, int ec_type)
Definition: warm_cb.cc:279
#define THROW(args)
Definition: vw_exception.h:181

◆ predict_or_learn_sup_adf()

template<bool use_cs>
void predict_or_learn_sup_adf ( warm_cb data,
multi_learner base,
example ec,
int  ec_type 
)

Definition at line 381 of file warm_cb.cc.

References ind_update(), and predict_sup_adf().

382 {
383  uint32_t action = predict_sup_adf(data, base, ec);
384 
385  if (ind_update(data, ec_type))
386  learn_sup_adf<use_cs>(data, ec, ec_type);
387 
388  ec.pred.multiclass = action;
389 }
uint32_t action
Definition: search.h:19
uint32_t predict_sup_adf(warm_cb &data, multi_learner &base, example &ec)
Definition: warm_cb.cc:337
bool ind_update(warm_cb &data, int ec_type)
Definition: warm_cb.cc:279

◆ predict_sublearner_adf()

uint32_t predict_sublearner_adf ( warm_cb data,
multi_learner base,
example ec,
uint32_t  i 
)

Definition at line 303 of file warm_cb.cc.

References copy_example_to_adf(), warm_cb::ecs, and LEARNER::learner< T, E >::predict().

Referenced by accumu_costs_iv_adf(), and predict_sup_adf().

304 {
305  copy_example_to_adf(data, ec);
306  base.predict(data.ecs, i);
307  return data.ecs[0]->pred.a_s[0].action + 1;
308 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void copy_example_to_adf(warm_cb &data, example &ec)
Definition: warm_cb.cc:165
multi_ex ecs
Definition: warm_cb.cc:44

◆ predict_sup_adf()

uint32_t predict_sup_adf ( warm_cb data,
multi_learner base,
example ec 
)

Definition at line 337 of file warm_cb.cc.

References warm_cb::cumulative_costs, find_min(), and predict_sublearner_adf().

Referenced by accumu_var_adf(), and predict_or_learn_sup_adf().

338 {
339  uint32_t argmin = find_min(data.cumulative_costs);
340  return predict_sublearner_adf(data, base, ec, argmin);
341 }
std::vector< float > cumulative_costs
Definition: warm_cb.cc:68
uint32_t predict_sublearner_adf(warm_cb &data, multi_learner &base, example &ec, uint32_t i)
Definition: warm_cb.cc:303
uint32_t find_min(std::vector< T > arr)
Definition: warm_cb.cc:136

◆ setup_lambdas()

void setup_lambdas ( warm_cb data)

Definition at line 202 of file warm_cb.cc.

References ABS_CENTRAL, ABS_CENTRAL_ZEROONE, warm_cb::choices_lambda, warm_cb::epsilon, f, warm_cb::lambda_scheme, warm_cb::lambdas, MINIMAX_CENTRAL_ZEROONE, minimax_lambda(), warm_cb::upd_inter, and warm_cb::upd_ws.

Referenced by init_adf_data().

203 {
204  // The lambdas are arranged in ascending order
205  std::vector<float>& lambdas = data.lambdas;
206  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas.push_back(0.f);
207 
208  // interaction only: set all lambda's to be identically 1
209  if (!data.upd_ws && data.upd_inter)
210  {
211  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas[i] = 1.0;
212  return;
213  }
214 
215  // warm start only: set all lambda's to be identically 0
216  if (!data.upd_inter && data.upd_ws)
217  {
218  for (uint32_t i = 0; i < data.choices_lambda; i++) lambdas[i] = 0.0;
219  return;
220  }
221 
222  uint32_t mid = data.choices_lambda / 2;
223 
225  lambdas[mid] = 0.5;
226  else
227  lambdas[mid] = minimax_lambda(data.epsilon);
228 
229  for (uint32_t i = mid; i > 0; i--) lambdas[i - 1] = lambdas[i] / 2.0f;
230 
231  for (uint32_t i = mid + 1; i < data.choices_lambda; i++) lambdas[i] = 1.f - (1.f - lambdas[i - 1]) / 2.0f;
232 
234  {
235  lambdas[0] = 0.0;
236  lambdas[data.choices_lambda - 1] = 1.0;
237  }
238 }
uint32_t choices_lambda
Definition: warm_cb.cc:51
float epsilon
Definition: warm_cb.cc:65
#define ABS_CENTRAL_ZEROONE
Definition: warm_cb.cc:31
std::vector< float > lambdas
Definition: warm_cb.cc:66
float minimax_lambda(float epsilon)
Definition: warm_cb.cc:200
#define MINIMAX_CENTRAL_ZEROONE
Definition: warm_cb.cc:33
bool upd_ws
Definition: warm_cb.cc:52
bool upd_inter
Definition: warm_cb.cc:53
#define ABS_CENTRAL
Definition: warm_cb.cc:30
int lambda_scheme
Definition: warm_cb.cc:58
float f
Definition: cache.cc:40

◆ warm_cb_setup()

base_learner* warm_cb_setup ( options_i options,
vw all 
)

Definition at line 552 of file warm_cb.cc.

References ABS_CENTRAL, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), vw::delete_prediction, f, finish(), vw::get_random_state(), init_adf_data(), LEARNER::init_cost_sensitive_learner(), LEARNER::init_multiclass_learner(), VW::config::options_i::insert(), LEARNER::make_base(), VW::config::make_option(), vw::p, LEARNER::learner< T, E >::set_finish(), setup_base(), THROW, prediction_type::to_string(), UAR, uniform_hash(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

553 {
554  uint32_t num_actions = 0;
555  auto data = scoped_calloc_or_throw<warm_cb>();
556  bool use_cs;
557 
558  option_group_definition new_options("Make Multiclass into Warm-starting Contextual Bandit");
559 
560  new_options
561  .add(make_option("warm_cb", num_actions)
562  .keep()
563  .help("Convert multiclass on <k> classes into a contextual bandit problem"))
564  .add(make_option("warm_cb_cs", use_cs)
565  .help("consume cost-sensitive classification examples instead of multiclass"))
566  .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label"))
567  .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label"))
568  .add(make_option("warm_start", data->ws_period)
569  .default_value(0U)
570  .help("number of training examples for warm start phase"))
571  .add(make_option("epsilon", data->epsilon).keep().help("epsilon-greedy exploration"))
572  .add(make_option("interaction", data->inter_period)
573  .default_value(UINT32_MAX)
574  .help("number of examples for the interactive contextual bandit learning phase"))
575  .add(make_option("warm_start_update", data->upd_ws).help("indicator of warm start updates"))
576  .add(make_option("interaction_update", data->upd_inter).help("indicator of interaction updates"))
577  .add(make_option("corrupt_type_warm_start", data->cor_type_ws)
578  .default_value(UAR)
579  .help("type of label corruption in the warm start phase (1: uniformly at random, 2: circular, 3: "
580  "replacing with overwriting label)"))
581  .add(make_option("corrupt_prob_warm_start", data->cor_prob_ws)
582  .default_value(0.f)
583  .help("probability of label corruption in the warm start phase"))
584  .add(make_option("choices_lambda", data->choices_lambda)
585  .default_value(1U)
586  .help("the number of candidate lambdas to aggregate (lambda is the importance weight parameter between "
587  "the two sources)"))
588  .add(make_option("lambda_scheme", data->lambda_scheme)
589  .default_value(ABS_CENTRAL)
590  .help("The scheme for generating candidate lambda set (1: center lambda=0.5, 2: center lambda=0.5, min "
591  "lambda=0, max lambda=1, 3: center lambda=epsilon/(1+epsilon), 4: center "
592  "lambda=epsilon/(1+epsilon), min lambda=0, max lambda=1); the rest of candidate lambda values are "
593  "generated using a doubling scheme"))
594  .add(make_option("overwrite_label", data->overwrite_label)
595  .default_value(1U)
596  .help("the label used by type 3 corruptions (overwriting)"))
597  .add(make_option("sim_bandit", data->sim_bandit)
598  .help("simulate contextual bandit updates on warm start examples"));
599 
600  options.add_and_parse(new_options);
601 
602  if (use_cs && (options.was_supplied("corrupt_type_warm_start") || options.was_supplied("corrupt_prob_warm_start")))
603  {
604  THROW("label corruption on cost-sensitive examples not currently supported");
605  }
606 
607  if (!options.was_supplied("warm_cb"))
608  {
609  return nullptr;
610  }
611 
612  data->app_seed = uniform_hash("vw", 2, 0);
613  data->a_s = v_init<action_score>();
614  data->all = &all;
615  data->_random_state = all.get_random_state();
616  data->use_cs = use_cs;
617 
618  init_adf_data(*data.get(), num_actions);
619 
620  options.insert("cb_min_cost", std::to_string(data->loss0));
621  options.insert("cb_max_cost", std::to_string(data->loss1));
622 
623  if (options.was_supplied("baseline"))
624  {
625  std::stringstream ss;
626  ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
627  options.insert("lr_multiplier", ss.str());
628  }
629 
631 
632  multi_learner* base = as_multiline(setup_base(options, all));
633  // Note: the current version of warm start CB can only support epsilon-greedy exploration
634  // We need to wait for the epsilon value to be passed from the base
635  // cb_explore learner, if there is one
636 
637  if (!options.was_supplied("epsilon"))
638  {
639  std::cerr << "Warning: no epsilon (greedy parameter) specified; resetting to 0.05" << std::endl;
640  data->epsilon = 0.05f;
641  }
642 
643  if (use_cs)
645  data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.p, data->choices_lambda);
646  else
648  data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.p, data->choices_lambda);
649 
650  l->set_finish(finish);
651  all.delete_prediction = nullptr;
652 
653  return make_base(*l);
654 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
#define UAR
Definition: warm_cb.cc:26
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void finish(warm_cb &data)
Definition: warm_cb.cc:152
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
#define ABS_CENTRAL
Definition: warm_cb.cc:30
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_finish(void(*f)(T &))
Definition: learner.h:265
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
void init_adf_data(warm_cb &data, const uint32_t num_actions)
Definition: warm_cb.cc:517
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12