Vowpal Wabbit
Classes | Functions
recall_tree_ns Namespace Reference

Classes

struct  node
 
struct  node_pred
 
struct  predict_type
 
struct  recall_tree
 

Functions

float to_prob (float x)
 
void init_tree (recall_tree &b, uint32_t root, uint32_t depth, uint32_t &routers_used)
 
void init_tree (recall_tree &b)
 
node_predfind (recall_tree &b, uint32_t cn, example &ec)
 
node_predfind_or_create (recall_tree &b, uint32_t cn, example &ec)
 
void compute_recall_lbest (recall_tree &b, node *n)
 
double plogp (double c, double n)
 
double updated_entropy (recall_tree &b, uint32_t cn, example &ec)
 
void insert_example_at_node (recall_tree &b, uint32_t cn, example &ec)
 
void add_node_id_feature (recall_tree &b, uint32_t cn, example &ec)
 
void remove_node_id_feature (recall_tree &, uint32_t, example &ec)
 
uint32_t oas_predict (recall_tree &b, single_learner &base, uint32_t cn, example &ec)
 
bool is_candidate (recall_tree &b, uint32_t cn, example &ec)
 
uint32_t descend (node &n, float prediction)
 
bool stop_recurse_check (recall_tree &b, uint32_t parent, uint32_t child)
 
predict_type predict_from (recall_tree &b, single_learner &base, example &ec, uint32_t cn)
 
void predict (recall_tree &b, single_learner &base, example &ec)
 
float train_node (recall_tree &b, single_learner &base, example &ec, uint32_t cn)
 
void learn (recall_tree &b, single_learner &base, example &ec)
 
void save_load_tree (recall_tree &b, io_buf &model_file, bool read, bool text)
 

Function Documentation

◆ add_node_id_feature()

void recall_tree_ns::add_node_id_feature ( recall_tree b,
uint32_t  cn,
example ec 
)

Definition at line 221 of file recall_tree.cc.

References recall_tree_ns::recall_tree::all, example_predict::feature_space, example_predict::indices, parameters::mask(), node_id_namespace, recall_tree_ns::recall_tree::node_only, recall_tree_ns::recall_tree::nodes, v_array< T >::push_back(), features::push_back(), parameters::stride_shift(), and vw::weights.

Referenced by learn(), and oas_predict().

222 {
223  vw* all = b.all;
224  uint64_t mask = all->weights.mask();
225  size_t ss = all->weights.stride_shift();
226 
229 
230  if (b.node_only)
231  {
232  fs.push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
233  }
234  else
235  {
236  while (cn > 0)
237  {
238  fs.push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
239  cn = b.nodes[cn].parent;
240  }
241  }
242 
243  // TODO: audit ?
244  // TODO: if namespace already exists ?
245 }
v_array< namespace_index > indices
parameters weights
Definition: global_data.h:537
void push_back(feature_value v, feature_index i)
the core definition of a set of features.
std::array< features, NUM_NAMESPACES > feature_space
void push_back(const T &new_ele)
Definition: v_array.h:107
uint32_t stride_shift()
constexpr unsigned char node_id_namespace
Definition: constant.h:31
uint64_t mask()

◆ compute_recall_lbest()

void recall_tree_ns::compute_recall_lbest ( recall_tree b,
node n 
)

Definition at line 150 of file recall_tree.cc.

References v_array< T >::begin(), recall_tree_ns::recall_tree::bern_hyper, v_array< T >::end(), f, recall_tree_ns::recall_tree::max_candidates, recall_tree_ns::node::n, recall_tree_ns::node::preds, and recall_tree_ns::node::recall_lbest.

Referenced by insert_example_at_node(), and save_load_tree().

