Vowpal Wabbit
Functions
log_multi.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerlog_multi_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ log_multi_setup()

LEARNER::base_learner* log_multi_setup ( VW::config::options_i options,
vw all 
)

Definition at line 496 of file log_multi.cc.

References add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), getLossFunction(), LEARNER::init_multiclass_learner(), init_tree(), learn(), vw::loss, LEARNER::make_base(), VW::config::make_option(), vw::p, predict(), save_load_tree(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

497 {
498  auto data = scoped_calloc_or_throw<log_multi>();
499  option_group_definition new_options("Logarithmic Time Multiclass Tree");
500  new_options.add(make_option("log_multi", data->k).keep().help("Use online tree for multiclass"))
501  .add(make_option("no_progress", data->progress).help("disable progressive validation"))
502  .add(make_option("swap_resistance", data->swap_resist).default_value(4).help("disable progressive validation"))
503  .add(make_option("swap_resistance", data->swap_resist)
504  .default_value(4)
505  .help("higher = more resistance to swap, default=4"));
506  options.add_and_parse(new_options);
507 
508  if (!options.was_supplied("log_multi"))
509  return nullptr;
510 
511  data->progress = !data->progress;
512 
513  std::string loss_function = "quantile";
514  float loss_parameter = 0.5;
515  delete (all.loss);
516  all.loss = getLossFunction(all, loss_function, loss_parameter);
517 
518  data->max_predictors = data->k - 1;
519  init_tree(*data.get());
520 
522  data, as_singleline(setup_base(options, all)), learn, predict, all.p, data->max_predictors);
523  l.set_save_load(save_load_tree);
524 
525  return make_base(l);
526 }
loss_function * loss
Definition: global_data.h:523
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
Definition: log_multi.cc:397
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void init_tree(log_multi &d)
Definition: log_multi.cc:122
void learn(log_multi &b, single_learner &base, example &ec)
Definition: log_multi.cc:321
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
loss_function * getLossFunction(vw &all, std::string funcName, float function_parameter)
void predict(log_multi &b, single_learner &base, example &ec)
Definition: log_multi.cc:304