Vowpal Wabbit
Classes | Functions
SVRG Namespace Reference

Classes

struct  svrg
 
struct  update
 

Functions

template<int offset>
void vec_add (float &p, const float x, float &w)
 
template<int offset>
float inline_predict (vw &all, example &ec)
 
float predict_stable (const svrg &s, example &ec)
 
void predict (svrg &s, single_learner &, example &ec)
 
float gradient_scalar (const svrg &s, const example &ec, float pred)
 
void update_inner_feature (update &u, float x, float &w)
 
void update_stable_feature (float &g_scalar, float x, float &w)
 
void update_inner (const svrg &s, example &ec)
 
void update_stable (const svrg &s, example &ec)
 
void learn (svrg &s, single_learner &base, example &ec)
 
void save_load (svrg &s, io_buf &model_file, bool read, bool text)
 

Function Documentation

◆ gradient_scalar()

float SVRG::gradient_scalar ( const svrg s,
const example ec,
float  pred 
)

Definition at line 61 of file svrg.cc.

References SVRG::svrg::all, loss_function::first_derivative(), example::l, label_data::label, vw::loss, vw::sd, polylabel::simple, and example::weight.

Referenced by update_inner(), and update_stable().

62 {
63  return s.all->loss->first_derivative(s.all->sd, pred, ec.l.simple.label) * ec.weight;
64 }
loss_function * loss
Definition: global_data.h:523
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float first_derivative(shared_data *, float prediction, float label)=0
vw * all
Definition: svrg.cc:27
shared_data * sd
Definition: global_data.h:375
polylabel l
Definition: example.h:57
float weight
Definition: example.h:62

◆ inline_predict()

template<int offset>
float SVRG::inline_predict ( vw all,
example ec 
)
inline

Definition at line 41 of file svrg.cc.

References label_data::initial, example::l, and polylabel::simple.

42 {
43  float acc = ec.l.simple.initial;
44  GD::foreach_feature<float, vec_add<offset> >(all, ec, acc);
45  return acc;
46 }
label_data simple
Definition: example.h:28
float initial
Definition: simple_label.h:16
polylabel l
Definition: example.h:57

◆ learn()

void SVRG::learn ( svrg s,
single_learner base,
example ec 
)

Definition at line 105 of file svrg.cc.

References SVRG::svrg::all, VW::get_weight(), example::in_use, VW::num_weights(), vw::passes_complete, predict(), SVRG::svrg::prev_pass, vw::quiet, VW::set_weight(), SVRG::svrg::stable_grad_count, SVRG::svrg::stage_size, update_inner(), update_stable(), W_INNER, W_STABLE, and W_STABLEGRAD.

Referenced by svrg_setup().

106 {
107  assert(ec.in_use);
108 
109  predict(s, base, ec);
110 
111  const int pass = (int)s.all->passes_complete;
112 
113  if (pass % (s.stage_size + 1) == 0) // Compute exact gradient
114  {
115  if (s.prev_pass != pass && !s.all->quiet)
116  {
117  std::cout << "svrg pass " << pass << ": committing stable point" << std::endl;
118  for (uint32_t j = 0; j < VW::num_weights(*s.all); j++)
119  {
120  float w = VW::get_weight(*s.all, j, W_INNER);
121  VW::set_weight(*s.all, j, W_STABLE, w);
122  VW::set_weight(*s.all, j, W_STABLEGRAD, 0.f);
123  }
124  s.stable_grad_count = 0;
125  std::cout << "svrg pass " << pass << ": computing exact gradient" << std::endl;
126  }
127  update_stable(s, ec);
128  s.stable_grad_count++;
129  }
130  else // Perform updates
131  {
132  if (s.prev_pass != pass && !s.all->quiet)
133  {
134  std::cout << "svrg pass " << pass << ": taking steps" << std::endl;
135  }
136  update_inner(s, ec);
137  }
138 
139  s.prev_pass = pass;
140 }
int prev_pass
Definition: svrg.cc:21
void update_stable(const svrg &s, example &ec)
Definition: svrg.cc:99
#define W_STABLE
Definition: svrg.cc:15
bool quiet
Definition: global_data.h:487
#define W_INNER
Definition: svrg.cc:14
vw * all
Definition: svrg.cc:27
int stable_grad_count
Definition: svrg.cc:22
float get_weight(vw &all, uint32_t index, uint32_t offset)
Definition: vw.h:177
void set_weight(vw &all, uint32_t index, uint32_t offset, float value)
Definition: vw.h:182
void predict(svrg &s, single_learner &, example &ec)
Definition: svrg.cc:55
size_t passes_complete
Definition: global_data.h:452
bool in_use
Definition: example.h:79
void update_inner(const svrg &s, example &ec)
Definition: svrg.cc:88
uint32_t num_weights(vw &all)
Definition: vw.h:187
#define W_STABLEGRAD
Definition: svrg.cc:16
int stage_size
Definition: svrg.cc:20

◆ predict()

void SVRG::predict ( svrg s,
single_learner ,
example ec 
)

Definition at line 55 of file svrg.cc.

References SVRG::svrg::all, GD::finalize_prediction(), example::partial_prediction, example::pred, polyprediction::scalar, and vw::sd.

Referenced by learn(), and svrg_setup().