151 {
152  if (n->n <= 0)
153  return;
154 
155  double mass_at_k = 0;
156 
157  for (node_pred* ls = n->preds.begin(); ls != n->preds.end() && ls < n->preds.begin() + b.max_candidates; ++ls)
158  {
159  mass_at_k += ls->label_count;
160  }
161 
162  float f = (float)mass_at_k / (float)n->n;
163  float stdf = std::sqrt(f * (1.f - f) / (float)n->n);
164  float diamf = 15.f / (std::sqrt(18.f) * (float)n->n);
165 
166  n->recall_lbest = std::max(0.f, f - std::sqrt(b.bern_hyper) * stdf - b.bern_hyper * diamf);
167 }
T *& begin()
Definition: v_array.h:42
T *& end()
Definition: v_array.h:43
v_array< node_pred > preds
Definition: recall_tree.cc:44
float f
Definition: cache.cc:40

◆ descend()

uint32_t recall_tree_ns::descend ( node n,
float  prediction 
)
inline

Definition at line 296 of file recall_tree.cc.

References recall_tree_ns::node::left, and recall_tree_ns::node::right.

296 { return prediction < 0 ? n.left : n.right; }

◆ find()

node_pred* recall_tree_ns::find ( recall_tree b,
uint32_t  cn,
example ec 
)

◆ find_or_create()

node_pred* recall_tree_ns::find_or_create ( recall_tree b,
uint32_t  cn,
example ec 
)

Definition at line 136 of file recall_tree.cc.

References v_array< T >::end(), find(), example::l, MULTICLASS::label_t::label, polylabel::multi, recall_tree_ns::recall_tree::nodes, and v_array< T >::push_back().

Referenced by insert_example_at_node().

