Vowpal Wabbit
Functions
recall_tree.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerrecall_tree_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ recall_tree_setup()

LEARNER::base_learner* recall_tree_setup ( VW::config::options_i options,
vw all 
)

Definition at line 502 of file recall_tree.cc.

References add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), f, vw::get_random_state(), LEARNER::init_multiclass_learner(), init_tree(), recall_tree_ns::learn(), LEARNER::make_base(), VW::config::make_option(), vw::p, recall_tree_ns::predict(), vw::quiet, save_load_tree(), LEARNER::learner< T, E >::set_save_load(), setup_base(), vw::trace_message, vw::training, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

503 {
504  auto tree = scoped_calloc_or_throw<recall_tree>();
505  option_group_definition new_options("Recall Tree");
506  new_options.add(make_option("recall_tree", tree->k).keep().help("Use online tree for multiclass"))
507  .add(make_option("max_candidates", tree->max_candidates)
508  .keep()
509  .help("maximum number of labels per leaf in the tree"))
510  .add(make_option("bern_hyper", tree->bern_hyper).default_value(1.f).help("recall tree depth penalty"))
511  .add(make_option("max_depth", tree->max_depth).keep().help("maximum depth of the tree, default log_2 (#classes)"))
512  .add(make_option("node_only", tree->node_only).keep().help("only use node features, not full path features"))
513  .add(make_option("randomized_routing", tree->randomized_routing).keep().help("randomized routing"));
514  options.add_and_parse(new_options);
515 
516  if (!options.was_supplied("recall_tree"))
517  return nullptr;
518 
519  tree->all = &all;
520  tree->_random_state = all.get_random_state();
521  tree->max_candidates = options.was_supplied("max_candidates")
522  ? tree->max_candidates
523  : std::min(tree->k, 4 * (uint32_t)(ceil(log(tree->k) / log(2.0))));
524  tree->max_depth =
525  options.was_supplied("max_depth") ? tree->max_depth : (uint32_t)std::ceil(std::log(tree->k) / std::log(2.0));
526 
527  init_tree(*tree.get());
528 
529  if (!all.quiet)
530  all.trace_message << "recall_tree:"
531  << " node_only = " << tree->node_only << " bern_hyper = " << tree->bern_hyper
532  << " max_depth = " << tree->max_depth << " routing = "
533  << (all.training ? (tree->randomized_routing ? "randomized" : "deterministic") : "n/a testonly")
534  << std::endl;
535 
537  tree, as_singleline(setup_base(options, all)), learn, predict, all.p, tree->max_routers + tree->k);
539 
540  return make_base(l);
541 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
void predict(recall_tree &b, single_learner &base, example &ec)
Definition: recall_tree.cc:335
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
bool training
Definition: global_data.h:488
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void learn(recall_tree &b, single_learner &base, example &ec)
Definition: recall_tree.cc:376
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
Definition: log_multi.cc:397
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void init_tree(log_multi &d)
Definition: log_multi.cc:122
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
float f
Definition: cache.cc:40