Vowpal Wabbit
Classes | Functions
gd_mf.cc File Reference
#include <fstream>
#include <float.h>
#include <string.h>
#include <stdio.h>
#include <netdb.h>
#include "gd.h"
#include "rand48.h"
#include "reductions.h"
#include "vw_exception.h"
#include "array_parameters.h"

Go to the source code of this file.

Classes

struct  gdmf
 
struct  pred_offset
 
class  set_rand_wrapper< T >
 

Functions

void mf_print_offset_features (gdmf &d, example &ec, size_t offset)
 
void mf_print_audit_features (gdmf &d, example &ec, size_t offset)
 
void offset_add (pred_offset &res, const float fx, float &fw)
 
template<class T >
float mf_predict (gdmf &d, example &ec, T &weights)
 
float mf_predict (gdmf &d, example &ec)
 
template<class T >
void sd_offset_update (T &weights, features &fs, uint64_t offset, float update, float regularization)
 
template<class T >
void mf_train (gdmf &d, example &ec, T &weights)
 
void mf_train (gdmf &d, example &ec)
 
void save_load (gdmf &d, io_buf &model_file, bool read, bool text)
 
void end_pass (gdmf &d)
 
void predict (gdmf &d, single_learner &, example &ec)
 
void learn (gdmf &d, single_learner &, example &ec)
 
base_learnergd_mf_setup (options_i &options, vw &all)
 

Function Documentation

◆ end_pass()

void end_pass ( gdmf d)

Definition at line 298 of file gd_mf.cc.

References gdmf::all, vw::check_holdout_every_n_passes, vw::current_pass, gdmf::early_stop_thres, vw::eta, vw::eta_decay_rate, vw::final_regressor_name, finalize_regressor(), vw::holdout_set_off, gdmf::no_win_counter, vw::save_per_pass, save_predictor(), set_done(), and summarize_holdout_set().

299 {
300  vw* all = d.all;
301 
302  all->eta *= all->eta_decay_rate;
303  if (all->save_per_pass)
305 
306  if (!all->holdout_set_off)
307  {
310  if ((d.early_stop_thres == d.no_win_counter) &&
311  ((all->check_holdout_every_n_passes <= 1) || ((all->current_pass % all->check_holdout_every_n_passes) == 0)))
312  set_done(*all);
313  }
314 }
void set_done(vw &all)
Definition: parser.cc:578
void finalize_regressor(vw &all, std::string reg_name)
bool holdout_set_off
Definition: global_data.h:499
size_t check_holdout_every_n_passes
Definition: global_data.h:503
bool summarize_holdout_set(vw &all, size_t &no_win_counter)
uint64_t early_stop_thres
Definition: gd_mf.cc:32
void save_predictor(vw &all, std::string reg_name, size_t current_pass)
uint64_t current_pass
Definition: global_data.h:396
float eta
Definition: global_data.h:531
bool save_per_pass
Definition: global_data.h:408
size_t no_win_counter
Definition: gd_mf.cc:31
std::string final_regressor_name
Definition: global_data.h:535
vw * all
Definition: gd_mf.cc:28
float eta_decay_rate
Definition: global_data.h:532

◆ gd_mf_setup()

base_learner* gd_mf_setup ( options_i options,
vw all 
)

Definition at line 327 of file gd_mf.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::end_pass(), vw::eta, f, VW::config::options_i::get_typed_option(), shared_data::holdout_best_loss, vw::holdout_set_off, LEARNER::init_learner(), vw::initial_t, learn(), LEARNER::make_base(), VW::config::make_option(), vw::power_t, ldamath::powf(), predict(), vw::random_weights, save_load(), vw::sd, LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_save_load(), parameters::stride_shift(), shared_data::t, THROW, UINT64_ONE, VW::config::options_i::was_supplied(), and vw::weights.

Referenced by parse_reductions().

