Vowpal Wabbit
learner.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5  */
6 #pragma once
7 // This is the interface for a learning algorithm
8 #include <iostream>
9 #include "memory.h"
10 #include "multiclass.h"
11 #include "simple_label.h"
12 #include "parser.h"
13 
14 #include <memory>
15 
16 namespace prediction_type
17 {
19 {
29 };
30 
32 } // namespace prediction_type
33 
34 namespace LEARNER
35 {
36 template <class T, class E>
37 struct learner;
38 
39 using base_learner = learner<char, char>;
42 
43 struct func_data
44 {
45  using fn = void (*)(void* data);
46  void* data;
49 };
50 
51 inline func_data tuple_dbf(void* data, base_learner* base, void (*func)(void*))
52 {
53  func_data foo;
54  foo.data = data;
55  foo.base = base;
56  foo.func = func;
57  return foo;
58 }
59 
60 struct learn_data
61 {
62  using fn = void (*)(void* data, base_learner& base, void* ex);
63  using multi_fn = void (*)(void* data, base_learner& base, void* ex, size_t count, size_t step, polyprediction* pred,
64  bool finalize_predictions);
65 
66  void* data;
72 };
73 
75 {
76  using fn = float (*)(void* data, base_learner& base, example& ex);
77  void* data;
79 };
80 
82 {
83  using fn = void (*)(void*, io_buf&, bool read, bool text);
84  void* data;
87 };
88 
90 {
91  using fn = void (*)(vw&, void* data, void* ex);
92  void* data;
95 };
96 
97 void generic_driver(vw& all);
98 void generic_driver(const std::vector<vw*>& alls);
99 void generic_driver_onethread(vw& all);
100 
101 inline void noop_sl(void*, io_buf&, bool, bool) {}
102 inline void noop(void*) {}
103 inline float noop_sensitivity(void*, base_learner&, example&)
104 {
105  std::cout << std::endl;
106  return 0.;
107 }
108 float recur_sensitivity(void*, base_learner&, example&);
109 
110 inline void increment_offset(example& ex, const size_t increment, const size_t i)
111 {
112  ex.ft_offset += static_cast<uint32_t>(increment * i);
113 }
114 
115 inline void increment_offset(multi_ex& ec_seq, const size_t increment, const size_t i)
116 {
117  for (auto ec : ec_seq) ec->ft_offset += static_cast<uint32_t>(increment * i);
118 }
119 
120 inline void decrement_offset(example& ex, const size_t increment, const size_t i)
121 {
122  assert(ex.ft_offset >= increment * i);
123  ex.ft_offset -= static_cast<uint32_t>(increment * i);
124 }
125 
126 inline void decrement_offset(multi_ex& ec_seq, const size_t increment, const size_t i)
127 {
128  for (auto ec : ec_seq)
129  {
130  assert(ec->ft_offset >= increment * i);
131  ec->ft_offset -= static_cast<uint32_t>(increment * i);
132  }
133 }
134 
135 template <class T, class E>
136 struct learner
137 {
138  private:
147 
148  std::shared_ptr<void> learner_data;
149  learner(){}; // Should only be able to construct a learner through init_learner function
150  public:
152  size_t weights; // this stores the number of "weight vectors" required by the learner.
153  size_t increment;
154  bool is_multiline; // Is this a single-line or multi-line reduction?
155 
156  using end_fptr_type = void (*)(vw&, void*, void*);
157  using finish_fptr_type = void (*)(void*);
158 
159  // called once for each example. Must work under reduction.
160  inline void learn(E& ec, size_t i = 0)
161  {
162  assert((is_multiline && std::is_same<multi_ex, E>::value) ||
163  (!is_multiline && std::is_same<example, E>::value)); // sanity check under debug compile
164  increment_offset(ec, increment, i);
165  learn_fd.learn_f(learn_fd.data, *learn_fd.base, (void*)&ec);
166  decrement_offset(ec, increment, i);
167  }
168 
169  inline void predict(E& ec, size_t i = 0)
170  {
171  assert((is_multiline && std::is_same<multi_ex, E>::value) ||
172  (!is_multiline && std::is_same<example, E>::value)); // sanity check under debug compile
173  increment_offset(ec, increment, i);
174  learn_fd.predict_f(learn_fd.data, *learn_fd.base, (void*)&ec);
175  decrement_offset(ec, increment, i);
176  }
177 
178  inline void multipredict(E& ec, size_t lo, size_t count, polyprediction* pred, bool finalize_predictions)
179  {
180  assert((is_multiline && std::is_same<multi_ex, E>::value) ||
181  (!is_multiline && std::is_same<example, E>::value)); // sanity check under debug compile
182  if (learn_fd.multipredict_f == NULL)
183  {
184  increment_offset(ec, increment, lo);
185  for (size_t c = 0; c < count; c++)
186  {
187  learn_fd.predict_f(learn_fd.data, *learn_fd.base, (void*)&ec);
188  if (finalize_predictions)
189  pred[c] = ec.pred; // TODO: this breaks for complex labels because = doesn't do deep copy!
190  else
191  pred[c].scalar = ec.partial_prediction;
192  // pred[c].scalar = finalize_prediction ec.partial_prediction; // TODO: this breaks for complex labels because =
193  // doesn't do deep copy! // note works if ec.partial_prediction, but only if finalize_prediction is run????
194  increment_offset(ec, increment, 1);
195  }
196  decrement_offset(ec, increment, lo + count);
197  }
198  else
199  {
200  increment_offset(ec, increment, lo);
201  learn_fd.multipredict_f(learn_fd.data, *learn_fd.base, (void*)&ec, count, increment, pred, finalize_predictions);
202  decrement_offset(ec, increment, lo);
203  }
204  }
205 
206  template <class L>
207  inline void set_predict(void (*u)(T&, L&, E&))
208  {
209  learn_fd.predict_f = (learn_data::fn)u;
210  }
211  template <class L>
212  inline void set_learn(void (*u)(T&, L&, E&))
213  {
214  learn_fd.learn_f = (learn_data::fn)u;
215  }
216  template <class L>
217  inline void set_multipredict(void (*u)(T&, L&, E&, size_t, size_t, polyprediction*, bool))
218  {
219  learn_fd.multipredict_f = (learn_data::multi_fn)u;
220  }
221 
222  inline void update(E& ec, size_t i = 0)
223  {
224  assert((is_multiline && std::is_same<multi_ex, E>::value) ||
225  (!is_multiline && std::is_same<example, E>::value)); // sanity check under debug compile
226  increment_offset(ec, increment, i);
227  learn_fd.update_f(learn_fd.data, *learn_fd.base, (void*)&ec);
228  decrement_offset(ec, increment, i);
229  }
230  template <class L>
231  inline void set_update(void (*u)(T& data, L& base, E&))
232  {
233  learn_fd.update_f = (learn_data::fn)u;
234  }
235 
236  // used for active learning and confidence to determine how easily predictions are changed
237  inline void set_sensitivity(float (*u)(T& data, base_learner& base, example&))
238  {
239  sensitivity_fd.data = learn_fd.data;
240  sensitivity_fd.sensitivity_f = (sensitivity_data::fn)u;
241  }
242  inline float sensitivity(example& ec, size_t i = 0)
243  {
244  increment_offset(ec, increment, i);
245  const float ret = sensitivity_fd.sensitivity_f(sensitivity_fd.data, *learn_fd.base, ec);
246  decrement_offset(ec, increment, i);
247  return ret;
248  }
249 
250  // called anytime saving or loading needs to happen. Autorecursive.
251  inline void save_load(io_buf& io, const bool read, const bool text)
252  {
253  save_load_fd.save_load_f(save_load_fd.data, io, read, text);
254  if (save_load_fd.base)
255  save_load_fd.base->save_load(io, read, text);
256  }
257  inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool))
258  {
259  save_load_fd.save_load_f = (save_load_data::fn)sl;
260  save_load_fd.data = learn_fd.data;
261  save_load_fd.base = learn_fd.base;
262  }
263 
264  // called to clean up state. Autorecursive.
265  void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data, learn_fd.base, (finish_fptr_type)(f)); }
266  inline void finish()
267  {
268  if (finisher_fd.data)
269  {
270  finisher_fd.func(finisher_fd.data);
271  }
272  learner_data.~shared_ptr<void>();
273  if (finisher_fd.base)
274  {
275  finisher_fd.base->finish();
276  free(finisher_fd.base);
277  }
278  }
279 
280  void end_pass()
281  {
282  end_pass_fd.func(end_pass_fd.data);
283  if (end_pass_fd.base)
284  end_pass_fd.base->end_pass();
285  } // autorecursive
286  void set_end_pass(void (*f)(T&)) { end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (func_data::fn)f); }
287 
288  // called after parsing of examples is complete. Autorecursive.
290  {
291  end_examples_fd.func(end_examples_fd.data);
292  if (end_examples_fd.base)
293  end_examples_fd.base->end_examples();
294  }
295  void set_end_examples(void (*f)(T&)) { end_examples_fd = tuple_dbf(learn_fd.data, learn_fd.base, (func_data::fn)f); }
296 
297  // Called at the beginning by the driver. Explicitly not recursive.
298  void init_driver() { init_fd.func(init_fd.data); }
299  void set_init_driver(void (*f)(T&)) { init_fd = tuple_dbf(learn_fd.data, learn_fd.base, (func_data::fn)f); }
300 
301  // called after learn example for each example. Explicitly not recursive.
302  inline void finish_example(vw& all, E& ec)
303  {
304  finish_example_fd.finish_example_f(all, finish_example_fd.data, (void*)&ec);
305  }
306  // called after learn example for each example. Explicitly not recursive.
307  void set_finish_example(void (*f)(vw& all, T&, E&))
308  {
309  finish_example_fd.data = learn_fd.data;
310  finish_example_fd.finish_example_f = (end_fptr_type)(f);
311  }
312 
313  template <class L>
314  static learner<T, E>& init_learner(T* dat, L* base, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&), size_t ws,
316  {
317  learner<T, E>& ret = calloc_or_throw<learner<T, E> >();
318 
319  if (base != nullptr)
320  { // a reduction
321 
322  // This is a copy assignment into the current object. The purpose is to copy all of the
323  // function data objects so that if this reduction does not define a function such as
324  // save_load then calling save_load on this object will essentially result in forwarding the
325  // call the next reduction that actually implements it.
326  ret = *(learner<T, E>*)(base);
327 
328  ret.learn_fd.base = make_base(*base);
329  ret.sensitivity_fd.sensitivity_f = (sensitivity_data::fn)recur_sensitivity;
330  ret.finisher_fd.data = dat;
331  ret.finisher_fd.base = make_base(*base);
332  ret.finisher_fd.func = (func_data::fn)noop;
333  ret.weights = ws;
334  ret.increment = base->increment * ret.weights;
335  }
336  else // a base learner
337  {
338  ret.weights = 1;
339  ret.increment = ws;
340  ret.end_pass_fd.func = (func_data::fn)noop;
341  ret.end_examples_fd.func = (func_data::fn)noop;
342  ret.init_fd.func = (func_data::fn)noop;
343  ret.save_load_fd.save_load_f = (save_load_data::fn)noop_sl;
344  ret.finisher_fd.data = dat;
345  ret.finisher_fd.func = (func_data::fn)noop;
347  ret.finish_example_fd.data = dat;
349  }
350 
351  ret.learner_data = std::shared_ptr<T>(dat, [](T* ptr) {
352  ptr->~T();
353  free(ptr);
354  });
355 
356  ret.learn_fd.data = dat;
357  ret.learn_fd.learn_f = (learn_data::fn)learn;
358  ret.learn_fd.update_f = (learn_data::fn)learn;
359  ret.learn_fd.predict_f = (learn_data::fn)predict;
360  ret.learn_fd.multipredict_f = nullptr;
361  ret.pred_type = pred_type;
362  ret.is_multiline = std::is_same<multi_ex, E>::value;
363 
364  return ret;
365  }
366 };
367 
368 template <class T, class E, class L>
369 learner<T, E>& init_learner(free_ptr<T>& dat, L* base, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&),
370  size_t ws, prediction_type::prediction_type_t pred_type)
371 {
372  auto ret = &learner<T, E>::init_learner(dat.get(), base, learn, predict, ws, pred_type);
373 
374  dat.release();
375  return *ret;
376 }
377 
378 // base learner/predictor
379 template <class T, class E, class L>
381  free_ptr<T>& dat, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&), size_t params_per_weight)
382 {
383  auto ret =
384  &learner<T, E>::init_learner(dat.get(), (L*)nullptr, learn, predict, params_per_weight, prediction_type::scalar);
385 
386  dat.release();
387  return *ret;
388 }
389 
390 // base predictor only
391 template <class T, class E, class L>
392 learner<T, E>& init_learner(void (*predict)(T&, L&, E&), size_t params_per_weight)
393 {
395  nullptr, (L*)nullptr, predict, predict, params_per_weight, prediction_type::scalar);
396 }
397 
398 template <class T, class E, class L>
399 learner<T, E>& init_learner(free_ptr<T>& dat, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&),
400  size_t params_per_weight, prediction_type::prediction_type_t pred_type)
401 {
402  auto ret = &learner<T, E>::init_learner(dat.get(), (L*)nullptr, learn, predict, params_per_weight, pred_type);
403  dat.release();
404  return *ret;
405 }
406 
407 // reduction with default prediction type
408 template <class T, class E, class L>
410  free_ptr<T>& dat, L* base, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&), size_t ws)
411 {
412  auto ret = &learner<T, E>::init_learner(dat.get(), base, learn, predict, ws, base->pred_type);
413 
414  dat.release();
415  return *ret;
416 }
417 
418 // reduction with default num_params
419 template <class T, class E, class L>
420 learner<T, E>& init_learner(free_ptr<T>& dat, L* base, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&))
421 {
422  auto ret = &learner<T, E>::init_learner(dat.get(), base, learn, predict, 1, base->pred_type);
423 
424  dat.release();
425  return *ret;
426 }
427 
428 // Reduction with no data.
429 template <class T, class E, class L>
430 learner<T, E>& init_learner(L* base, void (*learn)(T&, L&, E&), void (*predict)(T&, L&, E&))
431 {
432  return learner<T, E>::init_learner(nullptr, base, learn, predict, 1, base->pred_type);
433 }
434 
435 // multiclass reduction
436 template <class T, class E, class L>
437 learner<T, E>& init_multiclass_learner(free_ptr<T>& dat, L* base, void (*learn)(T&, L&, E&),
438  void (*predict)(T&, L&, E&), parser* p, size_t ws,
440 {
441  learner<T, E>& l = learner<T, E>::init_learner(dat.get(), base, learn, predict, ws, pred_type);
442 
443  dat.release();
444  l.set_finish_example(MULTICLASS::finish_example<T>);
446  return l;
447 }
448 
449 template <class T, class E, class L>
450 learner<T, E>& init_cost_sensitive_learner(free_ptr<T>& dat, L* base, void (*learn)(T&, L&, E&),
451  void (*predict)(T&, L&, E&), parser* p, size_t ws,
453 {
454  learner<T, E>& l = learner<T, E>::init_learner(dat.get(), base, learn, predict, ws, pred_type);
455  dat.release();
458  return l;
459 }
460 
461 template <class T, class E>
463 {
464  return (base_learner*)(&base);
465 }
466 
467 template <class T, class E>
469 {
470  if (l->is_multiline) // Tried to use a singleline reduction as a multiline reduction
471  return (multi_learner*)(l);
472  THROW("Tried to use a singleline reduction as a multiline reduction");
473 }
474 
475 template <class T, class E>
477 {
478  if (!l->is_multiline) // Tried to use a multiline reduction as a singleline reduction
479  return (single_learner*)(l);
480  THROW("Tried to use a multiline reduction as a singleline reduction");
481 }
482 
483 template <bool is_learn>
484 void multiline_learn_or_predict(multi_learner& base, multi_ex& examples, const uint64_t offset, const uint32_t id = 0)
485 {
486  std::vector<uint64_t> saved_offsets;
487  for (auto ec : examples)
488  {
489  saved_offsets.push_back(ec->ft_offset);
490  ec->ft_offset = offset;
491  }
492 
493  if (is_learn)
494  base.learn(examples, id);
495  else
496  base.predict(examples, id);
497 
498  for (size_t i = 0; i < examples.size(); i++) examples[i]->ft_offset = saved_offsets[i];
499 }
500 } // namespace LEARNER
void set_multipredict(void(*u)(T &, L &, E &, size_t, size_t, polyprediction *, bool))
Definition: learner.h:217
void finish_example(vw &all, E &ec)
Definition: learner.h:302
void save_load(io_buf &io, const bool read, const bool text)
Definition: learner.h:251
learn_data learn_fd
Definition: learner.h:140
void end_examples()
Definition: learner.h:289
void set_init_driver(void(*f)(T &))
Definition: learner.h:299
void set_update(void(*u)(T &data, L &base, E &))
Definition: learner.h:231
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void increment_offset(multi_ex &ec_seq, const size_t increment, const size_t i)
Definition: learner.h:115
void set_learn(void(*u)(T &, L &, E &))
Definition: learner.h:212
void generic_driver(ready_examples_queue &examples, context_type &context)
Definition: learner.cc:253
float scalar
Definition: example.h:45
void finish()
Definition: learner.h:266
label_parser cs_label
void multiline_learn_or_predict(multi_learner &base, multi_ex &examples, const uint64_t offset, const uint32_t id=0)
Definition: learner.h:484
func_data end_examples_fd
Definition: learner.h:145
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:450
void(*)(void *, io_buf &, bool read, bool text) fn
Definition: learner.h:83
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void set_predict(void(*u)(T &, L &, E &))
Definition: learner.h:207
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
save_load_data save_load_fd
Definition: learner.h:143
float(*)(void *data, base_learner &base, example &ex) fn
Definition: learner.h:76
base_learner * base
Definition: learner.h:67
void(*)(void *data, base_learner &base, void *ex, size_t count, size_t step, polyprediction *pred, bool finalize_predictions) multi_fn
Definition: learner.h:64
float recur_sensitivity(void *, base_learner &base, example &ec)
Definition: learner.cc:306
base_learner * base
Definition: learner.h:85
label_parser mc_label
Definition: multiclass.cc:93
func_data init_fd
Definition: learner.h:139
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
multi_fn multipredict_f
Definition: learner.h:71
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &))
Definition: learner.h:430
void(*)(vw &, void *data, void *ex) fn
Definition: learner.h:91
void(*)(void *data) fn
Definition: learner.h:45
void(*)(void *data, base_learner &base, void *ex) fn
Definition: learner.h:62
void(*)(vw &, void *, void *) end_fptr_type
Definition: learner.h:156
void finish_example(vw &all, example &ec)
base_learner * base
Definition: learner.h:47
bool is_multiline
Definition: learner.h:154
float noop_sensitivity(void *, base_learner &, example &)
Definition: learner.h:103
Definition: io_buf.h:54
void generic_driver_onethread(vw &all)
Definition: learner.cc:285
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
finish_example_data finish_example_fd
Definition: learner.h:142
std::vector< example * > multi_ex
Definition: example.h:122
base_learner * base
Definition: learner.h:93
void end_pass()
Definition: learner.h:280
sensitivity_data sensitivity_fd
Definition: learner.h:141
size_t increment
Definition: learner.h:153
void set_sensitivity(float(*u)(T &data, base_learner &base, example &))
Definition: learner.h:237
func_data finisher_fd
Definition: learner.h:146
static learner< T, E > & init_learner(T *dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:314
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
void set_finish(void(*f)(T &))
Definition: learner.h:265
func_data tuple_dbf(void *data, base_learner *base, void(*func)(void *))
Definition: learner.h:51
prediction_type::prediction_type_t pred_type
Definition: learner.h:149
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178
void noop_sl(void *, io_buf &, bool, bool)
Definition: learner.h:101
std::shared_ptr< void > learner_data
Definition: learner.h:148
Definition: parser.h:38
size_t weights
Definition: learner.h:152
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
void update(E &ec, size_t i=0)
Definition: learner.h:222
void learn(E &ec, size_t i=0)
Definition: learner.h:160
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
void init_driver()
Definition: learner.h:298
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
void decrement_offset(multi_ex &ec_seq, const size_t increment, const size_t i)
Definition: learner.h:126
func_data end_pass_fd
Definition: learner.h:144
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
void noop(void *)
Definition: learner.h:102
void set_end_examples(void(*f)(T &))
Definition: learner.h:295
label_parser lp
Definition: parser.h:102
learner< char, char > base_learner
void return_simple_example(vw &all, void *, example &ec)