Vowpal Wabbit
Namespaces | Functions
baseline.h File Reference

Go to the source code of this file.

Namespaces

 BASELINE
 

Functions

LEARNER::base_learnerbaseline_setup (VW::config::options_i &options, vw &all)
 
void BASELINE::set_baseline_enabled (example *ec)
 
void BASELINE::reset_baseline_disabled (example *ec)
 
bool BASELINE::baseline_enabled (example *ec)
 

Function Documentation

◆ baseline_setup()

LEARNER::base_learner* baseline_setup ( VW::config::options_i options,
vw all 
)

Definition at line 193 of file baseline.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), VW::alloc_examples(), LEARNER::as_singleline(), loss_function::getType(), LEARNER::init_learner(), vw::interactions, label_parser::label_size, vw::loss, LEARNER::make_base(), VW::config::make_option(), sensitivity(), LEARNER::learner< T, E >::set_sensitivity(), setup_base(), and simple_label.

Referenced by parse_reductions().

194 {
195  auto data = scoped_calloc_or_throw<baseline>();
196  bool baseline_option = false;
197  std::string loss_function;
198 
199  option_group_definition new_options("Baseline options");
200  new_options
201  .add(make_option("baseline", baseline_option)
202  .keep()
203  .help("Learn an additive baseline (from constant features) and a residual separately in regression."))
204  .add(make_option("lr_multiplier", data->lr_multiplier).help("learning rate multiplier for baseline model"))
205  .add(make_option("global_only", data->global_only)
206  .keep()
207  .help("use separate example with only global constant for baseline predictions"))
208  .add(make_option("check_enabled", data->check_enabled)
209  .keep()
210  .help("only use baseline when the example contains enabled flag"));
211  options.add_and_parse(new_options);
212 
213  if (!baseline_option)
214  return nullptr;
215 
216  // initialize baseline example
218  data->ec->interactions = &all.interactions;
219 
220  data->ec->in_use = true;
221  data->all = &all;
222 
223  auto loss_function_type = all.loss->getType();
224  if (loss_function_type != "logistic")
225  data->lr_scaling = true;
226 
227  auto base = as_singleline(setup_base(options, all));
228 
229  learner<baseline, example>& l = init_learner(data, base, predict_or_learn<true>, predict_or_learn<false>);
230 
232 
233  return make_base(l);
234 }
label_parser simple_label
loss_function * loss
Definition: global_data.h:523
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
float sensitivity(baseline &data, base_learner &base, example &ec)
Definition: baseline.cc:168
virtual std::string getType()=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_sensitivity(float(*u)(T &data, base_learner &base, example &))
Definition: learner.h:237
size_t label_size
Definition: label_parser.h:23
std::vector< std::string > interactions
Definition: global_data.h:457
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222