Vowpal Wabbit
Classes | Macros | Functions
ftrl.cc File Reference
#include <string>
#include "correctedMath.h"
#include "gd.h"

Go to the source code of this file.

Classes

struct  update_data
 
struct  ftrl
 
struct  uncertainty
 

Macros

#define W_XT   0
 
#define W_ZT   1
 
#define W_G2   2
 
#define W_MX   3
 
#define W_WE   4
 
#define W_MG   5
 

Functions

float sign (float w)
 
void predict_with_confidence (uncertainty &d, const float fx, float &fw)
 
float sensitivity (ftrl &b, base_learner &, example &ec)
 
template<bool audit>
void predict (ftrl &b, single_learner &, example &ec)
 
template<bool audit>
void multipredict (ftrl &b, base_learner &, example &ec, size_t count, size_t step, polyprediction *pred, bool finalize_predictions)
 
void inner_update_proximal (update_data &d, float x, float &wref)
 
void inner_update_pistol_state_and_predict (update_data &d, float x, float &wref)
 
void inner_update_pistol_post (update_data &d, float x, float &wref)
 
void inner_update_cb_state_and_predict (update_data &d, float x, float &wref)
 
void inner_update_cb_post (update_data &d, float x, float &wref)
 
void update_state_and_predict_cb (ftrl &b, single_learner &, example &ec)
 
void update_state_and_predict_pistol (ftrl &b, single_learner &, example &ec)
 
void update_after_prediction_proximal (ftrl &b, example &ec)
 
void update_after_prediction_pistol (ftrl &b, example &ec)
 
void update_after_prediction_cb (ftrl &b, example &ec)
 
template<bool audit>
void learn_proximal (ftrl &a, single_learner &base, example &ec)
 
void learn_pistol (ftrl &a, single_learner &base, example &ec)
 
void learn_cb (ftrl &a, single_learner &base, example &ec)
 
void save_load (ftrl &b, io_buf &model_file, bool read, bool text)
 
void end_pass (ftrl &g)
 
base_learnerftrl_setup (options_i &options, vw &all)
 

Macro Definition Documentation

◆ W_G2

#define W_G2   2

◆ W_MG

#define W_MG   5

Definition at line 18 of file ftrl.cc.

Referenced by inner_update_cb_post(), and inner_update_cb_state_and_predict().

◆ W_MX

#define W_MX   3

◆ W_WE

#define W_WE   4

Definition at line 17 of file ftrl.cc.

Referenced by inner_update_cb_post(), and inner_update_cb_state_and_predict().

◆ W_XT

#define W_XT   0

◆ W_ZT

#define W_ZT   1

Function Documentation

◆ end_pass()

void end_pass ( ftrl g)

Definition at line 321 of file ftrl.cc.

References ftrl::all, vw::check_holdout_every_n_passes, vw::current_pass, ftrl::early_stop_thres, vw::final_regressor_name, finalize_regressor(), vw::holdout_set_off, ftrl::no_win_counter, set_done(), and summarize_holdout_set().

322 {
323  vw& all = *g.all;
324 
325  if (!all.holdout_set_off)
326  {
329  if ((g.early_stop_thres == g.no_win_counter) &&
331  set_done(all);
332  }
333 }
size_t early_stop_thres
Definition: ftrl.cc:38
void set_done(vw &all)
Definition: parser.cc:578
void finalize_regressor(vw &all, std::string reg_name)
bool holdout_set_off
Definition: global_data.h:499
size_t check_holdout_every_n_passes
Definition: global_data.h:503
bool summarize_holdout_set(vw &all, size_t &no_win_counter)
uint64_t current_pass
Definition: global_data.h:396
vw * all
Definition: ftrl.cc:33
size_t no_win_counter
Definition: ftrl.cc:37
std::string final_regressor_name
Definition: global_data.h:535

◆ ftrl_setup()

base_learner* ftrl_setup ( options_i options,
vw all 
)

Definition at line 335 of file ftrl.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), vw::audit, LEARNER::end_pass(), VW::config::options_i::get_typed_option(), vw::hash_inv, shared_data::holdout_best_loss, vw::holdout_set_off, LEARNER::init_learner(), learn_cb(), learn_pistol(), LEARNER::make_base(), VW::config::make_option(), vw::normalized_sum_norm_x, vw::quiet, save_load(), vw::sd, sensitivity(), parameters::stride_shift(), UINT64_ONE, VW::config::options_i::was_supplied(), and vw::weights.

