Vowpal Wabbit
Classes | Namespaces | Functions | Variables
baseline.cc File Reference
#include <float.h>
#include <errno.h>
#include "reductions.h"
#include "vw.h"

Go to the source code of this file.

Classes

struct  baseline
 

Namespaces

 anonymous_namespace{baseline.cc}
 
 BASELINE
 

Functions

void BASELINE::set_baseline_enabled (example *ec)
 
void BASELINE::reset_baseline_disabled (example *ec)
 
bool BASELINE::baseline_enabled (example *ec)
 
void init_global (baseline &data)
 
template<bool is_learn>
void predict_or_learn (baseline &data, single_learner &base, example &ec)
 
float sensitivity (baseline &data, base_learner &base, example &ec)
 
base_learnerbaseline_setup (options_i &options, vw &all)
 

Variables

const float anonymous_namespace{baseline.cc}::max_multiplier = 1000.f
 
const size_t anonymous_namespace{baseline.cc}::baseline_enabled_idx = 1357
 

Function Documentation

◆ baseline_setup()

base_learner* baseline_setup ( options_i options,
vw all 
)

Definition at line 193 of file baseline.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), VW::alloc_examples(), LEARNER::as_singleline(), loss_function::getType(), LEARNER::init_learner(), vw::interactions, label_parser::label_size, vw::loss, LEARNER::make_base(), VW::config::make_option(), sensitivity(), LEARNER::learner< T, E >::set_sensitivity(), setup_base(), and simple_label.

Referenced by parse_reductions().

194 {
195  auto data = scoped_calloc_or_throw<baseline>();
196  bool baseline_option = false;
197  std::string loss_function;
198 
199  option_group_definition new_options("Baseline options");
200  new_options
201  .add(make_option("baseline", baseline_option)
202  .keep()
203  .help("Learn an additive baseline (from constant features) and a residual separately in regression."))
204  .add(make_option("lr_multiplier", data->lr_multiplier).help("learning rate multiplier for baseline model"))
205  .add(make_option("global_only", data->global_only)
206  .keep()
207  .help("use separate example with only global constant for baseline predictions"))
208  .add(make_option("check_enabled", data->check_enabled)
209  .keep()
210  .help("only use baseline when the example contains enabled flag"));
211  options.add_and_parse(new_options);
212 
213  if (!baseline_option)
214  return nullptr;
215 
216  // initialize baseline example
218  data->ec->interactions = &all.interactions;
219 
220  data->ec->in_use = true;
221  data->all = &all;
222 
223  auto loss_function_type = all.loss->getType();
224  if (loss_function_type != "logistic")
225  data->lr_scaling = true;
226 
227  auto base = as_singleline(setup_base(options, all));
228 
229  learner<baseline, example>& l = init_learner(data, base, predict_or_learn<true>, predict_or_learn<false>);
230 
232 
233  return make_base(l);
234 }
label_parser simple_label
loss_function * loss
Definition: global_data.h:523
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
float sensitivity(baseline &data, base_learner &base, example &ec)
Definition: baseline.cc:168
virtual std::string getType()=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_sensitivity(float(*u)(T &data, base_learner &base, example &))
Definition: learner.h:237
size_t label_size
Definition: label_parser.h:23
std::vector< std::string > interactions
Definition: global_data.h:457
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ init_global()

void init_global ( baseline data)

Definition at line 81 of file baseline.cc.

References baseline::all, constant, constant_namespace, baseline::ec, example_predict::feature_space, baseline::global_only, example_predict::indices, example::num_features, v_array< T >::push_back(), parameters::stride_shift(), example::total_sum_feat_sq, vw::weights, and vw::wpp.

Referenced by predict_or_learn().

82 {
83  if (!data.global_only)
84  return;
85  // use a separate global constant
87  // different index from constant to avoid conflicts
88  data.ec->feature_space[constant_namespace].push_back(
89  1, ((constant - 17) * data.all->wpp) << data.all->weights.stride_shift());
90  data.ec->total_sum_feat_sq++;
91  data.ec->num_features++;
92 }
v_array< namespace_index > indices
example * ec
Definition: baseline.cc:65
parameters weights
Definition: global_data.h:537
std::array< features, NUM_NAMESPACES > feature_space
void push_back(const T &new_ele)
Definition: v_array.h:107
size_t num_features
Definition: example.h:67
constexpr uint64_t constant
Definition: constant.h:11
vw * all
Definition: baseline.cc:66
bool global_only
Definition: baseline.cc:69
uint32_t wpp
Definition: global_data.h:432
float total_sum_feat_sq
Definition: example.h:71
uint32_t stride_shift()
constexpr unsigned char constant_namespace
Definition: constant.h:22

◆ predict_or_learn()

template<bool is_learn>
void predict_or_learn ( baseline data,
single_learner base,
example ec 
)

Definition at line 95 of file baseline.cc.

References baseline::all, BASELINE::baseline_enabled(), baseline::check_enabled, constant_namespace, VW::copy_example_metadata(), baseline::ec, vw::eta, f, baseline::global_initialized, baseline::global_only, init_global(), label_data::initial, example::l, LEARNER::learner< T, E >::learn(), baseline::lr_multiplier, baseline::lr_scaling, shared_data::max_label, anonymous_namespace{baseline.cc}::max_multiplier, shared_data::min_label, VW::move_feature_namespace(), example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, vw::sd, and polylabel::simple.