328 {
329  auto data = scoped_calloc_or_throw<gdmf>();
330 
331  bool bfgs = false;
332  bool conjugate_gradient = false;
333  option_group_definition gf_md_options("Gradient Descent Matrix Factorization");
334  gf_md_options.add(make_option("rank", data->rank).keep().help("rank for matrix factorization."));
335 
336  // Not supported, need to be checked to be false.
337  gf_md_options.add(make_option("bfgs", bfgs).help("Option not supported by this reduction"));
338  gf_md_options.add(
339  make_option("conjugate_gradient", conjugate_gradient).help("Option not supported by this reduction"));
340  options.add_and_parse(gf_md_options);
341 
342  if (!options.was_supplied("rank"))
343  return nullptr;
344 
345  if (options.was_supplied("adaptive"))
346  THROW("adaptive is not implemented for matrix factorization");
347  if (options.was_supplied("normalized"))
348  THROW("normalized is not implemented for matrix factorization");
349  if (options.was_supplied("exact_adaptive_norm"))
350  THROW("normalized adaptive updates is not implemented for matrix factorization");
351 
352  if (bfgs || conjugate_gradient)
353  THROW("bfgs is not implemented for matrix factorization");
354 
355  data->all = &all;
356  data->no_win_counter = 0;
357 
358  // store linear + 2*rank weights per index, round up to power of two
359  float temp = ceilf(logf((float)(data->rank * 2 + 1)) / logf(2.f));
360  all.weights.stride_shift((size_t)temp);
361  all.random_weights = true;
362 
363  if (!all.holdout_set_off)
364  {
365  all.sd->holdout_best_loss = FLT_MAX;
366  data->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
367  }
368 
369  if (!options.was_supplied("learning_rate") && !options.was_supplied("l"))
370  all.eta = 10; // default learning rate to 10 for non default update rule
371 
372  // default initial_t to 1 instead of 0
373  if (!options.was_supplied("initial_t"))
374  {
375  all.sd->t = 1.f;
376  all.initial_t = 1.f;
377  }
378  all.eta *= powf((float)(all.sd->t), all.power_t);
379 
383 
384  return make_base(l);
385 }
parameters weights
Definition: global_data.h:537
float initial_t
Definition: global_data.h:530
float power_t
Definition: global_data.h:447
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
Definition: bfgs.cc:62
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
bool holdout_set_off
Definition: global_data.h:499
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
T powf(T, T)
Definition: lda_core.cc:428
virtual bool was_supplied(const std::string &key)=0
bool random_weights
Definition: global_data.h:492
void save_load(gdmf &d, io_buf &model_file, bool read, bool text)
Definition: gd_mf.cc:248
void predict(gdmf &d, single_learner &, example &ec)
Definition: gd_mf.cc:316
float eta
Definition: global_data.h:531
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
constexpr uint64_t UINT64_ONE
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
uint32_t stride_shift()
void learn(gdmf &d, single_learner &, example &ec)
Definition: gd_mf.cc:318
#define THROW(args)
Definition: vw_exception.h:181
void end_pass(gdmf &d)
Definition: gd_mf.cc:298
float f
Definition: cache.cc:40

◆ learn()

void learn ( gdmf d,
single_learner ,
example ec 
)

Definition at line 318 of file gd_mf.cc.

References gdmf::all, example::l, label_data::label, mf_predict(), mf_train(), polylabel::simple, and vw::training.

Referenced by gd_mf_setup().

319 {
320  vw& all = *d.all;
321 
322  mf_predict(d, ec);
323  if (all.training && ec.l.simple.label != FLT_MAX)
324  mf_train(d, ec);
325 }
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
bool training
Definition: global_data.h:488
polylabel l
Definition: example.h:57
float mf_predict(gdmf &d, example &ec, T &weights)
Definition: gd_mf.cc:96
void mf_train(gdmf &d, example &ec, T &weights)
Definition: gd_mf.cc:187
vw * all
Definition: gd_mf.cc:28

◆ mf_predict() [1/2]

template<class T >
float mf_predict ( gdmf d,
example ec,
T &  weights 
)

Definition at line 96 of file gd_mf.cc.

References gdmf::all, vw::audit, v_array< T >::clear(), example_predict::feature_space, GD::finalize_prediction(), loss_function::getLoss(), label_data::initial, example::l, label_data::label, vw::loss, mf_print_audit_features(), example::num_features, pred_offset::p, vw::pairs, v_array< T >::push_back(), gdmf::rank, gdmf::scalars, vw::sd, vw::set_minmax, polylabel::simple, THROW, and vw::triples.

