Vowpal Wabbit
oaa.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <sstream>
7 #include <cfloat>
8 #include <cmath>
9 #include "correctedMath.h"
10 #include "reductions.h"
11 #include "rand48.h"
12 #include "vw_exception.h"
13 #include "vw.h"
14 
15 using namespace VW::config;
16 
17 struct oaa
18 {
19  uint64_t k;
20  vw* all; // for raw
21  polyprediction* pred; // for multipredict
22  uint64_t num_subsample; // for randomized subsampling, how many negatives to draw?
23  uint32_t* subsample_order; // for randomized subsampling, in what order should we touch classes
24  size_t subsample_id; // for randomized subsampling, where do we live in the list
25 
26  ~oaa()
27  {
28  free(pred);
29  free(subsample_order);
30  }
31 };
32 
34 {
35  MULTICLASS::label_t ld = ec.l.multi;
36  if (ld.label == 0 || (ld.label > o.k && ld.label != (uint32_t)-1))
37  std::cout << "label " << ld.label << " is not in {1," << o.k << "} This won't work right." << std::endl;
38 
39  ec.l.simple = {1., 0.f, 0.f}; // truth
40  base.learn(ec, ld.label - 1);
41 
42  size_t prediction = ld.label;
43  float best_partial_prediction = ec.partial_prediction;
44 
45  ec.l.simple.label = -1.;
46  float weight_temp = ec.weight;
47  ec.weight *= ((float)o.k) / (float)o.num_subsample;
48  size_t p = o.subsample_id;
49  size_t count = 0;
50  while (count < o.num_subsample)
51  {
52  uint32_t l = o.subsample_order[p];
53  p = (p + 1) % o.k;
54  if (l == ld.label - 1)
55  continue;
56  base.learn(ec, l);
57  if (ec.partial_prediction > best_partial_prediction)
58  {
59  best_partial_prediction = ec.partial_prediction;
60  prediction = l + 1;
61  }
62  count++;
63  }
64  o.subsample_id = p;
65 
66  ec.pred.multiclass = (uint32_t)prediction;
67  ec.l.multi = ld;
68  ec.weight = weight_temp;
69 }
70 
71 template <bool is_learn, bool print_all, bool scores, bool probabilities>
73 {
74  MULTICLASS::label_t mc_label_data = ec.l.multi;
75  if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
76  std::cout << "label " << mc_label_data.label << " is not in {1," << o.k << "} This won't work right." << std::endl;
77 
78  std::stringstream outputStringStream;
79  uint32_t prediction = 1;
80  v_array<float> scores_array;
81  if (scores)
82  scores_array = ec.pred.scalars;
83 
84  ec.l.simple = {FLT_MAX, 0.f, 0.f};
85  base.multipredict(ec, 0, o.k, o.pred, true);
86  for (uint32_t i = 2; i <= o.k; i++)
87  if (o.pred[i - 1].scalar > o.pred[prediction - 1].scalar)
88  prediction = i;
89 
90  if (ec.passthrough)
91  for (uint32_t i = 1; i <= o.k; i++) add_passthrough_feature(ec, i, o.pred[i - 1].scalar);
92 
93  if (is_learn)
94  {
95  for (uint32_t i = 1; i <= o.k; i++)
96  {
97  ec.l.simple = {(mc_label_data.label == i) ? 1.f : -1.f, 0.f, 0.f};
98  ec.pred.scalar = o.pred[i - 1].scalar;
99  base.update(ec, i - 1);
100  }
101  }
102 
103  if (print_all)
104  {
105  outputStringStream << "1:" << o.pred[0].scalar;
106  for (uint32_t i = 2; i <= o.k; i++) outputStringStream << ' ' << i << ':' << o.pred[i - 1].scalar;
107  o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
108  }
109 
110  if (scores)
111  {
112  scores_array.clear();
113  for (uint32_t i = 0; i < o.k; i++) scores_array.push_back(o.pred[i].scalar);
114  ec.pred.scalars = scores_array;
115 
116  if (probabilities)
117  {
118  float sum_prob = 0;
119  for (uint32_t i = 0; i < o.k; i++)
120  {
121  ec.pred.scalars[i] = 1.f / (1.f + correctedExp(-o.pred[i].scalar));
122  sum_prob += ec.pred.scalars[i];
123  }
124  float inv_sum_prob = 1.f / sum_prob;
125  for (uint32_t i = 0; i < o.k; i++) ec.pred.scalars[i] *= inv_sum_prob;
126  }
127  }
128  else
129  ec.pred.multiclass = prediction;
130 
131  ec.l.multi = mc_label_data;
132 }
133 
134 // TODO: partial code duplication with multiclass.cc:finish_example
135 template <bool probabilities>
136 void finish_example_scores(vw& all, oaa& o, example& ec)
137 {
138  // === Compute multiclass_log_loss
139  // TODO:
140  // What to do if the correct label is unknown, i.e. (uint32_t)-1?
141  // Suggestion: increase all.sd->weighted_unlabeled_examples???,
142  // but not sd.example_number, so the average loss is not influenced.
143  // What to do if the correct_class_prob==0?
144  // Suggestion: have some maximal multiclass_log_loss limit, e.g. 999.
145  float multiclass_log_loss = 999; // -log(0) = plus infinity
146  float correct_class_prob = 0;
147  if (probabilities)
148  {
149  if (ec.l.multi.label <= o.k) // prevent segmentation fault if labeĺ==(uint32_t)-1
150  correct_class_prob = ec.pred.scalars[ec.l.multi.label - 1];
151  if (correct_class_prob > 0)
152  multiclass_log_loss = -log(correct_class_prob) * ec.weight;
153  if (ec.test_only)
154  all.sd->holdout_multiclass_log_loss += multiclass_log_loss;
155  else
156  all.sd->multiclass_log_loss += multiclass_log_loss;
157  }
158  // === Compute `prediction` and zero_one_loss
159  // We have already computed `prediction` in predict_or_learn,
160  // but we cannot store it in ec.pred union because we store ec.pred.probs there.
161  uint32_t prediction = 0;
162  for (uint32_t i = 1; i < o.k; i++)
163  if (ec.pred.scalars[i] > ec.pred.scalars[prediction])
164  prediction = i;
165  prediction++; // prediction is 1-based index (not 0-based)
166  float zero_one_loss = 0;
167  if (ec.l.multi.label != prediction)
168  zero_one_loss = ec.weight;
169 
170  // === Print probabilities for all classes
171  std::ostringstream outputStringStream;
172  for (uint32_t i = 0; i < o.k; i++)
173  {
174  if (i > 0)
175  outputStringStream << ' ';
176  if (all.sd->ldict)
177  {
178  substring ss = all.sd->ldict->get(i + 1);
179  outputStringStream << std::string(ss.begin, ss.end - ss.begin);
180  }
181  else
182  outputStringStream << i + 1;
183  outputStringStream << ':' << ec.pred.scalars[i];
184  }
185  for (int sink : all.final_prediction_sink) all.print_text(sink, outputStringStream.str(), ec.tag);
186 
187  // === Report updates using zero-one loss
188  all.sd->update(ec.test_only, ec.l.multi.label != (uint32_t)-1, zero_one_loss, ec.weight, ec.num_features);
189  // Alternatively, we could report multiclass_log_loss.
190  // all.sd->update(ec.test_only, multiclass_log_loss, ec.weight, ec.num_features);
191  // Even better would be to report both losses, but this would mean to increase
192  // the number of columns and this would not fit narrow screens.
193  // So let's report (average) multiclass_log_loss only in the final resume.
194 
195  // === Print progress report
196  if (probabilities)
197  MULTICLASS::print_update_with_probability(all, ec, prediction);
198  else
199  MULTICLASS::print_update_with_score(all, ec, prediction);
200  VW::finish_example(all, ec);
201 }
202 
204 {
205  auto data = scoped_calloc_or_throw<oaa>();
206  bool probabilities = false;
207  bool scores = false;
208  option_group_definition new_options("One Against All Options");
209  new_options.add(make_option("oaa", data->k).keep().help("One-against-all multiclass with <k> labels"))
210  .add(make_option("oaa_subsample", data->num_subsample)
211  .help("subsample this number of negative examples when learning"))
212  .add(make_option("probabilities", probabilities).help("predict probabilites of all classes"))
213  .add(make_option("scores", scores).help("output raw scores per class"));
214  options.add_and_parse(new_options);
215 
216  if (!options.was_supplied("oaa"))
217  return nullptr;
218 
219  if (all.sd->ldict && (data->k != all.sd->ldict->getK()))
220  THROW("error: you have " << all.sd->ldict->getK() << " named labels; use that as the argument to oaa")
221 
222  data->all = &all;
223  data->pred = calloc_or_throw<polyprediction>(data->k);
224  data->subsample_order = nullptr;
225  data->subsample_id = 0;
226  if (data->num_subsample > 0)
227  {
228  if (data->num_subsample >= data->k)
229  {
230  data->num_subsample = 0;
231  all.trace_message << "oaa is turning off subsampling because your parameter >= K" << std::endl;
232  }
233  else
234  {
235  data->subsample_order = calloc_or_throw<uint32_t>(data->k);
236  for (size_t i = 0; i < data->k; i++) data->subsample_order[i] = (uint32_t)i;
237  for (size_t i = 0; i < data->k; i++)
238  {
239  size_t j = (size_t)(all.get_random_state()->get_and_update_random() * (float)(data->k - i)) + i;
240  uint32_t tmp = data->subsample_order[i];
241  data->subsample_order[i] = data->subsample_order[j];
242  data->subsample_order[j] = tmp;
243  }
244  }
245  }
246 
247  oaa* data_ptr = data.get();
249  auto base = as_singleline(setup_base(options, all));
250  if (probabilities || scores)
251  {
253  if (probabilities)
254  {
255  auto loss_function_type = all.loss->getType();
256  if (loss_function_type != "logistic")
257  all.trace_message << "WARNING: --probabilities should be used only with --loss_function=logistic" << std::endl;
258  // the three boolean template parameters are: is_learn, print_all and scores
259  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, true>,
260  predict_or_learn<false, false, true, true>, all.p, data->k, prediction_type::scalars);
261  all.sd->report_multiclass_log_loss = true;
262  l->set_finish_example(finish_example_scores<true>);
263  }
264  else
265  {
266  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, false>,
267  predict_or_learn<false, false, true, false>, all.p, data->k, prediction_type::scalars);
268  l->set_finish_example(finish_example_scores<false>);
269  }
270  }
271  else if (all.raw_prediction > 0)
272  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, true, false, false>,
273  predict_or_learn<false, true, false, false>, all.p, data->k, prediction_type::multiclass);
274  else
275  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, false, false>,
276  predict_or_learn<false, false, false, false>, all.p, data->k, prediction_type::multiclass);
277 
278  if (data_ptr->num_subsample > 0)
279  {
281  l->set_finish_example(MULTICLASS::finish_example_without_loss<oaa>);
282  }
283 
284  return make_base(*l);
285 }
Definition: oaa.cc:17
bool report_multiclass_log_loss
Definition: global_data.h:166
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
#define correctedExp
Definition: correctedMath.h:27
uint32_t multiclass
Definition: example.h:49
loss_function * loss
Definition: global_data.h:523
void(* delete_prediction)(void *)
Definition: global_data.h:485
void predict_or_learn(oaa &o, LEARNER::single_learner &base, example &ec)
Definition: oaa.cc:72
void set_learn(void(*u)(T &, L &, E &))
Definition: learner.h:212
float scalar
Definition: example.h:45
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
LEARNER::base_learner * oaa_setup(options_i &options, vw &all)
Definition: oaa.cc:203
void finish_example_scores(vw &all, oaa &o, example &ec)
Definition: oaa.cc:136
v_array< int > final_prediction_sink
Definition: global_data.h:518
namedlabels * ldict
Definition: global_data.h:153
double holdout_multiclass_log_loss
Definition: global_data.h:168
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float partial_prediction
Definition: example.h:68
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void print_update_with_probability(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:149
#define add_passthrough_feature(ec, i, x)
Definition: example.h:119
void print_update_with_score(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:153
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
double multiclass_log_loss
Definition: global_data.h:167
MULTICLASS::label_t multi
Definition: example.h:29
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
polyprediction * pred
Definition: oaa.cc:21
void delete_scalars(void *v)
Definition: example.h:37
void clear()
Definition: v_array.h:88
vw_ostream trace_message
Definition: global_data.h:424
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
virtual std::string getType()=0
uint64_t num_subsample
Definition: oaa.cc:22
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void finish_example(vw &, example &)
Definition: parser.cc:881
uint64_t k
Definition: oaa.cc:19
vw * all
Definition: oaa.cc:20
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
features * passthrough
Definition: example.h:74
uint32_t * subsample_order
Definition: oaa.cc:23
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t getK()
Definition: global_data.h:106
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178
uint64_t get(substring &s)
Definition: global_data.h:108
~oaa()
Definition: oaa.cc:26
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void update(E &ec, size_t i=0)
Definition: learner.h:222
void learn(E &ec, size_t i=0)
Definition: learner.h:160
size_t subsample_id
Definition: oaa.cc:24
void learn_randomized(oaa &o, LEARNER::single_learner &base, example &ec)
Definition: oaa.cc:33
float weight
Definition: example.h:62
v_array< float > scalars
Definition: example.h:46
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76