137 {
138  node_pred* ls = find(b, cn, ec);
139 
140  if (ls == b.nodes[cn].preds.end())
141  {
142  node_pred newls(ec.l.multi.label);
143  b.nodes[cn].preds.push_back(newls);
144  ls = b.nodes[cn].preds.end() - 1;
145  }
146 
147  return ls;
148 }
MULTICLASS::label_t multi
Definition: example.h:29
void push_back(const T &new_ele)
Definition: v_array.h:107
T *& end()
Definition: v_array.h:43
polylabel l
Definition: example.h:57
node_pred * find(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:126

◆ init_tree() [1/2]

void recall_tree_ns::init_tree ( recall_tree b,
uint32_t  root,
uint32_t  depth,
uint32_t &  routers_used 
)

Definition at line 92 of file recall_tree.cc.

References init_tree(), recall_tree_ns::recall_tree::max_depth, recall_tree_ns::recall_tree::nodes, v_array< T >::push_back(), and v_array< T >::size().

93 {
94  if (depth <= b.max_depth)
95  {
96  uint32_t left_child;
97  uint32_t right_child;
98  left_child = (uint32_t)b.nodes.size();
99  b.nodes.push_back(node());
100  right_child = (uint32_t)b.nodes.size();
101  b.nodes.push_back(node());
102  b.nodes[root].base_router = routers_used++;
103 
104  b.nodes[root].internal = true;
105  b.nodes[root].left = left_child;
106  b.nodes[left_child].parent = root;
107  b.nodes[left_child].depth = depth;
108  b.nodes[root].right = right_child;
109  b.nodes[right_child].parent = root;
110  b.nodes[right_child].depth = depth;
111 
112  init_tree(b, left_child, depth + 1, routers_used);
113  init_tree(b, right_child, depth + 1, routers_used);
114  }
115 }
size_t size() const
Definition: v_array.h:68
void push_back(const T &new_ele)
Definition: v_array.h:107
void init_tree(log_multi &d)
Definition: log_multi.cc:122

◆ init_tree() [2/2]

void recall_tree_ns::init_tree ( recall_tree b)

Definition at line 117 of file recall_tree.cc.

References init_tree(), recall_tree_ns::recall_tree::max_routers, recall_tree_ns::recall_tree::nodes, and v_array< T >::push_back().

118 {
119  uint32_t routers_used = 0;
120 
121  b.nodes.push_back(node());
122  init_tree(b, 0, 1, routers_used);
123  b.max_routers = routers_used;
124 }
void push_back(const T &new_ele)
Definition: v_array.h:107
void init_tree(log_multi &d)
Definition: log_multi.cc:122

◆ insert_example_at_node()

void recall_tree_ns::insert_example_at_node ( recall_tree b,
uint32_t  cn,
example ec 
)

Definition at line 200 of file recall_tree.cc.

References v_array< T >::begin(), compute_recall_lbest(), find_or_create(), recall_tree_ns::node_pred::label_count, recall_tree_ns::recall_tree::nodes, updated_entropy(), and example::weight.

Referenced by learn().

201 {
202  node_pred* ls = find_or_create(b, cn, ec);
203 
204  b.nodes[cn].entropy = updated_entropy(b, cn, ec);
205 
206  ls->label_count += ec.weight;
207 
208  while (ls != b.nodes[cn].preds.begin() && ls[-1].label_count < ls[0].label_count)
209  {
210  std::swap(ls[-1], ls[0]);
211  --ls;
212  }
213 
214  b.nodes[cn].n += ec.weight;
215 
216  compute_recall_lbest(b, &b.nodes[cn]);
217 }
node_pred * find_or_create(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:136
T *& begin()
Definition: v_array.h:42
void compute_recall_lbest(recall_tree &b, node *n)
Definition: recall_tree.cc:150
double updated_entropy(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:171
float weight
Definition: example.h:62

◆ is_candidate()

bool recall_tree_ns::is_candidate ( recall_tree b,
uint32_t  cn,
example ec 
)

Definition at line 284 of file recall_tree.cc.

References v_array< T >::begin(), v_array< T >::end(), example::l, MULTICLASS::label_t::label, recall_tree_ns::recall_tree::max_candidates, polylabel::multi, and recall_tree_ns::recall_tree::nodes.

Referenced by learn().

285 {
286  for (node_pred* ls = b.nodes[cn].preds.begin();
287  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
288  {
289  if (ls->label == ec.l.multi.label)
290  return true;
291  }
292 
293  return false;
294 }
T *& begin()
Definition: v_array.h:42
MULTICLASS::label_t multi
Definition: example.h:29
T *& end()
Definition: v_array.h:43
polylabel l
Definition: example.h:57

◆ learn()

void recall_tree_ns::learn ( recall_tree b,
single_learner base,
example ec 
)

Definition at line 376 of file recall_tree.cc.

References recall_tree_ns::recall_tree::_random_state, add_node_id_feature(), recall_tree_ns::recall_tree::all, v_array< T >::begin(), descend(), v_array< T >::end(), insert_example_at_node(), is_candidate(), example::l, MULTICLASS::label_t::label, LEARNER::learner< T, E >::learn(), recall_tree_ns::recall_tree::max_candidates, recall_tree_ns::recall_tree::max_routers, label_type::mc, polylabel::multi, polyprediction::multiclass, recall_tree_ns::recall_tree::nodes, example::pred, predict(), recall_tree_ns::recall_tree::randomized_routing, remove_node_id_feature(), polylabel::simple, stop_recurse_check(), to_prob(), train_node(), and vw::training.

Referenced by recall_tree_setup().

377 {
378  predict(b, base, ec);
379 
380  if (b.all->training && ec.l.multi.label != (uint32_t)-1) // if training the tree
381  {
382  uint32_t cn = 0;
383 
384  while (b.nodes[cn].internal)
385  {
386  float which = train_node(b, base, ec, cn);
387 
388  if (b.randomized_routing)
389  which = (b._random_state->get_and_update_random() > to_prob(which) ? -1.f : 1.f);
390 
391  uint32_t newcn = descend(b.nodes[cn], which);
392  bool cond = stop_recurse_check(b, cn, newcn);
393  insert_example_at_node(b, cn, ec);
394 
395  if (cond)
396  {
397  insert_example_at_node(b, newcn, ec);
398  break;
399  }
400 
401  cn = newcn;
402  }
403 
404  if (!b.nodes[cn].internal)
405  insert_example_at_node(b, cn, ec);
406 
407  if (is_candidate(b, cn, ec))
408  {
410  uint32_t save_pred = ec.pred.multiclass;
411 
412  add_node_id_feature(b, cn, ec);
413 
414  ec.l.simple = {1.f, 1.f, 0.f};
415  base.learn(ec, b.max_routers + mc.label - 1);
416  ec.l.simple = {-1.f, 1.f, 0.f};
417 
418  for (node_pred* ls = b.nodes[cn].preds.begin();
419  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
420  {
421  if (ls->label != mc.label)
422  base.learn(ec, b.max_routers + ls->label - 1);
423  }
424 
425  remove_node_id_feature(b, cn, ec);
426 
427  ec.l.multi = mc;
428  ec.pred.multiclass = save_pred;
429  }
430  }
431 }
uint32_t multiclass
Definition: example.h:49
std::shared_ptr< rand_state > _random_state
Definition: recall_tree.cc:65
bool stop_recurse_check(recall_tree &b, uint32_t parent, uint32_t child)
Definition: recall_tree.cc:306
void predict(recall_tree &b, single_learner &base, example &ec)
Definition: recall_tree.cc:335
label_data simple
Definition: example.h:28
void add_node_id_feature(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:221
T *& begin()
Definition: v_array.h:42
void insert_example_at_node(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:200
bool training
Definition: global_data.h:488
void train_node(log_multi &b, single_learner &base, example &ec, uint32_t &current, uint32_t &class_index, uint32_t)
Definition: log_multi.cc:251
MULTICLASS::label_t multi
Definition: example.h:29
bool is_candidate(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:284
void remove_node_id_feature(recall_tree &, uint32_t, example &ec)
Definition: recall_tree.cc:247
T *& end()
Definition: v_array.h:43
float to_prob(float x)
Definition: recall_tree.cc:85
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
uint32_t descend(node &n, float prediction)
Definition: log_multi.cc:296

◆ oas_predict()

uint32_t recall_tree_ns::oas_predict ( recall_tree b,
single_learner base,
uint32_t  cn,
example ec 
)

Definition at line 254 of file recall_tree.cc.

References add_node_id_feature(), v_array< T >::begin(), v_array< T >::end(), example::l, recall_tree_ns::recall_tree::max_candidates, recall_tree_ns::recall_tree::max_routers, label_type::mc, polylabel::multi, polyprediction::multiclass, recall_tree_ns::recall_tree::nodes, example::partial_prediction, example::pred, LEARNER::learner< T, E >::predict(), remove_node_id_feature(), and polylabel::simple.

Referenced by predict_from().

255 {
257  uint32_t save_pred = ec.pred.multiclass;
258 
259  uint32_t amaxscore = 0;
260 
261  add_node_id_feature(b, cn, ec);
262  ec.l.simple = {FLT_MAX, 0.f, 0.f};
263 
264  float maxscore = std::numeric_limits<float>::lowest();
265  for (node_pred* ls = b.nodes[cn].preds.begin();
266  ls != b.nodes[cn].preds.end() && ls < b.nodes[cn].preds.begin() + b.max_candidates; ++ls)
267  {
268  base.predict(ec, b.max_routers + ls->label - 1);
269  if (amaxscore == 0 || ec.partial_prediction > maxscore)
270  {
271  maxscore = ec.partial_prediction;
272  amaxscore = ls->label;
273  }
274  }
275 
276  remove_node_id_feature(b, cn, ec);
277 
278  ec.l.multi = mc;
279  ec.pred.multiclass = save_pred;
280 
281  return amaxscore;
282 }
uint32_t multiclass
Definition: example.h:49
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float partial_prediction
Definition: example.h:68
label_data simple
Definition: example.h:28
void add_node_id_feature(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:221
T *& begin()
Definition: v_array.h:42
MULTICLASS::label_t multi
Definition: example.h:29
void remove_node_id_feature(recall_tree &, uint32_t, example &ec)
Definition: recall_tree.cc:247
T *& end()
Definition: v_array.h:43
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60

◆ plogp()

double recall_tree_ns::plogp ( double  c,
double  n 
)

Definition at line 169 of file recall_tree.cc.

Referenced by updated_entropy().

169 { return (c == 0) ? 0 : (c / n) * log(c / n); }
constexpr uint64_t c
Definition: rand48.cc:12

◆ predict()

void recall_tree_ns::predict ( recall_tree b,
single_learner base,
example ec 
)

Definition at line 335 of file recall_tree.cc.

References recall_tree_ns::predict_type::class_prediction, polyprediction::multiclass, example::pred, and predict_from().

Referenced by learn(), and recall_tree_setup().

336 {
337  predict_type pred = predict_from(b, base, ec, 0);
338 
339  ec.pred.multiclass = pred.class_prediction;
340 }
uint32_t multiclass
Definition: example.h:49
predict_type predict_from(recall_tree &b, single_learner &base, example &ec, uint32_t cn)
Definition: recall_tree.cc:311
polyprediction pred
Definition: example.h:60

◆ predict_from()

predict_type recall_tree_ns::predict_from ( recall_tree b,
single_learner base,
example ec,
uint32_t  cn 
)

Definition at line 311 of file recall_tree.cc.

References descend(), example::l, label_type::mc, polylabel::multi, polyprediction::multiclass, recall_tree_ns::recall_tree::nodes, oas_predict(), example::partial_prediction, example::pred, LEARNER::learner< T, E >::predict(), polylabel::simple, and stop_recurse_check().

Referenced by predict().

312 {
314  uint32_t save_pred = ec.pred.multiclass;
315 
316  ec.l.simple = {FLT_MAX, 0.f, 0.f};
317  while (b.nodes[cn].internal)
318  {
319  base.predict(ec, b.nodes[cn].base_router);
320  uint32_t newcn = descend(b.nodes[cn], ec.partial_prediction);
321  bool cond = stop_recurse_check(b, cn, newcn);
322 
323  if (cond)
324  break;
325 
326  cn = newcn;
327  }
328 
329  ec.l.multi = mc;
330  ec.pred.multiclass = save_pred;
331 
332  return predict_type(cn, oas_predict(b, base, cn, ec));
333 }
uint32_t multiclass
Definition: example.h:49
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint32_t oas_predict(recall_tree &b, single_learner &base, uint32_t cn, example &ec)
Definition: recall_tree.cc:254
bool stop_recurse_check(recall_tree &b, uint32_t parent, uint32_t child)
Definition: recall_tree.cc:306
float partial_prediction
Definition: example.h:68
label_data simple
Definition: example.h:28
MULTICLASS::label_t multi
Definition: example.h:29
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
uint32_t descend(node &n, float prediction)
Definition: log_multi.cc:296

◆ remove_node_id_feature()

void recall_tree_ns::remove_node_id_feature ( recall_tree ,
uint32_t  ,
example ec 
)

Definition at line 247 of file recall_tree.cc.

References features::clear(), example_predict::feature_space, example_predict::indices, node_id_namespace, and v_array< T >::pop().

Referenced by learn(), and oas_predict().

248 {
250  fs.clear();
251  ec.indices.pop();
252 }
v_array< namespace_index > indices
T pop()
Definition: v_array.h:58
the core definition of a set of features.
std::array< features, NUM_NAMESPACES > feature_space
void clear()
constexpr unsigned char node_id_namespace
Definition: constant.h:31

◆ save_load_tree()

void recall_tree_ns::save_load_tree ( recall_tree b,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 433 of file recall_tree.cc.

References recall_tree_ns::node::base_router, v_array< T >::clear(), compute_recall_lbest(), recall_tree_ns::node::depth, recall_tree_ns::node::entropy, io_buf::files, recall_tree_ns::node::internal, recall_tree_ns::recall_tree::k, recall_tree_ns::node_pred::label, recall_tree_ns::node_pred::label_count, recall_tree_ns::node::left, recall_tree_ns::recall_tree::max_candidates, recall_tree_ns::recall_tree::max_depth, recall_tree_ns::node::n, recall_tree_ns::recall_tree::node_only, recall_tree_ns::recall_tree::nodes, recall_tree_ns::node::parent, recall_tree_ns::node::passes, recall_tree_ns::node::preds, v_array< T >::push_back(), recall_tree_ns::node::recall_lbest, recall_tree_ns::node::right, v_array< T >::size(), writeit, and writeitvar.

434 {
435  if (model_file.files.size() > 0)
436  {
437  std::stringstream msg;
438 
439  writeit(b.k, "k");
440  writeit(b.node_only, "node_only");
441  writeitvar(b.nodes.size(), "nodes", n_nodes);
442 
443  if (read)
444  {
445  b.nodes.clear();
446  for (uint32_t j = 0; j < n_nodes; ++j)
447  {
448  b.nodes.push_back(node());
449  }
450  }
451 
452  writeit(b.max_candidates, "max_candidates");
453  writeit(b.max_depth, "max_depth");
454 
455  for (uint32_t j = 0; j < n_nodes; ++j)
456  {
457  node* cn = &b.nodes[j];
458 
459  writeit(cn->parent, "parent");
460  writeit(cn->recall_lbest, "recall_lbest");
461  writeit(cn->internal, "internal");
462  writeit(cn->depth, "depth");
463  writeit(cn->base_router, "base_router");
464  writeit(cn->left, "left");
465  writeit(cn->right, "right");
466  writeit(cn->n, "n");
467  writeit(cn->entropy, "entropy");
468  writeit(cn->passes, "passes");
469 
470  writeitvar(cn->preds.size(), "n_preds", n_preds);
471 
472  if (read)
473  {
474  cn->preds.clear();
475 
476  for (uint32_t k = 0; k < n_preds; ++k)
477  {
478  cn->preds.push_back(node_pred(0));
479  }
480  }
481 
482  for (uint32_t k = 0; k < n_preds; ++k)
483  {
484  node_pred* pred = &cn->preds[k];
485 
486  writeit(pred->label, "label");
487  writeit(pred->label_count, "label_count");
488  }
489 
490  if (read)
491  {
492  compute_recall_lbest(b, cn);
493  }
494  }
495  }
496 }
size_t size() const
Definition: v_array.h:68
#define writeit(what, str)
Definition: io_buf.h:349
void compute_recall_lbest(recall_tree &b, node *n)
Definition: recall_tree.cc:150
void push_back(const T &new_ele)
Definition: v_array.h:107
v_array< int > files
Definition: io_buf.h:64
void clear()
Definition: v_array.h:88
v_array< node_pred > preds
Definition: recall_tree.cc:44
#define writeitvar(what, str, mywhat)
Definition: io_buf.h:356

◆ stop_recurse_check()

bool recall_tree_ns::stop_recurse_check ( recall_tree b,
uint32_t  parent,
uint32_t  child 
)

Definition at line 306 of file recall_tree.cc.

References recall_tree_ns::recall_tree::bern_hyper, and recall_tree_ns::recall_tree::nodes.

Referenced by learn(), and predict_from().

307 {
308  return b.bern_hyper > 0 && b.nodes[parent].recall_lbest >= b.nodes[child].recall_lbest;
309 }

◆ to_prob()

float recall_tree_ns::to_prob ( float  x)

Definition at line 85 of file recall_tree.cc.

References f.

Referenced by learn().

86 {
87  static const float alpha = 2.0f;
88  // http://stackoverflow.com/questions/2789481/problem-calling-stdmax
89  return std::max(0.f, std::min(1.f, 0.5f * (1.0f + alpha * x)));
90 }
float f
Definition: cache.cc:40

◆ train_node()

float recall_tree_ns::train_node ( recall_tree b,
single_learner base,
example ec,
uint32_t  cn 
)

Definition at line 342 of file recall_tree.cc.

References example::l, LEARNER::learner< T, E >::learn(), label_type::mc, polylabel::multi, polyprediction::multiclass, recall_tree_ns::recall_tree::nodes, example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, polylabel::simple, updated_entropy(), and MULTICLASS::label_t::weight.

343 {
345  uint32_t save_pred = ec.pred.multiclass;
346 
347  // minimize entropy
348  // better than maximize expected likelihood, and the proofs go through :)
349  double new_left = updated_entropy(b, b.nodes[cn].left, ec);
350  double new_right = updated_entropy(b, b.nodes[cn].right, ec);
351  double old_left = b.nodes[b.nodes[cn].left].entropy;
352  double old_right = b.nodes[b.nodes[cn].right].entropy;
353  double nl = b.nodes[b.nodes[cn].left].n;
354  double nr = b.nodes[b.nodes[cn].right].n;
355  double delta_left = nl * (new_left - old_left) + mc.weight * new_left;
356  double delta_right = nr * (new_right - old_right) + mc.weight * new_right;
357  float route_label = delta_left < delta_right ? -1.f : 1.f;
358  float imp_weight = fabs((float)(delta_left - delta_right));
359 
360  ec.l.simple = {route_label, imp_weight, 0.};
361  base.learn(ec, b.nodes[cn].base_router);
362 
363  // TODO: using the updated routing seems to help
364  // TODO: consider faster version using updated_prediction
365  // TODO: (doesn't play well with link function)
366  base.predict(ec, b.nodes[cn].base_router);
367 
368  float save_scalar = ec.pred.scalar;
369 
370  ec.l.multi = mc;
371  ec.pred.multiclass = save_pred;
372 
373  return save_scalar;
374 }
uint32_t multiclass
Definition: example.h:49
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
label_data simple
Definition: example.h:28
MULTICLASS::label_t multi
Definition: example.h:29
polylabel l
Definition: example.h:57
double updated_entropy(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:171
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160

◆ updated_entropy()

double recall_tree_ns::updated_entropy ( recall_tree b,
uint32_t  cn,
example ec 
)

Definition at line 171 of file recall_tree.cc.

References v_array< T >::end(), find(), recall_tree_ns::node_pred::label_count, recall_tree_ns::recall_tree::nodes, plogp(), and example::weight.

Referenced by insert_example_at_node(), and train_node().

172 {
173  node_pred* ls = find(b, cn, ec);
174 
175  // entropy = -\sum_k (c_k/n) Log[c_k/n]
176  // c_0 <- c_0 + 1, n <- n + 1
177  // entropy <- + (c_0/n) Log[c_0/n]
178  // - n/(n+1) \sum_{k>0} (c_k/n) Log[c_k/n]
179  // - Log[n/(n+1)] \sum_{k>0} (c_k/(n+1))
180  // - ((c_0+1)/(n+1)) Log[(c_0+1)/(n+1)]
181 
182  double c0 = (ls == b.nodes[cn].preds.end()) ? 0 : ls->label_count;
183  double deltac0 = ec.weight;
184  double n = b.nodes[cn].n;
185 
186  double novernp1 = n / (deltac0 + n);
187  double lognovernp1 = (novernp1 == 0) ? 0 : log(novernp1);
188  double nminusc0overnp1 = (n - c0) / (n + deltac0);
189 
190  double newentropy = b.nodes[cn].entropy;
191 
192  newentropy += plogp(c0, n);
193  newentropy *= novernp1;
194  newentropy -= lognovernp1 * nminusc0overnp1;
195  newentropy -= plogp(c0 + deltac0, n + deltac0);
196 
197  return newentropy;
198 }
double plogp(double c, double n)
Definition: recall_tree.cc:169
T *& end()
Definition: v_array.h:43
node_pred * find(recall_tree &b, uint32_t cn, example &ec)
Definition: recall_tree.cc:126
float weight
Definition: example.h:62