Vowpal Wabbit
Classes | Macros | Functions
OjaNewton.cc File Reference
#include <string>
#include "gd.h"
#include "vw.h"
#include "rand48.h"
#include "reductions.h"
#include <math.h>
#include <memory>

Go to the source code of this file.

Classes

struct  update_data
 
struct  OjaNewton
 

Macros

#define NORM2   (m + 1)
 

Functions

void keep_example (vw &all, OjaNewton &, example &ec)
 
void make_pred (update_data &data, float x, float &wref)
 
void predict (OjaNewton &ON, base_learner &, example &ec)
 
void update_Z_and_wbar (update_data &data, float x, float &wref)
 
void compute_Zx_and_norm (update_data &data, float x, float &wref)
 
void update_wbar_and_Zx (update_data &data, float x, float &wref)
 
void update_normalization (update_data &data, float x, float &wref)
 
void learn (OjaNewton &ON, base_learner &base, example &ec)
 
void save_load (OjaNewton &ON, io_buf &model_file, bool read, bool text)
 
base_learnerOjaNewton_setup (options_i &options, vw &all)
 

Macro Definition Documentation

◆ NORM2

#define NORM2   (m + 1)

Function Documentation

◆ compute_Zx_and_norm()

void compute_Zx_and_norm ( update_data data,
float  x,
float &  wref 
)

Definition at line 415 of file OjaNewton.cc.

References OjaNewton::D, OjaNewton::m, NORM2, update_data::norm2_x, OjaNewton::normalize, update_data::ON, and update_data::Zx.

416 {
417  float* w = &wref;
418  int m = data.ON->m;
419  if (data.ON->normalize)
420  x /= std::sqrt(w[NORM2]);
421 
422  for (int i = 1; i <= m; i++)
423  {
424  data.Zx[i] += w[i] * x * data.ON->D[i];
425  }
426  data.norm2_x += x * x;
427 }
float * D
Definition: OjaNewton.cc:44
bool normalize
Definition: OjaNewton.cc:57
#define NORM2
Definition: OjaNewton.cc:17
struct OjaNewton * ON
Definition: OjaNewton.cc:21
float * Zx
Definition: OjaNewton.cc:25
float norm2_x
Definition: OjaNewton.cc:24

◆ keep_example()

void keep_example ( vw all,
OjaNewton ,
example ec 
)

Definition at line 373 of file OjaNewton.cc.

References output_and_account_example().

Referenced by OjaNewton_setup().

373 { output_and_account_example(all, ec); }
void output_and_account_example(vw &all, active &a, example &ec)
Definition: active.cc:105

◆ learn()

void learn ( OjaNewton ON,
base_learner base,
example ec 
)

Definition at line 453 of file OjaNewton.cc.

References OjaNewton::all, OjaNewton::buffer, OjaNewton::check(), OjaNewton::cnt, OjaNewton::compute_AZx(), OjaNewton::compute_delta(), OjaNewton::data, OjaNewton::epoch_size, VW::finish_example(), loss_function::first_derivative(), update_data::g, example::in_use, example::l, label_data::label, vw::loss, OjaNewton::m, update_data::norm2_x, OjaNewton::normalize, example::pred, predict(), polyprediction::scalar, vw::sd, polylabel::simple, update_data::sketch_cnt, OjaNewton::t, OjaNewton::update_A(), OjaNewton::update_b(), OjaNewton::update_eigenvalues(), OjaNewton::update_K(), label_data::weight, OjaNewton::weight_buffer, and update_data::Zx.

Referenced by OjaNewton_setup().

