Vowpal Wabbit
ftrl.cc
Go to the documentation of this file.
1 /*
2  Copyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5  */
6 #include <string>
7 #include "correctedMath.h"
8 #include "gd.h"
9 
10 using namespace LEARNER;
11 using namespace VW::config;
12 
13 #define W_XT 0 // current parameter
14 #define W_ZT 1 // in proximal is "accumulated z(t) = z(t-1) + g(t) + sigma*w(t)", in general is the dual weight vector
15 #define W_G2 2 // accumulated gradient information
16 #define W_MX 3 // maximum absolute value
17 #define W_WE 4 // Wealth
18 #define W_MG 5 // maximum gradient
19 
21 {
22  float update;
23  float ftrl_alpha;
24  float ftrl_beta;
25  float l1_lambda;
26  float l2_lambda;
27  float predict;
29 };
30 
31 struct ftrl
32 {
33  vw* all; // features, finalize, l1, l2,
34  float ftrl_alpha;
35  float ftrl_beta;
36  struct update_data data;
39  uint32_t ftrl_size;
40  double total_weight;
41 };
42 
44 {
45  float pred;
46  float score;
47  ftrl& b;
48  uncertainty(ftrl& ftrlb) : b(ftrlb)
49  {
50  pred = 0;
51  score = 0;
52  }
53 };
54 
55 inline float sign(float w)
56 {
57  if (w < 0.)
58  return -1.;
59  else
60  return 1.;
61 }
62 
63 inline void predict_with_confidence(uncertainty& d, const float fx, float& fw)
64 {
65  float* w = &fw;
66  d.pred += w[W_XT] * fx;
67  float sqrtf_ng2 = sqrtf(w[W_G2]);
68  float uncertain = ((d.b.data.ftrl_beta + sqrtf_ng2) / d.b.data.ftrl_alpha + d.b.data.l2_lambda);
69  d.score += (1 / uncertain) * sign(fx);
70 }
71 float sensitivity(ftrl& b, base_learner& /* base */, example& ec)
72 {
73  uncertainty uncetain(b);
74  GD::foreach_feature<uncertainty, predict_with_confidence>(*(b.all), ec, uncetain);
75  return uncetain.score;
76 }
77 
78 template <bool audit>
80 {
83  if (audit)
84  GD::print_audit_features(*(b.all), ec);
85 }
86 
87 template <bool audit>
89  ftrl& b, base_learner&, example& ec, size_t count, size_t step, polyprediction* pred, bool finalize_predictions)
90 {
91  vw& all = *b.all;
92  for (size_t c = 0; c < count; c++) pred[c].scalar = ec.l.simple.initial;
93  if (b.all->weights.sparse)
94  {
96  count, step, pred, all.weights.sparse_weights, (float)all.sd->gravity};
97  GD::foreach_feature<GD::multipredict_info<sparse_parameters>, uint64_t, GD::vec_add_multipredict>(all, ec, mp);
98  }
99  else
100  {
101  GD::multipredict_info<dense_parameters> mp = {count, step, pred, all.weights.dense_weights, (float)all.sd->gravity};
102  GD::foreach_feature<GD::multipredict_info<dense_parameters>, uint64_t, GD::vec_add_multipredict>(all, ec, mp);
103  }
104  if (all.sd->contraction != 1.)
105  for (size_t c = 0; c < count; c++) pred[c].scalar *= (float)all.sd->contraction;
106  if (finalize_predictions)
107  for (size_t c = 0; c < count; c++) pred[c].scalar = GD::finalize_prediction(all.sd, pred[c].scalar);
108  if (audit)
109  {
110  for (size_t c = 0; c < count; c++)
111  {
112  ec.pred.scalar = pred[c].scalar;
113  GD::print_audit_features(all, ec);
114  ec.ft_offset += (uint64_t)step;
115  }
116  ec.ft_offset -= (uint64_t)(step * count);
117  }
118 }
119 
120 void inner_update_proximal(update_data& d, float x, float& wref)
121 {
122  float* w = &wref;
123  float gradient = d.update * x;
124  float ng2 = w[W_G2] + gradient * gradient;
125  float sqrt_ng2 = sqrtf(ng2);
126  float sqrt_wW_G2 = sqrtf(w[W_G2]);
127  float sigma = (sqrt_ng2 - sqrt_wW_G2) / d.ftrl_alpha;
128  w[W_ZT] += gradient - sigma * w[W_XT];
129  w[W_G2] = ng2;
130  sqrt_wW_G2 = sqrt_ng2;
131  float flag = sign(w[W_ZT]);
132  float fabs_zt = w[W_ZT] * flag;
133  if (fabs_zt <= d.l1_lambda)
134  w[W_XT] = 0.;
135  else
136  {
137  float step = 1 / (d.l2_lambda + (d.ftrl_beta + sqrt_wW_G2) / d.ftrl_alpha);
138  w[W_XT] = step * flag * (d.l1_lambda - fabs_zt);
139  }
140 }
141 
143 {
144  float* w = &wref;
145 
146  float fabs_x = fabs(x);
147  if (fabs_x > w[W_MX])
148  w[W_MX] = fabs_x;
149 
150  float squared_theta = w[W_ZT] * w[W_ZT];
151  float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
152  w[W_XT] = std::sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;
153 
154  d.predict += w[W_XT] * x;
155 }
156 
157 void inner_update_pistol_post(update_data& d, float x, float& wref)
158 {
159  float* w = &wref;
160  float gradient = d.update * x;
161 
162  w[W_ZT] += -gradient;
163  w[W_G2] += fabs(gradient);
164 }
165 
166 // Coin betting vectors
167 // W_XT 0 current parameter
168 // W_ZT 1 sum negative gradients
169 // W_G2 2 sum of absolute value of gradients
170 // W_MX 3 maximum absolute value
171 // W_WE 4 Wealth
172 // W_MG 5 Maximum Lipschitz constant
173 void inner_update_cb_state_and_predict(update_data& d, float x, float& wref)
174 {
175  float* w = &wref;
176  float w_mx = w[W_MX];
177  float w_xt = 0.0;
178 
179  float fabs_x = fabs(x);
180  if (fabs_x > w_mx)
181  {
182  w_mx = fabs_x;
183  }
184 
185  // COCOB update without sigmoid
186  if (w[W_MG] * w_mx > 0)
187  w_xt = (d.ftrl_alpha + w[W_WE]) * w[W_ZT] / (w[W_MG] * w_mx * (w[W_MG] * w_mx + w[W_G2]));
188 
189  d.predict += w_xt * x;
190  if (w_mx > 0)
191  d.normalized_squared_norm_x += x * x / (w_mx * w_mx);
192 }
193 
194 void inner_update_cb_post(update_data& d, float x, float& wref)
195 {
196  float* w = &wref;
197  float fabs_x = fabs(x);
198  float gradient = d.update * x;
199 
200  if (fabs_x > w[W_MX])
201  {
202  w[W_MX] = fabs_x;
203  }
204 
205  float fabs_gradient = fabs(d.update);
206  if (fabs_gradient > w[W_MG])
207  w[W_MG] = fabs_gradient > d.ftrl_beta ? fabs_gradient : d.ftrl_beta;
208 
209  // COCOB update without sigmoid.
210  // If a new Lipschitz constant and/or magnitude of x is found, the w is
211  // recalculated and used in the update of the wealth below.
212  if (w[W_MG] * w[W_MX] > 0)
213  w[W_XT] = (d.ftrl_alpha + w[W_WE]) * w[W_ZT] / (w[W_MG] * w[W_MX] * (w[W_MG] * w[W_MX] + w[W_G2]));
214  else
215  w[W_XT] = 0;
216 
217  w[W_ZT] += -gradient;
218  w[W_G2] += fabs(gradient);
219  w[W_WE] += (-gradient * w[W_XT]);
220 }
221 
223 {
224  b.data.predict = 0;
226 
227  GD::foreach_feature<update_data, inner_update_cb_state_and_predict>(*b.all, ec, b.data);
228 
230  b.total_weight += ec.weight;
231 
232  ec.partial_prediction = b.data.predict / ((float)((b.all->normalized_sum_norm_x + 1e-6) / b.total_weight));
233 
235 }
236 
238 {
239  b.data.predict = 0;
240 
241  GD::foreach_feature<update_data, inner_update_pistol_state_and_predict>(*b.all, ec, b.data);
244 }
245 
247 {
248  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
249 
250  GD::foreach_feature<update_data, inner_update_proximal>(*b.all, ec, b.data);
251 }
252 
254 {
255  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
256 
257  GD::foreach_feature<update_data, inner_update_pistol_post>(*b.all, ec, b.data);
258 }
259 
261 {
262  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
263 
264  GD::foreach_feature<update_data, inner_update_cb_post>(*b.all, ec, b.data);
265 }
266 
267 template <bool audit>
269 {
270  assert(ec.in_use);
271 
272  // predict with confidence
273  predict<audit>(a, base, ec);
274 
275  // update state based on the prediction
277 }
278 
280 {
281  assert(ec.in_use);
282 
283  // update state based on the example and predict
284  update_state_and_predict_pistol(a, base, ec);
285 
286  // update state based on the prediction
288 }
289 
291 {
292  assert(ec.in_use);
293 
294  // update state based on the example and predict
295  update_state_and_predict_cb(a, base, ec);
296 
297  // update state based on the prediction
299 }
300 
301 void save_load(ftrl& b, io_buf& model_file, bool read, bool text)
302 {
303  vw* all = b.all;
304  if (read)
305  initialize_regressor(*all);
306 
307  if (!model_file.files.empty())
308  {
309  bool resume = all->save_resume;
310  std::stringstream msg;
311  msg << ":" << resume << "\n";
312  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
313 
314  if (resume)
315  GD::save_load_online_state(*all, model_file, read, text, b.total_weight, nullptr, b.ftrl_size);
316  else
317  GD::save_load_regressor(*all, model_file, read, text);
318  }
319 }
320 
321 void end_pass(ftrl& g)
322 {
323  vw& all = *g.all;
324 
325  if (!all.holdout_set_off)
326  {
329  if ((g.early_stop_thres == g.no_win_counter) &&
331  set_done(all);
332  }
333 }
334 
336 {
337  auto b = scoped_calloc_or_throw<ftrl>();
338  bool ftrl_option = false;
339  bool pistol = false;
340  bool coin = false;
341 
342  option_group_definition new_options("Follow the Regularized Leader");
343  new_options.add(make_option("ftrl", ftrl_option).keep().help("FTRL: Follow the Proximal Regularized Leader"))
344  .add(make_option("coin", coin).keep().help("Coin betting optimizer"))
345  .add(make_option("pistol", pistol).keep().help("PiSTOL: Parameter-free STOchastic Learning"))
346  .add(make_option("ftrl_alpha", b->ftrl_alpha).help("Learning rate for FTRL optimization"))
347  .add(make_option("ftrl_beta", b->ftrl_beta).help("Learning rate for FTRL optimization"));
348  options.add_and_parse(new_options);
349 
350  if (!ftrl_option && !pistol && !coin)
351  {
352  return nullptr;
353  }
354 
355  // Defaults that are specific to the mode that was chosen.
356  if (ftrl_option)
357  {
358  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 0.005f;
359  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.1f;
360  }
361  else if (pistol)
362  {
363  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 1.0f;
364  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.5f;
365  }
366  else if (coin)
367  {
368  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 4.0f;
369  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 1.0f;
370  }
371 
372  b->all = &all;
373  b->no_win_counter = 0;
374  b->all->normalized_sum_norm_x = 0;
375  b->total_weight = 0;
376 
377  void (*learn_ptr)(ftrl&, single_learner&, example&) = nullptr;
378 
379  std::string algorithm_name;
380  if (ftrl_option)
381  {
382  algorithm_name = "Proximal-FTRL";
383  if (all.audit)
384  learn_ptr = learn_proximal<true>;
385  else
386  learn_ptr = learn_proximal<false>;
387  all.weights.stride_shift(2); // NOTE: for more parameter storage
388  b->ftrl_size = 3;
389  }
390  else if (pistol)
391  {
392  algorithm_name = "PiSTOL";
393  learn_ptr = learn_pistol;
394  all.weights.stride_shift(2); // NOTE: for more parameter storage
395  b->ftrl_size = 4;
396  }
397  else if (coin)
398  {
399  algorithm_name = "Coin Betting";
400  learn_ptr = learn_cb;
401  all.weights.stride_shift(3); // NOTE: for more parameter storage
402  b->ftrl_size = 6;
403  }
404 
405  b->data.ftrl_alpha = b->ftrl_alpha;
406  b->data.ftrl_beta = b->ftrl_beta;
407  b->data.l1_lambda = b->all->l1_lambda;
408  b->data.l2_lambda = b->all->l2_lambda;
409 
410  if (!all.quiet)
411  {
412  std::cerr << "Enabling FTRL based optimization" << std::endl;
413  std::cerr << "Algorithm used: " << algorithm_name << std::endl;
414  std::cerr << "ftrl_alpha = " << b->ftrl_alpha << std::endl;
415  std::cerr << "ftrl_beta = " << b->ftrl_beta << std::endl;
416  }
417 
418  if (!all.holdout_set_off)
419  {
420  all.sd->holdout_best_loss = FLT_MAX;
421  b->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
422  }
423 
425  if (all.audit || all.hash_inv)
426  l = &init_learner(b, learn_ptr, predict<true>, UINT64_ONE << all.weights.stride_shift());
427  else
428  l = &init_learner(b, learn_ptr, predict<false>, UINT64_ONE << all.weights.stride_shift());
429  l->set_sensitivity(sensitivity);
430  if (all.audit || all.hash_inv)
431  l->set_multipredict(multipredict<true>);
432  else
433  l->set_multipredict(multipredict<false>);
434  l->set_save_load(save_load);
435  l->set_end_pass(end_pass);
436  return make_base(*l);
437 }
void update_after_prediction_pistol(ftrl &b, example &ec)
Definition: ftrl.cc:253
void update_after_prediction_cb(ftrl &b, example &ec)
Definition: ftrl.cc:260
size_t early_stop_thres
Definition: ftrl.cc:38
float update
Definition: ftrl.cc:22
void update_state_and_predict_pistol(ftrl &b, single_learner &, example &ec)
Definition: ftrl.cc:237
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
uncertainty(ftrl &ftrlb)
Definition: ftrl.cc:48
#define correctedExp
Definition: correctedMath.h:27
void set_done(vw &all)
Definition: parser.cc:578
parameters weights
Definition: global_data.h:537
loss_function * loss
Definition: global_data.h:523
float l1_lambda
Definition: ftrl.cc:25
void print_audit_features(vw &all, example &ec)
Definition: gd.cc:331
void initialize_regressor(vw &all, T &weights)
void predict_with_confidence(uncertainty &d, const float fx, float &fw)
Definition: ftrl.cc:63
float scalar
Definition: example.h:45
Definition: ftrl.cc:31
bool hash_inv
Definition: global_data.h:541
#define W_XT
Definition: ftrl.cc:13
void inner_update_cb_post(update_data &d, float x, float &wref)
Definition: ftrl.cc:194
#define W_MX
Definition: ftrl.cc:16
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float partial_prediction
Definition: example.h:68
float sensitivity(ftrl &b, base_learner &, example &ec)
Definition: ftrl.cc:71
bool quiet
Definition: global_data.h:487
void vec_add_multipredict(multipredict_info< T > &mp, const float fx, uint64_t fi)
Definition: gd.h:40
double contraction
Definition: global_data.h:149
void predict(ftrl &b, single_learner &, example &ec)
Definition: ftrl.cc:79
void finalize_regressor(vw &all, std::string reg_name)
virtual void add_and_parse(const option_group_definition &group)=0
#define W_G2
Definition: ftrl.cc:15
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
#define W_WE
Definition: ftrl.cc:17
bool holdout_set_off
Definition: global_data.h:499
size_t check_holdout_every_n_passes
Definition: global_data.h:503
float sign(float w)
Definition: ftrl.cc:55
float ftrl_alpha
Definition: ftrl.cc:34
double total_weight
Definition: ftrl.cc:40
bool summarize_holdout_set(vw &all, size_t &no_win_counter)
void inner_update_pistol_state_and_predict(update_data &d, float x, float &wref)
Definition: ftrl.cc:142
virtual float first_derivative(shared_data *, float prediction, float label)=0
float pred
Definition: ftrl.cc:45
float inline_predict(vw &all, example &ec)
Definition: gd.h:98
float predict
Definition: ftrl.cc:27
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
void inner_update_cb_state_and_predict(update_data &d, float x, float &wref)
Definition: ftrl.cc:173
float ftrl_beta
Definition: ftrl.cc:24
float ftrl_beta
Definition: ftrl.cc:35
float normalized_squared_norm_x
Definition: ftrl.cc:28
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
void end_pass(example &ec, vw &all)
Definition: learner.cc:44
v_array< int > files
Definition: io_buf.h:64
virtual bool was_supplied(const std::string &key)=0
base_learner * ftrl_setup(options_i &options, vw &all)
Definition: ftrl.cc:335
dense_parameters dense_weights
#define W_MG
Definition: ftrl.cc:18
uint64_t current_pass
Definition: global_data.h:396
float initial
Definition: simple_label.h:16
float l2_lambda
Definition: ftrl.cc:26
Definition: io_buf.h:54
uint32_t ftrl_size
Definition: ftrl.cc:39
vw * all
Definition: ftrl.cc:33
size_t no_win_counter
Definition: ftrl.cc:37
void update_after_prediction_proximal(ftrl &b, example &ec)
Definition: ftrl.cc:246
float score
Definition: ftrl.cc:46
option_group_definition & add(T &&op)
Definition: options.h:90
void inner_update_pistol_post(update_data &d, float x, float &wref)
Definition: ftrl.cc:157
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
bool in_use
Definition: example.h:79
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
constexpr uint64_t UINT64_ONE
void learn_pistol(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:279
sparse_parameters sparse_weights
double gravity
Definition: global_data.h:148
bool empty() const
Definition: v_array.h:59
bool save_resume
Definition: global_data.h:415
uint32_t stride_shift()
void update_state_and_predict_cb(ftrl &b, single_learner &, example &ec)
Definition: ftrl.cc:222
#define W_ZT
Definition: ftrl.cc:14
bool audit
Definition: global_data.h:486
struct update_data data
Definition: ftrl.cc:36
void multipredict(ftrl &b, base_learner &, example &ec, size_t count, size_t step, polyprediction *pred, bool finalize_predictions)
Definition: ftrl.cc:88
void inner_update_proximal(update_data &d, float x, float &wref)
Definition: ftrl.cc:120
polyprediction pred
Definition: example.h:60
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
std::string final_regressor_name
Definition: global_data.h:535
void learn_cb(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:290
void save_load(ftrl &b, io_buf &model_file, bool read, bool text)
Definition: ftrl.cc:301
float ftrl_alpha
Definition: ftrl.cc:23
float weight
Definition: example.h:62
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
constexpr uint64_t c
Definition: rand48.cc:12
void learn_proximal(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:268
ftrl & b
Definition: ftrl.cc:47
float f
Definition: cache.cc:40
double normalized_sum_norm_x
Definition: global_data.h:420