Vowpal Wabbit
Classes | Namespaces | Macros | Functions
svrg.cc File Reference
#include <cassert>
#include <iostream>
#include "gd.h"
#include "vw.h"
#include "reductions.h"

Go to the source code of this file.

Classes

struct  SVRG::svrg
 
struct  SVRG::update
 

Namespaces

 SVRG
 

Macros

#define W_INNER   0
 
#define W_STABLE   1
 
#define W_STABLEGRAD   2
 

Functions

template<int offset>
void SVRG::vec_add (float &p, const float x, float &w)
 
template<int offset>
float SVRG::inline_predict (vw &all, example &ec)
 
float SVRG::predict_stable (const svrg &s, example &ec)
 
void SVRG::predict (svrg &s, single_learner &, example &ec)
 
float SVRG::gradient_scalar (const svrg &s, const example &ec, float pred)
 
void SVRG::update_inner_feature (update &u, float x, float &w)
 
void SVRG::update_stable_feature (float &g_scalar, float x, float &w)
 
void SVRG::update_inner (const svrg &s, example &ec)
 
void SVRG::update_stable (const svrg &s, example &ec)
 
void SVRG::learn (svrg &s, single_learner &base, example &ec)
 
void SVRG::save_load (svrg &s, io_buf &model_file, bool read, bool text)
 
base_learnersvrg_setup (options_i &options, vw &all)
 

Macro Definition Documentation

◆ W_INNER

#define W_INNER   0

Definition at line 14 of file svrg.cc.

Referenced by SVRG::learn().

◆ W_STABLE

#define W_STABLE   1

Definition at line 15 of file svrg.cc.

Referenced by SVRG::learn().

◆ W_STABLEGRAD

#define W_STABLEGRAD   2

Definition at line 16 of file svrg.cc.

Referenced by SVRG::learn(), SVRG::update_inner_feature(), and SVRG::update_stable_feature().

Function Documentation

◆ svrg_setup()

base_learner* svrg_setup ( options_i options,
vw all 
)

Definition at line 168 of file svrg.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::init_learner(), SVRG::learn(), LEARNER::make_base(), VW::config::make_option(), SVRG::predict(), SVRG::save_load(), LEARNER::learner< T, E >::set_save_load(), parameters::stride_shift(), UINT64_ONE, and vw::weights.

Referenced by parse_reductions().

169 {
170  auto s = scoped_calloc_or_throw<svrg>();
171 
172  bool svrg_option = false;
173  option_group_definition new_options("Stochastic Variance Reduced Gradient");
174  new_options.add(make_option("svrg", svrg_option).keep().help("Streaming Stochastic Variance Reduced Gradient"))
175  .add(make_option("stage_size", s->stage_size).default_value(1).help("Number of passes per SVRG stage"));
176  options.add_and_parse(new_options);
177 
178  if (!svrg_option)
179  {
180  return nullptr;
181  }
182 
183  s->all = &all;
184  s->prev_pass = -1;
185  s->stable_grad_count = 0;
186 
187  // Request more parameter storage (4 floats per feature)
188  all.weights.stride_shift(2);
191  return make_base(l);
192 }
parameters weights
Definition: global_data.h:537
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void predict(svrg &s, single_learner &, example &ec)
Definition: svrg.cc:55
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void learn(svrg &s, single_learner &base, example &ec)
Definition: svrg.cc:105
constexpr uint64_t UINT64_ONE
uint32_t stride_shift()
void save_load(svrg &s, io_buf &model_file, bool read, bool text)
Definition: svrg.cc:142