Vowpal Wabbit
recall_tree.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.node
5 */
6 #include <algorithm>
7 #include <cmath>
8 #include <cstdio>
9 #include <float.h>
10 #include <sstream>
11 #include <memory>
12 
13 #include "reductions.h"
14 #include "rand48.h"
15 
16 using namespace LEARNER;
17 using namespace VW::config;
18 
19 namespace recall_tree_ns
20 {
21 struct node_pred
22 {
23  uint32_t label;
24  double label_count;
25 
26  node_pred(uint32_t a) : label(a), label_count(0) {}
27 };
28 
29 struct node
30 {
31  uint32_t parent;
32  float recall_lbest;
33 
34  bool internal;
35  uint32_t depth;
36 
37  uint32_t base_router;
38  uint32_t left;
39  uint32_t right;
40  double n;
41  double entropy;
42  double passes;
43 
45 
46  node()
47  : parent(0)
48  , recall_lbest(0)
49  , internal(false)
50  , depth(0)
51  , base_router(0)
52  , left(0)
53  , right(0)
54  , n(0)
55  , entropy(0)
56  , passes(1)
57  , preds(v_init<node_pred>())
58  {
59  }
60 };
61 
63 {
64  vw* all;
65  std::shared_ptr<rand_state> _random_state;
66  uint32_t k;
67  bool node_only;
68 
70 
72  size_t max_routers;
73  size_t max_depth;
74  float bern_hyper;
75 
77 
79  {
80  for (auto& node : nodes) node.preds.delete_v();
81  nodes.delete_v();
82  }
83 };
84 
85 float to_prob(float x)
86 {
87  static const float alpha = 2.0f;
88  // http://stackoverflow.com/questions/2789481/problem-calling-stdmax
89  return std::max(0.f, std::min(1.f, 0.5f * (1.0f + alpha * x)));
90 }
91 
92 void init_tree(recall_tree& b, uint32_t root, uint32_t depth, uint32_t& routers_used)
93 {
94  if (depth <= b.max_depth)
95  {
96  uint32_t left_child;
97  uint32_t right_child;
98  left_child = (uint32_t)b.nodes.size();
99  b.nodes.push_back(node());
100  right_child = (uint32_t)b.nodes.size();
101  b.nodes.push_back(node());
102  b.nodes[root].base_router = routers_used++;
103 
104  b.nodes[root].internal = true;
105  b.nodes[root].left = left_child;
106  b.nodes[left_child].parent = root;
107  b.nodes[left_child].depth = depth;
108  b.nodes[root].right = right_child;
109  b.nodes[right_child].parent = root;
110  b.nodes[right_child].depth = depth;
111 
112  init_tree(b, left_child, depth + 1, routers_used);
113  init_tree(b, right_child, depth + 1, routers_used);
114  }
115 }
116 
118 {
119  uint32_t routers_used = 0;
120 
121  b.nodes.push_back(node());
122  init_tree(b, 0, 1, routers_used);
123  b.max_routers = routers_used;
124 }
125 
126 node_pred* find(recall_tree& b, uint32_t cn, example& ec)
127 {
128  node_pred* ls;
129 
130  for (ls = b.nodes[cn].preds.begin(); ls != b.nodes[cn].preds.end() && ls->label != ec.l.multi.label; ++ls)
131  ;
132 
133  return ls;
134 }
135 
137 {
138  node_pred* ls = find(b, cn, ec);
139 
140  if (ls == b.nodes[cn].preds.end())
141  {
142  node_pred newls(ec.l.multi.label);
143  b.nodes[cn].preds.push_back(newls);
144  ls = b.nodes[cn].preds.end() - 1;
145  }
146 
147  return ls;
148 }
149 
151 {
152  if (n->n <= 0)
153  return;
154 
155  double mass_at_k = 0;
156 
157  for (node_pred* ls = n->preds.begin(); ls != n->preds.end() && ls < n->preds.begin() + b.max_candidates; ++ls)
158  {
159  mass_at_k += ls->label_count;
160  }
161 
162  float f = (float)mass_at_k / (float)n->n;
163  float stdf = std::sqrt(f * (1.f - f) / (float)n->n);
164  float diamf = 15.f / (std::sqrt(18.f) * (float)n->n);
165 
166  n->recall_lbest = std::max(0.f, f - std::sqrt(b.bern_hyper) * stdf - b.bern_hyper * diamf);
167 }
168 
169 double plogp(double c, double n) { return (c == 0) ? 0 : (c / n) * log(c / n); }
170 
171 double updated_entropy(recall_tree& b, uint32_t cn, example& ec)
172 {
173  node_pred* ls = find(b, cn, ec);
174 
175  // entropy = -\sum_k (c_k/n) Log[c_k/n]
176  // c_0 <- c_0 + 1, n <- n + 1
177  // entropy <- + (c_0/n) Log[c_0/n]
178  // - n/(n+1) \sum_{k>0} (c_k/n) Log[c_k/n]
179  // - Log[n/(n+1)] \sum_{k>0} (c_k/(n+1))
180  // - ((c_0+1)/(n+1)) Log[(c_0+1)/(n+1)]
181 
182  double c0 = (ls == b.nodes[cn].preds.end()) ? 0 : ls->label_count;
183  double deltac0 = ec.weight;
184  double n = b.nodes[cn].n;
185 
186  double novernp1 = n / (deltac0 + n);
187  double lognovernp1 = (novernp1 == 0) ? 0 : log(novernp1);
188  double nminusc0overnp1 = (n - c0) / (n + deltac0);
189 
190  double newentropy = b.nodes[cn].entropy;
191 
192  newentropy += plogp(c0, n);
193  newentropy *= novernp1;
194  newentropy -= lognovernp1 * nminusc0overnp1;
195  newentropy -= plogp(c0 + deltac0, n + deltac0);
196 
197  return newentropy;
198 }
199 
200 void insert_example_at_node(recall_tree& b, uint32_t cn, example& ec)
201 {
202  node_pred* ls = find_or_create(b, cn, ec);
203 
204  b.nodes[cn].entropy = updated_entropy(b, cn, ec);
205 
206  ls->label_count += ec.weight;
207 
208  while (ls != b.nodes[cn].preds.begin() && ls[-1].label_count < ls[0].label_count)
209  {
210  std::swap(ls[-1], ls[0]);
211  --ls;
212  }
213 
214  b.nodes[cn].n += ec.weight;
215 
216  compute_recall_lbest(b, &b.nodes[cn]);
217 }
218 
219 // TODO: handle if features already in this namespace
220 
221 void add_node_id_feature(recall_tree& b, uint32_t cn, example& ec)
222 {
223  vw* all = b.all;
224  uint64_t mask = all->weights.mask();
225  size_t ss = all->weights.stride_shift();
226 
229 
230  if (b.node_only)
231  {
232  fs.push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
233  }
234  else
235  {
236  while (cn > 0)
237  {
238  fs.push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
239  cn = b.nodes[cn].parent;
240  }
241  }
242 
243  // TODO: audit ?
244  // TODO: if namespace already exists ?
245 }
246 
247 void remove_node_id_feature(recall_tree& /* b */, uint32_t /* cn */, example& ec)
248 {
250  fs.clear();
251  ec.indices.pop();
252 }
253 
254 uint32_t oas_predict(recall_tree& b, single_learner& base, uint32_t cn, example& ec)
255 {
257  uint32_t save_pred = ec.pred.multiclass;
258 
259  uint32_t amaxscore = 0;
260 
261  add_node_id_feature(b, cn, ec);
262  ec.l.simple = {FLT_MAX, 0.f, 0.f};
263 
264  float maxscore = std::numeric_limits<float>::lowest();
265  for (node_pred* ls = b.nodes[cn].preds.begin();
266  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
267  {
268  base.predict(ec, b.max_routers + ls->label - 1);
269  if (amaxscore == 0 || ec.partial_prediction > maxscore)
270  {
271  maxscore = ec.partial_prediction;
272  amaxscore = ls->label;
273  }
274  }
275 
276  remove_node_id_feature(b, cn, ec);
277 
278  ec.l.multi = mc;
279  ec.pred.multiclass = save_pred;
280 
281  return amaxscore;
282 }
283 
284 bool is_candidate(recall_tree& b, uint32_t cn, example& ec)
285 {
286  for (node_pred* ls = b.nodes[cn].preds.begin();
287  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
288  {
289  if (ls->label == ec.l.multi.label)
290  return true;
291  }
292 
293  return false;
294 }
295 
296 inline uint32_t descend(node& n, float prediction) { return prediction < 0 ? n.left : n.right; }
297 
299 {
300  uint32_t node_id;
302 
303  predict_type(uint32_t a, uint32_t b) : node_id(a), class_prediction(b) {}
304 };
305 
306 bool stop_recurse_check(recall_tree& b, uint32_t parent, uint32_t child)
307 {
308  return b.bern_hyper > 0 && b.nodes[parent].recall_lbest >= b.nodes[child].recall_lbest;
309 }
310 
312 {
314  uint32_t save_pred = ec.pred.multiclass;
315 
316  ec.l.simple = {FLT_MAX, 0.f, 0.f};
317  while (b.nodes[cn].internal)
318  {
319  base.predict(ec, b.nodes[cn].base_router);
320  uint32_t newcn = descend(b.nodes[cn], ec.partial_prediction);
321  bool cond = stop_recurse_check(b, cn, newcn);
322 
323  if (cond)
324  break;
325 
326  cn = newcn;
327  }
328 
329  ec.l.multi = mc;
330  ec.pred.multiclass = save_pred;
331 
332  return predict_type(cn, oas_predict(b, base, cn, ec));
333 }
334 
336 {
337  predict_type pred = predict_from(b, base, ec, 0);
338 
339  ec.pred.multiclass = pred.class_prediction;
340 }
341 
342 float train_node(recall_tree& b, single_learner& base, example& ec, uint32_t cn)
343 {
345  uint32_t save_pred = ec.pred.multiclass;
346 
347  // minimize entropy
348  // better than maximize expected likelihood, and the proofs go through :)
349  double new_left = updated_entropy(b, b.nodes[cn].left, ec);
350  double new_right = updated_entropy(b, b.nodes[cn].right, ec);
351  double old_left = b.nodes[b.nodes[cn].left].entropy;
352  double old_right = b.nodes[b.nodes[cn].right].entropy;
353  double nl = b.nodes[b.nodes[cn].left].n;
354  double nr = b.nodes[b.nodes[cn].right].n;
355  double delta_left = nl * (new_left - old_left) + mc.weight * new_left;
356  double delta_right = nr * (new_right - old_right) + mc.weight * new_right;
357  float route_label = delta_left < delta_right ? -1.f : 1.f;
358  float imp_weight = fabs((float)(delta_left - delta_right));
359 
360  ec.l.simple = {route_label, imp_weight, 0.};
361  base.learn(ec, b.nodes[cn].base_router);
362 
363  // TODO: using the updated routing seems to help
364  // TODO: consider faster version using updated_prediction
365  // TODO: (doesn't play well with link function)
366  base.predict(ec, b.nodes[cn].base_router);
367 
368  float save_scalar = ec.pred.scalar;
369 
370  ec.l.multi = mc;
371  ec.pred.multiclass = save_pred;
372 
373  return save_scalar;
374 }
375 
377 {
378  predict(b, base, ec);
379 
380  if (b.all->training && ec.l.multi.label != (uint32_t)-1) // if training the tree
381  {
382  uint32_t cn = 0;
383 
384  while (b.nodes[cn].internal)
385  {
386  float which = train_node(b, base, ec, cn);
387 
388  if (b.randomized_routing)
389  which = (b._random_state->get_and_update_random() > to_prob(which) ? -1.f : 1.f);
390 
391  uint32_t newcn = descend(b.nodes[cn], which);
392  bool cond = stop_recurse_check(b, cn, newcn);
393  insert_example_at_node(b, cn, ec);
394 
395  if (cond)
396  {
397  insert_example_at_node(b, newcn, ec);
398  break;
399  }
400 
401  cn = newcn;
402  }
403 
404  if (!b.nodes[cn].internal)
405  insert_example_at_node(b, cn, ec);
406 
407  if (is_candidate(b, cn, ec))
408  {
410  uint32_t save_pred = ec.pred.multiclass;
411 
412  add_node_id_feature(b, cn, ec);
413 
414  ec.l.simple = {1.f, 1.f, 0.f};
415  base.learn(ec, b.max_routers + mc.label - 1);
416  ec.l.simple = {-1.f, 1.f, 0.f};
417 
418  for (node_pred* ls = b.nodes[cn].preds.begin();
419  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
420  {
421  if (ls->label != mc.label)
422  base.learn(ec, b.max_routers + ls->label - 1);
423  }
424 
425  remove_node_id_feature(b, cn, ec);
426 
427  ec.l.multi = mc;
428  ec.pred.multiclass = save_pred;
429  }
430  }
431 }
432 
433 void save_load_tree(recall_tree& b, io_buf& model_file, bool read, bool text)
434 {
435  if (model_file.files.size() > 0)
436  {
437  std::stringstream msg;
438 
439  writeit(b.k, "k");
440  writeit(b.node_only, "node_only");
441  writeitvar(b.nodes.size(), "nodes", n_nodes);
442 
443  if (read)
444  {
445  b.nodes.clear();
446  for (uint32_t j = 0; j < n_nodes; ++j)
447  {
448  b.nodes.push_back(node());
449  }
450  }
451 
452  writeit(b.max_candidates, "max_candidates");
453  writeit(b.max_depth, "max_depth");
454 
455  for (uint32_t j = 0; j < n_nodes; ++j)
456  {
457  node* cn = &b.nodes[j];
458 
459  writeit(cn->parent, "parent");
460  writeit(cn->recall_lbest, "recall_lbest");
461  writeit(cn->internal, "internal");
462  writeit(cn->depth, "depth");
463  writeit(cn->base_router, "base_router");
464  writeit(cn->left, "left");
465  writeit(cn->right, "right");
466  writeit(cn->n, "n");
467  writeit(cn->entropy, "entropy");
468  writeit(cn->passes, "passes");
469 
470  writeitvar(cn->preds.size(), "n_preds", n_preds);
471 
472  if (read)
473  {
474  cn->preds.clear();
475 
476  for (uint32_t k = 0; k < n_preds; ++k)
477  {
478  cn->preds.push_back(node_pred(0));
479  }
480  }
481 
482  for (uint32_t k = 0; k < n_preds; ++k)
483  {
484  node_pred* pred = &cn->preds[k];
485 
486  writeit(pred->label, "label");
487  writeit(pred->label_count, "label_count");
488  }
489 
490  if (read)
491  {
492  compute_recall_lbest(b, cn);
493  }
494  }
495  }
496 }
497 
498 } // namespace recall_tree_ns
499 
500 using namespace recall_tree_ns;
501 
503 {
504  auto tree = scoped_calloc_or_throw<recall_tree>();
505  option_group_definition new_options("Recall Tree");
506  new_options.add(make_option("recall_tree", tree->k).keep().help("Use online tree for multiclass"))
507  .add(make_option("max_candidates", tree->max_candidates)
508  .keep()
509  .help("maximum number of labels per leaf in the tree"))
510  .add(make_option("bern_hyper", tree->bern_hyper).default_value(1.f).help("recall tree depth penalty"))
511  .add(make_option("max_depth", tree->max_depth).keep().help("maximum depth of the tree, default log_2 (#classes)"))
512  .add(make_option("node_only", tree->node_only).keep().help("only use node features, not full path features"))
513  .add(make_option("randomized_routing", tree->randomized_routing).keep().help("randomized routing"));
514  options.add_and_parse(new_options);
515 
516  if (!options.was_supplied("recall_tree"))
517  return nullptr;
518 
519  tree->all = &all;
520  tree->_random_state = all.get_random_state();
521  tree->max_candidates = options.was_supplied("max_candidates")
522  ? tree->max_candidates
523  : std::min(tree->k, 4 * (uint32_t)(ceil(log(tree->k) / log(2.0))));
524  tree->max_depth =
525  options.was_supplied("max_depth") ? tree->max_depth : (uint32_t)std::ceil(std::log(tree->k) / std::log(2.0));
526 
527  init_tree(*tree.get());
528 
529  if (!all.quiet)
530  all.trace_message << "recall_tree:"
531  << " node_only = " << tree->node_only << " bern_hyper = " << tree->bern_hyper
532  << " max_depth = " << tree->max_depth << " routing = "
533  << (all.training ? (tree->randomized_routing ? "randomized" : "deterministic") : "n/a testonly")
534  << std::endl;
535 
537  tree, as_singleline(setup_base(options, all)), learn, predict, all.p, tree->max_routers + tree->k);
539 
540  return make_base(l);
541 }
v_array< namespace_index > indices
uint32_t multiclass
Definition: example.h:49
parameters weights
Definition: global_data.h:537
std::shared_ptr< rand_state > _random_state
Definition: recall_tree.cc:65
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint32_t oas_predict(recall_tree &b, single_learner &base, uint32_t cn, example &ec)
Definition: recall_tree.cc:254
T pop()
Definition: v_array.h:58
void push_back(feature_value v, feature_index i)
float scalar
Definition: example.h:45
bool stop_recurse_check(recall_tree &b, uint32_t parent, uint32_t child)
Definition: recall_tree.cc:306
the core definition of a set of features.
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float partial_prediction
Definition: example.h:68
bool quiet
Definition: global_data.h:487
void predict(recall_tree &b, single_learner &base, example &ec)
Definition: recall_tree.cc:335
predict_type(uint32_t a, uint32_t b)
Definition: recall_tree.cc:303
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
label_data simple
Definition: example.h:28
node_pred * find_or_create(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:136
void add_node_id_feature(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:221
T *& begin()
Definition: v_array.h:42
void insert_example_at_node(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:200
bool training
Definition: global_data.h:488
size_t size() const
Definition: v_array.h:68
#define writeit(what, str)
Definition: io_buf.h:349
void train_node(log_multi &b, single_learner &base, example &ec, uint32_t &current, uint32_t &class_index, uint32_t)
Definition: log_multi.cc:251
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
MULTICLASS::label_t multi
Definition: example.h:29
void learn(recall_tree &b, single_learner &base, example &ec)
Definition: recall_tree.cc:376
void compute_recall_lbest(recall_tree &b, node *n)
Definition: recall_tree.cc:150
v_array< T > v_init()
Definition: v_array.h:179
void push_back(const T &new_ele)
Definition: v_array.h:107
v_array< int > files
Definition: io_buf.h:64
void clear()
Definition: v_array.h:88
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
base_learner * recall_tree_setup(options_i &options, vw &all)
Definition: recall_tree.cc:502
void clear()
bool is_candidate(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:284
double plogp(double c, double n)
Definition: recall_tree.cc:169
Definition: io_buf.h:54
void remove_node_id_feature(recall_tree &, uint32_t, example &ec)
Definition: recall_tree.cc:247
T *& end()
Definition: v_array.h:43
float to_prob(float x)
Definition: recall_tree.cc:85
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
predict_type predict_from(recall_tree &b, single_learner &base, example &ec, uint32_t cn)
Definition: recall_tree.cc:311
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
node_pred * find(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:126
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
Definition: log_multi.cc:397
double updated_entropy(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:171
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void init_tree(log_multi &d)
Definition: log_multi.cc:122
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t stride_shift()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
void learn(E &ec, size_t i=0)
Definition: learner.h:160
constexpr unsigned char node_id_namespace
Definition: constant.h:31
float weight
Definition: example.h:62
uint64_t mask()
constexpr uint64_t c
Definition: rand48.cc:12
v_array< node_pred > preds
Definition: recall_tree.cc:44
uint32_t descend(node &n, float prediction)
Definition: log_multi.cc:296
float f
Definition: cache.cc:40
#define writeitvar(what, str, mywhat)
Definition: io_buf.h:356