Vowpal Wabbit
Functions
OjaNewton.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerOjaNewton_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ OjaNewton_setup()

LEARNER::base_learner* OjaNewton_setup ( VW::config::options_i options,
vw all 
)

Definition at line 535 of file OjaNewton.cc.

References OjaNewton::_random_state, OjaNewton::A, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), OjaNewton::all, OjaNewton::alpha, update_data::AZx, OjaNewton::b, OjaNewton::buffer, OjaNewton::cnt, OjaNewton::D, OjaNewton::data, update_data::delta, OjaNewton::epoch_size, OjaNewton::ev, f, vw::get_random_state(), LEARNER::init_learner(), OjaNewton::K, keep_example(), learn(), OjaNewton::learning_rate_cnt, OjaNewton::m, LEARNER::make_base(), VW::config::make_option(), OjaNewton::normalize, update_data::ON, predict(), OjaNewton::random_init, save_load(), LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_save_load(), parameters::stride(), parameters::stride_shift(), OjaNewton::t, OjaNewton::tmp, OjaNewton::vv, VW::config::options_i::was_supplied(), OjaNewton::weight_buffer, vw::weights, OjaNewton::zv, and update_data::Zx.

Referenced by parse_reductions().

536 {
537  auto ON = scoped_calloc_or_throw<OjaNewton>();
538 
539  bool oja_newton;
540  float alpha_inverse;
541 
542  // These two are the only two boolean options that default to true. For now going to do this hack
543  // as the infrastructure doesn't easily support this possibility at the same time providing the
544  // ease of bool switches elsewhere. It seems that the switch behavior is more critical because
545  // of the positional data argument.
546  std::string normalize = "true";
547  std::string random_init = "true";
548  option_group_definition new_options("OjaNewton options");
549  new_options.add(make_option("OjaNewton", oja_newton).keep().help("Online Newton with Oja's Sketch"))
550  .add(make_option("sketch_size", ON->m).default_value(10).help("size of sketch"))
551  .add(make_option("epoch_size", ON->epoch_size).default_value(1).help("size of epoch"))
552  .add(make_option("alpha", ON->alpha).default_value(1.f).help("mutiplicative constant for indentiy"))
553  .add(make_option("alpha_inverse", alpha_inverse).help("one over alpha, similar to learning rate"))
554  .add(make_option("learning_rate_cnt", ON->learning_rate_cnt)
555  .default_value(2.f)
556  .help("constant for the learning rate 1/t"))
557  .add(make_option("normalize", normalize).help("normalize the features or not"))
558  .add(make_option("random_init", random_init).help("randomize initialization of Oja or not"));
559  options.add_and_parse(new_options);
560 
561  if (!options.was_supplied("OjaNewton"))
562  return nullptr;
563 
564  ON->all = &all;
565  ON->_random_state = all.get_random_state();
566 
567  ON->normalize = normalize == "true";
568  ON->random_init = random_init == "true";
569 
570  if (options.was_supplied("alpha_inverse"))
571  ON->alpha = 1.f / alpha_inverse;
572 
573  ON->cnt = 0;
574  ON->t = 1;
575  ON->ev = calloc_or_throw<float>(ON->m + 1);
576  ON->b = calloc_or_throw<float>(ON->m + 1);
577  ON->D = calloc_or_throw<float>(ON->m + 1);
578  ON->A = calloc_or_throw<float*>(ON->m + 1);
579  ON->K = calloc_or_throw<float*>(ON->m + 1);
580  for (int i = 1; i <= ON->m; i++)
581  {
582  ON->A[i] = calloc_or_throw<float>(ON->m + 1);
583  ON->K[i] = calloc_or_throw<float>(ON->m + 1);
584  ON->A[i][i] = 1;
585  ON->K[i][i] = 1;
586  ON->D[i] = 1;
587  }
588 
589  ON->buffer = calloc_or_throw<example*>(ON->epoch_size);
590  ON->weight_buffer = calloc_or_throw<float>(ON->epoch_size);
591 
592  ON->zv = calloc_or_throw<float>(ON->m + 1);
593  ON->vv = calloc_or_throw<float>(ON->m + 1);
594  ON->tmp = calloc_or_throw<float>(ON->m + 1);
595 
596  ON->data.ON = ON.get();
597  ON->data.Zx = calloc_or_throw<float>(ON->m + 1);
598  ON->data.AZx = calloc_or_throw<float>(ON->m + 1);
599  ON->data.delta = calloc_or_throw<float>(ON->m + 1);
600 
601  all.weights.stride_shift((uint32_t)ceil(log2(ON->m + 2)));
602 
606  return make_base(l);
607 }
parameters weights
Definition: global_data.h:537
void save_load(OjaNewton &ON, io_buf &model_file, bool read, bool text)
Definition: OjaNewton.cc:511
uint32_t stride()
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void keep_example(vw &all, OjaNewton &, example &ec)
Definition: OjaNewton.cc:373
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void predict(OjaNewton &ON, base_learner &, example &ec)
Definition: OjaNewton.cc:392
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t stride_shift()
void learn(OjaNewton &ON, base_learner &base, example &ec)
Definition: OjaNewton.cc:453
float f
Definition: cache.cc:40