Vowpal Wabbit
boosting.cc
Go to the documentation of this file.
1 /*
2  Copyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5  */
6 
7 /*
8  * Implementation of online boosting algorithms from
9  * Beygelzimer, Kale, Luo: Optimal and adaptive algorithms for online boosting,
10  * ICML-2015.
11  */
12 
13 #include <float.h>
14 #include <limits.h>
15 #include <math.h>
16 #include "correctedMath.h"
17 #include <stdio.h>
18 #include <string>
19 #include <sstream>
20 #include <vector>
21 #include <memory>
22 
23 #include "reductions.h"
24 #include "vw.h"
25 #include "rand48.h"
26 
27 using namespace LEARNER;
28 using namespace VW::config;
29 
30 using std::cerr;
31 using std::endl;
32 
33 inline float sign(float w)
34 {
35  if (w <= 0.)
36  return -1.;
37  else
38  return 1.;
39 }
40 
41 int64_t choose(int64_t n, int64_t k)
42 {
43  if (k > n)
44  return 0;
45  if (k < 0)
46  return 0;
47  if (k == n)
48  return 1;
49  if (k == 0 && n != 0)
50  return 1;
51  int64_t r = 1;
52  for (int64_t d = 1; d <= k; ++d)
53  {
54  r *= n--;
55  r /= d;
56  }
57  return r;
58 }
59 
60 struct boosting
61 {
62  int N;
63  float gamma;
64  std::string alg;
65  vw* all;
66  std::shared_ptr<rand_state> _random_state;
67  std::vector<std::vector<int64_t> > C;
68  std::vector<float> alpha;
69  std::vector<float> v;
70  int t;
71 };
72 
73 //---------------------------------------------------
74 // Online Boost-by-Majority (BBM)
75 // --------------------------------------------------
76 template <bool is_learn>
78 {
79  label_data& ld = ec.l.simple;
80 
81  float final_prediction = 0;
82 
83  float s = 0;
84  float u = ec.weight;
85 
86  if (is_learn)
87  o.t++;
88 
89  for (int i = 0; i < o.N; i++)
90  {
91  if (is_learn)
92  {
93  float k = floorf((float)(o.N - i - s) / 2);
94  int64_t c;
95  if (o.N - (i + 1) < 0)
96  c = 0;
97  else if (k > o.N - (i + 1))
98  c = 0;
99  else if (k < 0)
100  c = 0;
101  else if (o.C[o.N - (i + 1)][(int64_t)k] != -1)
102  c = o.C[o.N - (i + 1)][(int64_t)k];
103  else
104  {
105  c = choose(o.N - (i + 1), (int64_t)k);
106  o.C[o.N - (i + 1)][(int64_t)k] = c;
107  }
108 
109  float w = c * (float)pow((double)(0.5 + o.gamma), (double)k) *
110  (float)pow((double)0.5 - o.gamma, (double)(o.N - (i + 1) - k));
111 
112  // update ec.weight, weight for learner i (starting from 0)
113  ec.weight = u * w;
114 
115  base.predict(ec, i);
116 
117  // ec.pred.scalar is now the i-th learner prediction on this example
118  s += ld.label * ec.pred.scalar;
119 
120  final_prediction += ec.pred.scalar;
121 
122  base.learn(ec, i);
123  }
124  else
125  {
126  base.predict(ec, i);
127  final_prediction += ec.pred.scalar;
128  }
129  }
130 
131  ec.weight = u;
132  ec.partial_prediction = final_prediction;
133  ec.pred.scalar = sign(final_prediction);
134 
135  if (ld.label == ec.pred.scalar)
136  ec.loss = 0.;
137  else
138  ec.loss = ec.weight;
139 }
140 
141 //-----------------------------------------------------------------
142 // Logistic boost
143 //-----------------------------------------------------------------
144 template <bool is_learn>
146 {
147  label_data& ld = ec.l.simple;
148 
149  float final_prediction = 0;
150 
151  float s = 0;
152  float u = ec.weight;
153 
154  if (is_learn)
155  o.t++;
156  float eta = 4.f / sqrtf((float)o.t);
157 
158  for (int i = 0; i < o.N; i++)
159  {
160  if (is_learn)
161  {
162  float w = 1 / (1 + correctedExp(s));
163 
164  ec.weight = u * w;
165 
166  base.predict(ec, i);
167  float z;
168  z = ld.label * ec.pred.scalar;
169 
170  s += z * o.alpha[i];
171 
172  // if ld.label * ec.pred.scalar < 0, learner i made a mistake
173 
174  final_prediction += ec.pred.scalar * o.alpha[i];
175 
176  // update alpha
177  o.alpha[i] += eta * z / (1 + correctedExp(s));
178  if (o.alpha[i] > 2.)
179  o.alpha[i] = 2;
180  if (o.alpha[i] < -2.)
181  o.alpha[i] = -2;
182 
183  base.learn(ec, i);
184  }
185  else
186  {
187  base.predict(ec, i);
188  final_prediction += ec.pred.scalar * o.alpha[i];
189  }
190  }
191 
192  ec.weight = u;
193  ec.partial_prediction = final_prediction;
194  ec.pred.scalar = sign(final_prediction);
195 
196  if (ld.label == ec.pred.scalar)
197  ec.loss = 0.;
198  else
199  ec.loss = ec.weight;
200 }
201 
202 template <bool is_learn>
204 {
205  label_data& ld = ec.l.simple;
206 
207  float final_prediction = 0, partial_prediction = 0;
208 
209  float s = 0;
210  float v_normalization = 0, v_partial_sum = 0;
211  float u = ec.weight;
212 
213  if (is_learn)
214  o.t++;
215  float eta = 4.f / (float)sqrtf((float)o.t);
216 
217  float stopping_point = o._random_state->get_and_update_random();
218 
219  for (int i = 0; i < o.N; i++)
220  {
221  if (is_learn)
222  {
223  float w = 1 / (1 + correctedExp(s));
224 
225  ec.weight = u * w;
226 
227  base.predict(ec, i);
228  float z;
229 
230  z = ld.label * ec.pred.scalar;
231 
232  s += z * o.alpha[i];
233 
234  if (v_partial_sum <= stopping_point)
235  {
236  final_prediction += ec.pred.scalar * o.alpha[i];
237  }
238 
239  partial_prediction += ec.pred.scalar * o.alpha[i];
240 
241  v_partial_sum += o.v[i];
242 
243  // update v, exp(-1) = 0.36788
244  if (ld.label * partial_prediction < 0)
245  {
246  o.v[i] *= 0.36788f;
247  }
248  v_normalization += o.v[i];
249 
250  // update alpha
251  o.alpha[i] += eta * z / (1 + correctedExp(s));
252  if (o.alpha[i] > 2.)
253  o.alpha[i] = 2;
254  if (o.alpha[i] < -2.)
255  o.alpha[i] = -2;
256 
257  base.learn(ec, i);
258  }
259  else
260  {
261  base.predict(ec, i);
262  if (v_partial_sum <= stopping_point)
263  {
264  final_prediction += ec.pred.scalar * o.alpha[i];
265  }
266  else
267  {
268  // stopping at learner i
269  break;
270  }
271  v_partial_sum += o.v[i];
272  }
273  }
274 
275  // normalize v vector in training
276  if (is_learn)
277  {
278  for (int i = 0; i < o.N; i++)
279  {
280  if (v_normalization)
281  o.v[i] /= v_normalization;
282  }
283  }
284 
285  ec.weight = u;
286  ec.partial_prediction = final_prediction;
287  ec.pred.scalar = sign(final_prediction);
288 
289  if (ld.label == ec.pred.scalar)
290  ec.loss = 0.;
291  else
292  ec.loss = ec.weight;
293 }
294 
295 void save_load_sampling(boosting& o, io_buf& model_file, bool read, bool text)
296 {
297  if (model_file.files.size() == 0)
298  return;
299  std::stringstream os;
300  os << "boosts " << o.N << endl;
301  bin_text_read_write_fixed(model_file, (char*)&(o.N), sizeof(o.N), "", read, os, text);
302 
303  if (read)
304  {
305  o.alpha.resize(o.N);
306  o.v.resize(o.N);
307  }
308 
309  for (int i = 0; i < o.N; i++)
310  if (read)
311  {
312  float f;
313  model_file.bin_read_fixed((char*)&f, sizeof(f), "");
314  o.alpha[i] = f;
315  }
316  else
317  {
318  std::stringstream os2;
319  os2 << "alpha " << o.alpha[i] << endl;
320  bin_text_write_fixed(model_file, (char*)&(o.alpha[i]), sizeof(o.alpha[i]), os2, text);
321  }
322 
323  for (int i = 0; i < o.N; i++)
324  if (read)
325  {
326  float f;
327  model_file.bin_read_fixed((char*)&f, sizeof(f), "");
328  o.v[i] = f;
329  }
330  else
331  {
332  std::stringstream os2;
333  os2 << "v " << o.v[i] << endl;
334  bin_text_write_fixed(model_file, (char*)&(o.v[i]), sizeof(o.v[i]), os2, text);
335  }
336 
337  if (read)
338  {
339  cerr << "Loading alpha and v: " << endl;
340  }
341  else
342  {
343  cerr << "Saving alpha and v, current weighted_examples = "
345  }
346  for (int i = 0; i < o.N; i++)
347  {
348  cerr << o.alpha[i] << " " << o.v[i] << endl;
349  }
350  cerr << endl;
351 }
352 
353 void return_example(vw& all, boosting& /* a */, example& ec)
354 {
356  VW::finish_example(all, ec);
357 }
358 
359 void save_load(boosting& o, io_buf& model_file, bool read, bool text)
360 {
361  if (model_file.files.size() == 0)
362  return;
363  std::stringstream os;
364  os << "boosts " << o.N << endl;
365  bin_text_read_write_fixed(model_file, (char*)&(o.N), sizeof(o.N), "", read, os, text);
366 
367  if (read)
368  o.alpha.resize(o.N);
369 
370  for (int i = 0; i < o.N; i++)
371  if (read)
372  {
373  float f;
374  model_file.bin_read_fixed((char*)&f, sizeof(f), "");
375  o.alpha[i] = f;
376  }
377  else
378  {
379  std::stringstream os2;
380  os2 << "alpha " << o.alpha[i] << endl;
381  bin_text_write_fixed(model_file, (char*)&(o.alpha[i]), sizeof(o.alpha[i]), os2, text);
382  }
383 
384  if (!o.all->quiet)
385  {
386  if (read)
387  cerr << "Loading alpha: " << endl;
388  else
389  cerr << "Saving alpha, current weighted_examples = " << o.all->sd->weighted_examples() << endl;
390  for (int i = 0; i < o.N; i++) cerr << o.alpha[i] << " " << endl;
391 
392  cerr << endl;
393  }
394 }
395 
397 {
398  free_ptr<boosting> data = scoped_calloc_or_throw<boosting>();
399  option_group_definition new_options("Boosting");
400  new_options.add(make_option("boosting", data->N).keep().help("Online boosting with <N> weak learners"))
401  .add(make_option("gamma", data->gamma)
402  .default_value(0.1f)
403  .help("weak learner's edge (=0.1), used only by online BBM"))
404  .add(
405  make_option("alg", data->alg)
406  .keep()
407  .default_value("BBM")
408  .help("specify the boosting algorithm: BBM (default), logistic (AdaBoost.OL.W), adaptive (AdaBoost.OL)"));
409  options.add_and_parse(new_options);
410 
411  if (!options.was_supplied("boosting"))
412  return nullptr;
413 
414  // Description of options:
415  // "BBM" implements online BBM (Algorithm 1 in BLK'15)
416  // "logistic" implements AdaBoost.OL.W (importance weighted version
417  // of Algorithm 2 in BLK'15)
418  // "adaptive" implements AdaBoost.OL (Algorithm 2 in BLK'15,
419  // using sampling rather than importance weighting)
420 
421  if (!all.quiet)
422  cerr << "Number of weak learners = " << data->N << endl;
423  if (!all.quiet)
424  cerr << "Gamma = " << data->gamma << endl;
425 
426  data->C = std::vector<std::vector<int64_t> >(data->N, std::vector<int64_t>(data->N, -1));
427  data->t = 0;
428  data->all = &all;
429  data->_random_state = all.get_random_state();
430  data->alpha = std::vector<float>(data->N, 0);
431  data->v = std::vector<float>(data->N, 1);
432 
434  if (data->alg == "BBM")
435  l = &init_learner<boosting, example>(
436  data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>, data->N);
437  else if (data->alg == "logistic")
438  {
439  l = &init_learner<boosting, example>(data, as_singleline(setup_base(options, all)), predict_or_learn_logistic<true>,
440  predict_or_learn_logistic<false>, data->N);
442  }
443  else if (data->alg == "adaptive")
444  {
445  l = &init_learner<boosting, example>(data, as_singleline(setup_base(options, all)), predict_or_learn_adaptive<true>,
446  predict_or_learn_adaptive<false>, data->N);
447  l->set_save_load(save_load_sampling);
448  }
449  else
450  THROW("Unrecognized boosting algorithm: \'" << data->alg << "\' Bailing!");
451 
452  l->set_finish_example(return_example);
453 
454  return make_base(*l);
455 }
#define correctedExp
Definition: correctedMath.h:27
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
int t
Definition: boosting.cc:70
double weighted_unlabeled_examples
Definition: global_data.h:143
void output_and_account_example(vw &all, active &a, example &ec)
Definition: active.cc:105
void predict_or_learn(boosting &o, LEARNER::single_learner &base, example &ec)
Definition: boosting.cc:77
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
LEARNER::base_learner * boosting_setup(options_i &options, vw &all)
Definition: boosting.cc:396
float partial_prediction
Definition: example.h:68
bool quiet
Definition: global_data.h:487
vw * all
Definition: boosting.cc:65
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
float label
Definition: simple_label.h:14
void predict_or_learn_logistic(boosting &o, LEARNER::single_learner &base, example &ec)
Definition: boosting.cc:145
label_data simple
Definition: example.h:28
void save_load(boosting &o, io_buf &model_file, bool read, bool text)
Definition: boosting.cc:359
size_t size() const
Definition: v_array.h:68
std::string alg
Definition: boosting.cc:64
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
size_t bin_text_write_fixed(io_buf &io, char *data, size_t len, std::stringstream &msg, bool text)
Definition: io_buf.h:313
size_t bin_read_fixed(char *data, size_t len, const char *read_message)
Definition: io_buf.h:230
shared_data * sd
Definition: global_data.h:375
void return_example(vw &all, boosting &, example &ec)
Definition: boosting.cc:353
void predict_or_learn_adaptive(boosting &o, LEARNER::single_learner &base, example &ec)
Definition: boosting.cc:203
v_array< int > files
Definition: io_buf.h:64
virtual bool was_supplied(const std::string &key)=0
std::vector< float > alpha
Definition: boosting.cc:68
std::shared_ptr< rand_state > _random_state
Definition: boosting.cc:66
float sign(float w)
Definition: boosting.cc:33
std::vector< std::vector< int64_t > > C
Definition: boosting.cc:67
Definition: io_buf.h:54
void finish_example(vw &, example &)
Definition: parser.cc:881
int64_t choose(int64_t n, int64_t k)
Definition: boosting.cc:41
float loss
Definition: example.h:70
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
polylabel l
Definition: example.h:57
void save_load_sampling(boosting &o, io_buf &model_file, bool read, bool text)
Definition: boosting.cc:295
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
double weighted_labeled_examples
Definition: global_data.h:141
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
std::vector< float > v
Definition: boosting.cc:69
float gamma
Definition: boosting.cc:63
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
double weighted_examples()
Definition: global_data.h:188
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
int N
Definition: boosting.cc:62