62 uint32_t length = 1 << all->
num_bits;
70 for (
int i = 1; i <= m; i++) (&(weights.
strided_index(i)))[i] = 1.f;
76 const double PI2 = 2.f * 3.1415927f;
78 for (uint32_t i = 0; i < length; i++)
82 for (
int j = 1; j <= m; j++)
88 r1 = _random_state->get_and_update_random();
89 r2 = _random_state->get_and_update_random();
92 (&w)[j] = std::sqrt(-2.
f * log(r1)) * (
float)cos(PI2 * r2);
98 for (
int j = 1; j <= m; j++)
100 for (
int k = 1; k <= j - 1; k++)
104 for (uint32_t i = 0; i < length; i++)
106 for (uint32_t i = 0; i < length; i++)
110 for (uint32_t i = 0; i < length; i++)
112 norm = std::sqrt(norm);
113 for (uint32_t i = 0; i < length; i++) (&(weights.
strided_index(i)))[j] /= (float)norm;
119 for (
int i = 1; i <= m; i++)
122 for (
int j = 1; j <= i; j++)
124 data.AZx[i] += A[i][j] * data.Zx[j];
131 for (
int i = 1; i <= m; i++)
133 float gamma = fmin(learning_rate_cnt / t, 1.
f);
134 float tmp = data.AZx[i] * data.sketch_cnt;
138 ev[i] = gamma * tmp * tmp;
142 ev[i] = (1 - gamma) * t * ev[i] / (t - 1) + gamma * t * tmp * tmp;
150 for (
int i = 1; i <= m; i++)
152 float gamma = fmin(learning_rate_cnt / t, 1.
f);
162 data.delta[i] = gamma * data.Zx[i] * data.sketch_cnt;
164 data.bdelta += data.delta[i] * b[i];
170 float tmp = data.norm2_x * data.sketch_cnt * data.sketch_cnt;
171 for (
int i = 1; i <= m; i++)
173 for (
int j = 1; j <= m; j++)
175 K[i][j] += data.delta[i] * data.Zx[j] * data.sketch_cnt;
176 K[i][j] += data.delta[j] * data.Zx[i] * data.sketch_cnt;
177 K[i][j] += data.delta[i] * data.delta[j] * tmp;
184 for (
int i = 1; i <= m; i++)
186 for (
int j = 1; j < i; j++)
189 for (
int k = 1; k <= i; k++)
191 zv[j] += A[i][k] * K[k][j];
195 for (
int j = 1; j < i; j++)
198 for (
int k = 1; k <= j; k++)
200 vv[j] += A[j][k] * zv[k];
204 for (
int j = 1; j < i; j++)
206 for (
int k = j; k < i; k++)
208 A[i][j] -= vv[k] * A[k][j];
213 for (
int j = 1; j <= i; j++)
216 for (
int k = 1; k <= i; k++)
218 temp += K[j][k] * A[i][k];
220 norm += A[i][j] * temp;
224 for (
int j = 1; j <= i; j++)
233 for (
int j = 1; j <= m; j++)
236 for (
int i = j; i <= m; i++)
238 tmp += ev[i] * data.AZx[i] * A[i][j] / (alpha * (alpha + ev[i]));
240 b[j] += tmp * data.g;
246 for (
int j = 1; j <= m; j++)
248 float scale = fabs(A[j][j]);
249 for (
int i = j + 1; i <= m; i++) scale = fmin(fabs(A[i][j]), scale);
252 for (
int i = 1; i <= m; i++)
267 for (
int i = 1; i <= m; i++)
268 for (
int j = i; j <= m; j++) max_norm = fmax(max_norm, fabs(K[i][j]));
279 for (
int j = 1; j <= m; j++)
281 memset(tmp, 0,
sizeof(
double) * (m + 1));
283 for (
int i = 1; i <= m; i++)
285 for (
int h = 1; h <= m; h++)
287 tmp[i] += A[i][h] * K[h][j];
291 for (
int i = 1; i <= m; i++) K[i][j] = tmp[i];
294 for (
int i = 1; i <= m; i++)
296 memset(tmp, 0,
sizeof(
double) * (m + 1));
298 for (
int j = 1; j <= m; j++)
299 for (
int h = 1; h <= m; h++) tmp[j] += K[i][h] * A[j][h];
301 for (
int j = 1; j <= m; j++)
309 uint32_t length = 1 << all->
num_bits;
310 for (uint32_t i = 0; i < length; i++)
313 for (
int j = 1; j <= m; j++) w += (&w)[j] * b[j] * D[j];
316 memset(b, 0,
sizeof(
double) * (m + 1));
321 for (uint32_t i = 0; i < length; ++i)
323 memset(tmp, 0,
sizeof(
float) * (m + 1));
325 for (
int j = 1; j <= m; j++)
327 for (
int h = 1; h <= m; ++h) tmp[j] += A[j][h] * D[h] * (&w)[h];
329 for (
int j = 1; j <= m; ++j)
337 for (
int i = 1; i <= m; i++)
339 memset(A[i], 0,
sizeof(
double) * (m + 1));
357 for (
int i = 1; i <= m; i++)
382 x /= std::sqrt(w[
NORM2]);
386 for (
int i = 1; i <= m; i++)
395 GD::foreach_feature<update_data, make_pred>(*ON.
all, ec, ON.
data);
405 x /= std::sqrt(w[
NORM2]);
408 for (
int i = 1; i <= m; i++)
410 w[i] += data.
delta[i] * s / data.
ON->
D[i];
420 x /= std::sqrt(w[
NORM2]);
422 for (
int i = 1; i <= m; i++)
424 data.
Zx[i] += w[i] * x * data.
ON->
D[i];
434 x /= std::sqrt(w[
NORM2]);
436 float g = data.
g * x;
438 for (
int i = 1; i <= m; i++)
440 data.
Zx[i] += w[i] * x * data.
ON->
D[i];
450 w[
NORM2] += x * x * data.
g * data.
g;
465 GD::foreach_feature<update_data, update_normalization>(*ON.
all, ec, data);
478 memset(data.
Zx, 0,
sizeof(
float) * (ON.
m + 1));
479 GD::foreach_feature<update_data, compute_Zx_and_norm>(*ON.
all, ex, data);
487 GD::foreach_feature<update_data, update_Z_and_wbar>(*ON.
all, ex, data);
494 memset(data.
Zx, 0,
sizeof(
float) * (ON.
m + 1));
495 GD::foreach_feature<update_data, update_wbar_and_Zx>(*ON.
all, ec, data);
523 std::stringstream msg;
524 msg <<
":" << resume <<
"\n";
537 auto ON = scoped_calloc_or_throw<OjaNewton>();
546 std::string normalize =
"true";
547 std::string random_init =
"true";
549 new_options.
add(
make_option(
"OjaNewton", oja_newton).keep().help(
"Online Newton with Oja's Sketch"))
553 .
add(
make_option(
"alpha_inverse", alpha_inverse).help(
"one over alpha, similar to learning rate"))
556 .help(
"constant for the learning rate 1/t"))
557 .
add(
make_option(
"normalize", normalize).help(
"normalize the features or not"))
558 .
add(
make_option(
"random_init", random_init).help(
"randomize initialization of Oja or not"));
571 ON->
alpha = 1.f / alpha_inverse;
575 ON->
ev = calloc_or_throw<float>(
ON->
m + 1);
576 ON->
b = calloc_or_throw<float>(
ON->
m + 1);
577 ON->
D = calloc_or_throw<float>(
ON->
m + 1);
578 ON->
A = calloc_or_throw<float*>(
ON->
m + 1);
579 ON->
K = calloc_or_throw<float*>(
ON->
m + 1);
580 for (
int i = 1; i <=
ON->
m; i++)
582 ON->
A[i] = calloc_or_throw<float>(
ON->
m + 1);
583 ON->
K[i] = calloc_or_throw<float>(
ON->
m + 1);
592 ON->
zv = calloc_or_throw<float>(
ON->
m + 1);
593 ON->
vv = calloc_or_throw<float>(
ON->
m + 1);
594 ON->
tmp = calloc_or_throw<float>(
ON->
m + 1);
void compute_Zx_and_norm(update_data &data, float x, float &wref)
float finalize_prediction(shared_data *sd, float ret)
void initialize_regressor(vw &all, T &weights)
base_learner * OjaNewton_setup(options_i &options, vw &all)
void save_load(OjaNewton &ON, io_buf &model_file, bool read, bool text)
void output_and_account_example(vw &all, active &a, example &ec)
void make_pred(update_data &data, float x, float &wref)
base_learner * make_base(learner< T, E > &base)
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
void update_wbar_and_Zx(update_data &data, float x, float &wref)
virtual float first_derivative(shared_data *, float prediction, float label)=0
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
std::shared_ptr< rand_state > get_random_state()
void update_Z_and_wbar(update_data &data, float x, float &wref)
void update_eigenvalues()
void set_finish_example(void(*f)(vw &all, T &, E &))
void update_normalization(update_data &data, float x, float &wref)
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
void keep_example(vw &all, OjaNewton &, example &ec)
virtual bool was_supplied(const std::string &key)=0
std::shared_ptr< rand_state > _random_state
weight & strided_index(size_t index)
void finish_example(vw &, example &)
option_group_definition & add(T &&op)
int add(svm_params ¶ms, svm_example *fec)
void predict(OjaNewton &ON, base_learner &, example &ec)
typed_option< T > make_option(std::string name, T &location)
void initialize_Z(parameters &weights)
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
void learn(OjaNewton &ON, base_learner &base, example &ec)