Vowpal Wabbit
ect.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 /*
7  Initial implementation by Hal Daume and John Langford. Reimplementation
8  by John Langford.
9 */
10 
11 #include <iostream>
12 #include <fstream>
13 #include <ctime>
14 #include <numeric>
15 
16 #include "reductions.h"
17 
18 using namespace LEARNER;
19 using namespace VW::config;
20 
21 struct direction
22 {
23  size_t id; // unique id for node
24  size_t tournament; // unique id for node
25  uint32_t winner; // up traversal, winner
26  uint32_t loser; // up traversal, loser
27  uint32_t left; // down traversal, left
28  uint32_t right; // down traversal, right
29  bool last;
30 };
31 
32 struct ect
33 {
34  uint64_t k;
35  uint64_t errors;
37 
38  v_array<direction> directions; // The nodes of the tournament datastructure
39 
41 
42  v_array<uint32_t> final_nodes; // The final nodes of each tournament.
43 
44  v_array<size_t> up_directions; // On edge e, which node n is in the up direction?
45  v_array<size_t> down_directions; // On edge e, which node n is in the down direction?
46 
47  size_t tree_height; // The height of the final tournament.
48 
49  uint32_t last_pair;
50 
52 
53  ~ect()
54  {
55  for (auto& all_level : all_levels)
56  {
57  for (auto& t : all_level) t.delete_v();
58  all_level.delete_v();
59  }
60  all_levels.delete_v();
61  final_nodes.delete_v();
62  up_directions.delete_v();
63  directions.delete_v();
64  down_directions.delete_v();
65  tournaments_won.delete_v();
66  }
67 };
68 
70 {
71  for (unsigned long i : db)
72  if (i != 0)
73  return true;
74  return false;
75 }
76 
77 size_t final_depth(size_t eliminations)
78 {
79  eliminations--;
80  for (size_t i = 0; i < 32; i++)
81  if (eliminations >> i == 0)
82  return i;
83  std::cerr << "too many eliminations" << std::endl;
84  return 31;
85 }
86 
87 bool not_empty(v_array<v_array<uint32_t>> const& tournaments)
88 {
89  auto const first_non_empty_tournament = std::find_if(
90  tournaments.cbegin(), tournaments.cend(), [](v_array<uint32_t>& tournament) { return !tournament.empty(); });
91  return first_non_empty_tournament != tournaments.cend();
92 }
93 
95 {
96  for (auto const& t : level)
97  {
98  for (auto i : t) std::cout << " " << i;
99  std::cout << " | ";
100  }
101  std::cout << std::endl;
102 }
103 
104 size_t create_circuit(ect& e, uint64_t max_label, uint64_t eliminations)
105 {
106  if (max_label == 1)
107  return 0;
108 
109  v_array<v_array<uint32_t>> tournaments = v_init<v_array<uint32_t>>();
110  v_array<uint32_t> t = v_init<uint32_t>();
111 
112  for (uint32_t i = 0; i < max_label; i++)
113  {
114  t.push_back(i);
115  direction d = {i, 0, 0, 0, 0, 0, false};
116  e.directions.push_back(d);
117  }
118 
119  tournaments.push_back(t);
120 
121  for (size_t i = 0; i < eliminations - 1; i++) tournaments.push_back(v_array<uint32_t>());
122 
123  e.all_levels.push_back(tournaments);
124 
125  size_t level = 0;
126 
127  uint32_t node = (uint32_t)e.directions.size();
128 
129  while (not_empty(e.all_levels[level]))
130  {
131  v_array<v_array<uint32_t>> new_tournaments = v_init<v_array<uint32_t>>();
132  tournaments = e.all_levels[level];
133 
134  for (size_t t = 0; t < tournaments.size(); t++)
135  {
136  v_array<uint32_t> empty = v_init<uint32_t>();
137  new_tournaments.push_back(empty);
138  }
139 
140  for (size_t t = 0; t < tournaments.size(); t++)
141  {
142  for (size_t j = 0; j < tournaments[t].size() / 2; j++)
143  {
144  uint32_t id = node++;
145  uint32_t left = tournaments[t][2 * j];
146  uint32_t right = tournaments[t][2 * j + 1];
147 
148  direction d = {id, t, 0, 0, left, right, false};
149  e.directions.push_back(d);
150  uint32_t direction_index = (uint32_t)e.directions.size() - 1;
151  if (e.directions[left].tournament == t)
152  e.directions[left].winner = direction_index;
153  else
154  e.directions[left].loser = direction_index;
155  if (e.directions[right].tournament == t)
156  e.directions[right].winner = direction_index;
157  else
158  e.directions[right].loser = direction_index;
159  if (e.directions[left].last)
160  e.directions[left].winner = direction_index;
161 
162  if (tournaments[t].size() == 2 && (t == 0 || tournaments[t - 1].empty()))
163  {
164  e.directions[direction_index].last = true;
165  if (t + 1 < tournaments.size())
166  new_tournaments[t + 1].push_back(id);
167  else // winner eliminated.
168  e.directions[direction_index].winner = 0;
169  e.final_nodes.push_back((uint32_t)(e.directions.size() - 1));
170  }
171  else
172  new_tournaments[t].push_back(id);
173  if (t + 1 < tournaments.size())
174  new_tournaments[t + 1].push_back(id);
175  else // loser eliminated.
176  e.directions[direction_index].loser = 0;
177  }
178  if (tournaments[t].size() % 2 == 1)
179  new_tournaments[t].push_back(tournaments[t].last());
180  }
181  e.all_levels.push_back(new_tournaments);
182  level++;
183  }
184 
185  e.last_pair = (uint32_t)((max_label - 1) * eliminations);
186 
187  if (max_label > 1)
188  e.tree_height = final_depth(eliminations);
189 
190  return e.last_pair + (eliminations - 1);
191 }
192 
193 uint32_t ect_predict(ect& e, single_learner& base, example& ec)
194 {
195  if (e.k == (size_t)1)
196  return 1;
197 
198  uint32_t finals_winner = 0;
199 
200  // Binary final elimination tournament first
201  ec.l.simple = {FLT_MAX, 0., 0.};
202 
203  for (size_t i = e.tree_height - 1; i != (size_t)0 - 1; i--)
204  {
205  if ((finals_winner | (((size_t)1) << i)) <= e.errors)
206  {
207  // a real choice exists
208  uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; // This is unique.
209 
210  base.learn(ec, problem_number);
211 
212  if (ec.pred.scalar > e.class_boundary)
213  finals_winner = finals_winner | (((size_t)1) << i);
214  }
215  }
216 
217  uint32_t id = e.final_nodes[finals_winner];
218  while (id >= e.k)
219  {
220  base.learn(ec, id - e.k);
221 
222  if (ec.pred.scalar > e.class_boundary)
223  id = e.directions[id].right;
224  else
225  id = e.directions[id].left;
226  }
227  return id + 1;
228 }
229 
230 void ect_train(ect& e, single_learner& base, example& ec)
231 {
232  if (e.k == 1) // nothing to do
233  return;
235 
236  label_data simple_temp;
237 
238  simple_temp.initial = 0.;
239 
241 
242  uint32_t id = e.directions[mc.label - 1].winner;
243  bool left = e.directions[id].left == mc.label - 1;
244  do
245  {
246  if (left)
247  simple_temp.label = -1;
248  else
249  simple_temp.label = 1;
250 
251  ec.l.simple = simple_temp;
252  base.learn(ec, id - e.k);
253  float old_weight = ec.weight;
254  ec.weight = 0.;
255  base.learn(ec, id - e.k); // inefficient, we should extract final prediction exactly.
256  ec.weight = old_weight;
257 
258  bool won = (ec.pred.scalar - e.class_boundary) * simple_temp.label > 0;
259 
260  if (won)
261  {
262  if (!e.directions[id].last)
263  left = e.directions[e.directions[id].winner].left == id;
264  else
265  e.tournaments_won.push_back(true);
266  id = e.directions[id].winner;
267  }
268  else
269  {
270  if (!e.directions[id].last)
271  {
272  left = e.directions[e.directions[id].loser].left == id;
273  if (e.directions[id].loser == 0)
274  e.tournaments_won.push_back(false);
275  }
276  else
277  e.tournaments_won.push_back(false);
278  id = e.directions[id].loser;
279  }
280  } while (id != 0);
281 
282  if (e.tournaments_won.empty())
283  std::cout << "badness!" << std::endl;
284 
285  // tournaments_won is a bit vector determining which tournaments the label won.
286  for (size_t i = 0; i < e.tree_height; i++)
287  {
288  for (uint32_t j = 0; j < e.tournaments_won.size() / 2; j++)
289  {
290  bool left = e.tournaments_won[j * 2];
291  bool right = e.tournaments_won[j * 2 + 1];
292  if (left == right) // no query to do
293  e.tournaments_won[j] = left;
294  else // query to do
295  {
296  if (left)
297  simple_temp.label = -1;
298  else
299  simple_temp.label = 1;
300  simple_temp.weight = (float)(1 << (e.tree_height - i - 1));
301  ec.l.simple = simple_temp;
302 
303  uint32_t problem_number = e.last_pair + j * (1 << (i + 1)) + (1 << i) - 1;
304 
305  base.learn(ec, problem_number);
306 
307  if (ec.pred.scalar > e.class_boundary)
308  e.tournaments_won[j] = right;
309  else
310  e.tournaments_won[j] = left;
311  }
312  if (e.tournaments_won.size() % 2 == 1)
314  e.tournaments_won.end() = e.tournaments_won.begin() + (1 + e.tournaments_won.size()) / 2;
315  }
316  }
317 }
318 
319 void predict(ect& e, single_learner& base, example& ec)
320 {
322  if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
323  std::cout << "label " << mc.label << " is not in {1," << e.k << "} This won't work right." << std::endl;
324  ec.pred.multiclass = ect_predict(e, base, ec);
325  ec.l.multi = mc;
326 }
327 
328 void learn(ect& e, single_learner& base, example& ec)
329 {
331  predict(e, base, ec);
332  uint32_t pred = ec.pred.multiclass;
333 
334  if (mc.label != (uint32_t)-1)
335  ect_train(e, base, ec);
336  ec.l.multi = mc;
337  ec.pred.multiclass = pred;
338 }
339 
341 {
342  auto data = scoped_calloc_or_throw<ect>();
343  std::string link;
344  option_group_definition new_options("Error Correcting Tournament Options");
345  new_options.add(make_option("ect", data->k).keep().help("Error correcting tournament with <k> labels"))
346  .add(make_option("error", data->errors).keep().default_value(0).help("errors allowed by ECT"))
347  // Used to check value. TODO replace
348  .add(make_option("link", link)
349  .default_value("identity")
350  .keep()
351  .help("Specify the link function: identity, logistic, glf1 or poisson"));
352  options.add_and_parse(new_options);
353 
354  if (!options.was_supplied("ect"))
355  return nullptr;
356 
357  size_t wpp = create_circuit(*data.get(), data->k, data->errors + 1);
358 
359  base_learner* base = setup_base(options, all);
360  if (link == "logistic")
361  data->class_boundary = 0.5; // as --link=logistic maps predictions in [0;1]
362 
364 
365  return make_base(l);
366 }
uint32_t multiclass
Definition: example.h:49
v_array< size_t > down_directions
Definition: ect.cc:45
bool not_empty(v_array< v_array< uint32_t >> const &tournaments)
Definition: ect.cc:87
v_array< v_array< v_array< uint32_t > > > all_levels
Definition: ect.cc:40
void learn(ect &e, single_learner &base, example &ec)
Definition: ect.cc:328
float scalar
Definition: example.h:45
v_array< bool > tournaments_won
Definition: ect.cc:51
void predict(ect &e, single_learner &base, example &ec)
Definition: ect.cc:319
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float weight
Definition: simple_label.h:15
virtual void add_and_parse(const option_group_definition &group)=0
Definition: ect.cc:32
float label
Definition: simple_label.h:14
size_t tournament
Definition: ect.cc:24
label_data simple
Definition: example.h:28
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
~ect()
Definition: ect.cc:53
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
MULTICLASS::label_t multi
Definition: example.h:29
uint32_t right
Definition: ect.cc:28
void push_back(const T &new_ele)
Definition: v_array.h:107
uint32_t last_pair
Definition: ect.cc:49
float id(float in)
Definition: scorer.cc:51
void clear()
Definition: v_array.h:88
virtual bool was_supplied(const std::string &key)=0
void print_level(v_array< v_array< uint32_t >> const &level)
Definition: ect.cc:94
size_t id
Definition: ect.cc:23
v_array< uint32_t > final_nodes
Definition: ect.cc:42
float initial
Definition: simple_label.h:16
bool last
Definition: ect.cc:29
bool exists(v_array< size_t > db)
Definition: ect.cc:69
T *& end()
Definition: v_array.h:43
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
size_t create_circuit(ect &e, uint64_t max_label, uint64_t eliminations)
Definition: ect.cc:104
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float class_boundary
Definition: ect.cc:36
bool empty() const
Definition: v_array.h:59
size_t tree_height
Definition: ect.cc:47
uint32_t loser
Definition: ect.cc:26
v_array< direction > directions
Definition: ect.cc:38
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint64_t k
Definition: ect.cc:34
Definition: ect.cc:21
uint32_t left
Definition: ect.cc:27
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
T last() const
Definition: v_array.h:57
v_array< size_t > up_directions
Definition: ect.cc:44
polyprediction pred
Definition: example.h:60
void ect_train(ect &e, single_learner &base, example &ec)
Definition: ect.cc:230
void delete_v()
Definition: v_array.h:98
void learn(E &ec, size_t i=0)
Definition: learner.h:160
size_t final_depth(size_t eliminations)
Definition: ect.cc:77
uint32_t ect_predict(ect &e, single_learner &base, example &ec)
Definition: ect.cc:193
float weight
Definition: example.h:62
base_learner * ect_setup(options_i &options, vw &all)
Definition: ect.cc:340
uint32_t winner
Definition: ect.cc:25
uint64_t errors
Definition: ect.cc:35