Vowpal Wabbit
Functions
boosting.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerboosting_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ boosting_setup()

LEARNER::base_learner* boosting_setup ( VW::config::options_i options,
vw all 
)

Definition at line 396 of file boosting.cc.

References add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), f, vw::get_random_state(), LEARNER::make_base(), VW::config::make_option(), vw::quiet, return_example(), save_load(), save_load_sampling(), LEARNER::learner< T, E >::set_save_load(), setup_base(), THROW, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

397 {
398  free_ptr<boosting> data = scoped_calloc_or_throw<boosting>();
399  option_group_definition new_options("Boosting");
400  new_options.add(make_option("boosting", data->N).keep().help("Online boosting with <N> weak learners"))
401  .add(make_option("gamma", data->gamma)
402  .default_value(0.1f)
403  .help("weak learner's edge (=0.1), used only by online BBM"))
404  .add(
405  make_option("alg", data->alg)
406  .keep()
407  .default_value("BBM")
408  .help("specify the boosting algorithm: BBM (default), logistic (AdaBoost.OL.W), adaptive (AdaBoost.OL)"));
409  options.add_and_parse(new_options);
410 
411  if (!options.was_supplied("boosting"))
412  return nullptr;
413 
414  // Description of options:
415  // "BBM" implements online BBM (Algorithm 1 in BLK'15)
416  // "logistic" implements AdaBoost.OL.W (importance weighted version
417  // of Algorithm 2 in BLK'15)
418  // "adaptive" implements AdaBoost.OL (Algorithm 2 in BLK'15,
419  // using sampling rather than importance weighting)
420 
421  if (!all.quiet)
422  cerr << "Number of weak learners = " << data->N << endl;
423  if (!all.quiet)
424  cerr << "Gamma = " << data->gamma << endl;
425 
426  data->C = std::vector<std::vector<int64_t> >(data->N, std::vector<int64_t>(data->N, -1));
427  data->t = 0;
428  data->all = &all;
429  data->_random_state = all.get_random_state();
430  data->alpha = std::vector<float>(data->N, 0);
431  data->v = std::vector<float>(data->N, 1);
432 
434  if (data->alg == "BBM")
435  l = &init_learner<boosting, example>(
436  data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>, data->N);
437  else if (data->alg == "logistic")
438  {
439  l = &init_learner<boosting, example>(data, as_singleline(setup_base(options, all)), predict_or_learn_logistic<true>,
440  predict_or_learn_logistic<false>, data->N);
442  }
443  else if (data->alg == "adaptive")
444  {
445  l = &init_learner<boosting, example>(data, as_singleline(setup_base(options, all)), predict_or_learn_adaptive<true>,
446  predict_or_learn_adaptive<false>, data->N);
447  l->set_save_load(save_load_sampling);
448  }
449  else
450  THROW("Unrecognized boosting algorithm: \'" << data->alg << "\' Bailing!");
451 
452  l->set_finish_example(return_example);
453 
454  return make_base(*l);
455 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
void save_load(boosting &o, io_buf &model_file, bool read, bool text)
Definition: boosting.cc:359
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void return_example(vw &all, boosting &, example &ec)
Definition: boosting.cc:353
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void save_load_sampling(boosting &o, io_buf &model_file, bool read, bool text)
Definition: boosting.cc:295
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40