454 {
455  assert(ec.in_use);
456 
457  // predict
458  predict(ON, base, ec);
459 
460  update_data& data = ON.data;
461  data.g = ON.all->loss->first_derivative(ON.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight;
462  data.g /= 2; // for half square loss
463 
464  if (ON.normalize)
465  GD::foreach_feature<update_data, update_normalization>(*ON.all, ec, data);
466 
467  ON.buffer[ON.cnt] = &ec;
468  ON.weight_buffer[ON.cnt++] = data.g / 2;
469 
470  if (ON.cnt == ON.epoch_size)
471  {
472  for (int k = 0; k < ON.epoch_size; k++, ON.t++)
473  {
474  example& ex = *(ON.buffer[k]);
475  data.sketch_cnt = ON.weight_buffer[k];
476 
477  data.norm2_x = 0;
478  memset(data.Zx, 0, sizeof(float) * (ON.m + 1));
479  GD::foreach_feature<update_data, compute_Zx_and_norm>(*ON.all, ex, data);
480  ON.compute_AZx();
481 
482  ON.update_eigenvalues();
483  ON.compute_delta();
484 
485  ON.update_K();
486 
487  GD::foreach_feature<update_data, update_Z_and_wbar>(*ON.all, ex, data);
488  }
489 
490  ON.update_A();
491  // ON.update_D();
492  }
493 
494  memset(data.Zx, 0, sizeof(float) * (ON.m + 1));
495  GD::foreach_feature<update_data, update_wbar_and_Zx>(*ON.all, ec, data);
496  ON.compute_AZx();
497 
498  ON.update_b();
499  ON.check();
500 
501  if (ON.cnt == ON.epoch_size)
502  {
503  ON.cnt = 0;
504  for (int k = 0; k < ON.epoch_size; k++)
505  {
506  VW::finish_example(*ON.all, *ON.buffer[k]);
507  }
508  }
509 }
void check()
Definition: OjaNewton.cc:264
loss_function * loss
Definition: global_data.h:523
bool normalize
Definition: OjaNewton.cc:57
float scalar
Definition: example.h:45
example ** buffer
Definition: OjaNewton.cc:52
struct update_data data
Definition: OjaNewton.cc:54
void update_A()
Definition: OjaNewton.cc:182
float weight
Definition: simple_label.h:15
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
int epoch_size
Definition: OjaNewton.cc:37
virtual float first_derivative(shared_data *, float prediction, float label)=0
void update_eigenvalues()
Definition: OjaNewton.cc:129
void update_b()
Definition: OjaNewton.cc:231
void compute_AZx()
Definition: OjaNewton.cc:117
void compute_delta()
Definition: OjaNewton.cc:147
shared_data * sd
Definition: global_data.h:375
float sketch_cnt
Definition: OjaNewton.cc:23
void finish_example(vw &, example &)
Definition: parser.cc:881
void predict(OjaNewton &ON, base_learner &, example &ec)
Definition: OjaNewton.cc:392
float * Zx
Definition: OjaNewton.cc:25
polylabel l
Definition: example.h:57
float * weight_buffer
Definition: OjaNewton.cc:53
bool in_use
Definition: example.h:79
polyprediction pred
Definition: example.h:60
vw * all
Definition: OjaNewton.cc:34
void update_K()
Definition: OjaNewton.cc:168
float norm2_x
Definition: OjaNewton.cc:24

◆ make_pred()

void make_pred ( update_data data,
float  x,
float &  wref 
)

Definition at line 375 of file OjaNewton.cc.

References OjaNewton::b, OjaNewton::D, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::prediction.

376 {
377  int m = data.ON->m;
378  float* w = &wref;
379 
380  if (data.ON->normalize)
381  {
382  x /= std::sqrt(w[NORM2]);
383  }
384 
385  data.prediction += w[0] * x;
386  for (int i = 1; i <= m; i++)
387  {
388  data.prediction += w[i] * x * data.ON->D[i] * data.ON->b[i];
389  }
390 }
float * D
Definition: OjaNewton.cc:44
float prediction
Definition: OjaNewton.cc:29
bool normalize
Definition: OjaNewton.cc:57
#define NORM2
Definition: OjaNewton.cc:17
struct OjaNewton * ON
Definition: OjaNewton.cc:21
float * b
Definition: OjaNewton.cc:43

◆ OjaNewton_setup()

base_learner* OjaNewton_setup ( options_i options,
vw all 
)

Definition at line 535 of file OjaNewton.cc.

References OjaNewton::_random_state, OjaNewton::A, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), OjaNewton::all, OjaNewton::alpha, update_data::AZx, OjaNewton::b, OjaNewton::buffer, OjaNewton::cnt, OjaNewton::D, OjaNewton::data, update_data::delta, OjaNewton::epoch_size, OjaNewton::ev, f, vw::get_random_state(), LEARNER::init_learner(), OjaNewton::K, keep_example(), learn(), OjaNewton::learning_rate_cnt, OjaNewton::m, LEARNER::make_base(), VW::config::make_option(), OjaNewton::normalize, update_data::ON, predict(), OjaNewton::random_init, save_load(), LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_save_load(), parameters::stride(), parameters::stride_shift(), OjaNewton::t, OjaNewton::tmp, OjaNewton::vv, VW::config::options_i::was_supplied(), OjaNewton::weight_buffer, vw::weights, OjaNewton::zv, and update_data::Zx.

