Vowpal Wabbit
autolink.cc
Go to the documentation of this file.
1 #include "autolink.h"
2 
3 #include "learner.h"
4 #include "global_data.h"
5 #include "parse_args.h"
6 
7 #include <cstdint>
8 
9 using namespace VW::config;
10 
11 namespace VW
12 {
13 struct autolink
14 {
15  autolink(uint32_t d, uint32_t stride_shift);
16  void predict(LEARNER::single_learner& base, example& ec);
17  void learn(LEARNER::single_learner& base, example& ec);
18 
19  private:
20  void prepare_example(LEARNER::single_learner& base, example& ec);
21  void reset_example(example& ec);
22 
23  // degree of the polynomial
24  const uint32_t _poly_degree;
25  const uint32_t _stride_shift;
26  static constexpr int AUTOCONSTANT = 524267083;
27 };
28 } // namespace VW
29 
30 VW::autolink::autolink(uint32_t poly_degree, uint32_t stride_shift)
31  : _poly_degree(poly_degree), _stride_shift(stride_shift)
32 {
33 }
34 
36 {
37  prepare_example(base, ec);
38  base.predict(ec);
39  reset_example(ec);
40 }
41 
43 {
44  prepare_example(base, ec);
45  base.learn(ec);
46  reset_example(ec);
47 }
48 
50 {
51  base.predict(ec);
52  float base_pred = ec.pred.scalar;
53 
54  // Add features of label.
57  for (size_t i = 0; i < _poly_degree; i++)
58  {
59  if (base_pred != 0.)
60  {
61  fs.push_back(base_pred, AUTOCONSTANT + (i << _stride_shift));
62  base_pred *= ec.pred.scalar;
63  }
64  }
66 }
67 
69 {
72  fs.clear();
73  ec.indices.pop();
74 }
75 
76 template <bool is_learn>
78 {
79  if (is_learn)
80  b.learn(base, ec);
81  else
82  b.predict(base, ec);
83 }
84 
86 {
87  uint32_t d;
88  option_group_definition new_options("Autolink");
89  new_options.add(make_option("autolink", d).keep().help("create link function with polynomial d"));
90  options.add_and_parse(new_options);
91 
92  if (!options.was_supplied("autolink"))
93  return nullptr;
94 
95  auto autolink_reduction = scoped_calloc_or_throw<VW::autolink>(d, all.weights.stride_shift());
96  return make_base(init_learner(
97  autolink_reduction, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>));
98 }
v_array< namespace_index > indices
parameters weights
Definition: global_data.h:537
void predict(E &ec, size_t i=0)
Definition: learner.h:169
T pop()
Definition: v_array.h:58
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
float scalar
Definition: example.h:45
the core definition of a set of features.
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
virtual bool was_supplied(const std::string &key)=0
void clear()
constexpr unsigned char autolink_namespace
Definition: constant.h:24
option_group_definition & add(T &&op)
Definition: options.h:90
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float total_sum_feat_sq
Definition: example.h:71
float sum_feat_sq
Definition: autolink.cc:11
uint32_t stride_shift()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965