Vowpal Wabbit
Public Member Functions | Public Attributes | List of all members
OjaNewton Struct Reference

Public Member Functions

void initialize_Z (parameters &weights)
 
void compute_AZx ()
 
void update_eigenvalues ()
 
void compute_delta ()
 
void update_K ()
 
void update_A ()
 
void update_b ()
 
void update_D ()
 
void check ()
 
 ~OjaNewton ()
 

Public Attributes

vwall
 
std::shared_ptr< rand_state_random_state
 
int m
 
int epoch_size
 
float alpha
 
int cnt
 
int t
 
float * ev
 
float * b
 
float * D
 
float ** A
 
float ** K
 
float * zv
 
float * vv
 
float * tmp
 
example ** buffer
 
float * weight_buffer
 
struct update_data data
 
float learning_rate_cnt
 
bool normalize
 
bool random_init
 

Detailed Description

Definition at line 32 of file OjaNewton.cc.

Constructor & Destructor Documentation

◆ ~OjaNewton()

OjaNewton::~OjaNewton ( )
inline

Definition at line 345 of file OjaNewton.cc.

346  {
347  free(ev);
348  free(b);
349  free(D);
350  free(buffer);
351  free(weight_buffer);
352  free(zv);
353  free(vv);
354  free(tmp);
355  if (A)
356  {
357  for (int i = 1; i <= m; i++)
358  {
359  free(A[i]);
360  free(K[i]);
361  }
362  }
363 
364  free(A);
365  free(K);
366 
367  free(data.Zx);
368  free(data.AZx);
369  free(data.delta);
370  }
float * D
Definition: OjaNewton.cc:44
float ** K
Definition: OjaNewton.cc:46
float * zv
Definition: OjaNewton.cc:48
float * vv
Definition: OjaNewton.cc:49
example ** buffer
Definition: OjaNewton.cc:52
float * tmp
Definition: OjaNewton.cc:50
float * AZx
Definition: OjaNewton.cc:26
struct update_data data
Definition: OjaNewton.cc:54
float ** A
Definition: OjaNewton.cc:45
float * delta
Definition: OjaNewton.cc:27
float * ev
Definition: OjaNewton.cc:42
float * Zx
Definition: OjaNewton.cc:25
float * weight_buffer
Definition: OjaNewton.cc:53
float * b
Definition: OjaNewton.cc:43

Member Function Documentation

◆ check()

void OjaNewton::check ( )
inline

Definition at line 264 of file OjaNewton.cc.

References vw::num_bits, parameters::strided_index(), and vw::weights.

Referenced by learn().

265  {
266  double max_norm = 0;
267  for (int i = 1; i <= m; i++)
268  for (int j = i; j <= m; j++) max_norm = fmax(max_norm, fabs(K[i][j]));
269  // printf("|K| = %f\n", max_norm);
270  if (max_norm < 1e7)
271  return;
272 
273  // implicit -> explicit representation
274  // printf("begin conversion: t = %d, norm(K) = %f\n", t, max_norm);
275 
276  // first step: K <- AKA'
277 
278  // K <- AK
279  for (int j = 1; j <= m; j++)
280  {
281  memset(tmp, 0, sizeof(double) * (m + 1));
282 
283  for (int i = 1; i <= m; i++)
284  {
285  for (int h = 1; h <= m; h++)
286  {
287  tmp[i] += A[i][h] * K[h][j];
288  }
289  }
290 
291  for (int i = 1; i <= m; i++) K[i][j] = tmp[i];
292  }
293  // K <- KA'
294  for (int i = 1; i <= m; i++)
295  {
296  memset(tmp, 0, sizeof(double) * (m + 1));
297 
298  for (int j = 1; j <= m; j++)
299  for (int h = 1; h <= m; h++) tmp[j] += K[i][h] * A[j][h];
300 
301  for (int j = 1; j <= m; j++)
302  {
303  K[i][j] = tmp[j];
304  }
305  }
306 
307  // second step: w[0] <- w[0] + (DZ)'b, b <- 0.
308 
309  uint32_t length = 1 << all->num_bits;
310  for (uint32_t i = 0; i < length; i++)
311  {
312  weight& w = all->weights.strided_index(i);
313  for (int j = 1; j <= m; j++) w += (&w)[j] * b[j] * D[j];
314  }
315 
316  memset(b, 0, sizeof(double) * (m + 1));
317 
318  // third step: Z <- ADZ, A, D <- Identity
319 
320  // double norm = 0;
321  for (uint32_t i = 0; i < length; ++i)
322  {
323  memset(tmp, 0, sizeof(float) * (m + 1));
324  weight& w = all->weights.strided_index(i);
325  for (int j = 1; j <= m; j++)
326  {
327  for (int h = 1; h <= m; ++h) tmp[j] += A[j][h] * D[h] * (&w)[h];
328  }
329  for (int j = 1; j <= m; ++j)
330  {
331  // norm = std::max(norm, fabs(tmp[j]));
332  (&w)[j] = tmp[j];
333  }
334  }
335  // printf("|Z| = %f\n", norm);
336 
337  for (int i = 1; i <= m; i++)
338  {
339  memset(A[i], 0, sizeof(double) * (m + 1));
340  D[i] = 1;
341  A[i][i] = 1;
342  }
343  }
float * D
Definition: OjaNewton.cc:44
float ** K
Definition: OjaNewton.cc:46
parameters weights
Definition: global_data.h:537
float * tmp
Definition: OjaNewton.cc:50
float ** A
Definition: OjaNewton.cc:45
uint32_t num_bits
Definition: global_data.h:398
weight & strided_index(size_t index)
float weight
vw * all
Definition: OjaNewton.cc:34
float * b
Definition: OjaNewton.cc:43