Referenced by learn(), mf_predict(), and predict().

97 {
98  vw& all = *d.all;
99  label_data& ld = ec.l.simple;
100  float prediction = ld.initial;
101 
102  for (std::string& i : d.all->pairs)
103  {
104  ec.num_features -= ec.feature_space[(int)i[0]].size() * ec.feature_space[(int)i[1]].size();
105  ec.num_features += ec.feature_space[(int)i[0]].size() * d.rank;
106  ec.num_features += ec.feature_space[(int)i[1]].size() * d.rank;
107  }
108 
109  // clear stored predictions
110  d.scalars.clear();
111 
112  float linear_prediction = 0.;
113  // linear terms
114 
115  for (features& fs : ec) GD::foreach_feature<float, GD::vec_add, T>(weights, fs, linear_prediction);
116 
117  // store constant + linear prediction
118  // note: constant is now automatically added
119  d.scalars.push_back(linear_prediction);
120 
121  prediction += linear_prediction;
122  // interaction terms
123  for (std::string& i : d.all->pairs)
124  {
125  if (ec.feature_space[(int)i[0]].size() > 0 && ec.feature_space[(int)i[1]].size() > 0)
126  {
127  for (uint64_t k = 1; k <= d.rank; k++)
128  {
129  // x_l * l^k
130  // l^k is from index+1 to index+d.rank
131  // float x_dot_l = sd_offset_add(weights, ec.atomics[(int)(*i)[0]].begin(), ec.atomics[(int)(*i)[0]].end(), k);
132  pred_offset x_dot_l = {0., k};
133  GD::foreach_feature<pred_offset, offset_add, T>(weights, ec.feature_space[(int)i[0]], x_dot_l);
134  // x_r * r^k
135  // r^k is from index+d.rank+1 to index+2*d.rank
136  // float x_dot_r = sd_offset_add(weights, ec.atomics[(int)(*i)[1]].begin(), ec.atomics[(int)(*i)[1]].end(),
137  // k+d.rank);
138  pred_offset x_dot_r = {0., k + d.rank};
139  GD::foreach_feature<pred_offset, offset_add, T>(weights, ec.feature_space[(int)i[1]], x_dot_r);
140 
141  prediction += x_dot_l.p * x_dot_r.p;
142 
143  // store prediction from interaction terms
144  d.scalars.push_back(x_dot_l.p);
145  d.scalars.push_back(x_dot_r.p);
146  }
147  }
148  }
149 
150  if (all.triples.begin() != all.triples.end())
151  THROW("cannot use triples in matrix factorization");
152 
153  // d.scalars has linear, x_dot_l_1, x_dot_r_1, x_dot_l_2, x_dot_r_2, ...
154 
155  ec.partial_prediction = prediction;
156 
157  all.set_minmax(all.sd, ld.label);
158 
159  ec.pred.scalar = GD::finalize_prediction(all.sd, ec.partial_prediction);
160 
161  if (ld.label != FLT_MAX)
162  ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ec.weight;
163 
164  if (all.audit)
165  mf_print_audit_features(d, ec, 0);
166 
167  return ec.pred.scalar;
168 }
float finalize_prediction(shared_data *sd, float ret)
Definition: gd.cc:339
loss_function * loss
Definition: global_data.h:523
std::vector< std::string > pairs
Definition: global_data.h:459
v_array< float > scalars
Definition: gd_mf.cc:29
the core definition of a set of features.
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
std::array< features, NUM_NAMESPACES > feature_space
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
virtual float getLoss(shared_data *, float prediction, float label)=0
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
void clear()
Definition: v_array.h:88
size_t num_features
Definition: example.h:67
float initial
Definition: simple_label.h:16
std::vector< std::string > triples
Definition: global_data.h:461
polylabel l
Definition: example.h:57
bool audit
Definition: global_data.h:486
void mf_print_audit_features(gdmf &d, example &ec, size_t offset)
Definition: gd_mf.cc:81
#define THROW(args)
Definition: vw_exception.h:181
uint32_t rank
Definition: gd_mf.cc:30
vw * all
Definition: gd_mf.cc:28

