Vowpal Wabbit
baseline.cc
Go to the documentation of this file.
1 /*
2  Copyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5 */
6 #include <float.h>
7 #include <errno.h>
8 
9 #include "reductions.h"
10 #include "vw.h"
11 
12 using namespace LEARNER;
13 using namespace VW::config;
14 
15 namespace
16 {
17 const float max_multiplier = 1000.f;
18 const size_t baseline_enabled_idx = 1357; // feature index for enabling baseline
19 } // namespace
20 
21 namespace BASELINE
22 {
24 {
25  auto& fs = ec->feature_space[message_namespace];
26  for (auto& f : fs)
27  {
28  if (f.index() == baseline_enabled_idx)
29  {
30  f.value() = 1;
31  return;
32  }
33  }
34  // if not found, push new feature
35  fs.push_back(1, baseline_enabled_idx);
36 }
37 
39 {
40  auto& fs = ec->feature_space[message_namespace];
41  for (auto& f : fs)
42  {
43  if (f.index() == baseline_enabled_idx)
44  {
45  f.value() = 0;
46  return;
47  }
48  }
49 }
50 
52 {
53  auto& fs = ec->feature_space[message_namespace];
54  for (auto& f : fs)
55  {
56  if (f.index() == baseline_enabled_idx)
57  return f.value() == 1;
58  }
59  return false;
60 }
61 } // namespace BASELINE
62 
63 struct baseline
64 {
66  vw* all;
67  bool lr_scaling; // whether to scale baseline learning rate based on max label
69  bool global_only; // only use a global constant for the baseline
71  bool check_enabled; // only use baseline when the example contains enabled flag
72 
74  {
75  if (ec)
77  free(ec);
78  }
79 };
80 
81 void init_global(baseline& data)
82 {
83  if (!data.global_only)
84  return;
85  // use a separate global constant
87  // different index from constant to avoid conflicts
88  data.ec->feature_space[constant_namespace].push_back(
89  1, ((constant - 17) * data.all->wpp) << data.all->weights.stride_shift());
90  data.ec->total_sum_feat_sq++;
91  data.ec->num_features++;
92 }
93 
94 template <bool is_learn>
96 {
97  // no baseline if check_enabled is true and example contains flag
98  if (data.check_enabled && !BASELINE::baseline_enabled(&ec))
99  {
100  if (is_learn)
101  base.learn(ec);
102  else
103  base.predict(ec);
104  return;
105  }
106 
107  // always do a full prediction, for safety in accurate predictive validation
108  if (data.global_only)
109  {
110  if (!data.global_initialized)
111  {
112  init_global(data);
113  data.global_initialized = true;
114  }
115  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
116  base.predict(*data.ec);
117  ec.l.simple.initial = data.ec->pred.scalar;
118  base.predict(ec);
119  }
120  else
121  base.predict(ec);
122 
123  if (is_learn)
124  {
125  const float pred = ec.pred.scalar; // save 'safe' prediction
126 
127  // now learn
128  data.ec->l.simple = ec.l.simple;
129  if (!data.global_only)
130  {
131  // move label & constant features data over to baseline example
132  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
134  }
135 
136  // regress baseline on label
137  if (data.lr_scaling)
138  {
139  float multiplier = data.lr_multiplier;
140  if (multiplier == 0)
141  {
142  multiplier = std::max(0.0001f, std::max(std::abs(data.all->sd->min_label), std::abs(data.all->sd->max_label)));
143  if (multiplier > max_multiplier)
144  multiplier = max_multiplier;
145  }
146  data.all->eta *= multiplier;
147  base.learn(*data.ec);
148  data.all->eta /= multiplier;
149  }
150  else
151  base.learn(*data.ec);
152 
153  // regress residual
154  ec.l.simple.initial = data.ec->pred.scalar;
155  base.learn(ec);
156 
157  if (!data.global_only)
158  {
159  // move feature data back to the original example
161  }
162 
163  // return the safe prediction
164  ec.pred.scalar = pred;
165  }
166 }
167 
168 float sensitivity(baseline& data, base_learner& base, example& ec)
169 {
170  // no baseline if check_enabled is true and example contains flag
171  if (data.check_enabled && !BASELINE::baseline_enabled(&ec))
172  return base.sensitivity(ec);
173 
174  if (!data.global_only)
175  THROW("sensitivity for baseline without --global_only not implemented");
176 
177  // sensitivity of baseline term
178  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
179  data.ec->l.simple.label = ec.l.simple.label;
180  data.ec->pred.scalar = ec.pred.scalar;
181  // std::cout << "before base" << std::endl;
182  const float baseline_sens = base.sensitivity(*data.ec);
183  // std::cout << "base sens: " << baseline_sens << std::endl;
184 
185  // sensitivity of residual
186  as_singleline(&base)->predict(*data.ec);
187  ec.l.simple.initial = data.ec->pred.scalar;
188  const float sens = base.sensitivity(ec);
189  // std::cout << " residual sens: " << sens << std::endl;
190  return baseline_sens + sens;
191 }
192 
194 {
195  auto data = scoped_calloc_or_throw<baseline>();
196  bool baseline_option = false;
197  std::string loss_function;
198 
199  option_group_definition new_options("Baseline options");
200  new_options
201  .add(make_option("baseline", baseline_option)
202  .keep()
203  .help("Learn an additive baseline (from constant features) and a residual separately in regression."))
204  .add(make_option("lr_multiplier", data->lr_multiplier).help("learning rate multiplier for baseline model"))
205  .add(make_option("global_only", data->global_only)
206  .keep()
207  .help("use separate example with only global constant for baseline predictions"))
208  .add(make_option("check_enabled", data->check_enabled)
209  .keep()
210  .help("only use baseline when the example contains enabled flag"));
211  options.add_and_parse(new_options);
212 
213  if (!baseline_option)
214  return nullptr;
215 
216  // initialize baseline example
218  data->ec->interactions = &all.interactions;
219 
220  data->ec->in_use = true;
221  data->all = &all;
222 
223  auto loss_function_type = all.loss->getType();
224  if (loss_function_type != "logistic")
225  data->lr_scaling = true;
226 
227  auto base = as_singleline(setup_base(options, all));
228 
229  learner<baseline, example>& l = init_learner(data, base, predict_or_learn<true>, predict_or_learn<false>);
230 
232 
233  return make_base(l);
234 }
v_array< namespace_index > indices
example * ec
Definition: baseline.cc:65
label_parser simple_label
parameters weights
Definition: global_data.h:537
loss_function * loss
Definition: global_data.h:523
void predict(E &ec, size_t i=0)
Definition: learner.h:169
bool lr_scaling
Definition: baseline.cc:67
void copy_example_metadata(bool, example *dst, example *src)
Definition: example.cc:48
float scalar
Definition: example.h:45
void(* delete_label)(void *)
Definition: label_parser.h:16
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
void set_baseline_enabled(example *ec)
Definition: baseline.cc:23
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
constexpr unsigned char message_namespace
Definition: constant.h:32
bool check_enabled
Definition: baseline.cc:71
float lr_multiplier
Definition: baseline.cc:68
~baseline()
Definition: baseline.cc:73
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
float sensitivity(baseline &data, base_learner &base, example &ec)
Definition: baseline.cc:168
virtual std::string getType()=0
constexpr uint64_t constant
Definition: constant.h:11
void reset_baseline_disabled(example *ec)
Definition: baseline.cc:38
base_learner * baseline_setup(options_i &options, vw &all)
Definition: baseline.cc:193
float initial
Definition: simple_label.h:16
vw * all
Definition: baseline.cc:66
bool global_only
Definition: baseline.cc:69
uint32_t wpp
Definition: global_data.h:432
bool global_initialized
Definition: baseline.cc:70
void init_global(baseline &data)
Definition: baseline.cc:81
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
float eta
Definition: global_data.h:531
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float total_sum_feat_sq
Definition: example.h:71
void set_sensitivity(float(*u)(T &data, base_learner &base, example &))
Definition: learner.h:237
float min_label
Definition: global_data.h:150
size_t label_size
Definition: label_parser.h:23
std::vector< std::string > interactions
Definition: global_data.h:457
float max_label
Definition: global_data.h:151
void move_feature_namespace(example *dst, example *src, namespace_index c)
Definition: example.cc:92
uint32_t stride_shift()
void predict_or_learn(baseline &data, single_learner &base, example &ec)
Definition: baseline.cc:95
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
constexpr unsigned char constant_namespace
Definition: constant.h:22
bool baseline_enabled(example *ec)
Definition: baseline.cc:51
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40