Vowpal Wabbit
svrg.cc
Go to the documentation of this file.
1 
2 #include <cassert>
3 #include <iostream>
4 
5 #include "gd.h"
6 #include "vw.h"
7 #include "reductions.h"
8 
9 using namespace LEARNER;
10 using namespace VW::config;
11 
12 namespace SVRG
13 {
14 #define W_INNER 0 // working "inner-loop" weights, updated per example
15 #define W_STABLE 1 // stable weights, updated per stage
16 #define W_STABLEGRAD 2 // gradient corresponding to stable weights
17 
18 struct svrg
19 {
20  int stage_size; // Number of data passes per stage.
21  int prev_pass; // To detect that we're in a new pass.
22  int stable_grad_count; // Number of data points that
23  // contributed to the stable gradient
24  // calculation.
25 
26  // The VW process' global state.
27  vw* all;
28 };
29 
30 // Mimic GD::inline_predict but with offset for predicting with either
31 // stable versus inner weights.
32 
33 template <int offset>
34 inline void vec_add(float& p, const float x, float& w)
35 {
36  float* ws = &w;
37  p += x * ws[offset];
38 }
39 
40 template <int offset>
41 inline float inline_predict(vw& all, example& ec)
42 {
43  float acc = ec.l.simple.initial;
44  GD::foreach_feature<float, vec_add<offset> >(all, ec, acc);
45  return acc;
46 }
47 
48 // -- Prediction, using inner vs. stable weights --
49 
50 float predict_stable(const svrg& s, example& ec)
51 {
52  return GD::finalize_prediction(s.all->sd, inline_predict<W_STABLE>(*s.all, ec));
53 }
54 
56 {
57  ec.partial_prediction = inline_predict<W_INNER>(*s.all, ec);
59 }
60 
61 float gradient_scalar(const svrg& s, const example& ec, float pred)
62 {
63  return s.all->loss->first_derivative(s.all->sd, pred, ec.l.simple.label) * ec.weight;
64 }
65 
66 // -- Updates, taking inner steps vs. accumulating a full gradient --
67 
68 struct update
69 {
72  float eta;
73  float norm;
74 };
75 
76 inline void update_inner_feature(update& u, float x, float& w)
77 {
78  float* ws = &w;
79  w -= u.eta * ((u.g_scalar_inner - u.g_scalar_stable) * x + ws[W_STABLEGRAD] / u.norm);
80 }
81 
82 inline void update_stable_feature(float& g_scalar, float x, float& w)
83 {
84  float* ws = &w;
85  ws[W_STABLEGRAD] += g_scalar * x;
86 }
87 
88 void update_inner(const svrg& s, example& ec)
89 {
90  update u;
91  // |ec| already has prediction according to inner weights.
94  u.eta = s.all->eta;
95  u.norm = (float)s.stable_grad_count;
96  GD::foreach_feature<update, update_inner_feature>(*s.all, ec, u);
97 }
98 
99 void update_stable(const svrg& s, example& ec)
100 {
101  float g = gradient_scalar(s, ec, predict_stable(s, ec));
102  GD::foreach_feature<float, update_stable_feature>(*s.all, ec, g);
103 }
104 
105 void learn(svrg& s, single_learner& base, example& ec)
106 {
107  assert(ec.in_use);
108 
109  predict(s, base, ec);
110 
111  const int pass = (int)s.all->passes_complete;
112 
113  if (pass % (s.stage_size + 1) == 0) // Compute exact gradient
114  {
115  if (s.prev_pass != pass && !s.all->quiet)
116  {
117  std::cout << "svrg pass " << pass << ": committing stable point" << std::endl;
118  for (uint32_t j = 0; j < VW::num_weights(*s.all); j++)
119  {
120  float w = VW::get_weight(*s.all, j, W_INNER);
121  VW::set_weight(*s.all, j, W_STABLE, w);
122  VW::set_weight(*s.all, j, W_STABLEGRAD, 0.f);
123  }
124  s.stable_grad_count = 0;
125  std::cout << "svrg pass " << pass << ": computing exact gradient" << std::endl;
126  }
127  update_stable(s, ec);
128  s.stable_grad_count++;
129  }
130  else // Perform updates
131  {
132  if (s.prev_pass != pass && !s.all->quiet)
133  {
134  std::cout << "svrg pass " << pass << ": taking steps" << std::endl;
135  }
136  update_inner(s, ec);
137  }
138 
139  s.prev_pass = pass;
140 }
141 
142 void save_load(svrg& s, io_buf& model_file, bool read, bool text)
143 {
144  if (read)
145  {
147  }
148 
149  if (!model_file.files.empty())
150  {
151  bool resume = s.all->save_resume;
152  std::stringstream msg;
153  msg << ":" << resume << "\n";
154  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
155 
156  double temp = 0.;
157  if (resume)
158  GD::save_load_online_state(*s.all, model_file, read, text, temp);
159  else
160  GD::save_load_regressor(*s.all, model_file, read, text);
161  }
162 }
163 
164 } // namespace SVRG
165 
166 using namespace SVRG;
167 
169 {
170  auto s = scoped_calloc_or_throw<svrg>();
171 
172  bool svrg_option = false;
173  option_group_definition new_options("Stochastic Variance Reduced Gradient");
174  new_options.add(make_option("svrg", svrg_option).keep().help("Streaming Stochastic Variance Reduced Gradient"))
175  .add(make_option("stage_size", s->stage_size).default_value(1).help("Number of passes per SVRG stage"));
176  options.add_and_parse(new_options);
177 
178  if (!svrg_option)
179  {
180  return nullptr;
181  }
182 
183  s->all = &all;
184  s->prev_pass = -1;
185  s->stable_grad_count = 0;
186 
187  // Request more parameter storage (4 floats per feature)
188  all.weights.stride_shift(2);
191  return make_base(l);
192 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
parameters weights
Definition: global_data.h:537
loss_function * loss
Definition: global_data.h:523
void initialize_regressor(vw &all, T &weights)
base_learner * svrg_setup(options_i &options, vw &all)
Definition: svrg.cc:168
float norm
Definition: svrg.cc:73
float scalar
Definition: example.h:45
float predict_stable(const svrg &s, example &ec)
Definition: svrg.cc:50
int prev_pass
Definition: svrg.cc:21
void update_stable(const svrg &s, example &ec)
Definition: svrg.cc:99
float eta
Definition: svrg.cc:72
#define W_STABLE
Definition: svrg.cc:15
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float partial_prediction
Definition: example.h:68
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
void update_stable_feature(float &g_scalar, float x, float &w)
Definition: svrg.cc:82
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float first_derivative(shared_data *, float prediction, float label)=0
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
float g_scalar_inner
Definition: svrg.cc:71
#define W_INNER
Definition: svrg.cc:14
Definition: svrg.cc:12
vw * all
Definition: svrg.cc:27
float inline_predict(vw &all, example &ec)
Definition: svrg.cc:41
int stable_grad_count
Definition: svrg.cc:22
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
v_array< int > files
Definition: io_buf.h:64
float get_weight(vw &all, uint32_t index, uint32_t offset)
Definition: vw.h:177
void update_inner_feature(update &u, float x, float &w)
Definition: svrg.cc:76
void set_weight(vw &all, uint32_t index, uint32_t offset, float value)
Definition: vw.h:182
float g_scalar_stable
Definition: svrg.cc:70
float initial
Definition: simple_label.h:16
Definition: io_buf.h:54
float gradient_scalar(const svrg &s, const example &ec, float pred)
Definition: svrg.cc:61
float eta
Definition: global_data.h:531
option_group_definition & add(T &&op)
Definition: options.h:90
void predict(svrg &s, single_learner &, example &ec)
Definition: svrg.cc:55
polylabel l
Definition: example.h:57
size_t passes_complete
Definition: global_data.h:452
bool in_use
Definition: example.h:79
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void learn(svrg &s, single_learner &base, example &ec)
Definition: svrg.cc:105
constexpr uint64_t UINT64_ONE
bool empty() const
Definition: v_array.h:59
void update_inner(const svrg &s, example &ec)
Definition: svrg.cc:88
bool save_resume
Definition: global_data.h:415
uint32_t stride_shift()
void save_load(svrg &s, io_buf &model_file, bool read, bool text)
Definition: svrg.cc:142
void vec_add(float &p, const float x, float &w)
Definition: svrg.cc:34
polyprediction pred
Definition: example.h:60
uint32_t num_weights(vw &all)
Definition: vw.h:187
#define W_STABLEGRAD
Definition: svrg.cc:16
int stage_size
Definition: svrg.cc:20
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
float weight
Definition: example.h:62
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326