Vowpal Wabbit
Functions
gd_mf.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnergd_mf_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ gd_mf_setup()

LEARNER::base_learner* gd_mf_setup ( VW::config::options_i options,
vw all 
)

Definition at line 327 of file gd_mf.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::end_pass(), vw::eta, f, VW::config::options_i::get_typed_option(), shared_data::holdout_best_loss, vw::holdout_set_off, LEARNER::init_learner(), vw::initial_t, learn(), LEARNER::make_base(), VW::config::make_option(), vw::power_t, ldamath::powf(), predict(), vw::random_weights, save_load(), vw::sd, LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_save_load(), parameters::stride_shift(), shared_data::t, THROW, UINT64_ONE, VW::config::options_i::was_supplied(), and vw::weights.

Referenced by parse_reductions().

328 {
329  auto data = scoped_calloc_or_throw<gdmf>();
330 
331  bool bfgs = false;
332  bool conjugate_gradient = false;
333  option_group_definition gf_md_options("Gradient Descent Matrix Factorization");
334  gf_md_options.add(make_option("rank", data->rank).keep().help("rank for matrix factorization."));
335 
336  // Not supported, need to be checked to be false.
337  gf_md_options.add(make_option("bfgs", bfgs).help("Option not supported by this reduction"));
338  gf_md_options.add(
339  make_option("conjugate_gradient", conjugate_gradient).help("Option not supported by this reduction"));
340  options.add_and_parse(gf_md_options);
341 
342  if (!options.was_supplied("rank"))
343  return nullptr;
344 
345  if (options.was_supplied("adaptive"))
346  THROW("adaptive is not implemented for matrix factorization");
347  if (options.was_supplied("normalized"))
348  THROW("normalized is not implemented for matrix factorization");
349  if (options.was_supplied("exact_adaptive_norm"))
350  THROW("normalized adaptive updates is not implemented for matrix factorization");
351 
352  if (bfgs || conjugate_gradient)
353  THROW("bfgs is not implemented for matrix factorization");
354 
355  data->all = &all;
356  data->no_win_counter = 0;
357 
358  // store linear + 2*rank weights per index, round up to power of two
359  float temp = ceilf(logf((float)(data->rank * 2 + 1)) / logf(2.f));
360  all.weights.stride_shift((size_t)temp);
361  all.random_weights = true;
362 
363  if (!all.holdout_set_off)
364  {
365  all.sd->holdout_best_loss = FLT_MAX;
366  data->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
367  }
368 
369  if (!options.was_supplied("learning_rate") && !options.was_supplied("l"))
370  all.eta = 10; // default learning rate to 10 for non default update rule
371 
372  // default initial_t to 1 instead of 0
373  if (!options.was_supplied("initial_t"))
374  {
375  all.sd->t = 1.f;
376  all.initial_t = 1.f;
377  }
378  all.eta *= powf((float)(all.sd->t), all.power_t);
379 
383 
384  return make_base(l);
385 }
parameters weights
Definition: global_data.h:537
float initial_t
Definition: global_data.h:530
float power_t
Definition: global_data.h:447
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
Definition: bfgs.cc:62
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
bool holdout_set_off
Definition: global_data.h:499
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
T powf(T, T)
Definition: lda_core.cc:428
virtual bool was_supplied(const std::string &key)=0
bool random_weights
Definition: global_data.h:492
void save_load(gdmf &d, io_buf &model_file, bool read, bool text)
Definition: gd_mf.cc:248
void predict(gdmf &d, single_learner &, example &ec)
Definition: gd_mf.cc:316
float eta
Definition: global_data.h:531
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
constexpr uint64_t UINT64_ONE
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
uint32_t stride_shift()
void learn(gdmf &d, single_learner &, example &ec)
Definition: gd_mf.cc:318
#define THROW(args)
Definition: vw_exception.h:181
void end_pass(gdmf &d)
Definition: gd_mf.cc:298
float f
Definition: cache.cc:40