Vowpal Wabbit
|
#include <string>
#include "gd.h"
#include "vw.h"
#include "rand48.h"
#include "reductions.h"
#include <math.h>
#include <memory>
Go to the source code of this file.
Classes | |
struct | update_data |
struct | OjaNewton |
Macros | |
#define | NORM2 (m + 1) |
Functions | |
void | keep_example (vw &all, OjaNewton &, example &ec) |
void | make_pred (update_data &data, float x, float &wref) |
void | predict (OjaNewton &ON, base_learner &, example &ec) |
void | update_Z_and_wbar (update_data &data, float x, float &wref) |
void | compute_Zx_and_norm (update_data &data, float x, float &wref) |
void | update_wbar_and_Zx (update_data &data, float x, float &wref) |
void | update_normalization (update_data &data, float x, float &wref) |
void | learn (OjaNewton &ON, base_learner &base, example &ec) |
void | save_load (OjaNewton &ON, io_buf &model_file, bool read, bool text) |
base_learner * | OjaNewton_setup (options_i &options, vw &all) |
#define NORM2 (m + 1) |
Definition at line 17 of file OjaNewton.cc.
Referenced by compute_Zx_and_norm(), OjaNewton::initialize_Z(), make_pred(), update_normalization(), update_wbar_and_Zx(), and update_Z_and_wbar().
void compute_Zx_and_norm | ( | update_data & | data, |
float | x, | ||
float & | wref | ||
) |
Definition at line 415 of file OjaNewton.cc.
References OjaNewton::D, OjaNewton::m, NORM2, update_data::norm2_x, OjaNewton::normalize, update_data::ON, and update_data::Zx.
Definition at line 373 of file OjaNewton.cc.
References output_and_account_example().
Referenced by OjaNewton_setup().
void learn | ( | OjaNewton & | ON, |
base_learner & | base, | ||
example & | ec | ||
) |
Definition at line 453 of file OjaNewton.cc.
References OjaNewton::all, OjaNewton::buffer, OjaNewton::check(), OjaNewton::cnt, OjaNewton::compute_AZx(), OjaNewton::compute_delta(), OjaNewton::data, OjaNewton::epoch_size, VW::finish_example(), loss_function::first_derivative(), update_data::g, example::in_use, example::l, label_data::label, vw::loss, OjaNewton::m, update_data::norm2_x, OjaNewton::normalize, example::pred, predict(), polyprediction::scalar, vw::sd, polylabel::simple, update_data::sketch_cnt, OjaNewton::t, OjaNewton::update_A(), OjaNewton::update_b(), OjaNewton::update_eigenvalues(), OjaNewton::update_K(), label_data::weight, OjaNewton::weight_buffer, and update_data::Zx.
Referenced by OjaNewton_setup().
void make_pred | ( | update_data & | data, |
float | x, | ||
float & | wref | ||
) |
Definition at line 375 of file OjaNewton.cc.
References OjaNewton::b, OjaNewton::D, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::prediction.
base_learner* OjaNewton_setup | ( | options_i & | options, |
vw & | all | ||
) |
Definition at line 535 of file OjaNewton.cc.
References OjaNewton::_random_state, OjaNewton::A, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), OjaNewton::all, OjaNewton::alpha, update_data::AZx, OjaNewton::b, OjaNewton::buffer, OjaNewton::cnt, OjaNewton::D, OjaNewton::data, update_data::delta, OjaNewton::epoch_size, OjaNewton::ev, f, vw::get_random_state(), LEARNER::init_learner(), OjaNewton::K, keep_example(), learn(), OjaNewton::learning_rate_cnt, OjaNewton::m, LEARNER::make_base(), VW::config::make_option(), OjaNewton::normalize, update_data::ON, predict(), OjaNewton::random_init, save_load(), LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_save_load(), parameters::stride(), parameters::stride_shift(), OjaNewton::t, OjaNewton::tmp, OjaNewton::vv, VW::config::options_i::was_supplied(), OjaNewton::weight_buffer, vw::weights, OjaNewton::zv, and update_data::Zx.
Referenced by parse_reductions().
void predict | ( | OjaNewton & | ON, |
base_learner & | , | ||
example & | ec | ||
) |
Definition at line 392 of file OjaNewton.cc.
References OjaNewton::all, OjaNewton::data, GD::finalize_prediction(), example::partial_prediction, example::pred, update_data::prediction, polyprediction::scalar, and vw::sd.
Referenced by learn(), and OjaNewton_setup().
Definition at line 511 of file OjaNewton.cc.
References OjaNewton::all, bin_text_read_write_fixed(), io_buf::files, initialize_regressor(), OjaNewton::initialize_Z(), GD::save_load_online_state(), GD::save_load_regressor(), vw::save_resume, v_array< T >::size(), and vw::weights.
Referenced by OjaNewton_setup().
void update_normalization | ( | update_data & | data, |
float | x, | ||
float & | wref | ||
) |
Definition at line 445 of file OjaNewton.cc.
References update_data::g, OjaNewton::m, NORM2, and update_data::ON.
void update_wbar_and_Zx | ( | update_data & | data, |
float | x, | ||
float & | wref | ||
) |
Definition at line 429 of file OjaNewton.cc.
References OjaNewton::alpha, OjaNewton::D, update_data::g, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::Zx.
void update_Z_and_wbar | ( | update_data & | data, |
float | x, | ||
float & | wref | ||
) |
Definition at line 400 of file OjaNewton.cc.
References update_data::bdelta, OjaNewton::D, update_data::delta, OjaNewton::m, NORM2, OjaNewton::normalize, update_data::ON, and update_data::sketch_cnt.