Vowpal Wabbit
Classes | Functions
ect.cc File Reference
#include <iostream>
#include <fstream>
#include <ctime>
#include <numeric>
#include "reductions.h"

Go to the source code of this file.

Classes

struct  direction
 
struct  ect
 

Functions

bool exists (v_array< size_t > db)
 
size_t final_depth (size_t eliminations)
 
bool not_empty (v_array< v_array< uint32_t >> const &tournaments)
 
void print_level (v_array< v_array< uint32_t >> const &level)
 
size_t create_circuit (ect &e, uint64_t max_label, uint64_t eliminations)
 
uint32_t ect_predict (ect &e, single_learner &base, example &ec)
 
void ect_train (ect &e, single_learner &base, example &ec)
 
void predict (ect &e, single_learner &base, example &ec)
 
void learn (ect &e, single_learner &base, example &ec)
 
base_learnerect_setup (options_i &options, vw &all)
 

Function Documentation

◆ create_circuit()

size_t create_circuit ( ect e,
uint64_t  max_label,
uint64_t  eliminations 
)

Definition at line 104 of file ect.cc.

References ect::all_levels, ect::directions, final_depth(), ect::final_nodes, id(), v_array< T >::last(), ect::last_pair, not_empty(), v_array< T >::push_back(), v_array< T >::size(), and ect::tree_height.

Referenced by ect_setup().

105 {
106  if (max_label == 1)
107  return 0;
108 
109  v_array<v_array<uint32_t>> tournaments = v_init<v_array<uint32_t>>();
110  v_array<uint32_t> t = v_init<uint32_t>();
111 
112  for (uint32_t i = 0; i < max_label; i++)
113  {
114  t.push_back(i);
115  direction d = {i, 0, 0, 0, 0, 0, false};
116  e.directions.push_back(d);
117  }
118 
119  tournaments.push_back(t);
120 
121  for (size_t i = 0; i < eliminations - 1; i++) tournaments.push_back(v_array<uint32_t>());
122 
123  e.all_levels.push_back(tournaments);
124 
125  size_t level = 0;
126 
127  uint32_t node = (uint32_t)e.directions.size();
128 
129  while (not_empty(e.all_levels[level]))
130  {
131  v_array<v_array<uint32_t>> new_tournaments = v_init<v_array<uint32_t>>();
132  tournaments = e.all_levels[level];
133 
134  for (size_t t = 0; t < tournaments.size(); t++)
135  {
136  v_array<uint32_t> empty = v_init<uint32_t>();
137  new_tournaments.push_back(empty);
138  }
139 
140  for (size_t t = 0; t < tournaments.size(); t++)
141  {
142  for (size_t j = 0; j < tournaments[t].size() / 2; j++)
143  {
144  uint32_t id = node++;
145  uint32_t left = tournaments[t][2 * j];
146  uint32_t right = tournaments[t][2 * j + 1];
147 
148  direction d = {id, t, 0, 0, left, right, false};
149  e.directions.push_back(d);
150  uint32_t direction_index = (uint32_t)e.directions.size() - 1;
151  if (e.directions[left].tournament == t)
152  e.directions[left].winner = direction_index;
153  else
154  e.directions[left].loser = direction_index;
155  if (e.directions[right].tournament == t)
156  e.directions[right].winner = direction_index;
157  else
158  e.directions[right].loser = direction_index;
159  if (e.directions[left].last)
160  e.directions[left].winner = direction_index;
161 
162  if (tournaments[t].size() == 2 && (t == 0 || tournaments[t - 1].empty()))
163  {
164  e.directions[direction_index].last = true;
165  if (t + 1 < tournaments.size())
166  new_tournaments[t + 1].push_back(id);
167  else // winner eliminated.
168  e.directions[direction_index].winner = 0;
169  e.final_nodes.push_back((uint32_t)(e.directions.size() - 1));
170  }
171  else
172  new_tournaments[t].push_back(id);
173  if (t + 1 < tournaments.size())
174  new_tournaments[t + 1].push_back(id);
175  else // loser eliminated.
176  e.directions[direction_index].loser = 0;
177  }
178  if (tournaments[t].size() % 2 == 1)
179  new_tournaments[t].push_back(tournaments[t].last());
180  }
181  e.all_levels.push_back(new_tournaments);
182  level++;
183  }
184 
185  e.last_pair = (uint32_t)((max_label - 1) * eliminations);
186 
187  if (max_label > 1)
188  e.tree_height = final_depth(eliminations);
189 
190  return e.last_pair + (eliminations - 1);
191 }
bool not_empty(v_array< v_array< uint32_t >> const &tournaments)
Definition: ect.cc:87
v_array< v_array< v_array< uint32_t > > > all_levels
Definition: ect.cc:40
size_t size() const
Definition: v_array.h:68
void push_back(const T &new_ele)
Definition: v_array.h:107
uint32_t last_pair
Definition: ect.cc:49
float id(float in)
Definition: scorer.cc:51
v_array< uint32_t > final_nodes
Definition: ect.cc:42
size_t tree_height
Definition: ect.cc:47
v_array< direction > directions
Definition: ect.cc:38
Definition: ect.cc:21
T last() const
Definition: v_array.h:57
size_t final_depth(size_t eliminations)
Definition: ect.cc:77