◆ compute_AZx()

void OjaNewton::compute_AZx ( )
inline

Definition at line 117 of file OjaNewton.cc.

Referenced by learn().

118  {
119  for (int i = 1; i <= m; i++)
120  {
121  data.AZx[i] = 0;
122  for (int j = 1; j <= i; j++)
123  {
124  data.AZx[i] += A[i][j] * data.Zx[j];
125  }
126  }
127  }
float * AZx
Definition: OjaNewton.cc:26
struct update_data data
Definition: OjaNewton.cc:54
float ** A
Definition: OjaNewton.cc:45
float * Zx
Definition: OjaNewton.cc:25

◆ compute_delta()

void OjaNewton::compute_delta ( )
inline

Definition at line 147 of file OjaNewton.cc.

References f.

Referenced by learn().

148  {
149  data.bdelta = 0;
150  for (int i = 1; i <= m; i++)
151  {
152  float gamma = fmin(learning_rate_cnt / t, 1.f);
153 
154  // if different learning rates are used
155  /*data.delta[i] = gamma * data.AZx[i] * data.sketch_cnt;
156  for (int j = 1; j < i; j++) {
157  data.delta[i] -= A[i][j] * data.delta[j];
158  }
159  data.delta[i] /= A[i][i];*/
160 
161  // if a same learning rate is used
162  data.delta[i] = gamma * data.Zx[i] * data.sketch_cnt;
163 
164  data.bdelta += data.delta[i] * b[i];
165  }
166  }
struct update_data data
Definition: OjaNewton.cc:54
float * delta
Definition: OjaNewton.cc:27
float learning_rate_cnt
Definition: OjaNewton.cc:56
float bdelta
Definition: OjaNewton.cc:28
float sketch_cnt
Definition: OjaNewton.cc:23
float * Zx
Definition: OjaNewton.cc:25
float * b
Definition: OjaNewton.cc:43
float f
Definition: cache.cc:40

◆ initialize_Z()

void OjaNewton::initialize_Z ( parameters weights)
inline

Definition at line 60 of file OjaNewton.cc.

References f, NORM2, vw::num_bits, and parameters::strided_index().

Referenced by save_load().