Referenced by parse_reductions().

536 {
537  auto ON = scoped_calloc_or_throw<OjaNewton>();
538 
539  bool oja_newton;
540  float alpha_inverse;
541 
542  // These two are the only two boolean options that default to true. For now going to do this hack
543  // as the infrastructure doesn't easily support this possibility at the same time providing the
544  // ease of bool switches elsewhere. It seems that the switch behavior is more critical because
545  // of the positional data argument.
546  std::string normalize = "true";
547  std::string random_init = "true";
548  option_group_definition new_options("OjaNewton options");
549  new_options.add(make_option("OjaNewton", oja_newton).keep().help("Online Newton with Oja's Sketch"))
550  .add(make_option("sketch_size", ON->m).default_value(10).help("size of sketch"))
551  .add(make_option("epoch_size", ON->epoch_size).default_value(1).help("size of epoch"))
552  .add(make_option("alpha", ON->alpha).default_value(1.f).help("mutiplicative constant for indentiy"))
553  .add(make_option("alpha_inverse", alpha_inverse).help("one over alpha, similar to learning rate"))
554  .add(make_option("learning_rate_cnt", ON->learning_rate_cnt)
555  .default_value(2.f)
556  .help("constant for the learning rate 1/t"))
557  .add(make_option("normalize", normalize).help("normalize the features or not"))
558  .add(make_option("random_init", random_init).help("randomize initialization of Oja or not"));
559  options.add_and_parse(new_options);
560 
561  if (!options.was_supplied("OjaNewton"))
562  return nullptr;
563 
564  ON->all = &all;
565  ON->_random_state = all.get_random_state();
566 
567  ON->normalize = normalize == "true";
568  ON->random_init = random_init == "true";
569 
570  if (options.was_supplied("alpha_inverse"))
571  ON->alpha = 1.f / alpha_inverse;
572 
573  ON->cnt = 0;
574  ON->t = 1;
575  ON->ev = calloc_or_throw<float>(ON->m + 1);
576  ON->b = calloc_or_throw<float>(ON->m + 1);
577  ON->D = calloc_or_throw<float>(ON->m + 1);
578  ON->A = calloc_or_throw<float*>(ON->m + 1);
579  ON->K = calloc_or_throw<float*>(ON->m + 1);
580  for (int i = 1; i <= ON->m; i++)
581  {
582  ON->A[i] = calloc_or_throw<float>(ON->m + 1);
583  ON->K[i] = calloc_or_throw<float>(ON->m + 1);
584  ON->A[i][i] = 1;
585  ON->K[i][i] = 1;
586  ON->D[i] = 1;
587  }
588 
589  ON->buffer = calloc_or_throw<example*>(ON->epoch_size);
590  ON->weight_buffer = calloc_or_throw<float>(ON->epoch_size);
591 
592  ON->zv = calloc_or_throw<float>(ON->m + 1);
593  ON->vv = calloc_or_throw<float>(ON->m + 1);
594  ON->tmp = calloc_or_throw<float>(ON->m + 1);
595 
596  ON->data.ON = ON.get();
597  ON->data.Zx = calloc_or_throw<float>(ON->m + 1);
598  ON->data.AZx = calloc_or_throw<float>(ON->m + 1);
599  ON->data.delta = calloc_or_throw<float>(ON->m + 1);
600 
601  all.weights.stride_shift((uint32_t)ceil(log2(ON->m + 2)));
602 
606  return make_base(l);
607 }
parameters weights
Definition: global_data.h:537
void save_load(OjaNewton &ON, io_buf &model_file, bool read, bool text)
Definition: OjaNewton.cc:511
uint32_t stride()
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void keep_example(vw &all, OjaNewton &, example &ec)
Definition: OjaNewton.cc:373
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void predict(OjaNewton &ON, base_learner &, example &ec)
Definition: OjaNewton.cc:392
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t stride_shift()
void learn(OjaNewton &ON, base_learner &base, example &ec)
Definition: OjaNewton.cc:453
float f
Definition: cache.cc:40

◆ predict()

void predict ( OjaNewton ON,
base_learner ,
example ec 
)

Definition at line 392 of file OjaNewton.cc.

References OjaNewton::all, OjaNewton::data, GD::finalize_prediction(), example::partial_prediction, example::pred, update_data::prediction, polyprediction::scalar, and vw::sd.

Referenced by learn(), and OjaNewton_setup().

