Vowpal Wabbit
Functions
ftrl.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerftrl_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ ftrl_setup()

LEARNER::base_learner* ftrl_setup ( VW::config::options_i options,
vw all 
)

Definition at line 335 of file ftrl.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), vw::audit, LEARNER::end_pass(), VW::config::options_i::get_typed_option(), vw::hash_inv, shared_data::holdout_best_loss, vw::holdout_set_off, LEARNER::init_learner(), learn_cb(), learn_pistol(), LEARNER::make_base(), VW::config::make_option(), vw::normalized_sum_norm_x, vw::quiet, save_load(), vw::sd, sensitivity(), parameters::stride_shift(), UINT64_ONE, VW::config::options_i::was_supplied(), and vw::weights.

Referenced by parse_reductions().

336 {
337  auto b = scoped_calloc_or_throw<ftrl>();
338  bool ftrl_option = false;
339  bool pistol = false;
340  bool coin = false;
341 
342  option_group_definition new_options("Follow the Regularized Leader");
343  new_options.add(make_option("ftrl", ftrl_option).keep().help("FTRL: Follow the Proximal Regularized Leader"))
344  .add(make_option("coin", coin).keep().help("Coin betting optimizer"))
345  .add(make_option("pistol", pistol).keep().help("PiSTOL: Parameter-free STOchastic Learning"))
346  .add(make_option("ftrl_alpha", b->ftrl_alpha).help("Learning rate for FTRL optimization"))
347  .add(make_option("ftrl_beta", b->ftrl_beta).help("Learning rate for FTRL optimization"));
348  options.add_and_parse(new_options);
349 
350  if (!ftrl_option && !pistol && !coin)
351  {
352  return nullptr;
353  }
354 
355  // Defaults that are specific to the mode that was chosen.
356  if (ftrl_option)
357  {
358  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 0.005f;
359  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.1f;
360  }
361  else if (pistol)
362  {
363  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 1.0f;
364  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.5f;
365  }
366  else if (coin)
367  {
368  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 4.0f;
369  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 1.0f;
370  }
371 
372  b->all = &all;
373  b->no_win_counter = 0;
374  b->all->normalized_sum_norm_x = 0;
375  b->total_weight = 0;
376 
377  void (*learn_ptr)(ftrl&, single_learner&, example&) = nullptr;
378 
379  std::string algorithm_name;
380  if (ftrl_option)
381  {
382  algorithm_name = "Proximal-FTRL";
383  if (all.audit)
384  learn_ptr = learn_proximal<true>;
385  else
386  learn_ptr = learn_proximal<false>;
387  all.weights.stride_shift(2); // NOTE: for more parameter storage
388  b->ftrl_size = 3;
389  }
390  else if (pistol)
391  {
392  algorithm_name = "PiSTOL";
393  learn_ptr = learn_pistol;
394  all.weights.stride_shift(2); // NOTE: for more parameter storage
395  b->ftrl_size = 4;
396  }
397  else if (coin)
398  {
399  algorithm_name = "Coin Betting";
400  learn_ptr = learn_cb;
401  all.weights.stride_shift(3); // NOTE: for more parameter storage
402  b->ftrl_size = 6;
403  }
404 
405  b->data.ftrl_alpha = b->ftrl_alpha;
406  b->data.ftrl_beta = b->ftrl_beta;
407  b->data.l1_lambda = b->all->l1_lambda;
408  b->data.l2_lambda = b->all->l2_lambda;
409 
410  if (!all.quiet)
411  {
412  std::cerr << "Enabling FTRL based optimization" << std::endl;
413  std::cerr << "Algorithm used: " << algorithm_name << std::endl;
414  std::cerr << "ftrl_alpha = " << b->ftrl_alpha << std::endl;
415  std::cerr << "ftrl_beta = " << b->ftrl_beta << std::endl;
416  }
417 
418  if (!all.holdout_set_off)
419  {
420  all.sd->holdout_best_loss = FLT_MAX;
421  b->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
422  }
423 
425  if (all.audit || all.hash_inv)
426  l = &init_learner(b, learn_ptr, predict<true>, UINT64_ONE << all.weights.stride_shift());
427  else
428  l = &init_learner(b, learn_ptr, predict<false>, UINT64_ONE << all.weights.stride_shift());
429  l->set_sensitivity(sensitivity);
430  if (all.audit || all.hash_inv)
431  l->set_multipredict(multipredict<true>);
432  else
433  l->set_multipredict(multipredict<false>);
434  l->set_save_load(save_load);
435  l->set_end_pass(end_pass);
436  return make_base(*l);
437 }
parameters weights
Definition: global_data.h:537
Definition: ftrl.cc:31
bool hash_inv
Definition: global_data.h:541
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float sensitivity(ftrl &b, base_learner &, example &ec)
Definition: ftrl.cc:71
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
bool holdout_set_off
Definition: global_data.h:499
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
constexpr uint64_t UINT64_ONE
void end_pass(ftrl &g)
Definition: ftrl.cc:321
void learn_pistol(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:279
uint32_t stride_shift()
bool audit
Definition: global_data.h:486
void learn_cb(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:290
void save_load(ftrl &b, io_buf &model_file, bool read, bool text)
Definition: ftrl.cc:301
double normalized_sum_norm_x
Definition: global_data.h:420