Referenced by parse_reductions().

336 {
337  auto b = scoped_calloc_or_throw<ftrl>();
338  bool ftrl_option = false;
339  bool pistol = false;
340  bool coin = false;
341 
342  option_group_definition new_options("Follow the Regularized Leader");
343  new_options.add(make_option("ftrl", ftrl_option).keep().help("FTRL: Follow the Proximal Regularized Leader"))
344  .add(make_option("coin", coin).keep().help("Coin betting optimizer"))
345  .add(make_option("pistol", pistol).keep().help("PiSTOL: Parameter-free STOchastic Learning"))
346  .add(make_option("ftrl_alpha", b->ftrl_alpha).help("Learning rate for FTRL optimization"))
347  .add(make_option("ftrl_beta", b->ftrl_beta).help("Learning rate for FTRL optimization"));
348  options.add_and_parse(new_options);
349 
350  if (!ftrl_option && !pistol && !coin)
351  {
352  return nullptr;
353  }
354 
355  // Defaults that are specific to the mode that was chosen.
356  if (ftrl_option)
357  {
358  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 0.005f;
359  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.1f;
360  }
361  else if (pistol)
362  {
363  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 1.0f;
364  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.5f;
365  }
366  else if (coin)
367  {
368  b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 4.0f;
369  b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 1.0f;
370  }
371 
372  b->all = &all;
373  b->no_win_counter = 0;
374  b->all->normalized_sum_norm_x = 0;
375  b->total_weight = 0;
376 
377  void (*learn_ptr)(ftrl&, single_learner&, example&) = nullptr;
378 
379  std::string algorithm_name;
380  if (ftrl_option)
381  {
382  algorithm_name = "Proximal-FTRL";
383  if (all.audit)
384  learn_ptr = learn_proximal<true>;
385  else
386  learn_ptr = learn_proximal<false>;
387  all.weights.stride_shift(2); // NOTE: for more parameter storage
388  b->ftrl_size = 3;
389  }
390  else if (pistol)
391  {
392  algorithm_name = "PiSTOL";
393  learn_ptr = learn_pistol;
394  all.weights.stride_shift(2); // NOTE: for more parameter storage
395  b->ftrl_size = 4;
396  }
397  else if (coin)
398  {
399  algorithm_name = "Coin Betting";
400  learn_ptr = learn_cb;
401  all.weights.stride_shift(3); // NOTE: for more parameter storage
402  b->ftrl_size = 6;
403  }
404 
405  b->data.ftrl_alpha = b->ftrl_alpha;
406  b->data.ftrl_beta = b->ftrl_beta;
407  b->data.l1_lambda = b->all->l1_lambda;
408  b->data.l2_lambda = b->all->l2_lambda;
409 
410  if (!all.quiet)
411  {
412  std::cerr << "Enabling FTRL based optimization" << std::endl;
413  std::cerr << "Algorithm used: " << algorithm_name << std::endl;
414  std::cerr << "ftrl_alpha = " << b->ftrl_alpha << std::endl;
415  std::cerr << "ftrl_beta = " << b->ftrl_beta << std::endl;
416  }
417 
418  if (!all.holdout_set_off)
419  {
420  all.sd->holdout_best_loss = FLT_MAX;
421  b->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
422  }
423 
425  if (all.audit || all.hash_inv)
426  l = &init_learner(b, learn_ptr, predict<true>, UINT64_ONE << all.weights.stride_shift());
427  else
428  l = &init_learner(b, learn_ptr, predict<false>, UINT64_ONE << all.weights.stride_shift());
429  l->set_sensitivity(sensitivity);
430  if (all.audit || all.hash_inv)
431  l->set_multipredict(multipredict<true>);
432  else
433  l->set_multipredict(multipredict<false>);
434  l->set_save_load(save_load);
435  l->set_end_pass(end_pass);
436  return make_base(*l);
437 }
parameters weights
Definition: global_data.h:537
Definition: ftrl.cc:31
bool hash_inv
Definition: global_data.h:541
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float sensitivity(ftrl &b, base_learner &, example &ec)
Definition: ftrl.cc:71
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
bool holdout_set_off
Definition: global_data.h:499
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
constexpr uint64_t UINT64_ONE
void end_pass(ftrl &g)
Definition: ftrl.cc:321
void learn_pistol(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:279
uint32_t stride_shift()
bool audit
Definition: global_data.h:486
void learn_cb(ftrl &a, single_learner &base, example &ec)
Definition: ftrl.cc:290
void save_load(ftrl &b, io_buf &model_file, bool read, bool text)
Definition: ftrl.cc:301
double normalized_sum_norm_x
Definition: global_data.h:420

◆ inner_update_cb_post()

void inner_update_cb_post ( update_data d,
float  x,
float &  wref 
)

Definition at line 194 of file ftrl.cc.

References update_data::ftrl_alpha, update_data::ftrl_beta, update_data::update, W_G2, W_MG, W_MX, W_WE, W_XT, and W_ZT.

195 {
196  float* w = &wref;
197  float fabs_x = fabs(x);
198  float gradient = d.update * x;
199 
200  if (fabs_x > w[W_MX])
201  {
202  w[W_MX] = fabs_x;
203  }
204 
205  float fabs_gradient = fabs(d.update);
206  if (fabs_gradient > w[W_MG])
207  w[W_MG] = fabs_gradient > d.ftrl_beta ? fabs_gradient : d.ftrl_beta;
208 
209  // COCOB update without sigmoid.
210  // If a new Lipschitz constant and/or magnitude of x is found, the w is
211  // recalculated and used in the update of the wealth below.
212  if (w[W_MG] * w[W_MX] > 0)
213  w[W_XT] = (d.ftrl_alpha + w[W_WE]) * w[W_ZT] / (w[W_MG] * w[W_MX] * (w[W_MG] * w[W_MX] + w[W_G2]));
214  else
215  w[W_XT] = 0;
216 
217  w[W_ZT] += -gradient;
218  w[W_G2] += fabs(gradient);
219  w[W_WE] += (-gradient * w[W_XT]);
220 }
float update
Definition: ftrl.cc:22
#define W_XT
Definition: ftrl.cc:13
#define W_MX
Definition: ftrl.cc:16
#define W_G2
Definition: ftrl.cc:15
#define W_WE
Definition: ftrl.cc:17
float ftrl_beta
Definition: ftrl.cc:24
#define W_MG
Definition: ftrl.cc:18
#define W_ZT
Definition: ftrl.cc:14
float ftrl_alpha
Definition: ftrl.cc:23

◆ inner_update_cb_state_and_predict()

void inner_update_cb_state_and_predict ( update_data d,
float  x,
float &  wref 
)

Definition at line 173 of file ftrl.cc.

References update_data::ftrl_alpha, update_data::normalized_squared_norm_x, update_data::predict, W_G2, W_MG, W_MX, W_WE, and W_ZT.

174 {
175  float* w = &wref;
176  float w_mx = w[W_MX];
177  float w_xt = 0.0;
178 
179  float fabs_x = fabs(x);
180  if (fabs_x > w_mx)
181  {
182  w_mx = fabs_x;
183  }
184 
185  // COCOB update without sigmoid
186  if (w[W_MG] * w_mx > 0)
187  w_xt = (d.ftrl_alpha + w[W_WE]) * w[W_ZT] / (w[W_MG] * w_mx * (w[W_MG] * w_mx + w[W_G2]));
188 
189  d.predict += w_xt * x;
190  if (w_mx > 0)
191  d.normalized_squared_norm_x += x * x / (w_mx * w_mx);
192 }
#define W_MX
Definition: ftrl.cc:16
#define W_G2
Definition: ftrl.cc:15
#define W_WE
Definition: ftrl.cc:17
float predict
Definition: ftrl.cc:27
float normalized_squared_norm_x
Definition: ftrl.cc:28
#define W_MG
Definition: ftrl.cc:18
#define W_ZT
Definition: ftrl.cc:14
float ftrl_alpha
Definition: ftrl.cc:23

◆ inner_update_pistol_post()

void inner_update_pistol_post ( update_data d,
float  x,
float &  wref 
)

Definition at line 157 of file ftrl.cc.

References update_data::update, W_G2, and W_ZT.

158 {
159  float* w = &wref;
160  float gradient = d.update * x;
161 
162  w[W_ZT] += -gradient;
163  w[W_G2] += fabs(gradient);
164 }
float update
Definition: ftrl.cc:22
#define W_G2
Definition: ftrl.cc:15
#define W_ZT
Definition: ftrl.cc:14

◆ inner_update_pistol_state_and_predict()

void inner_update_pistol_state_and_predict ( update_data d,
float  x,
float &  wref 
)

Definition at line 142 of file ftrl.cc.

References correctedExp, f, update_data::ftrl_alpha, update_data::ftrl_beta, update_data::predict, W_G2, W_MX, W_XT, and W_ZT.

143 {
144  float* w = &wref;
145 
146  float fabs_x = fabs(x);
147  if (fabs_x > w[W_MX])
148  w[W_MX] = fabs_x;
149 
150  float squared_theta = w[W_ZT] * w[W_ZT];
151  float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
152  w[W_XT] = std::sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;
153 
154  d.predict += w[W_XT] * x;
155 }
#define correctedExp
Definition: correctedMath.h:27
#define W_XT
Definition: ftrl.cc:13
#define W_MX
Definition: ftrl.cc:16
#define W_G2
Definition: ftrl.cc:15
float predict
Definition: ftrl.cc:27
float ftrl_beta
Definition: ftrl.cc:24
#define W_ZT
Definition: ftrl.cc:14
float ftrl_alpha
Definition: ftrl.cc:23
float f
Definition: cache.cc:40

◆ inner_update_proximal()

void inner_update_proximal ( update_data d,
float  x,
float &  wref 
)

Definition at line 120 of file ftrl.cc.

References update_data::ftrl_alpha, update_data::ftrl_beta, update_data::l1_lambda, update_data::l2_lambda, sign(), update_data::update, W_G2, W_XT, and W_ZT.

121 {
122  float* w = &wref;
123  float gradient = d.update * x;
124  float ng2 = w[W_G2] + gradient * gradient;
125  float sqrt_ng2 = sqrtf(ng2);
126  float sqrt_wW_G2 = sqrtf(w[W_G2]);
127  float sigma = (sqrt_ng2 - sqrt_wW_G2) / d.ftrl_alpha;
128  w[W_ZT] += gradient - sigma * w[W_XT];
129  w[W_G2] = ng2;
130  sqrt_wW_G2 = sqrt_ng2;
131  float flag = sign(w[W_ZT]);
132  float fabs_zt = w[W_ZT] * flag;
133  if (fabs_zt <= d.l1_lambda)
134  w[W_XT] = 0.;
135  else
136  {
137  float step = 1 / (d.l2_lambda + (d.ftrl_beta + sqrt_wW_G2) / d.ftrl_alpha);
138  w[W_XT] = step * flag * (d.l1_lambda - fabs_zt);
139  }
140 }
float update
Definition: ftrl.cc:22
float l1_lambda
Definition: ftrl.cc:25
#define W_XT
Definition: ftrl.cc:13
#define W_G2
Definition: ftrl.cc:15
float sign(float w)
Definition: ftrl.cc:55
float ftrl_beta
Definition: ftrl.cc:24
float l2_lambda
Definition: ftrl.cc:26
#define W_ZT
Definition: ftrl.cc:14
float ftrl_alpha
Definition: ftrl.cc:23

◆ learn_cb()

void learn_cb ( ftrl a,
single_learner base,
example ec 
)

Definition at line 290 of file ftrl.cc.

References example::in_use, update_after_prediction_cb(), and update_state_and_predict_cb().

Referenced by ftrl_setup().

291 {
292  assert(ec.in_use);
293 
294  // update state based on the example and predict
295  update_state_and_predict_cb(a, base, ec);
296 
297  // update state based on the prediction
299 }
void update_after_prediction_cb(ftrl &b, example &ec)
Definition: ftrl.cc:260
bool in_use
Definition: example.h:79
void update_state_and_predict_cb(ftrl &b, single_learner &, example &ec)
Definition: ftrl.cc:222

◆ learn_pistol()

void learn_pistol ( ftrl a,
single_learner base,
example ec 
)

Definition at line 279 of file ftrl.cc.

References example::in_use, update_after_prediction_pistol(), and update_state_and_predict_pistol().

Referenced by ftrl_setup().

280 {
281  assert(ec.in_use);
282 
283  // update state based on the example and predict
284  update_state_and_predict_pistol(a, base, ec);
285 
286  // update state based on the prediction
288 }
void update_after_prediction_pistol(ftrl &b, example &ec)
Definition: ftrl.cc:253
void update_state_and_predict_pistol(ftrl &b, single_learner &, example &ec)
Definition: ftrl.cc:237
bool in_use
Definition: example.h:79

◆ learn_proximal()

template<bool audit>
void learn_proximal ( ftrl a,
single_learner base,
example ec 
)

Definition at line 268 of file ftrl.cc.

References a, example::in_use, and update_after_prediction_proximal().

269 {
270  assert(ec.in_use);
271 
272  // predict with confidence
273  predict<audit>(a, base, ec);
274 
275  // update state based on the prediction
277 }
void update_after_prediction_proximal(ftrl &b, example &ec)
Definition: ftrl.cc:246
constexpr uint64_t a
Definition: rand48.cc:11
bool in_use
Definition: example.h:79

◆ multipredict()

template<bool audit>
void multipredict ( ftrl b,
base_learner ,
example ec,
size_t  count,
size_t  step,
polyprediction pred,
bool  finalize_predictions 
)

Definition at line 88 of file ftrl.cc.

References ftrl::all, c, shared_data::contraction, parameters::dense_weights, GD::finalize_prediction(), example_predict::ft_offset, shared_data::gravity, label_data::initial, example::l, example::pred, GD::print_audit_features(), prediction_type::scalar, polyprediction::scalar, vw::sd, polylabel::simple, parameters::sparse, parameters::sparse_weights, GD::vec_add_multipredict(), and vw::weights.

90 {
91  vw& all = *b.all;
92  for (size_t c = 0; c < count; c++) pred[c].scalar = ec.l.simple.initial;
93  if (b.all->weights.sparse)
94  {
96  count, step, pred, all.weights.sparse_weights, (float)all.sd->gravity};
97  GD::foreach_feature<GD::multipredict_info<sparse_parameters>, uint64_t, GD::vec_add_multipredict>(all, ec, mp);
98  }
99  else
100  {
101  GD::multipredict_info<dense_parameters> mp = {count, step, pred, all.weights.dense_weights, (float)all.sd->gravity};
102  GD::foreach_feature<GD::multipredict_info<dense_parameters>, uint64_t, GD::vec_add_multipredict>(all, ec, mp);
103  }
104  if (all.sd->contraction != 1.)
105  for (size_t c = 0; c < count; c++) pred[c].scalar *= (float)all.sd->contraction;
106  if (finalize_predictions)
107  for (size_t c = 0; c < count; c++) pred[c].scalar = GD::finalize_prediction(all.sd, pred[c].scalar);
108  if (audit)
109  {
110  for (size_t c = 0; c < count; c++)
111  {
112  ec.pred.scalar = pred[c].scalar;
113  GD::print_audit_features(all, ec);
114  ec.ft_offset += (uint64_t)step;
115  }
116  ec.ft_offset -= (uint64_t)(step * count);
117  }
118 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
parameters weights
Definition: global_data.h:537
void print_audit_features(vw &all, example &ec)
Definition: gd.cc:331
float scalar
Definition: example.h:45
void vec_add_multipredict(multipredict_info< T > &mp, const float fx, uint64_t fi)
Definition: gd.h:40
double contraction
Definition: global_data.h:149
label_data simple
Definition: example.h:28
shared_data * sd
Definition: global_data.h:375
dense_parameters dense_weights
float initial
Definition: simple_label.h:16
vw * all
Definition: ftrl.cc:33
polylabel l
Definition: example.h:57
sparse_parameters sparse_weights
double gravity
Definition: global_data.h:148
polyprediction pred
Definition: example.h:60
constexpr uint64_t c
Definition: rand48.cc:12

◆ predict()

template<bool audit>
void predict ( ftrl b,
single_learner ,
example ec 
)

Definition at line 79 of file ftrl.cc.

References ftrl::all, GD::finalize_prediction(), GD::inline_predict(), example::partial_prediction, example::pred, GD::print_audit_features(), polyprediction::scalar, and vw::sd.

80 {
83  if (audit)
84  GD::print_audit_features(*(b.all), ec);
85 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
void print_audit_features(vw &all, example &ec)
Definition: gd.cc:331
float scalar
Definition: example.h:45
float partial_prediction
Definition: example.h:68
float inline_predict(vw &all, example &ec)
Definition: gd.h:98
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
polyprediction pred
Definition: example.h:60

◆ predict_with_confidence()

void predict_with_confidence ( uncertainty d,
const float  fx,
float &  fw 
)
inline

Definition at line 63 of file ftrl.cc.

References uncertainty::b, ftrl::data, update_data::ftrl_alpha, update_data::ftrl_beta, update_data::l2_lambda, uncertainty::pred, uncertainty::score, sign(), W_G2, and W_XT.

64 {
65  float* w = &fw;
66  d.pred += w[W_XT] * fx;
67  float sqrtf_ng2 = sqrtf(w[W_G2]);
68  float uncertain = ((d.b.data.ftrl_beta + sqrtf_ng2) / d.b.data.ftrl_alpha + d.b.data.l2_lambda);
69  d.score += (1 / uncertain) * sign(fx);
70 }
#define W_XT
Definition: ftrl.cc:13
#define W_G2
Definition: ftrl.cc:15
float sign(float w)
Definition: ftrl.cc:55
float pred
Definition: ftrl.cc:45
float ftrl_beta
Definition: ftrl.cc:24
float l2_lambda
Definition: ftrl.cc:26
float score
Definition: ftrl.cc:46
struct update_data data
Definition: ftrl.cc:36
float ftrl_alpha
Definition: ftrl.cc:23
ftrl & b
Definition: ftrl.cc:47

◆ save_load()

void save_load ( ftrl b,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 301 of file ftrl.cc.

References ftrl::all, bin_text_read_write_fixed(), v_array< T >::empty(), io_buf::files, ftrl::ftrl_size, initialize_regressor(), GD::save_load_online_state(), GD::save_load_regressor(), vw::save_resume, and ftrl::total_weight.

Referenced by ftrl_setup().

302 {
303  vw* all = b.all;
304  if (read)
305  initialize_regressor(*all);
306 
307  if (!model_file.files.empty())
308  {
309  bool resume = all->save_resume;
310  std::stringstream msg;
311  msg << ":" << resume << "\n";
312  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
313 
314  if (resume)
315  GD::save_load_online_state(*all, model_file, read, text, b.total_weight, nullptr, b.ftrl_size);
316  else
317  GD::save_load_regressor(*all, model_file, read, text);
318  }
319 }
void initialize_regressor(vw &all, T &weights)
double total_weight
Definition: ftrl.cc:40
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
v_array< int > files
Definition: io_buf.h:64
uint32_t ftrl_size
Definition: ftrl.cc:39
vw * all
Definition: ftrl.cc:33
bool empty() const
Definition: v_array.h:59
bool save_resume
Definition: global_data.h:415
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326

◆ sensitivity()

float sensitivity ( ftrl b,
base_learner ,
example ec 
)

Definition at line 71 of file ftrl.cc.

References ftrl::all, and uncertainty::score.

Referenced by ftrl_setup().

72 {
73  uncertainty uncetain(b);
74  GD::foreach_feature<uncertainty, predict_with_confidence>(*(b.all), ec, uncetain);
75  return uncetain.score;
76 }
vw * all
Definition: ftrl.cc:33

◆ sign()

float sign ( float  w)
inline

Definition at line 55 of file ftrl.cc.

Referenced by inner_update_proximal(), and predict_with_confidence().

56 {
57  if (w < 0.)
58  return -1.;
59  else
60  return 1.;
61 }

◆ update_after_prediction_cb()

void update_after_prediction_cb ( ftrl b,
example ec 
)

Definition at line 260 of file ftrl.cc.

References ftrl::all, ftrl::data, loss_function::first_derivative(), example::l, label_data::label, vw::loss, example::pred, polyprediction::scalar, vw::sd, polylabel::simple, update_data::update, and example::weight.

Referenced by learn_cb().

261 {
262  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
263 
264  GD::foreach_feature<update_data, inner_update_cb_post>(*b.all, ec, b.data);
265 }
float update
Definition: ftrl.cc:22
loss_function * loss
Definition: global_data.h:523
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float first_derivative(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
polylabel l
Definition: example.h:57
struct update_data data
Definition: ftrl.cc:36
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62

◆ update_after_prediction_pistol()

void update_after_prediction_pistol ( ftrl b,
example ec 
)

Definition at line 253 of file ftrl.cc.

References ftrl::all, ftrl::data, loss_function::first_derivative(), example::l, label_data::label, vw::loss, example::pred, polyprediction::scalar, vw::sd, polylabel::simple, update_data::update, and example::weight.

Referenced by learn_pistol().

254 {
255  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
256 
257  GD::foreach_feature<update_data, inner_update_pistol_post>(*b.all, ec, b.data);
258 }
float update
Definition: ftrl.cc:22
loss_function * loss
Definition: global_data.h:523
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float first_derivative(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
polylabel l
Definition: example.h:57
struct update_data data
Definition: ftrl.cc:36
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62

◆ update_after_prediction_proximal()

void update_after_prediction_proximal ( ftrl b,
example ec 
)

Definition at line 246 of file ftrl.cc.

References ftrl::all, ftrl::data, loss_function::first_derivative(), example::l, label_data::label, vw::loss, example::pred, polyprediction::scalar, vw::sd, polylabel::simple, update_data::update, and example::weight.

Referenced by learn_proximal().

247 {
248  b.data.update = b.all->loss->first_derivative(b.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
249 
250  GD::foreach_feature<update_data, inner_update_proximal>(*b.all, ec, b.data);
251 }
float update
Definition: ftrl.cc:22
loss_function * loss
Definition: global_data.h:523
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float first_derivative(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
polylabel l
Definition: example.h:57
struct update_data data
Definition: ftrl.cc:36
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62

◆ update_state_and_predict_cb()

void update_state_and_predict_cb ( ftrl b,
single_learner ,
example ec 
)

Definition at line 222 of file ftrl.cc.

References ftrl::all, ftrl::data, GD::finalize_prediction(), update_data::normalized_squared_norm_x, vw::normalized_sum_norm_x, example::partial_prediction, example::pred, update_data::predict, polyprediction::scalar, vw::sd, ftrl::total_weight, and example::weight.

Referenced by learn_cb().

223 {
224  b.data.predict = 0;
226 
227  GD::foreach_feature<update_data, inner_update_cb_state_and_predict>(*b.all, ec, b.data);
228 
230  b.total_weight += ec.weight;
231 
232  ec.partial_prediction = b.data.predict / ((float)((b.all->normalized_sum_norm_x + 1e-6) / b.total_weight));
233 
235 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
float scalar
Definition: example.h:45
float partial_prediction
Definition: example.h:68
double total_weight
Definition: ftrl.cc:40
float predict
Definition: ftrl.cc:27
float normalized_squared_norm_x
Definition: ftrl.cc:28
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
struct update_data data
Definition: ftrl.cc:36
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
double normalized_sum_norm_x
Definition: global_data.h:420

◆ update_state_and_predict_pistol()

void update_state_and_predict_pistol ( ftrl b,
single_learner ,
example ec 
)

Definition at line 237 of file ftrl.cc.

References ftrl::all, ftrl::data, GD::finalize_prediction(), example::partial_prediction, example::pred, update_data::predict, polyprediction::scalar, and vw::sd.

Referenced by learn_pistol().

238 {
239  b.data.predict = 0;
240 
241  GD::foreach_feature<update_data, inner_update_pistol_state_and_predict>(*b.all, ec, b.data);
244 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
float scalar
Definition: example.h:45
float partial_prediction
Definition: example.h:68
float predict
Definition: ftrl.cc:27
shared_data * sd
Definition: global_data.h:375
vw * all
Definition: ftrl.cc:33
struct update_data data
Definition: ftrl.cc:36
polyprediction pred
Definition: example.h:60