56 {
57  ec.partial_prediction = inline_predict<W_INNER>(*s.all, ec);
59 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
float scalar
Definition: example.h:45
float partial_prediction
Definition: example.h:68
vw * all
Definition: svrg.cc:27
shared_data * sd
Definition: global_data.h:375
polyprediction pred
Definition: example.h:60

◆ predict_stable()

float SVRG::predict_stable ( const svrg s,
example ec 
)

Definition at line 50 of file svrg.cc.

References SVRG::svrg::all, GD::finalize_prediction(), and vw::sd.

Referenced by update_inner(), and update_stable().

51 {
52  return GD::finalize_prediction(s.all->sd, inline_predict<W_STABLE>(*s.all, ec));
53 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
vw * all
Definition: svrg.cc:27
shared_data * sd
Definition: global_data.h:375

◆ save_load()

void SVRG::save_load ( svrg s,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 142 of file svrg.cc.

References SVRG::svrg::all, bin_text_read_write_fixed(), v_array< T >::empty(), io_buf::files, initialize_regressor(), GD::save_load_online_state(), GD::save_load_regressor(), and vw::save_resume.

Referenced by svrg_setup().

143 {
144  if (read)
145  {
147  }
148 
149  if (!model_file.files.empty())
150  {
151  bool resume = s.all->save_resume;
152  std::stringstream msg;
153  msg << ":" << resume << "\n";
154  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
155 
156  double temp = 0.;
157  if (resume)
158  GD::save_load_online_state(*s.all, model_file, read, text, temp);
159  else
160  GD::save_load_regressor(*s.all, model_file, read, text);
161  }
162 }
void initialize_regressor(vw &all, T &weights)
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
vw * all
Definition: svrg.cc:27
v_array< int > files
Definition: io_buf.h:64
bool empty() const
Definition: v_array.h:59
bool save_resume
Definition: global_data.h:415
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326

◆ update_inner()

void SVRG::update_inner ( const svrg s,
example ec 
)

Definition at line 88 of file svrg.cc.

References SVRG::svrg::all, SVRG::update::eta, vw::eta, SVRG::update::g_scalar_inner, SVRG::update::g_scalar_stable, gradient_scalar(), SVRG::update::norm, example::pred, predict_stable(), polyprediction::scalar, and SVRG::svrg::stable_grad_count.

Referenced by learn().

89 {
90  update u;
91  // |ec| already has prediction according to inner weights.
94  u.eta = s.all->eta;
95  u.norm = (float)s.stable_grad_count;
96  GD::foreach_feature<update, update_inner_feature>(*s.all, ec, u);
97 }
float norm
Definition: svrg.cc:73
float scalar
Definition: example.h:45
float predict_stable(const svrg &s, example &ec)
Definition: svrg.cc:50
float eta
Definition: svrg.cc:72
float g_scalar_inner
Definition: svrg.cc:71
vw * all
Definition: svrg.cc:27
int stable_grad_count
Definition: svrg.cc:22
float g_scalar_stable
Definition: svrg.cc:70
float gradient_scalar(const svrg &s, const example &ec, float pred)
Definition: svrg.cc:61
float eta
Definition: global_data.h:531
polyprediction pred
Definition: example.h:60

◆ update_inner_feature()

void SVRG::update_inner_feature ( update u,
float  x,
float &  w 
)
inline

Definition at line 76 of file svrg.cc.

References SVRG::update::eta, SVRG::update::g_scalar_inner, SVRG::update::g_scalar_stable, SVRG::update::norm, and W_STABLEGRAD.

77 {
78  float* ws = &w;
79  w -= u.eta * ((u.g_scalar_inner - u.g_scalar_stable) * x + ws[W_STABLEGRAD] / u.norm);
80 }
float norm
Definition: svrg.cc:73
float eta
Definition: svrg.cc:72
float g_scalar_inner
Definition: svrg.cc:71
float g_scalar_stable
Definition: svrg.cc:70
#define W_STABLEGRAD
Definition: svrg.cc:16

◆ update_stable()

void SVRG::update_stable ( const svrg s,
example ec 
)

Definition at line 99 of file svrg.cc.

References SVRG::svrg::all, gradient_scalar(), and predict_stable().

Referenced by learn().

100 {
101  float g = gradient_scalar(s, ec, predict_stable(s, ec));
102  GD::foreach_feature<float, update_stable_feature>(*s.all, ec, g);
103 }
float predict_stable(const svrg &s, example &ec)
Definition: svrg.cc:50
vw * all
Definition: svrg.cc:27
float gradient_scalar(const svrg &s, const example &ec, float pred)
Definition: svrg.cc:61

◆ update_stable_feature()

void SVRG::update_stable_feature ( float &  g_scalar,
float  x,
float &  w 
)
inline

Definition at line 82 of file svrg.cc.

References W_STABLEGRAD.

83 {
84  float* ws = &w;
85  ws[W_STABLEGRAD] += g_scalar * x;
86 }
#define W_STABLEGRAD
Definition: svrg.cc:16

◆ vec_add()

template<int offset>
void SVRG::vec_add ( float &  p,
const float  x,
float &  w 
)
inline

Definition at line 34 of file svrg.cc.

35 {
36  float* ws = &w;
37  p += x * ws[offset];
38 }