◆ ect_predict()

uint32_t ect_predict ( ect e,
single_learner base,
example ec 
)

Definition at line 193 of file ect.cc.

References ect::class_boundary, ect::directions, ect::errors, ect::final_nodes, id(), ect::k, example::l, ect::last_pair, LEARNER::learner< T, E >::learn(), example::pred, polyprediction::scalar, polylabel::simple, and ect::tree_height.

Referenced by predict().

194 {
195  if (e.k == (size_t)1)
196  return 1;
197 
198  uint32_t finals_winner = 0;
199 
200  // Binary final elimination tournament first
201  ec.l.simple = {FLT_MAX, 0., 0.};
202 
203  for (size_t i = e.tree_height - 1; i != (size_t)0 - 1; i--)
204  {
205  if ((finals_winner | (((size_t)1) << i)) <= e.errors)
206  {
207  // a real choice exists
208  uint32_t problem_number = e.last_pair + (finals_winner | (((uint32_t)1) << i)) - 1; // This is unique.
209 
210  base.learn(ec, problem_number);
211 
212  if (ec.pred.scalar > e.class_boundary)
213  finals_winner = finals_winner | (((size_t)1) << i);
214  }
215  }
216 
217  uint32_t id = e.final_nodes[finals_winner];
218  while (id >= e.k)
219  {
220  base.learn(ec, id - e.k);
221 
222  if (ec.pred.scalar > e.class_boundary)
223  id = e.directions[id].right;
224  else
225  id = e.directions[id].left;
226  }
227  return id + 1;
228 }
float scalar
Definition: example.h:45
label_data simple
Definition: example.h:28
uint32_t last_pair
Definition: ect.cc:49
float id(float in)
Definition: scorer.cc:51
v_array< uint32_t > final_nodes
Definition: ect.cc:42
polylabel l
Definition: example.h:57
float class_boundary
Definition: ect.cc:36
size_t tree_height
Definition: ect.cc:47
v_array< direction > directions
Definition: ect.cc:38
uint64_t k
Definition: ect.cc:34
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
uint64_t errors
Definition: ect.cc:35

◆ ect_setup()

base_learner* ect_setup ( options_i options,
vw all 
)