393 {
394  ON.data.prediction = 0;
395  GD::foreach_feature<update_data, make_pred>(*ON.all, ec, ON.data);
396  ec.partial_prediction = (float)ON.data.prediction;
398 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
float prediction
Definition: OjaNewton.cc:29
float scalar
Definition: example.h:45
struct update_data data
Definition: OjaNewton.cc:54
float partial_prediction
Definition: example.h:68
shared_data * sd
Definition: global_data.h:375
polyprediction pred
Definition: example.h:60
vw * all
Definition: OjaNewton.cc:34

◆ save_load()

void save_load ( OjaNewton ON,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 511 of file OjaNewton.cc.

References OjaNewton::all, bin_text_read_write_fixed(), io_buf::files, initialize_regressor(), OjaNewton::initialize_Z(), GD::save_load_online_state(), GD::save_load_regressor(), vw::save_resume, v_array< T >::size(), and vw::weights.

Referenced by OjaNewton_setup().

512 {
513  vw& all = *ON.all;
514  if (read)
515  {
517  ON.initialize_Z(all.weights);
518  }
519 
520  if (model_file.files.size() > 0)
521  {
522  bool resume = all.save_resume;
523  std::stringstream msg;
524  msg << ":" << resume << "\n";
525  bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);
526 
527  double temp = 0.;
528  if (resume)
529  GD::save_load_online_state(all, model_file, read, text, temp);
530  else
531  GD::save_load_regressor(all, model_file, read, text);
532  }
533 }
parameters weights
Definition: global_data.h:537
void initialize_regressor(vw &all, T &weights)
size_t size() const
Definition: v_array.h:68
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
Definition: gd.cc:776
v_array< int > files
Definition: io_buf.h:64
bool save_resume
Definition: global_data.h:415
void initialize_Z(parameters &weights)
Definition: OjaNewton.cc:60
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
Definition: gd.cc:707
vw * all
Definition: OjaNewton.cc:34
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326

◆ update_normalization()

void update_normalization ( update_data data,
float  x,
float &  wref 
)

Definition at line 445 of file OjaNewton.cc.

References update_data::g, OjaNewton::m, NORM2, and update_data::ON.

446 {
447  float* w = &wref;
448  int m = data.ON->m;
449 
450  w[NORM2] += x * x * data.g * data.g;
451 }
#define NORM2
Definition: OjaNewton.cc:17
struct OjaNewton * ON
Definition: OjaNewton.cc:21

◆ update_wbar_and_Zx()

void update_wbar_and_Zx ( update_data data,
float  x,
float &  wref 
)

Definition at line 429 of file OjaNewton.cc.

References OjaNewton::alpha, OjaNewton::D, update_data::g, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::Zx.

430 {
431  float* w = &wref;
432  int m = data.ON->m;
433  if (data.ON->normalize)
434  x /= std::sqrt(w[NORM2]);
435 
436  float g = data.g * x;
437 
438  for (int i = 1; i <= m; i++)
439  {
440  data.Zx[i] += w[i] * x * data.ON->D[i];
441  }
442  w[0] -= g / data.ON->alpha;
443 }
float * D
Definition: OjaNewton.cc:44
bool normalize
Definition: OjaNewton.cc:57
#define NORM2
Definition: OjaNewton.cc:17
struct OjaNewton * ON
Definition: OjaNewton.cc:21
float * Zx
Definition: OjaNewton.cc:25
float alpha
Definition: OjaNewton.cc:38

◆ update_Z_and_wbar()

void update_Z_and_wbar ( update_data data,
float  x,
float &  wref 
)

Definition at line 400 of file OjaNewton.cc.

References update_data::bdelta, OjaNewton::D, update_data::delta, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::sketch_cnt.

401 {
402  float* w = &wref;
403  int m = data.ON->m;
404  if (data.ON->normalize)
405  x /= std::sqrt(w[NORM2]);
406  float s = data.sketch_cnt * x;
407 
408  for (int i = 1; i <= m; i++)
409  {
410  w[i] += data.delta[i] * s / data.ON->D[i];
411  }
412  w[0] -= s * data.bdelta;
413 }
float * D
Definition: OjaNewton.cc:44
bool normalize
Definition: OjaNewton.cc:57
float * delta
Definition: OjaNewton.cc:27
float bdelta
Definition: OjaNewton.cc:28
#define NORM2
Definition: OjaNewton.cc:17
float sketch_cnt
Definition: OjaNewton.cc:23
struct OjaNewton * ON
Definition: OjaNewton.cc:21