Vowpal Wabbit
log_multi.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.node
5 */
6 #include <float.h>
7 #include <math.h>
8 #include <stdio.h>
9 #include <sstream>
10 
11 #include "reductions.h"
12 
13 using namespace LEARNER;
14 using namespace VW::config;
15 
16 class node_pred
17 {
18  public:
19  double Ehk;
20  float norm_Ehk;
21  uint32_t nk;
22  uint32_t label;
23  uint32_t label_count;
24 
25  bool operator==(node_pred v) { return (label == v.label); }
26 
28  {
29  if (label > v.label)
30  return true;
31  return false;
32  }
33 
35  {
36  if (label < v.label)
37  return true;
38  return false;
39  }
40 
41  node_pred(uint32_t l)
42  {
43  label = l;
44  Ehk = 0.f;
45  norm_Ehk = 0;
46  nk = 0;
47  label_count = 0;
48  }
49 };
50 
51 typedef struct
52 {
53  // everyone has
54  uint32_t parent; // the parent node
55  v_array<node_pred> preds; // per-class state
56  uint32_t
57  min_count; // the number of examples reaching this node (if it's a leaf) or the minimum reaching any grandchild.
58 
59  bool internal; // internal or leaf
60 
61  // internal nodes have
62  uint32_t base_predictor; // id of the base predictor
63  uint32_t left; // left child
64  uint32_t right; // right child
65  float norm_Eh; // the average margin at the node
66  double Eh; // total margin at the node
67  uint32_t n; // total events at the node
68 
69  // leaf has
70  uint32_t max_count; // the number of samples of the most common label
71  uint32_t max_count_label; // the most common label
72 } node;
73 
74 struct log_multi
75 {
76  uint32_t k;
77 
79 
82 
83  bool progress;
84  uint32_t swap_resist;
85 
86  uint32_t nbofswaps;
87 
89  {
90  // save_node_stats(b);
91  for (auto& node : nodes) node.preds.delete_v();
92  nodes.delete_v();
93  }
94 };
95 
96 inline void init_leaf(node& n)
97 {
98  n.internal = false;
99  n.preds.clear();
100  n.base_predictor = 0;
101  n.norm_Eh = 0;
102  n.Eh = 0;
103  n.n = 0;
104  n.max_count = 0;
105  n.max_count_label = 1;
106  n.left = 0;
107  n.right = 0;
108 }
109 
110 inline node init_node()
111 {
112  node node;
113 
114  node.parent = 0;
115  node.min_count = 0;
116  node.preds = v_init<node_pred>();
117  init_leaf(node);
118 
119  return node;
120 }
121 
123 {
125  d.nbofswaps = 0;
126 }
127 
128 inline uint32_t min_left_right(log_multi& b, const node& n)
129 {
130  return std::min(b.nodes[n.left].min_count, b.nodes[n.right].min_count);
131 }
132 
133 inline uint32_t find_switch_node(log_multi& b)
134 {
135  uint32_t node = 0;
136  while (b.nodes[node].internal)
137  if (b.nodes[b.nodes[node].left].min_count < b.nodes[b.nodes[node].right].min_count)
138  node = b.nodes[node].left;
139  else
140  node = b.nodes[node].right;
141  return node;
142 }
143 
144 inline void update_min_count(log_multi& b, uint32_t node)
145 {
146  // Constant time min count update.
147  while (node != 0)
148  {
149  uint32_t prev = node;
150  node = b.nodes[node].parent;
151 
152  if (b.nodes[node].min_count == b.nodes[prev].min_count)
153  break;
154  else
155  b.nodes[node].min_count = min_left_right(b, b.nodes[node]);
156  }
157 }
158 
159 void display_tree_dfs(log_multi& b, const node& node, uint32_t depth)
160 {
161  for (uint32_t i = 0; i < depth; i++) std::cout << "\t";
162  std::cout << node.min_count << " " << node.left << " " << node.right;
163  std::cout << " label = " << node.max_count_label << " labels = ";
164  for (size_t i = 0; i < node.preds.size(); i++)
165  std::cout << node.preds[i].label << ":" << node.preds[i].label_count << "\t";
166  std::cout << std::endl;
167 
168  if (node.internal)
169  {
170  std::cout << "Left";
171  display_tree_dfs(b, b.nodes[node.left], depth + 1);
172 
173  std::cout << "Right";
174  display_tree_dfs(b, b.nodes[node.right], depth + 1);
175  }
176 }
177 
178 bool children(log_multi& b, uint32_t& current, uint32_t& class_index, uint32_t label)
179 {
180  class_index = (uint32_t)b.nodes[current].preds.unique_add_sorted(node_pred(label));
181  b.nodes[current].preds[class_index].label_count++;
182 
183  if (b.nodes[current].preds[class_index].label_count > b.nodes[current].max_count)
184  {
185  b.nodes[current].max_count = b.nodes[current].preds[class_index].label_count;
186  b.nodes[current].max_count_label = b.nodes[current].preds[class_index].label;
187  }
188 
189  if (b.nodes[current].internal)
190  return true;
191  else if (b.nodes[current].preds.size() > 1 &&
193  b.nodes[current].min_count - b.nodes[current].max_count > b.swap_resist * (b.nodes[0].min_count + 1)))
194  {
195  // need children and we can make them.
196  uint32_t left_child;
197  uint32_t right_child;
199  {
200  left_child = (uint32_t)b.nodes.size();
202  right_child = (uint32_t)b.nodes.size();
204  b.nodes[current].base_predictor = (uint32_t)b.predictors_used++;
205  }
206  else
207  {
208  uint32_t swap_child = find_switch_node(b);
209  uint32_t swap_parent = b.nodes[swap_child].parent;
210  uint32_t swap_grandparent = b.nodes[swap_parent].parent;
211  if (b.nodes[swap_child].min_count != b.nodes[0].min_count)
212  std::cout << "glargh " << b.nodes[swap_child].min_count << " != " << b.nodes[0].min_count << std::endl;
213  b.nbofswaps++;
214 
215  uint32_t nonswap_child;
216  if (swap_child == b.nodes[swap_parent].right)
217  nonswap_child = b.nodes[swap_parent].left;
218  else
219  nonswap_child = b.nodes[swap_parent].right;
220 
221  if (swap_parent == b.nodes[swap_grandparent].left)
222  b.nodes[swap_grandparent].left = nonswap_child;
223  else
224  b.nodes[swap_grandparent].right = nonswap_child;
225  b.nodes[nonswap_child].parent = swap_grandparent;
226  update_min_count(b, nonswap_child);
227 
228  init_leaf(b.nodes[swap_child]);
229  left_child = swap_child;
230  b.nodes[current].base_predictor = b.nodes[swap_parent].base_predictor;
231  init_leaf(b.nodes[swap_parent]);
232  right_child = swap_parent;
233  }
234  b.nodes[current].left = left_child;
235  b.nodes[left_child].parent = current;
236  b.nodes[current].right = right_child;
237  b.nodes[right_child].parent = current;
238 
239  b.nodes[left_child].min_count = b.nodes[current].min_count / 2;
240  b.nodes[right_child].min_count = b.nodes[current].min_count - b.nodes[left_child].min_count;
241  update_min_count(b, left_child);
242 
243  b.nodes[left_child].max_count_label = b.nodes[current].max_count_label;
244  b.nodes[right_child].max_count_label = b.nodes[current].max_count_label;
245 
246  b.nodes[current].internal = true;
247  }
248  return b.nodes[current].internal;
249 }
250 
252  log_multi& b, single_learner& base, example& ec, uint32_t& current, uint32_t& class_index, uint32_t /* depth */)
253 {
254  if (b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
255  ec.l.simple.label = -1.f;
256  else
257  ec.l.simple.label = 1.f;
258 
259  base.learn(ec, b.nodes[current].base_predictor); // depth
260 
261  ec.l.simple.label = FLT_MAX;
262  base.predict(ec, b.nodes[current].base_predictor); // depth
263 
264  b.nodes[current].Eh += (double)ec.partial_prediction;
265  b.nodes[current].preds[class_index].Ehk += (double)ec.partial_prediction;
266  b.nodes[current].n++;
267  b.nodes[current].preds[class_index].nk++;
268 
269  b.nodes[current].norm_Eh = (float)b.nodes[current].Eh / b.nodes[current].n;
270  b.nodes[current].preds[class_index].norm_Ehk =
271  (float)b.nodes[current].preds[class_index].Ehk / b.nodes[current].preds[class_index].nk;
272 }
273 
275 {
276  if (node.internal)
277  {
278  if (node.min_count != min_left_right(b, node))
279  {
280  std::cout << "badness! " << std::endl;
281  display_tree_dfs(b, b.nodes[0], 0);
282  }
283  verify_min_dfs(b, b.nodes[node.left]);
284  verify_min_dfs(b, b.nodes[node.right]);
285  }
286 }
287 
288 size_t sum_count_dfs(log_multi& b, const node& node)
289 {
290  if (node.internal)
291  return sum_count_dfs(b, b.nodes[node.left]) + sum_count_dfs(b, b.nodes[node.right]);
292  else
293  return node.min_count;
294 }
295 
296 inline uint32_t descend(node& n, float prediction)
297 {
298  if (prediction < 0)
299  return n.left;
300  else
301  return n.right;
302 }
303 
305 {
307 
308  ec.l.simple = {FLT_MAX, 0.f, 0.f};
309  uint32_t cn = 0;
310  uint32_t depth = 0;
311  while (b.nodes[cn].internal)
312  {
313  base.predict(ec, b.nodes[cn].base_predictor); // depth
314  cn = descend(b.nodes[cn], ec.pred.scalar);
315  depth++;
316  }
317  ec.pred.multiclass = b.nodes[cn].max_count_label;
318  ec.l.multi = mc;
319 }
320 
321 void learn(log_multi& b, single_learner& base, example& ec)
322 {
323  // verify_min_dfs(b, b.nodes[0]);
324  if (ec.l.multi.label == (uint32_t)-1 || b.progress)
325  predict(b, base, ec);
326 
327  if (ec.l.multi.label != (uint32_t)-1) // if training the tree
328  {
330  uint32_t start_pred = ec.pred.multiclass;
331 
332  uint32_t class_index = 0;
333  ec.l.simple = {FLT_MAX, 0.f, 0.f};
334  uint32_t cn = 0;
335  uint32_t depth = 0;
336  while (children(b, cn, class_index, mc.label))
337  {
338  train_node(b, base, ec, cn, class_index, depth);
339  cn = descend(b.nodes[cn], ec.pred.scalar);
340  depth++;
341  }
342 
343  b.nodes[cn].min_count++;
344  update_min_count(b, cn);
345  ec.pred.multiclass = start_pred;
346  ec.l.multi = mc;
347  }
348 }
349 
351 {
352  FILE* fp;
353  uint32_t i, j;
354  uint32_t total;
355  log_multi* b = &d;
356 
357  fp = fopen("atxm_debug.csv", "wt");
358 
359  for (i = 0; i < b->nodes.size(); i++)
360  {
361  fprintf(fp, "Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (int)i, (int)b->nodes[i].internal,
362  b->nodes[i].Eh / b->nodes[i].n, b->nodes[i].n);
363 
364  fprintf(fp, "Label:, ");
365  for (j = 0; j < b->nodes[i].preds.size(); j++)
366  {
367  fprintf(fp, "%6d,", (int)b->nodes[i].preds[j].label);
368  }
369  fprintf(fp, "\n");
370 
371  fprintf(fp, "Ehk:, ");
372  for (j = 0; j < b->nodes[i].preds.size(); j++)
373  {
374  fprintf(fp, "%7.4f,", b->nodes[i].preds[j].Ehk / b->nodes[i].preds[j].nk);
375  }
376  fprintf(fp, "\n");
377 
378  total = 0;
379 
380  fprintf(fp, "nk:, ");
381  for (j = 0; j < b->nodes[i].preds.size(); j++)
382  {
383  fprintf(fp, "%6d,", (int)b->nodes[i].preds[j].nk);
384  total += b->nodes[i].preds[j].nk;
385  }
386  fprintf(fp, "\n");
387 
388  fprintf(fp, "max(lab:cnt:tot):, %3d,%6d,%7d,\n", (int)b->nodes[i].max_count_label, (int)b->nodes[i].max_count,
389  (int)total);
390  fprintf(fp, "left: %4d, right: %4d", (int)b->nodes[i].left, (int)b->nodes[i].right);
391  fprintf(fp, "\n\n");
392  }
393 
394  fclose(fp);
395 }
396 
397 void save_load_tree(log_multi& b, io_buf& model_file, bool read, bool text)
398 {
399  if (model_file.files.size() > 0)
400  {
401  std::stringstream msg;
402  msg << "k = " << b.k;
403  bin_text_read_write_fixed(model_file, (char*)&b.max_predictors, sizeof(b.k), "", read, msg, text);
404 
405  msg << "nodes = " << b.nodes.size() << " ";
406  uint32_t temp = (uint32_t)b.nodes.size();
407  bin_text_read_write_fixed(model_file, (char*)&temp, sizeof(temp), "", read, msg, text);
408  if (read)
409  for (uint32_t j = 1; j < temp; j++) b.nodes.push_back(init_node());
410 
411  msg << "max predictors = " << b.max_predictors << " ";
412  bin_text_read_write_fixed(model_file, (char*)&b.max_predictors, sizeof(b.max_predictors), "", read, msg, text);
413 
414  msg << "predictors_used = " << b.predictors_used << " ";
415  bin_text_read_write_fixed(model_file, (char*)&b.predictors_used, sizeof(b.predictors_used), "", read, msg, text);
416 
417  msg << "progress = " << b.progress << " ";
418  bin_text_read_write_fixed(model_file, (char*)&b.progress, sizeof(b.progress), "", read, msg, text);
419 
420  msg << "swap_resist = " << b.swap_resist << "\n";
421  bin_text_read_write_fixed(model_file, (char*)&b.swap_resist, sizeof(b.swap_resist), "", read, msg, text);
422 
423  for (size_t j = 0; j < b.nodes.size(); j++)
424  {
425  // Need to read or write nodes.
426  node& n = b.nodes[j];
427 
428  msg << " parent = " << n.parent;
429  bin_text_read_write_fixed(model_file, (char*)&n.parent, sizeof(n.parent), "", read, msg, text);
430 
431  uint32_t temp = (uint32_t)n.preds.size();
432 
433  msg << " preds = " << temp;
434  bin_text_read_write_fixed(model_file, (char*)&temp, sizeof(temp), "", read, msg, text);
435  if (read)
436  for (uint32_t k = 0; k < temp; k++) n.preds.push_back(node_pred(1));
437 
438  msg << " min_count = " << n.min_count;
439  bin_text_read_write_fixed(model_file, (char*)&n.min_count, sizeof(n.min_count), "", read, msg, text);
440 
441  msg << " internal = " << n.internal;
442  bin_text_read_write_fixed(model_file, (char*)&n.internal, sizeof(n.internal), "", read, msg, text);
443 
444  if (n.internal)
445  {
446  msg << " base_predictor = " << n.base_predictor;
447  bin_text_read_write_fixed(model_file, (char*)&n.base_predictor, sizeof(n.base_predictor), "", read, msg, text);
448 
449  msg << " left = " << n.left;
450  bin_text_read_write_fixed(model_file, (char*)&n.left, sizeof(n.left), "", read, msg, text);
451 
452  msg << " right = " << n.right;
453  bin_text_read_write_fixed(model_file, (char*)&n.right, sizeof(n.right), "", read, msg, text);
454 
455  msg << " norm_Eh = " << n.norm_Eh;
456  bin_text_read_write_fixed(model_file, (char*)&n.norm_Eh, sizeof(n.norm_Eh), "", read, msg, text);
457 
458  msg << " Eh = " << n.Eh;
459  bin_text_read_write_fixed(model_file, (char*)&n.Eh, sizeof(n.Eh), "", read, msg, text);
460 
461  msg << " n = " << n.n << "\n";
462  bin_text_read_write_fixed(model_file, (char*)&n.n, sizeof(n.n), "", read, msg, text);
463  }
464  else
465  {
466  msg << " max_count = " << n.max_count;
467  bin_text_read_write_fixed(model_file, (char*)&n.max_count, sizeof(n.max_count), "", read, msg, text);
468  msg << " max_count_label = " << n.max_count_label << "\n";
470  model_file, (char*)&n.max_count_label, sizeof(n.max_count_label), "", read, msg, text);
471  }
472 
473  for (size_t k = 0; k < n.preds.size(); k++)
474  {
475  node_pred& p = n.preds[k];
476 
477  msg << " Ehk = " << p.Ehk;
478  bin_text_read_write_fixed(model_file, (char*)&p.Ehk, sizeof(p.Ehk), "", read, msg, text);
479 
480  msg << " norm_Ehk = " << p.norm_Ehk;
481  bin_text_read_write_fixed(model_file, (char*)&p.norm_Ehk, sizeof(p.norm_Ehk), "", read, msg, text);
482 
483  msg << " nk = " << p.nk;
484  bin_text_read_write_fixed(model_file, (char*)&p.nk, sizeof(p.nk), "", read, msg, text);
485 
486  msg << " label = " << p.label;
487  bin_text_read_write_fixed(model_file, (char*)&p.label, sizeof(p.label), "", read, msg, text);
488 
489  msg << " label_count = " << p.label_count << "\n";
490  bin_text_read_write_fixed(model_file, (char*)&p.label_count, sizeof(p.label_count), "", read, msg, text);
491  }
492  }
493  }
494 }
495 
496 base_learner* log_multi_setup(options_i& options, vw& all) // learner setup
497 {
498  auto data = scoped_calloc_or_throw<log_multi>();
499  option_group_definition new_options("Logarithmic Time Multiclass Tree");
500  new_options.add(make_option("log_multi", data->k).keep().help("Use online tree for multiclass"))
501  .add(make_option("no_progress", data->progress).help("disable progressive validation"))
502  .add(make_option("swap_resistance", data->swap_resist).default_value(4).help("disable progressive validation"))
503  .add(make_option("swap_resistance", data->swap_resist)
504  .default_value(4)
505  .help("higher = more resistance to swap, default=4"));
506  options.add_and_parse(new_options);
507 
508  if (!options.was_supplied("log_multi"))
509  return nullptr;
510 
511  data->progress = !data->progress;
512 
513  std::string loss_function = "quantile";
514  float loss_parameter = 0.5;
515  delete (all.loss);
516  all.loss = getLossFunction(all, loss_function, loss_parameter);
517 
518  data->max_predictors = data->k - 1;
519  init_tree(*data.get());
520 
522  data, as_singleline(setup_base(options, all)), learn, predict, all.p, data->max_predictors);
523  l.set_save_load(save_load_tree);
524 
525  return make_base(l);
526 }
v_array< node > nodes
Definition: log_multi.cc:78
uint32_t multiclass
Definition: example.h:49
loss_function * loss
Definition: global_data.h:523
float norm_Ehk
Definition: log_multi.cc:20
uint32_t nbofswaps
Definition: log_multi.cc:86
void predict(E &ec, size_t i=0)
Definition: learner.h:169
double Ehk
Definition: log_multi.cc:19
float scalar
Definition: example.h:45
void save_node_stats(log_multi &d)
Definition: log_multi.cc:350
bool operator<(node_pred v)
Definition: log_multi.cc:34
uint32_t parent
Definition: log_multi.cc:54
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
size_t predictors_used
Definition: log_multi.cc:81
float partial_prediction
Definition: example.h:68
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
uint32_t max_count
Definition: log_multi.cc:70
node_pred(uint32_t l)
Definition: log_multi.cc:41
node init_node()
Definition: log_multi.cc:110
size_t size() const
Definition: v_array.h:68
void init_leaf(node &n)
Definition: log_multi.cc:96
void train_node(log_multi &b, single_learner &base, example &ec, uint32_t &current, uint32_t &class_index, uint32_t)
Definition: log_multi.cc:251
v_array< node_pred > preds
Definition: log_multi.cc:55
bool progress
Definition: log_multi.cc:83
uint32_t left
Definition: log_multi.cc:63
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
uint32_t max_count_label
Definition: log_multi.cc:71
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t min_count
Definition: log_multi.cc:57
float norm_Eh
Definition: log_multi.cc:65
bool operator==(node_pred v)
Definition: log_multi.cc:25
uint32_t n
Definition: log_multi.cc:67
void push_back(const T &new_ele)
Definition: v_array.h:107
double Eh
Definition: log_multi.cc:66
uint32_t right
Definition: log_multi.cc:64
v_array< int > files
Definition: io_buf.h:64
void clear()
Definition: v_array.h:88
virtual bool was_supplied(const std::string &key)=0
base_learner * log_multi_setup(options_i &options, vw &all)
Definition: log_multi.cc:496
void update_min_count(log_multi &b, uint32_t node)
Definition: log_multi.cc:144
~log_multi()
Definition: log_multi.cc:88
void display_tree_dfs(log_multi &b, const node &node, uint32_t depth)
Definition: log_multi.cc:159
uint32_t min_left_right(log_multi &b, const node &n)
Definition: log_multi.cc:128
size_t max_predictors
Definition: log_multi.cc:80
Definition: io_buf.h:54
uint32_t nk
Definition: log_multi.cc:21
uint32_t label
Definition: log_multi.cc:22
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
polylabel l
Definition: example.h:57
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
Definition: log_multi.cc:397
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t label_count
Definition: log_multi.cc:23
void init_tree(log_multi &d)
Definition: log_multi.cc:122
void learn(log_multi &b, single_learner &base, example &ec)
Definition: log_multi.cc:321
uint32_t k
Definition: log_multi.cc:76
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t swap_resist
Definition: log_multi.cc:84
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
bool internal
Definition: log_multi.cc:59
void delete_v()
Definition: v_array.h:98
size_t sum_count_dfs(log_multi &b, const node &node)
Definition: log_multi.cc:288
void learn(E &ec, size_t i=0)
Definition: learner.h:160
bool children(log_multi &b, uint32_t &current, uint32_t &class_index, uint32_t label)
Definition: log_multi.cc:178
loss_function * getLossFunction(vw &all, std::string funcName, float function_parameter)
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
void verify_min_dfs(log_multi &b, const node &node)
Definition: log_multi.cc:274
uint32_t find_switch_node(log_multi &b)
Definition: log_multi.cc:133
uint32_t descend(node &n, float prediction)
Definition: log_multi.cc:296
uint32_t base_predictor
Definition: log_multi.cc:62
size_t unique_add_sorted(const T &new_ele)
Definition: v_array.h:140
void predict(log_multi &b, single_learner &base, example &ec)
Definition: log_multi.cc:304
bool operator>(node_pred v)
Definition: log_multi.cc:27