96 {
97  // no baseline if check_enabled is true and example contains flag
98  if (data.check_enabled && !BASELINE::baseline_enabled(&ec))
99  {
100  if (is_learn)
101  base.learn(ec);
102  else
103  base.predict(ec);
104  return;
105  }
106 
107  // always do a full prediction, for safety in accurate predictive validation
108  if (data.global_only)
109  {
110  if (!data.global_initialized)
111  {
112  init_global(data);
113  data.global_initialized = true;
114  }
115  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
116  base.predict(*data.ec);
117  ec.l.simple.initial = data.ec->pred.scalar;
118  base.predict(ec);
119  }
120  else
121  base.predict(ec);
122 
123  if (is_learn)
124  {
125  const float pred = ec.pred.scalar; // save 'safe' prediction
126 
127  // now learn
128  data.ec->l.simple = ec.l.simple;
129  if (!data.global_only)
130  {
131  // move label & constant features data over to baseline example
132  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
134  }
135 
136  // regress baseline on label
137  if (data.lr_scaling)
138  {
139  float multiplier = data.lr_multiplier;
140  if (multiplier == 0)
141  {
142  multiplier = std::max(0.0001f, std::max(std::abs(data.all->sd->min_label), std::abs(data.all->sd->max_label)));
143  if (multiplier > max_multiplier)
144  multiplier = max_multiplier;
145  }
146  data.all->eta *= multiplier;
147  base.learn(*data.ec);
148  data.all->eta /= multiplier;
149  }
150  else
151  base.learn(*data.ec);
152 
153  // regress residual
154  ec.l.simple.initial = data.ec->pred.scalar;
155  base.learn(ec);
156 
157  if (!data.global_only)
158  {
159  // move feature data back to the original example
161  }
162 
163  // return the safe prediction
164  ec.pred.scalar = pred;
165  }
166 }
example * ec
Definition: baseline.cc:65
void predict(E &ec, size_t i=0)
Definition: learner.h:169
bool lr_scaling
Definition: baseline.cc:67
void copy_example_metadata(bool, example *dst, example *src)
Definition: example.cc:48
float scalar
Definition: example.h:45
label_data simple
Definition: example.h:28
bool check_enabled
Definition: baseline.cc:71
float lr_multiplier
Definition: baseline.cc:68
shared_data * sd
Definition: global_data.h:375
float initial
Definition: simple_label.h:16
vw * all
Definition: baseline.cc:66
bool global_only
Definition: baseline.cc:69
bool global_initialized
Definition: baseline.cc:70
void init_global(baseline &data)
Definition: baseline.cc:81
float eta
Definition: global_data.h:531
polylabel l
Definition: example.h:57
float min_label
Definition: global_data.h:150
float max_label
Definition: global_data.h:151
void move_feature_namespace(example *dst, example *src, namespace_index c)
Definition: example.cc:92
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
constexpr unsigned char constant_namespace
Definition: constant.h:22
bool baseline_enabled(example *ec)
Definition: baseline.cc:51
float f
Definition: cache.cc:40

◆ sensitivity()

float sensitivity ( baseline data,
base_learner base,
example ec 
)

Definition at line 168 of file baseline.cc.

References LEARNER::as_singleline(), BASELINE::baseline_enabled(), baseline::check_enabled, VW::copy_example_metadata(), baseline::ec, baseline::global_only, label_data::initial, example::l, label_data::label, example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, LEARNER::learner< T, E >::sensitivity(), polylabel::simple, and THROW.

Referenced by baseline_setup(), and predict_or_learn_with_confidence().

169 {
170  // no baseline if check_enabled is true and example contains flag
171  if (data.check_enabled && !BASELINE::baseline_enabled(&ec))
172  return base.sensitivity(ec);
173 
174  if (!data.global_only)
175  THROW("sensitivity for baseline without --global_only not implemented");
176 
177  // sensitivity of baseline term
178  VW::copy_example_metadata(/*audit=*/false, data.ec, &ec);
179  data.ec->l.simple.label = ec.l.simple.label;
180  data.ec->pred.scalar = ec.pred.scalar;
181  // std::cout << "before base" << std::endl;
182  const float baseline_sens = base.sensitivity(*data.ec);
183  // std::cout << "base sens: " << baseline_sens << std::endl;
184 
185  // sensitivity of residual
186  as_singleline(&base)->predict(*data.ec);
187  ec.l.simple.initial = data.ec->pred.scalar;
188  const float sens = base.sensitivity(ec);
189  // std::cout << " residual sens: " << sens << std::endl;
190  return baseline_sens + sens;
191 }
example * ec
Definition: baseline.cc:65
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void copy_example_metadata(bool, example *dst, example *src)
Definition: example.cc:48
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
bool check_enabled
Definition: baseline.cc:71
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
float initial
Definition: simple_label.h:16
bool global_only
Definition: baseline.cc:69
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
bool baseline_enabled(example *ec)
Definition: baseline.cc:51
#define THROW(args)
Definition: vw_exception.h:181