Vowpal Wabbit
Classes | Namespaces | Functions
memory_tree.cc File Reference
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <float.h>
#include <time.h>
#include <sstream>
#include <ctime>
#include <memory>
#include "reductions.h"
#include "rand48.h"
#include "vw.h"
#include "v_array.h"

Go to the source code of this file.

Classes

struct  memory_tree_ns::node
 
struct  memory_tree_ns::memory_tree
 

Namespaces

 memory_tree_ns
 

Functions

template<typename T >
void memory_tree_ns::remove_at_index (v_array< T > &array, uint32_t index)
 
void memory_tree_ns::copy_example_data (example *dst, example *src, bool oas=false)
 
void memory_tree_ns::free_example (example *ec)
 
void memory_tree_ns::diag_kronecker_prod_fs_test (features &f1, features &f2, features &prod_f, float &total_sum_feat_sq, float norm_sq1, float norm_sq2)
 
int memory_tree_ns::cmpfunc (const void *a, const void *b)
 
void memory_tree_ns::diag_kronecker_product_test (example &ec1, example &ec2, example &ec, bool oas=false)
 
float memory_tree_ns::linear_kernel (const flat_example *fec1, const flat_example *fec2)
 
float memory_tree_ns::normalized_linear_prod (memory_tree &b, example *ec1, example *ec2)
 
void memory_tree_ns::init_tree (memory_tree &b)
 
uint64_t memory_tree_ns::insert_descent (node &n, const float prediction)
 
int memory_tree_ns::random_sample_example_pop (memory_tree &b, uint64_t &cn)
 
float memory_tree_ns::train_node (memory_tree &b, single_learner &base, example &ec, const uint64_t cn)
 
void memory_tree_ns::split_leaf (memory_tree &b, single_learner &base, const uint64_t cn)
 
int memory_tree_ns::compare_label (const void *a, const void *b)
 
uint32_t memory_tree_ns::over_lap (v_array< uint32_t > &array_1, v_array< uint32_t > &array_2)
 
uint32_t memory_tree_ns::hamming_loss (v_array< uint32_t > &array_1, v_array< uint32_t > &array_2)
 
void memory_tree_ns::collect_labels_from_leaf (memory_tree &b, const uint64_t cn, v_array< uint32_t > &leaf_labs)
 
void memory_tree_ns::train_one_against_some_at_leaf (memory_tree &b, single_learner &base, const uint64_t cn, example &ec)
 
uint32_t memory_tree_ns::compute_hamming_loss_via_oas (memory_tree &b, single_learner &base, const uint64_t cn, example &ec, v_array< uint32_t > &selected_labs)
 
int64_t memory_tree_ns::pick_nearest (memory_tree &b, single_learner &base, const uint64_t cn, example &ec)
 
float memory_tree_ns::get_overlap_from_two_examples (example &ec1, example &ec2)
 
float memory_tree_ns::F1_score_for_two_examples (example &ec1, example &ec2)
 
void memory_tree_ns::predict (memory_tree &b, single_learner &base, example &ec)
 
float memory_tree_ns::return_reward_from_node (memory_tree &b, single_learner &base, uint64_t cn, example &ec, float weight=1.f)
 
void memory_tree_ns::learn_at_leaf_random (memory_tree &b, single_learner &base, const uint64_t &leaf_id, example &ec, const float &weight)
 
void memory_tree_ns::route_to_leaf (memory_tree &b, single_learner &base, const uint32_t &ec_array_index, uint64_t cn, v_array< uint64_t > &path, bool insertion)
 
void memory_tree_ns::single_query_and_learn (memory_tree &b, single_learner &base, const uint32_t &ec_array_index, example &ec)
 
void memory_tree_ns::update_rew (memory_tree &b, single_learner &base, const uint32_t &ec_array_index, example &ec)
 
void memory_tree_ns::insert_example (memory_tree &b, single_learner &base, const uint32_t &ec_array_index, bool fake_insert=false)
 
void memory_tree_ns::experience_replay (memory_tree &b, single_learner &base)
 
void memory_tree_ns::learn (memory_tree &b, single_learner &base, example &ec)
 
void memory_tree_ns::end_pass (memory_tree &b)
 
void memory_tree_ns::save_load_example (example *ec, io_buf &model_file, bool &read, bool &text, std::stringstream &msg, bool &oas)
 
void memory_tree_ns::save_load_node (node &cn, io_buf &model_file, bool &read, bool &text, std::stringstream &msg)
 
void memory_tree_ns::save_load_memory_tree (memory_tree &b, io_buf &model_file, bool read, bool text)
 
base_learnermemory_tree_setup (options_i &options, vw &all)
 

Function Documentation

◆ memory_tree_setup()

base_learner* memory_tree_setup ( options_i options,
vw all 
)

Definition at line 1223 of file memory_tree.cc.