◆ mf_predict() [2/2]

float mf_predict ( gdmf d,
example ec 
)

Definition at line 170 of file gd_mf.cc.

References gdmf::all, parameters::dense_weights, mf_predict(), parameters::sparse, parameters::sparse_weights, and vw::weights.

171 {
172  vw& all = *d.all;
173  if (all.weights.sparse)
174  return mf_predict(d, ec, all.weights.sparse_weights);
175  else
176  return mf_predict(d, ec, all.weights.dense_weights);
177 }
parameters weights
Definition: global_data.h:537
dense_parameters dense_weights
sparse_parameters sparse_weights
float mf_predict(gdmf &d, example &ec, T &weights)
Definition: gd_mf.cc:96
vw * all
Definition: gd_mf.cc:28

◆ mf_print_audit_features()

void mf_print_audit_features ( gdmf d,
example ec,
size_t  offset 
)

Definition at line 81 of file gd_mf.cc.

References gdmf::all, mf_print_offset_features(), example::pred, print_result(), polyprediction::scalar, vw::stdout_fileno, and example::tag.

Referenced by mf_predict().

82 {
83  print_result(d.all->stdout_fileno, ec.pred.scalar, -1, ec.tag);
84  mf_print_offset_features(d, ec, offset);
85 }
v_array< char > tag
Definition: example.h:63
void mf_print_offset_features(gdmf &d, example &ec, size_t offset)
Definition: gd_mf.cc:36
float scalar
Definition: example.h:45
int stdout_fileno
Definition: global_data.h:434
void print_result(int f, float res, v_array< char > tag, float lb, float ub)
Definition: bs.cc:136
polyprediction pred
Definition: example.h:60
vw * all
Definition: gd_mf.cc:28

◆ mf_print_offset_features()

void mf_print_offset_features ( gdmf d,
example ec,
size_t  offset 
)

Definition at line 36 of file gd_mf.cc.

References gdmf::all, f, parameters::mask(), vw::pairs, gdmf::rank, THROW, vw::triples, and vw::weights.

Referenced by mf_print_audit_features().

37 {
38  vw& all = *d.all;
39  parameters& weights = all.weights;
40  uint64_t mask = weights.mask();
41  for (features& fs : ec)
42  {
43  bool audit = !fs.space_names.empty();
44  for (auto& f : fs.values_indices_audit())
45  {
46  std::cout << '\t';
47  if (audit)
48  std::cout << f.audit().get()->first << '^' << f.audit().get()->second << ':';
49  std::cout << f.index() << "(" << ((f.index() + offset) & mask) << ")" << ':' << f.value();
50  std::cout << ':' << (&weights[f.index()])[offset];
51  }
52  }
53  for (std::string& i : all.pairs)
54  if (ec.feature_space[(unsigned char)i[0]].size() > 0 && ec.feature_space[(unsigned char)i[1]].size() > 0)
55  {
56  /* print out nsk^feature:hash:value:weight:nsk^feature^:hash:value:weight:prod_weights */
57  for (size_t k = 1; k <= d.rank; k++)
58  {
59  for (features::iterator_all& f1 : ec.feature_space[(unsigned char)i[0]].values_indices_audit())
60  for (features::iterator_all& f2 : ec.feature_space[(unsigned char)i[1]].values_indices_audit())
61  {
62  std::cout << '\t' << f1.audit().get()->first << k << '^' << f1.audit().get()->second << ':'
63  << ((f1.index() + k) & mask) << "(" << ((f1.index() + offset + k) & mask) << ")" << ':'
64  << f1.value();
65  std::cout << ':' << (&weights[f1.index()])[offset + k];
66 
67  std::cout << ':' << f2.audit().get()->first << k << '^' << f2.audit().get()->second << ':'
68  << ((f2.index() + k + d.rank) & mask) << "(" << ((f2.index() + offset + k + d.rank) & mask) << ")"
69  << ':' << f2.value();
70  std::cout << ':' << (&weights[f2.index()])[offset + k + d.rank];
71 
72  std::cout << ':' << (&weights[f1.index()])[offset + k] * (&weights[f2.index()])[offset + k + d.rank];
73  }
74  }
75  }
76  if (all.triples.begin() != all.triples.end())
77  THROW("cannot use triples in matrix factorization");
78  std::cout << std::endl;
79 }
parameters weights
Definition: global_data.h:537
std::vector< std::string > pairs
Definition: global_data.h:459
the core definition of a set of features.
iterator over values, indicies and audit space names
std::vector< std::string > triples
Definition: global_data.h:461
uint64_t mask()
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40
uint32_t rank
Definition: gd_mf.cc:30
vw * all
Definition: gd_mf.cc:28