61  {
62  uint32_t length = 1 << all->num_bits;
63  if (normalize) // initialize normalization part
64  {
65  for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[NORM2] = 0.1f;
66  }
67  if (!random_init)
68  {
69  // simple initialization
70  for (int i = 1; i <= m; i++) (&(weights.strided_index(i)))[i] = 1.f;
71  }
72  else
73  {
74  // more complicated initialization: orthgonal basis of a random matrix
75 
76  const double PI2 = 2.f * 3.1415927f;
77 
78  for (uint32_t i = 0; i < length; i++)
79  {
80  weight& w = weights.strided_index(i);
81  float r1, r2;
82  for (int j = 1; j <= m; j++)
83  {
84  // box-muller tranform: https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
85  // redraw until r1 should be strictly positive
86  do
87  {
88  r1 = _random_state->get_and_update_random();
89  r2 = _random_state->get_and_update_random();
90  } while (r1 == 0.f);
91 
92  (&w)[j] = std::sqrt(-2.f * log(r1)) * (float)cos(PI2 * r2);
93  }
94  }
95  }
96 
97  // Gram-Schmidt
98  for (int j = 1; j <= m; j++)
99  {
100  for (int k = 1; k <= j - 1; k++)
101  {
102  double tmp = 0;
103 
104  for (uint32_t i = 0; i < length; i++)
105  tmp += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[k];
106  for (uint32_t i = 0; i < length; i++)
107  (&(weights.strided_index(i)))[j] -= (float)tmp * (&(weights.strided_index(i)))[k];
108  }
109  double norm = 0;
110  for (uint32_t i = 0; i < length; i++)
111  norm += ((double)(&(weights.strided_index(i)))[j]) * (&(weights.strided_index(i)))[j];
112  norm = std::sqrt(norm);
113  for (uint32_t i = 0; i < length; i++) (&(weights.strided_index(i)))[j] /= (float)norm;
114  }
115  }
bool normalize
Definition: OjaNewton.cc:57
float * tmp
Definition: OjaNewton.cc:50
uint32_t num_bits
Definition: global_data.h:398
#define NORM2
Definition: OjaNewton.cc:17
std::shared_ptr< rand_state > _random_state
Definition: OjaNewton.cc:35
weight & strided_index(size_t index)
float weight
bool random_init
Definition: OjaNewton.cc:58
vw * all
Definition: OjaNewton.cc:34
float f
Definition: cache.cc:40

◆ update_A()

void OjaNewton::update_A ( )
inline

Definition at line 182 of file OjaNewton.cc.

Referenced by learn().

183  {
184  for (int i = 1; i <= m; i++)
185  {
186  for (int j = 1; j < i; j++)
187  {
188  zv[j] = 0;
189  for (int k = 1; k <= i; k++)
190  {
191  zv[j] += A[i][k] * K[k][j];
192  }
193  }
194 
195  for (int j = 1; j < i; j++)
196  {
197  vv[j] = 0;
198  for (int k = 1; k <= j; k++)
199  {
200  vv[j] += A[j][k] * zv[k];
201  }
202  }
203 
204  for (int j = 1; j < i; j++)
205  {
206  for (int k = j; k < i; k++)
207  {
208  A[i][j] -= vv[k] * A[k][j];
209  }
210  }
211 
212  float norm = 0;
213  for (int j = 1; j <= i; j++)
214  {
215  float temp = 0;
216  for (int k = 1; k <= i; k++)
217  {
218  temp += K[j][k] * A[i][k];
219  }
220  norm += A[i][j] * temp;
221  }
222  norm = sqrtf(norm);
223 
224  for (int j = 1; j <= i; j++)
225  {
226  A[i][j] /= norm;
227  }
228  }
229  }
float ** K
Definition: OjaNewton.cc:46
float * zv
Definition: OjaNewton.cc:48
float * vv
Definition: OjaNewton.cc:49
float ** A
Definition: OjaNewton.cc:45

◆ update_b()

void OjaNewton::update_b ( )
inline

Definition at line 231 of file OjaNewton.cc.

Referenced by learn().

232  {
233  for (int j = 1; j <= m; j++)
234  {
235  float tmp = 0;
236  for (int i = j; i <= m; i++)
237  {
238  tmp += ev[i] * data.AZx[i] * A[i][j] / (alpha * (alpha + ev[i]));
239  }
240  b[j] += tmp * data.g;
241  }
242  }
float * tmp
Definition: OjaNewton.cc:50
float * AZx
Definition: OjaNewton.cc:26
struct update_data data
Definition: OjaNewton.cc:54
float ** A
Definition: OjaNewton.cc:45
float * ev
Definition: OjaNewton.cc:42
float alpha
Definition: OjaNewton.cc:38
float * b
Definition: OjaNewton.cc:43

◆ update_D()

void OjaNewton::update_D ( )
inline

Definition at line 244 of file OjaNewton.cc.

245  {
246  for (int j = 1; j <= m; j++)
247  {
248  float scale = fabs(A[j][j]);
249  for (int i = j + 1; i <= m; i++) scale = fmin(fabs(A[i][j]), scale);
250  if (scale < 1e-10)
251  continue;
252  for (int i = 1; i <= m; i++)
253  {
254  A[i][j] /= scale;
255  K[j][i] *= scale;
256  K[i][j] *= scale;
257  }
258  b[j] /= scale;
259  D[j] *= scale;
260  // printf("D[%d] = %f\n", j, D[j]);
261  }
262  }
float * D
Definition: OjaNewton.cc:44
float ** K
Definition: OjaNewton.cc:46
float ** A
Definition: OjaNewton.cc:45
float * b
Definition: OjaNewton.cc:43

