Vowpal Wabbit
Functions
svrg.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnersvrg_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ svrg_setup()

LEARNER::base_learner* svrg_setup ( VW::config::options_i options,
vw all 
)

Definition at line 168 of file svrg.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::init_learner(), SVRG::learn(), LEARNER::make_base(), VW::config::make_option(), SVRG::predict(), SVRG::save_load(), LEARNER::learner< T, E >::set_save_load(), parameters::stride_shift(), UINT64_ONE, and vw::weights.

Referenced by parse_reductions().

169 {
170  auto s = scoped_calloc_or_throw<svrg>();
171 
172  bool svrg_option = false;
173  option_group_definition new_options("Stochastic Variance Reduced Gradient");
174  new_options.add(make_option("svrg", svrg_option).keep().help("Streaming Stochastic Variance Reduced Gradient"))
175  .add(make_option("stage_size", s->stage_size).default_value(1).help("Number of passes per SVRG stage"));
176  options.add_and_parse(new_options);
177 
178  if (!svrg_option)
179  {
180  return nullptr;
181  }
182 
183  s->all = &all;
184  s->prev_pass = -1;
185  s->stable_grad_count = 0;
186 
187  // Request more parameter storage (4 floats per feature)
188  all.weights.stride_shift(2);
191  return make_base(l);
192 }
parameters weights
Definition: global_data.h:537
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void predict(svrg &s, single_learner &, example &ec)
Definition: svrg.cc:55
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void learn(svrg &s, single_learner &base, example &ec)
Definition: svrg.cc:105
constexpr uint64_t UINT64_ONE
uint32_t stride_shift()
void save_load(svrg &s, io_buf &model_file, bool read, bool text)
Definition: svrg.cc:142