◆ mf_train() [1/2]

template<class T >
void mf_train ( gdmf d,
example ec,
T &  weights 
)

Definition at line 187 of file gd_mf.cc.

References gdmf::all, vw::eta, loss_function::getUpdate(), example::l, vw::l2_lambda, label_data::label, vw::loss, vw::pairs, vw::power_t, ldamath::powf(), example::pred, gdmf::rank, polyprediction::scalar, gdmf::scalars, vw::sd, polylabel::simple, shared_data::t, THROW, vw::triples, GD::update(), and example::weight.

Referenced by learn(), and mf_train().

188 {
189  vw& all = *d.all;
190  label_data& ld = ec.l.simple;
191 
192  // use final prediction to get update size
193  // update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight
194  float eta_t = all.eta / powf((float)all.sd->t + ec.weight, (float)all.power_t) / 3.f * ec.weight;
195  float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); // ec.total_sum_feat_sq);
196 
197  float regularization = eta_t * all.l2_lambda;
198 
199  // linear update
200  for (features& fs : ec) sd_offset_update<T>(weights, fs, 0, update, regularization);
201 
202  // quadratic update
203  for (std::string& i : all.pairs)
204  {
205  if (ec.feature_space[(int)i[0]].size() > 0 && ec.feature_space[(int)i[1]].size() > 0)
206  {
207  // update l^k weights
208  for (size_t k = 1; k <= d.rank; k++)
209  {
210  // r^k \cdot x_r
211  float r_dot_x = d.scalars[2 * k];
212  // l^k <- l^k + update * (r^k \cdot x_r) * x_l
213  sd_offset_update<T>(weights, ec.feature_space[(int)i[0]], k, update * r_dot_x, regularization);
214  }
215  // update r^k weights
216  for (size_t k = 1; k <= d.rank; k++)
217  {
218  // l^k \cdot x_l
219  float l_dot_x = d.scalars[2 * k - 1];
220  // r^k <- r^k + update * (l^k \cdot x_l) * x_r
221  sd_offset_update<T>(weights, ec.feature_space[(int)i[1]], k + d.rank, update * l_dot_x, regularization);
222  }
223  }
224  }
225  if (all.triples.begin() != all.triples.end())
226  THROW("cannot use triples in matrix factorization");
227 }
loss_function * loss
Definition: global_data.h:523
virtual float getUpdate(float prediction, float label, float update_scale, float pred_per_update)=0
std::vector< std::string > pairs
Definition: global_data.h:459
float scalar
Definition: example.h:45
v_array< float > scalars
Definition: gd_mf.cc:29
float power_t
Definition: global_data.h:447
the core definition of a set of features.
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
shared_data * sd
Definition: global_data.h:375
T powf(T, T)
Definition: lda_core.cc:428
float l2_lambda
Definition: global_data.h:445
std::vector< std::string > triples
Definition: global_data.h:461
float eta
Definition: global_data.h:531
polylabel l
Definition: example.h:57
void update(gd &g, base_learner &, example &ec)
Definition: gd.cc:647
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
#define THROW(args)
Definition: vw_exception.h:181
uint32_t rank
Definition: gd_mf.cc:30
vw * all
Definition: gd_mf.cc:28

◆ mf_train() [2/2]

void mf_train ( gdmf d,
example ec 
)

Definition at line 229 of file gd_mf.cc.

References gdmf::all, parameters::dense_weights, mf_train(), parameters::sparse, parameters::sparse_weights, and vw::weights.