References add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), label_parser::delete_label, vw::delete_prediction, LEARNER::end_pass(), f, vw::get_random_state(), LEARNER::init_learner(), LEARNER::init_multiclass_learner(), init_tree(), vw::label_type, memory_tree_ns::learn(), parser::lp, LEARNER::make_base(), VW::config::make_option(), label_type::multi, MULTILABEL::multilabel, prediction_type::multilabels, vw::numpasses, vw::p, memory_tree_ns::predict(), vw::quiet, memory_tree_ns::save_load_memory_tree(), LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_save_load(), setup_base(), and vw::trace_message.

Referenced by parse_reductions().

1224 {
1225  using namespace memory_tree_ns;
1226  auto tree = scoped_calloc_or_throw<memory_tree>();
1227  option_group_definition new_options("Memory Tree");
1228 
1229  new_options
1230  .add(make_option("memory_tree", tree->max_nodes)
1231  .keep()
1232  .default_value(0)
1233  .help("Make a memory tree with at most <n> nodes"))
1234  .add(make_option("max_number_of_labels", tree->max_num_labels)
1235  .default_value(10)
1236  .help("max number of unique label"))
1237  .add(make_option("leaf_example_multiplier", tree->leaf_example_multiplier)
1238  .default_value(1)
1239  .help("multiplier on examples per leaf (default = log nodes)"))
1240  .add(make_option("alpha", tree->alpha).default_value(0.1f).help("Alpha"))
1241  .add(make_option("dream_repeats", tree->dream_repeats)
1242  .default_value(1)
1243  .help("number of dream operations per example (default = 1)"))
1244  .add(make_option("top_K", tree->top_K).default_value(1).help("top K prediction error (default 1)"))
1245  .add(make_option("learn_at_leaf", tree->learn_at_leaf).help("whether or not learn at leaf (defualt = True)"))
1246  .add(make_option("oas", tree->oas).help("use oas at the leaf"))
1247  .add(make_option("dream_at_update", tree->dream_at_update)
1248  .default_value(0)
1249  .help("turn on dream operations at reward based update as well"))
1250  .add(make_option("online", tree->online).help("turn on dream operations at reward based update as well"));
1251  options.add_and_parse(new_options);
1252  if (!tree->max_nodes)
1253  {
1254  return nullptr;
1255  }
1256 
1257  tree->all = &all;
1258  tree->_random_state = all.get_random_state();
1259  tree->current_pass = 0;
1260  tree->final_pass = all.numpasses;
1261 
1262  tree->max_leaf_examples = (size_t)(tree->leaf_example_multiplier * (log(tree->max_nodes) / log(2)));
1263 
1264  init_tree(*tree);
1265 
1266  if (!all.quiet)
1267  all.trace_message << "memory_tree:"
1268  << " "
1269  << "max_nodes = " << tree->max_nodes << " "
1270  << "max_leaf_examples = " << tree->max_leaf_examples << " "
1271  << "alpha = " << tree->alpha << " "
1272  << "oas = " << tree->oas << " "
1273  << "online =" << tree->online << " " << std::endl;
1274 
1275  size_t num_learners = 0;
1276 
1277  // multi-class classification
1278  if (tree->oas == false)
1279  {
1280  num_learners = tree->max_nodes + 1;
1282  init_multiclass_learner(tree, as_singleline(setup_base(options, all)), learn, predict, all.p, num_learners);
1283  // srand(time(0));
1286 
1287  return make_base(l);
1288  } // multi-label classification
1289  else
1290  {
1291  num_learners = tree->max_nodes + 1 + tree->max_num_labels;
1293  tree, as_singleline(setup_base(options, all)), learn, predict, num_learners, prediction_type::multilabels);
1294 
1295  // all.p->lp = MULTILABEL::multilabel;
1296  // all.label_type = label_type::multi;
1297  // all.delete_prediction = MULTILABEL::multilabel.delete_label;
1298  // srand(time(0));
1301  // l.set_end_pass(end_pass);
1302 
1303  all.p->lp = MULTILABEL::multilabel;
1306 
1307  return make_base(l);
1308  }
1309 }
void(* delete_prediction)(void *)
Definition: global_data.h:485
void predict(memory_tree &b, single_learner &base, example &ec)
Definition: memory_tree.cc:666
void(* delete_label)(void *)
Definition: label_parser.h:16
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
void end_pass(memory_tree &b)
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void save_load_memory_tree(memory_tree &b, io_buf &model_file, bool read, bool text)
void init_tree(memory_tree &b)
Definition: memory_tree.cc:281
vw_ostream trace_message
Definition: global_data.h:424
size_t numpasses
Definition: global_data.h:451
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
label_parser multilabel
Definition: multilabel.cc:118
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void learn(memory_tree &b, single_learner &base, example &ec)
float f
Definition: cache.cc:40
label_parser lp
Definition: parser.h:102