Definition at line 340 of file ect.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), create_circuit(), LEARNER::init_multiclass_learner(), learn(), LEARNER::make_base(), VW::config::make_option(), vw::p, predict(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

341 {
342  auto data = scoped_calloc_or_throw<ect>();
343  std::string link;
344  option_group_definition new_options("Error Correcting Tournament Options");
345  new_options.add(make_option("ect", data->k).keep().help("Error correcting tournament with <k> labels"))
346  .add(make_option("error", data->errors).keep().default_value(0).help("errors allowed by ECT"))
347  // Used to check value. TODO replace
348  .add(make_option("link", link)
349  .default_value("identity")
350  .keep()
351  .help("Specify the link function: identity, logistic, glf1 or poisson"));
352  options.add_and_parse(new_options);
353 
354  if (!options.was_supplied("ect"))
355  return nullptr;
356 
357  size_t wpp = create_circuit(*data.get(), data->k, data->errors + 1);
358 
359  base_learner* base = setup_base(options, all);
360  if (link == "logistic")
361  data->class_boundary = 0.5; // as --link=logistic maps predictions in [0;1]
362 
364 
365  return make_base(l);
366 }
void learn(ect &e, single_learner &base, example &ec)
Definition: ect.cc:328
void predict(ect &e, single_learner &base, example &ec)
Definition: ect.cc:319
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
virtual bool was_supplied(const std::string &key)=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
size_t create_circuit(ect &e, uint64_t max_label, uint64_t eliminations)
Definition: ect.cc:104
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ ect_train()

void ect_train ( ect e,
single_learner base,
example ec 
)

Definition at line 230 of file ect.cc.

References v_array< T >::begin(), ect::class_boundary, v_array< T >::clear(), ect::directions, v_array< T >::empty(), v_array< T >::end(), id(), label_data::initial, ect::k, example::l, label_data::label, MULTICLASS::label_t::label, v_array< T >::last(), ect::last_pair, LEARNER::learner< T, E >::learn(), label_type::mc, polylabel::multi, example::pred, v_array< T >::push_back(), polyprediction::scalar, polylabel::simple, v_array< T >::size(), ect::tournaments_won, ect::tree_height, label_data::weight, and example::weight.

Referenced by learn().

231 {
232  if (e.k == 1) // nothing to do
233  return;
235 
236  label_data simple_temp;
237 
238  simple_temp.initial = 0.;
239 
241 
242  uint32_t id = e.directions[mc.label - 1].winner;
243  bool left = e.directions[id].left == mc.label - 1;
244  do
245  {
246  if (left)
247  simple_temp.label = -1;
248  else
249  simple_temp.label = 1;
250 
251  ec.l.simple = simple_temp;
252  base.learn(ec, id - e.k);
253  float old_weight = ec.weight;
254  ec.weight = 0.;
255  base.learn(ec, id - e.k); // inefficient, we should extract final prediction exactly.
256  ec.weight = old_weight;
257 
258  bool won = (ec.pred.scalar - e.class_boundary) * simple_temp.label > 0;
259 
260  if (won)
261  {
262  if (!e.directions[id].last)
263  left = e.directions[e.directions[id].winner].left == id;
264  else
265  e.tournaments_won.push_back(true);
266  id = e.directions[id].winner;
267  }
268  else
269  {
270  if (!e.directions[id].last)
271  {
272  left = e.directions[e.directions[id].loser].left == id;
273  if (e.directions[id].loser == 0)
274  e.tournaments_won.push_back(false);
275  }
276  else
277  e.tournaments_won.push_back(false);
278  id = e.directions[id].loser;
279  }
280  } while (id != 0);
281 
282  if (e.tournaments_won.empty())
283  std::cout << "badness!" << std::endl;
284 
285  // tournaments_won is a bit vector determining which tournaments the label won.
286  for (size_t i = 0; i < e.tree_height; i++)
287  {
288  for (uint32_t j = 0; j < e.tournaments_won.size() / 2; j++)
289  {
290  bool left = e.tournaments_won[j * 2];
291  bool right = e.tournaments_won[j * 2 + 1];
292  if (left == right) // no query to do
293  e.tournaments_won[j] = left;
294  else // query to do
295  {
296  if (left)
297  simple_temp.label = -1;
298  else
299  simple_temp.label = 1;
300  simple_temp.weight = (float)(1 << (e.tree_height - i - 1));
301  ec.l.simple = simple_temp;
302 
303  uint32_t problem_number = e.last_pair + j * (1 << (i + 1)) + (1 << i) - 1;
304 
305  base.learn(ec, problem_number);
306 
307  if (ec.pred.scalar > e.class_boundary)
308  e.tournaments_won[j] = right;
309  else
310  e.tournaments_won[j] = left;
311  }
312  if (e.tournaments_won.size() % 2 == 1)
314  e.tournaments_won.end() = e.tournaments_won.begin() + (1 + e.tournaments_won.size()) / 2;
315  }
316  }
317 }
float scalar
Definition: example.h:45
v_array< bool > tournaments_won
Definition: ect.cc:51
float weight
Definition: simple_label.h:15
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
MULTICLASS::label_t multi
Definition: example.h:29
void push_back(const T &new_ele)
Definition: v_array.h:107
uint32_t last_pair
Definition: ect.cc:49
float id(float in)
Definition: scorer.cc:51
void clear()
Definition: v_array.h:88
float initial
Definition: simple_label.h:16
T *& end()
Definition: v_array.h:43
polylabel l
Definition: example.h:57
float class_boundary
Definition: ect.cc:36
bool empty() const
Definition: v_array.h:59
size_t tree_height
Definition: ect.cc:47
v_array< direction > directions
Definition: ect.cc:38
uint64_t k
Definition: ect.cc:34
T last() const
Definition: v_array.h:57
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62

◆ exists()

bool exists ( v_array< size_t >  db)

Definition at line 69 of file ect.cc.

70 {
71  for (unsigned long i : db)
72  if (i != 0)
73  return true;
74  return false;
75 }

◆ final_depth()

size_t final_depth ( size_t  eliminations)

Definition at line 77 of file ect.cc.

Referenced by create_circuit().

78 {
79  eliminations--;
80  for (size_t i = 0; i < 32; i++)
81  if (eliminations >> i == 0)
82  return i;
83  std::cerr << "too many eliminations" << std::endl;
84  return 31;
85 }

◆ learn()

void learn ( ect e,
single_learner base,
example ec 
)

Definition at line 328 of file ect.cc.

References ect_train(), example::l, MULTICLASS::label_t::label, label_type::mc, polylabel::multi, polyprediction::multiclass, example::pred, and predict().

Referenced by ect_setup().

329 {
331  predict(e, base, ec);
332  uint32_t pred = ec.pred.multiclass;
333 
334  if (mc.label != (uint32_t)-1)
335  ect_train(e, base, ec);
336  ec.l.multi = mc;
337  ec.pred.multiclass = pred;
338 }
uint32_t multiclass
Definition: example.h:49
void predict(ect &e, single_learner &base, example &ec)
Definition: ect.cc:319
MULTICLASS::label_t multi
Definition: example.h:29
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
void ect_train(ect &e, single_learner &base, example &ec)
Definition: ect.cc:230

◆ not_empty()

bool not_empty ( v_array< v_array< uint32_t >> const &  tournaments)

Definition at line 87 of file ect.cc.

Referenced by create_circuit().

88 {
89  auto const first_non_empty_tournament = std::find_if(
90  tournaments.cbegin(), tournaments.cend(), [](v_array<uint32_t>& tournament) { return !tournament.empty(); });
91  return first_non_empty_tournament != tournaments.cend();
92 }
T * cbegin() const
Definition: v_array.h:48
T * cend() const
Definition: v_array.h:49
bool empty() const
Definition: v_array.h:59

◆ predict()

void predict ( ect e,
single_learner base,
example ec 
)

Definition at line 319 of file ect.cc.

References ect_predict(), ect::k, example::l, MULTICLASS::label_t::label, label_type::mc, polylabel::multi, polyprediction::multiclass, and example::pred.

Referenced by ect_setup(), and learn().

320 {
322  if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
323  std::cout << "label " << mc.label << " is not in {1," << e.k << "} This won't work right." << std::endl;
324  ec.pred.multiclass = ect_predict(e, base, ec);
325  ec.l.multi = mc;
326 }
uint32_t multiclass
Definition: example.h:49
MULTICLASS::label_t multi
Definition: example.h:29
polylabel l
Definition: example.h:57
uint64_t k
Definition: ect.cc:34
polyprediction pred
Definition: example.h:60
uint32_t ect_predict(ect &e, single_learner &base, example &ec)
Definition: ect.cc:193

◆ print_level()

void print_level ( v_array< v_array< uint32_t >> const &  level)

Definition at line 94 of file ect.cc.

95 {
96  for (auto const& t : level)
97  {
98  for (auto i : t) std::cout << " " << i;
99  std::cout << " | ";
100  }
101  std::cout << std::endl;
102 }