230 {
231  if (d.all->weights.sparse)
232  mf_train(d, ec, d.all->weights.sparse_weights);
233  else
234  mf_train(d, ec, d.all->weights.dense_weights);
235 }
parameters weights
Definition: global_data.h:537
dense_parameters dense_weights
sparse_parameters sparse_weights
void mf_train(gdmf &d, example &ec, T &weights)
Definition: gd_mf.cc:187
vw * all
Definition: gd_mf.cc:28

◆ offset_add()

void offset_add ( pred_offset res,
const float  fx,
float &  fw 
)

Definition at line 93 of file gd_mf.cc.

References pred_offset::offset, and pred_offset::p.

93 { res.p += (&fw)[res.offset] * fx; }
uint64_t offset
Definition: gd_mf.cc:90
float p
Definition: gd_mf.cc:89

◆ predict()

void predict ( gdmf d,
single_learner ,
example ec 
)

Definition at line 316 of file gd_mf.cc.

References mf_predict().

Referenced by gd_mf_setup().

316 { mf_predict(d, ec); }
float mf_predict(gdmf &d, example &ec, T &weights)
Definition: gd_mf.cc:96

◆ save_load()

void save_load ( gdmf d,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 248 of file gd_mf.cc.

References gdmf::all, bin_text_read_write_fixed(), parameters::dense_weights, io_buf::files, initialize_regressor(), vw::num_bits, vw::random_weights, gdmf::rank, dense_parameters::set_default(), sparse_parameters::set_default(), v_array< T >::size(), parameters::sparse, parameters::sparse_weights, parameters::stride(), parameters::strided_index(), and vw::weights.

Referenced by gd_mf_setup().

249 {
250  vw& all = *d.all;
251  uint64_t length = (uint64_t)1 << all.num_bits;
252  if (read)
253  {
255  if (all.random_weights)
256  {
257  uint32_t stride = all.weights.stride();
258  if (all.weights.sparse)
260  else
262  }
263  }
264 
265  if (model_file.files.size() > 0)
266  {
267  uint64_t i = 0;
268  size_t brw = 1;
269  do
270  {
271  brw = 0;
272  size_t K = d.rank * 2 + 1;
273  std::stringstream msg;
274  msg << i << " ";
275  brw += bin_text_read_write_fixed(model_file, (char*)&i, sizeof(i), "", read, msg, text);
276  if (brw != 0)
277  {
278  weight* w_i = &(all.weights.strided_index(i));
279  for (uint64_t k = 0; k < K; k++)
280  {
281  weight* v = w_i + k;
282  msg << v << " ";
283  brw += bin_text_read_write_fixed(model_file, (char*)v, sizeof(*v), "", read, msg, text);
284  }
285  }
286  if (text)
287  {
288  msg << "\n";
289  brw += bin_text_read_write_fixed(model_file, nullptr, 0, "", read, msg, text);
290  }
291 
292  if (!read)
293  ++i;
294  } while ((!read && i < length) || (read && brw > 0));
295  }
296 }
parameters weights
Definition: global_data.h:537
void initialize_regressor(vw &all, T &weights)
uint32_t stride()
void set_default(R &info)
uint32_t num_bits
Definition: global_data.h:398
size_t size() const
Definition: v_array.h:68
v_array< int > files
Definition: io_buf.h:64
bool random_weights
Definition: global_data.h:492
dense_parameters dense_weights
weight & strided_index(size_t index)
float weight
sparse_parameters sparse_weights
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
uint32_t rank
Definition: gd_mf.cc:30
vw * all
Definition: gd_mf.cc:28

◆ sd_offset_update()

template<class T >
void sd_offset_update ( T &  weights,
features fs,
uint64_t  offset,
float  update,
float  regularization 
)

Definition at line 180 of file gd_mf.cc.

References features::indicies, features::size(), and features::values.

181 {
182  for (size_t i = 0; i < fs.size(); i++)
183  (&weights[fs.indicies[i]])[offset] += update * fs.values[i] - regularization * (&weights[fs.indicies[i]])[offset];
184 }
v_array< feature_index > indicies
v_array< feature_value > values
size_t size() const
void update(gd &g, base_learner &, example &ec)
Definition: gd.cc:647