◆ update_eigenvalues()

void OjaNewton::update_eigenvalues ( )
inline

Definition at line 129 of file OjaNewton.cc.

References f.

Referenced by learn().

130  {
131  for (int i = 1; i <= m; i++)
132  {
133  float gamma = fmin(learning_rate_cnt / t, 1.f);
134  float tmp = data.AZx[i] * data.sketch_cnt;
135 
136  if (t == 1)
137  {
138  ev[i] = gamma * tmp * tmp;
139  }
140  else
141  {
142  ev[i] = (1 - gamma) * t * ev[i] / (t - 1) + gamma * t * tmp * tmp;
143  }
144  }
145  }
float * tmp
Definition: OjaNewton.cc:50
float * AZx
Definition: OjaNewton.cc:26
struct update_data data
Definition: OjaNewton.cc:54
float learning_rate_cnt
Definition: OjaNewton.cc:56
float sketch_cnt
Definition: OjaNewton.cc:23
float * ev
Definition: OjaNewton.cc:42
float f
Definition: cache.cc:40

◆ update_K()

void OjaNewton::update_K ( )
inline

Definition at line 168 of file OjaNewton.cc.

Referenced by learn().

169  {
171  for (int i = 1; i <= m; i++)
172  {
173  for (int j = 1; j <= m; j++)
174  {
175  K[i][j] += data.delta[i] * data.Zx[j] * data.sketch_cnt;
176  K[i][j] += data.delta[j] * data.Zx[i] * data.sketch_cnt;
177  K[i][j] += data.delta[i] * data.delta[j] * tmp;
178  }
179  }
180  }
float ** K
Definition: OjaNewton.cc:46
float * tmp
Definition: OjaNewton.cc:50
struct update_data data
Definition: OjaNewton.cc:54
float * delta
Definition: OjaNewton.cc:27
float sketch_cnt
Definition: OjaNewton.cc:23
float * Zx
Definition: OjaNewton.cc:25
float norm2_x
Definition: OjaNewton.cc:24

Member Data Documentation

◆ _random_state

std::shared_ptr<rand_state> OjaNewton::_random_state

Definition at line 35 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ A

float** OjaNewton::A

Definition at line 45 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ all

vw* OjaNewton::all

Definition at line 34 of file OjaNewton.cc.

Referenced by learn(), OjaNewton_setup(), predict(), and save_load().

◆ alpha

float OjaNewton::alpha

Definition at line 38 of file OjaNewton.cc.

Referenced by OjaNewton_setup(), and update_wbar_and_Zx().

◆ b

float* OjaNewton::b

Definition at line 43 of file OjaNewton.cc.

Referenced by make_pred(), and OjaNewton_setup().

◆ buffer

example** OjaNewton::buffer

Definition at line 52 of file OjaNewton.cc.

Referenced by learn(), and OjaNewton_setup().

◆ cnt

int OjaNewton::cnt

Definition at line 39 of file OjaNewton.cc.

Referenced by learn(), and OjaNewton_setup().

◆ D

float* OjaNewton::D

◆ data

struct update_data OjaNewton::data

Definition at line 54 of file OjaNewton.cc.

Referenced by learn(), OjaNewton_setup(), and predict().

◆ epoch_size

int OjaNewton::epoch_size

Definition at line 37 of file OjaNewton.cc.

Referenced by learn(), and OjaNewton_setup().

◆ ev

float* OjaNewton::ev

Definition at line 42 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ K

float** OjaNewton::K

Definition at line 46 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ learning_rate_cnt

float OjaNewton::learning_rate_cnt

Definition at line 56 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ m

int OjaNewton::m

◆ normalize

bool OjaNewton::normalize

◆ random_init

bool OjaNewton::random_init

Definition at line 58 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ t

int OjaNewton::t

Definition at line 40 of file OjaNewton.cc.

Referenced by learn(), and OjaNewton_setup().

◆ tmp

float* OjaNewton::tmp

Definition at line 50 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ vv

float* OjaNewton::vv

Definition at line 49 of file OjaNewton.cc.

Referenced by OjaNewton_setup().

◆ weight_buffer

float* OjaNewton::weight_buffer

Definition at line 53 of file OjaNewton.cc.

Referenced by learn(), and OjaNewton_setup().

◆ zv

float* OjaNewton::zv

Definition at line 48 of file OjaNewton.cc.

Referenced by OjaNewton_setup().


The documentation for this struct was generated from the following file: