Vowpal Wabbit
Functions
memory_tree.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnermemory_tree_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ memory_tree_setup()

LEARNER::base_learner* memory_tree_setup ( VW::config::options_i options,
vw all 
)

Definition at line 1223 of file memory_tree.cc.

References add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), label_parser::delete_label, vw::delete_prediction, LEARNER::end_pass(), f, vw::get_random_state(), LEARNER::init_learner(), LEARNER::init_multiclass_learner(), init_tree(), vw::label_type, memory_tree_ns::learn(), parser::lp, LEARNER::make_base(), VW::config::make_option(), label_type::multi, MULTILABEL::multilabel, prediction_type::multilabels, vw::numpasses, vw::p, memory_tree_ns::predict(), vw::quiet, memory_tree_ns::save_load_memory_tree(), LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_save_load(), setup_base(), and vw::trace_message.

Referenced by parse_reductions().

1224 {
1225  using namespace memory_tree_ns;
1226  auto tree = scoped_calloc_or_throw<memory_tree>();
1227  option_group_definition new_options("Memory Tree");
1228 
1229  new_options
1230  .add(make_option("memory_tree", tree->max_nodes)
1231  .keep()
1232  .default_value(0)
1233  .help("Make a memory tree with at most <n> nodes"))
1234  .add(make_option("max_number_of_labels", tree->max_num_labels)
1235  .default_value(10)
1236  .help("max number of unique label"))
1237  .add(make_option("leaf_example_multiplier", tree->leaf_example_multiplier)
1238  .default_value(1)
1239  .help("multiplier on examples per leaf (default = log nodes)"))
1240  .add(make_option("alpha", tree->alpha).default_value(0.1f).help("Alpha"))
1241  .add(make_option("dream_repeats", tree->dream_repeats)
1242  .default_value(1)
1243  .help("number of dream operations per example (default = 1)"))
1244  .add(make_option("top_K", tree->top_K).default_value(1).help("top K prediction error (default 1)"))
1245  .add(make_option("learn_at_leaf", tree->learn_at_leaf).help("whether or not learn at leaf (defualt = True)"))
1246  .add(make_option("oas", tree->oas).help("use oas at the leaf"))
1247  .add(make_option("dream_at_update", tree->dream_at_update)
1248  .default_value(0)
1249  .help("turn on dream operations at reward based update as well"))
1250  .add(make_option("online", tree->online).help("turn on dream operations at reward based update as well"));
1251  options.add_and_parse(new_options);
1252  if (!tree->max_nodes)
1253  {
1254  return nullptr;
1255  }
1256 
1257  tree->all = &all;
1258  tree->_random_state = all.get_random_state();
1259  tree->current_pass = 0;
1260  tree->final_pass = all.numpasses;
1261 
1262  tree->max_leaf_examples = (size_t)(tree->leaf_example_multiplier * (log(tree->max_nodes) / log(2)));
1263 
1264  init_tree(*tree);
1265 
1266  if (!all.quiet)
1267  all.trace_message << "memory_tree:"
1268  << " "
1269  << "max_nodes = " << tree->max_nodes << " "
1270  << "max_leaf_examples = " << tree->max_leaf_examples << " "
1271  << "alpha = " << tree->alpha << " "
1272  << "oas = " << tree->oas << " "
1273  << "online =" << tree->online << " " << std::endl;
1274 
1275  size_t num_learners = 0;
1276 
1277  // multi-class classification
1278  if (tree->oas == false)
1279  {
1280  num_learners = tree->max_nodes + 1;
1282  init_multiclass_learner(tree, as_singleline(setup_base(options, all)), learn, predict, all.p, num_learners);
1283  // srand(time(0));
1286 
1287  return make_base(l);
1288  } // multi-label classification
1289  else
1290  {
1291  num_learners = tree->max_nodes + 1 + tree->max_num_labels;
1293  tree, as_singleline(setup_base(options, all)), learn, predict, num_learners, prediction_type::multilabels);
1294 
1295  // all.p->lp = MULTILABEL::multilabel;
1296  // all.label_type = label_type::multi;
1297  // all.delete_prediction = MULTILABEL::multilabel.delete_label;
1298  // srand(time(0));
1301  // l.set_end_pass(end_pass);
1302 
1303  all.p->lp = MULTILABEL::multilabel;
1306 
1307  return make_base(l);
1308  }
1309 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
void predict(memory_tree &b, single_learner &base, example &ec)
Definition: memory_tree.cc:666
void(* delete_label)(void *)
Definition: label_parser.h:16
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
void end_pass(memory_tree &b)
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void save_load_memory_tree(memory_tree &b, io_buf &model_file, bool read, bool text)
void init_tree(memory_tree &b)
Definition: memory_tree.cc:281
vw_ostream trace_message
Definition: global_data.h:424
size_t numpasses
Definition: global_data.h:451
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
label_parser multilabel
Definition: multilabel.cc:118
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void learn(memory_tree &b, single_learner &base, example &ec)
float f
Definition: cache.cc:40
label_parser lp
Definition: parser.h:102