Vowpal Wabbit
OjaNewton.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 #include <string>
7 #include "gd.h"
8 #include "vw.h"
9 #include "rand48.h"
10 #include "reductions.h"
11 #include <math.h>
12 #include <memory>
13 
14 using namespace LEARNER;
15 using namespace VW::config;
16 
17 #define NORM2 (m + 1)
18 
19 struct update_data
20 {
21  struct OjaNewton* ON;
22  float g;
23  float sketch_cnt;
24  float norm2_x;
25  float* Zx;
26  float* AZx;
27  float* delta;
28  float bdelta;
29  float prediction;
30 };
31 
32 struct OjaNewton
33 {
34  vw* all;
35  std::shared_ptr<rand_state> _random_state;
36  int m;
38  float alpha;
39  int cnt;
40  int t;
41 
42  float* ev;
43  float* b;
44  float* D;
45  float** A;
46  float** K;
47 
48  float* zv;
49  float* vv;
50  float* tmp;
51 
53  float* weight_buffer;
54  struct update_data data;
55 
57  bool normalize;
59 
60  void initialize_Z(parameters& weights)
61  {
62  uint32_t length = 1 << all->num_bits;
63  if (normalize) // initialize normalization part
64  {
65  for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[NORM2] = 0.1f;
66  }
67  if (!random_init)
68  {
69  // simple initialization
70  for (int i = 1; i <= m; i++) (&(weights.strided_index(i)))[i] = 1.f;
71  }
72  else
73  {
74  // more complicated initialization: orthgonal basis of a random matrix
75 
76  const double PI2 = 2.f * 3.1415927f;
77 
78  for (uint32_t i = 0; i < length; i++)
79  {
80  weight& w = weights.strided_index(i);
81  float r1, r2;
82  for (int j = 1; j <= m; j++)
83  {
84  // box-muller tranform: https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
85  // redraw until r1 should be strictly positive
86  do
87  {
88  r1 = _random_state->get_and_update_random();
89  r2 = _random_state->get_and_update_random();
90  } while (r1 == 0.f);
91 
92  (&w)[j] = std::sqrt(-2.f * log(r1)) * (float)cos(PI2 * r2);
93  }
94  }
95  }
96 
97  // Gram-Schmidt
98  for (int j = 1; j <= m; j++)
99  {
100  for (int k = 1; k <= j - 1; k++)
101  {
102  double tmp = 0;
103 
104  for (uint32_t i = 0; i < length; i++)
105  tmp += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[k];
106  for (uint32_t i = 0; i < length; i++)
107  (&(weights.strided_index(i)))[j] -= (float)tmp * (&(weights.strided_index(i)))[k];
108  }
109  double norm = 0;
110  for (uint32_t i = 0; i < length; i++)
111  norm += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[j];
112  norm = std::sqrt(norm);
113  for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[j] /= (float)norm;
114  }
115  }
116 
117  void compute_AZx()
118  {
119  for (int i = 1; i <= m; i++)
120  {
121  data.AZx[i] = 0;
122  for (int j = 1; j <= i; j++)
123  {
124  data.AZx[i] += A[i][j] * data.Zx[j];
125  }
126  }
127  }
128 
130  {
131  for (int i = 1; i <= m; i++)
132  {
133  float gamma = fmin(learning_rate_cnt / t, 1.f);
134  float tmp = data.AZx[i] * data.sketch_cnt;
135 
136  if (t == 1)
137  {
138  ev[i] = gamma * tmp * tmp;
139  }
140  else
141  {
142  ev[i] = (1 - gamma) * t * ev[i] / (t - 1) + gamma * t * tmp * tmp;
143  }
144  }
145  }
146 
148  {
149  data.bdelta = 0;
150  for (int i = 1; i <= m; i++)
151  {
152  float gamma = fmin(learning_rate_cnt / t, 1.f);
153 
154  // if different learning rates are used
155  /*data.delta[i] = gamma * data.AZx[i] * data.sketch_cnt;
156  for (int j = 1; j < i; j++) {
157  data.delta[i] -= A[i][j] * data.delta[j];
158  }
159  data.delta[i] /= A[i][i];*/
160 
161  // if a same learning rate is used
162  data.delta[i] = gamma * data.Zx[i] * data.sketch_cnt;
163 
164  data.bdelta += data.delta[i] * b[i];
165  }
166  }
167 
168  void update_K()
169  {
170  float tmp = data.norm2_x * data.sketch_cnt * data.sketch_cnt;
171  for (int i = 1; i <= m; i++)
172  {
173  for (int j = 1; j <= m; j++)
174  {
175  K[i][j] += data.delta[i] * data.Zx[j] * data.sketch_cnt;
176  K[i][j] += data.delta[j] * data.Zx[i] * data.sketch_cnt;
177  K[i][j] += data.delta[i] * data.delta[j] * tmp;
178  }
179  }
180  }
181 
182  void update_A()
183  {
184  for (int i = 1; i <= m; i++)
185  {
186  for (int j = 1; j < i; j++)
187  {
188  zv[j] = 0;
189  for (int k = 1; k <= i; k++)
190  {
191  zv[j] += A[i][k] * K[k][j];
192  }
193  }
194 
195  for (int j = 1; j < i; j++)
196  {
197  vv[j] = 0;
198  for (int k = 1; k <= j; k++)
199  {
200  vv[j] += A[j][k] * zv[k];
201  }
202  }
203 
204  for (int j = 1; j < i; j++)
205  {
206  for (int k = j; k < i; k++)
207  {
208  A[i][j] -= vv[k] * A[k][j];
209  }
210  }
211 
212  float norm = 0;
213  for (int j = 1; j <= i; j++)
214  {
215  float temp = 0;
216  for (int k = 1; k <= i; k++)
217  {
218  temp += K[j][k] * A[i][k];
219  }
220  norm += A[i][j] * temp;
221  }
222  norm = sqrtf(norm);
223 
224  for (int j = 1; j <= i; j++)
225  {
226  A[i][j] /= norm;
227  }
228  }
229  }
230 
231  void update_b()
232  {
233  for (int j = 1; j <= m; j++)
234  {
235  float tmp = 0;
236  for (int i = j; i <= m; i++)
237  {
238  tmp += ev[i] * data.AZx[i] * A[i][j] / (alpha * (alpha + ev[i]));
239  }
240  b[j] += tmp * data.g;
241  }
242  }
243 
244  void update_D()
245  {
246  for (int j = 1; j <= m; j++)
247  {
248  float scale = fabs(A[j][j]);
249  for (int i = j + 1; i <= m; i++) scale = fmin(fabs(A[i][j]), scale);
250  if (scale < 1e-10)
251  continue;
252  for (int i = 1; i <= m; i++)
253  {
254  A[i][j] /= scale;
255  K[j][i] *= scale;
256  K[i][j] *= scale;
257  }
258  b[j] /= scale;
259  D[j] *= scale;
260  // printf("D[%d] = %f\n", j, D[j]);
261  }
262  }
263 
264  void check()
265  {
266  double max_norm = 0;
267  for (int i = 1; i <= m; i++)
268  for (int j = i; j <= m; j++) max_norm = fmax(max_norm, fabs(K[i][j]));
269  // printf("|K| = %f\n", max_norm);
270  if (max_norm < 1e7)
271  return;
272 
273  // implicit -> explicit representation
274  // printf("begin conversion: t = %d, norm(K) = %f\n", t, max_norm);
275 
276  // first step: K <- AKA'
277 
278  // K <- AK
279  for (int j = 1; j <= m; j++)
280  {
281  memset(tmp, 0, sizeof(double) * (m + 1));
282 
283  for (int i = 1; i <= m; i++)
284  {
285  for (int h = 1; h <= m; h++)
286  {
287  tmp[i] += A[i][h] * K[h][j];
288  }
289  }
290 
291  for (int i = 1; i <= m; i++) K[i][j] = tmp[i];
292  }
293  // K <- KA'
294  for (int i = 1; i <= m; i++)
295  {
296  memset(tmp, 0, sizeof(double) * (m + 1));
297 
298  for (int j = 1; j <= m; j++)
299  for (int h = 1; h <= m; h++) tmp[j] += K[i][h] * A[j][h];
300 
301  for (int j = 1; j <= m; j++)
302  {
303  K[i][j] = tmp[j];
304  }
305  }
306 
307  // second step: w[0] <- w[0] + (DZ)'b, b <- 0.
308 
309  uint32_t length = 1 << all->num_bits;
310  for (uint32_t i = 0; i < length; i++)
311  {
312  weight& w = all->weights.strided_index(i);
313  for (int j = 1; j <= m; j++) w += (&w)[j] * b[j] * D[j];
314  }
315 
316  memset(b, 0, sizeof(double) * (m + 1));
317 
318  // third step: Z <- ADZ, A, D <- Identity
319 
320  // double norm = 0;
321  for (uint32_t i = 0; i < length; ++i)
322  {
323  memset(tmp, 0, sizeof(float) * (m + 1));
324  weight& w = all->weights.strided_index(i);
325  for (int j = 1; j <= m; j++)
326  {
327  for (int h = 1; h <= m; ++h) tmp[j] += A[j][h] * D[h] * (&w)[h];
328  }
329  for (int j = 1; j <= m; ++j)
330  {
331  // norm = std::max(norm, fabs(tmp[j]));
332  (&w)[j] = tmp[j];
333  }
334  }
335  // printf("|Z| = %f\n", norm);
336 
337  for (int i = 1; i <= m; i++)
338  {
339  memset(A[i], 0, sizeof(double) * (m + 1));
340  D[i] = 1;
341  A[i][i] = 1;
342  }
343  }
344 
346  {
347  free(ev);
348  free(b);
349  free(D);
350  free(buffer);
351  free(weight_buffer);
352  free(zv);
353  free(vv);
354  free(tmp);
355  if (A)
356  {
357  for (int i = 1; i <= m; i++)
358  {
359  free(A[i]);
360  free(K[i]);
361  }
362  }
363 
364  free(A);
365  free(K);
366 
367  free(data.Zx);
368  free(data.AZx);
369  free(data.delta);
370  }
371 };
372 
373 void keep_example(vw& all, OjaNewton& /* ON */, example& ec) { output_and_account_example(all, ec); }
374 
375 void make_pred(update_data& data, float x, float& wref)
376 {
377  int m = data.ON->m;
378  float* w = &wref;
379 
380  if (data.ON->normalize)
381  {
382  x /= std::sqrt(w[NORM2]);
383  }
384 
385  data.prediction += w[0] * x;
386  for (int i = 1; i <= m; i++)
387  {
388  data.prediction += w[i] * x * data.ON->D[i] * data.ON->b[i];
389  }
390 }
391 
393 {
394  ON.data.prediction = 0;
395  GD::foreach_feature<update_data, make_pred>(*ON.all, ec, ON.data);
396  ec.partial_prediction = (float)ON.data.prediction;
398 }
399 
400 void update_Z_and_wbar(update_data& data, float x, float& wref)
401 {
402  float* w = &wref;
403  int m = data.ON->m;
404  if (data.ON->normalize)
405  x /= std::sqrt(w[NORM2]);
406  float s = data.sketch_cnt * x;
407 
408  for (int i = 1; i <= m; i++)
409  {
410  w[i] += data.delta[i] * s / data.ON->D[i];
411  }
412  w[0] -= s * data.bdelta;
413 }
414 
415 void compute_Zx_and_norm(update_data& data, float x, float& wref)
416 {
417  float* w = &wref;
418  int m = data.ON->m;
419  if (data.ON->normalize)
420  x /= std::sqrt(w[NORM2]);
421 
422  for (int i = 1; i <= m; i++)
423  {
424  data.Zx[i] += w[i] * x * data.ON->D[i];
425  }
426  data.norm2_x += x * x;
427 }
428 
429 void update_wbar_and_Zx(update_data& data, float x, float& wref)
430 {
431  float* w = &wref;
432  int m = data.ON->m;
433  if (data.ON->normalize)
434  x /= std::sqrt(w[NORM2]);
435 
436  float g = data.g * x;
437 
438  for (int i = 1; i <= m; i++)
439  {
440  data.Zx[i] += w[i] * x * data.ON->D[i];
441  }
442  w[0] -= g / data.ON->alpha;
443 }
444 
445 void update_normalization(update_data& data, float x, float& wref)
446 {
447  float* w = &wref;
448  int m = data.ON->m;
449 
450  w[NORM2] += x * x * data.g * data.g;
451 }
452 
454 {
455  assert(ec.in_use);
456 
457  // predict
458  predict(ON, base, ec);
459 
460  update_data& data = ON.data;
461  data.g = ON.all->loss->first_derivative(ON.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight;
462  data.g /= 2; // for half square loss
463 
464  if (ON.normalize)
465  GD::foreach_feature<update_data, update_normalization>(*ON.all, ec, data);
466 
467  ON.buffer[ON.cnt] = &ec;
468  ON.weight_buffer[ON.cnt++] = data.g / 2;
469 
470  if (ON.cnt == ON.epoch_size)
471  {
472  for (int k = 0; k < ON.epoch_size; k++, ON.t++)
473  {
474  example& ex = *(ON.buffer[k]);
475  data.sketch_cnt = ON.weight_buffer[k];
476 
477  data.norm2_x = 0;
478  memset(data.Zx, 0, sizeof(float) * (ON.m + 1));
479  GD::foreach_feature<update_data, compute_Zx_and_norm>(*ON.all, ex, data);
480  ON.compute_AZx();
481 
482  ON.update_eigenvalues();
483  ON.compute_delta();
484 
485  ON.update_K();
486 
487  GD::foreach_feature<update_data, update_Z_and_wbar>(*ON.all, ex, data);
488  }
489 
490  ON.update_A();
491  // ON.update_D();
492  }
493 
494  memset(data.Zx, 0, sizeof(float) * (ON.m + 1));
495  GD::foreach_feature<update_data, update_wbar_and_Zx>(*ON.all, ec, data);
496  ON.compute_AZx();
497 
498  ON.update_b();
499  ON.check();
500 
501  if (ON.cnt == ON.epoch_size)
502  {
503  ON.cnt = 0;
504  for (int k = 0; k < ON.epoch_size; k++)
505  {
506  VW::finish_example(*ON.all, *ON.buffer[k]);
507  }
508  }
509 }
510 
511 void save_load(OjaNewton& ON, io_buf& model_file, bool read, bool text)
512 {
513  vw& all = *ON.all;
514  if (read)
515  {
517  ON.initialize_Z(all.weights);
518  }
519 
520  if (model_file.files.size() > 0)
521  {
522  bool resume = all.save_resume;
523  std::stringstream msg;
524  msg << ":" << resume << "\n";
525  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
526 
527  double temp = 0.;
528  if (resume)
529  GD::save_load_online_state(all, model_file, read, text, temp);
530  else
531  GD::save_load_regressor(all, model_file, read, text);
532  }
533 }
534 
536 {
537  auto ON = scoped_calloc_or_throw<OjaNewton>();
538 
539  bool oja_newton;
540  float alpha_inverse;
541 
542  // These two are the only two boolean options that default to true. For now going to do this hack
543  // as the infrastructure doesn't easily support this possibility at the same time providing the
544  // ease of bool switches elsewhere. It seems that the switch behavior is more critical because
545  // of the positional data argument.
546  std::string normalize = "true";
547  std::string random_init = "true";
548  option_group_definition new_options("OjaNewton options");
549  new_options.add(make_option("OjaNewton", oja_newton).keep().help("Online Newton with Oja's Sketch"))
550  .add(make_option("sketch_size", ON->m).default_value(10).help("size of sketch"))
551  .add(make_option("epoch_size", ON->epoch_size).default_value(1).help("size of epoch"))
552  .add(make_option("alpha", ON->alpha).default_value(1.f).help("mutiplicative constant for indentiy"))
553  .add(make_option("alpha_inverse", alpha_inverse).help("one over alpha, similar to learning rate"))
554  .add(make_option("learning_rate_cnt", ON->learning_rate_cnt)
555  .default_value(2.f)
556  .help("constant for the learning rate 1/t"))
557  .add(make_option("normalize", normalize).help("normalize the features or not"))
558  .add(make_option("random_init", random_init).help("randomize initialization of Oja or not"));
559  options.add_and_parse(new_options);
560 
561  if (!options.was_supplied("OjaNewton"))
562  return nullptr;
563 
564  ON->all = &all;
566 
567  ON->normalize = normalize == "true";
568  ON->random_init = random_init == "true";
569 
570  if (options.was_supplied("alpha_inverse"))
571  ON->alpha = 1.f / alpha_inverse;
572 
573  ON->cnt = 0;
574  ON->t = 1;
575  ON->ev = calloc_or_throw<float>(ON->m + 1);
576  ON->b = calloc_or_throw<float>(ON->m + 1);
577  ON->D = calloc_or_throw<float>(ON->m + 1);
578  ON->A = calloc_or_throw<float*>(ON->m + 1);
579  ON->K = calloc_or_throw<float*>(ON->m + 1);
580  for (int i = 1; i <= ON->m; i++)
581  {
582  ON->A[i] = calloc_or_throw<float>(ON->m + 1);
583  ON->K[i] = calloc_or_throw<float>(ON->m + 1);
584  ON->A[i][i] = 1;
585  ON->K[i][i] = 1;
586  ON->D[i] = 1;
587  }
588 
589  ON->buffer = calloc_or_throw<example*>(ON->epoch_size);
590  ON->weight_buffer = calloc_or_throw<float>(ON->epoch_size);
591 
592  ON->zv = calloc_or_throw<float>(ON->m + 1);
593  ON->vv = calloc_or_throw<float>(ON->m + 1);
594  ON->tmp = calloc_or_throw<float>(ON->m + 1);
595 
596  ON->data.ON = ON.get();
597  ON->data.Zx = calloc_or_throw<float>(ON->m + 1);
598  ON->data.AZx = calloc_or_throw<float>(ON->m + 1);
599  ON->data.delta = calloc_or_throw<float>(ON->m + 1);
600 
601  all.weights.stride_shift((uint32_t)ceil(log2(ON->m + 2)));
602 
606  return make_base(l);
607 }
void compute_Zx_and_norm(update_data &data, float x, float &wref)
Definition: OjaNewton.cc:415
float * D
Definition: OjaNewton.cc:44
void check()
Definition: OjaNewton.cc:264
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
float ** K
Definition: OjaNewton.cc:46
parameters weights
Definition: global_data.h:537
loss_function * loss
Definition: global_data.h:523
float prediction
Definition: OjaNewton.cc:29
void initialize_regressor(vw &all, T &weights)
bool normalize
Definition: OjaNewton.cc:57
base_learner * OjaNewton_setup(options_i &options, vw &all)
Definition: OjaNewton.cc:535
float * zv
Definition: OjaNewton.cc:48
float scalar
Definition: example.h:45
void save_load(OjaNewton &ON, io_buf &model_file, bool read, bool text)
Definition: OjaNewton.cc:511
example ** buffer
Definition: OjaNewton.cc:52
float * vv
Definition: OjaNewton.cc:49
float * tmp
Definition: OjaNewton.cc:50
void output_and_account_example(vw &all, active &a, example &ec)
Definition: active.cc:105
float * AZx
Definition: OjaNewton.cc:26
struct update_data data
Definition: OjaNewton.cc:54
uint32_t stride()
void make_pred(update_data &data, float x, float &wref)
Definition: OjaNewton.cc:375
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void update_A()
Definition: OjaNewton.cc:182
float partial_prediction
Definition: example.h:68
float weight
Definition: simple_label.h:15
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
float ** A
Definition: OjaNewton.cc:45
void update_wbar_and_Zx(update_data &data, float x, float &wref)
Definition: OjaNewton.cc:429
int epoch_size
Definition: OjaNewton.cc:37
uint32_t num_bits
Definition: global_data.h:398
size_t size() const
Definition: v_array.h:68
virtual float first_derivative(shared_data *, float prediction, float label)=0
float * delta
Definition: OjaNewton.cc:27
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
void update_Z_and_wbar(update_data &data, float x, float &wref)
Definition: OjaNewton.cc:400
float learning_rate_cnt
Definition: OjaNewton.cc:56
void update_eigenvalues()
Definition: OjaNewton.cc:129
void update_b()
Definition: OjaNewton.cc:231
void compute_AZx()
Definition: OjaNewton.cc:117
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
void update_normalization(update_data &data, float x, float &wref)
Definition: OjaNewton.cc:445
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void keep_example(vw &all, OjaNewton &, example &ec)
Definition: OjaNewton.cc:373
void compute_delta()
Definition: OjaNewton.cc:147
shared_data * sd
Definition: global_data.h:375
float bdelta
Definition: OjaNewton.cc:28
v_array< int > files
Definition: io_buf.h:64
#define NORM2
Definition: OjaNewton.cc:17
virtual bool was_supplied(const std::string &key)=0
float sketch_cnt
Definition: OjaNewton.cc:23
std::shared_ptr< rand_state > _random_state
Definition: OjaNewton.cc:35
weight & strided_index(size_t index)
Definition: io_buf.h:54
void finish_example(vw &, example &)
Definition: parser.cc:881
void update_D()
Definition: OjaNewton.cc:244
float * ev
Definition: OjaNewton.cc:42
struct OjaNewton * ON
Definition: OjaNewton.cc:21
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void predict(OjaNewton &ON, base_learner &, example &ec)
Definition: OjaNewton.cc:392
float * Zx
Definition: OjaNewton.cc:25
polylabel l
Definition: example.h:57
float * weight_buffer
Definition: OjaNewton.cc:53
bool in_use
Definition: example.h:79
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
bool save_resume
Definition: global_data.h:415
uint32_t stride_shift()
void initialize_Z(parameters &weights)
Definition: OjaNewton.cc:60
polyprediction pred
Definition: example.h:60
bool random_init
Definition: OjaNewton.cc:58
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
float alpha
Definition: OjaNewton.cc:38
vw * all
Definition: OjaNewton.cc:34
float * b
Definition: OjaNewton.cc:43
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
void update_K()
Definition: OjaNewton.cc:168
void learn(OjaNewton &ON, base_learner &base, example &ec)
Definition: OjaNewton.cc:453
float f
Definition: cache.cc:40
float norm2_x
Definition: OjaNewton.cc:24