Vowpal Wabbit
Classes | Functions
oaa.cc File Reference
#include <sstream>
#include <cfloat>
#include <cmath>
#include "correctedMath.h"
#include "reductions.h"
#include "rand48.h"
#include "vw_exception.h"
#include "vw.h"

Go to the source code of this file.

Classes

struct  oaa
 

Functions

void learn_randomized (oaa &o, LEARNER::single_learner &base, example &ec)
 
template<bool is_learn, bool print_all, bool scores, bool probabilities>
void predict_or_learn (oaa &o, LEARNER::single_learner &base, example &ec)
 
template<bool probabilities>
void finish_example_scores (vw &all, oaa &o, example &ec)
 
LEARNER::base_learneroaa_setup (options_i &options, vw &all)
 

Function Documentation

◆ finish_example_scores()

template<bool probabilities>
void finish_example_scores ( vw all,
oaa o,
example ec 
)

Definition at line 136 of file oaa.cc.

References substring::begin, substring::end, vw::final_prediction_sink, VW::finish_example(), namedlabels::get(), shared_data::holdout_multiclass_log_loss, oaa::k, example::l, MULTICLASS::label_t::label, shared_data::ldict, polylabel::multi, shared_data::multiclass_log_loss, example::num_features, example::pred, vw::print_text, MULTICLASS::print_update_with_probability(), MULTICLASS::print_update_with_score(), polyprediction::scalars, vw::sd, example::tag, example::test_only, shared_data::update(), and example::weight.

137 {
138  // === Compute multiclass_log_loss
139  // TODO:
140  // What to do if the correct label is unknown, i.e. (uint32_t)-1?
141  // Suggestion: increase all.sd->weighted_unlabeled_examples???,
142  // but not sd.example_number, so the average loss is not influenced.
143  // What to do if the correct_class_prob==0?
144  // Suggestion: have some maximal multiclass_log_loss limit, e.g. 999.
145  float multiclass_log_loss = 999; // -log(0) = plus infinity
146  float correct_class_prob = 0;
147  if (probabilities)
148  {
149  if (ec.l.multi.label <= o.k) // prevent segmentation fault if labeĺ==(uint32_t)-1
150  correct_class_prob = ec.pred.scalars[ec.l.multi.label - 1];
151  if (correct_class_prob > 0)
152  multiclass_log_loss = -log(correct_class_prob) * ec.weight;
153  if (ec.test_only)
154  all.sd->holdout_multiclass_log_loss += multiclass_log_loss;
155  else
156  all.sd->multiclass_log_loss += multiclass_log_loss;
157  }
158  // === Compute `prediction` and zero_one_loss
159  // We have already computed `prediction` in predict_or_learn,
160  // but we cannot store it in ec.pred union because we store ec.pred.probs there.
161  uint32_t prediction = 0;
162  for (uint32_t i = 1; i < o.k; i++)
163  if (ec.pred.scalars[i] > ec.pred.scalars[prediction])
164  prediction = i;
165  prediction++; // prediction is 1-based index (not 0-based)
166  float zero_one_loss = 0;
167  if (ec.l.multi.label != prediction)
168  zero_one_loss = ec.weight;
169 
170  // === Print probabilities for all classes
171  std::ostringstream outputStringStream;
172  for (uint32_t i = 0; i < o.k; i++)
173  {
174  if (i > 0)
175  outputStringStream << ' ';
176  if (all.sd->ldict)
177  {
178  substring ss = all.sd->ldict->get(i + 1);
179  outputStringStream << std::string(ss.begin, ss.end - ss.begin);
180  }
181  else
182  outputStringStream << i + 1;
183  outputStringStream << ':' << ec.pred.scalars[i];
184  }
185  for (int sink : all.final_prediction_sink) all.print_text(sink, outputStringStream.str(), ec.tag);
186 
187  // === Report updates using zero-one loss
188  all.sd->update(ec.test_only, ec.l.multi.label != (uint32_t)-1, zero_one_loss, ec.weight, ec.num_features);
189  // Alternatively, we could report multiclass_log_loss.
190  // all.sd->update(ec.test_only, multiclass_log_loss, ec.weight, ec.num_features);
191  // Even better would be to report both losses, but this would mean to increase
192  // the number of columns and this would not fit narrow screens.
193  // So let's report (average) multiclass_log_loss only in the final resume.
194 
195  // === Print progress report
196  if (probabilities)
197  MULTICLASS::print_update_with_probability(all, ec, prediction);
198  else
199  MULTICLASS::print_update_with_score(all, ec, prediction);
200  VW::finish_example(all, ec);
201 }
v_array< char > tag
Definition: example.h:63
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
v_array< int > final_prediction_sink
Definition: global_data.h:518
namedlabels * ldict
Definition: global_data.h:153
double holdout_multiclass_log_loss
Definition: global_data.h:168
void print_update_with_probability(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:149
void print_update_with_score(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:153
double multiclass_log_loss
Definition: global_data.h:167
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void finish_example(vw &, example &)
Definition: parser.cc:881
uint64_t k
Definition: oaa.cc:19
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
polylabel l
Definition: example.h:57
uint64_t get(substring &s)
Definition: global_data.h:108
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
v_array< float > scalars
Definition: example.h:46
bool test_only
Definition: example.h:76

◆ learn_randomized()

void learn_randomized ( oaa o,
LEARNER::single_learner base,
example ec 
)

Definition at line 33 of file oaa.cc.

References oaa::k, example::l, label_data::label, MULTICLASS::label_t::label, LEARNER::learner< T, E >::learn(), polylabel::multi, polyprediction::multiclass, oaa::num_subsample, example::partial_prediction, example::pred, polylabel::simple, oaa::subsample_id, oaa::subsample_order, and example::weight.

Referenced by oaa_setup().

34 {
35  MULTICLASS::label_t ld = ec.l.multi;
36  if (ld.label == 0 || (ld.label > o.k && ld.label != (uint32_t)-1))
37  std::cout << "label " << ld.label << " is not in {1," << o.k << "} This won't work right." << std::endl;
38 
39  ec.l.simple = {1., 0.f, 0.f}; // truth
40  base.learn(ec, ld.label - 1);
41 
42  size_t prediction = ld.label;
43  float best_partial_prediction = ec.partial_prediction;
44 
45  ec.l.simple.label = -1.;
46  float weight_temp = ec.weight;
47  ec.weight *= ((float)o.k) / (float)o.num_subsample;
48  size_t p = o.subsample_id;
49  size_t count = 0;
50  while (count < o.num_subsample)
51  {
52  uint32_t l = o.subsample_order[p];
53  p = (p + 1) % o.k;
54  if (l == ld.label - 1)
55  continue;
56  base.learn(ec, l);
57  if (ec.partial_prediction > best_partial_prediction)
58  {
59  best_partial_prediction = ec.partial_prediction;
60  prediction = l + 1;
61  }
62  count++;
63  }
64  o.subsample_id = p;
65 
66  ec.pred.multiclass = (uint32_t)prediction;
67  ec.l.multi = ld;
68  ec.weight = weight_temp;
69 }
uint32_t multiclass
Definition: example.h:49
float partial_prediction
Definition: example.h:68
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
MULTICLASS::label_t multi
Definition: example.h:29
uint64_t num_subsample
Definition: oaa.cc:22
uint64_t k
Definition: oaa.cc:19
polylabel l
Definition: example.h:57
uint32_t * subsample_order
Definition: oaa.cc:23
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
size_t subsample_id
Definition: oaa.cc:24
float weight
Definition: example.h:62

◆ oaa_setup()

LEARNER::base_learner* oaa_setup ( options_i options,
vw all 
)

Definition at line 203 of file oaa.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), vw::delete_prediction, delete_scalars(), vw::get_random_state(), namedlabels::getK(), loss_function::getType(), LEARNER::init_multiclass_learner(), shared_data::ldict, learn_randomized(), vw::loss, LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, oaa::num_subsample, vw::p, vw::raw_prediction, shared_data::report_multiclass_log_loss, prediction_type::scalars, vw::sd, LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_learn(), setup_base(), THROW, vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

204 {
205  auto data = scoped_calloc_or_throw<oaa>();
206  bool probabilities = false;
207  bool scores = false;
208  option_group_definition new_options("One Against All Options");
209  new_options.add(make_option("oaa", data->k).keep().help("One-against-all multiclass with <k> labels"))
210  .add(make_option("oaa_subsample", data->num_subsample)
211  .help("subsample this number of negative examples when learning"))
212  .add(make_option("probabilities", probabilities).help("predict probabilites of all classes"))
213  .add(make_option("scores", scores).help("output raw scores per class"));
214  options.add_and_parse(new_options);
215 
216  if (!options.was_supplied("oaa"))
217  return nullptr;
218 
219  if (all.sd->ldict && (data->k != all.sd->ldict->getK()))
220  THROW("error: you have " << all.sd->ldict->getK() << " named labels; use that as the argument to oaa")
221 
222  data->all = &all;
223  data->pred = calloc_or_throw<polyprediction>(data->k);
224  data->subsample_order = nullptr;
225  data->subsample_id = 0;
226  if (data->num_subsample > 0)
227  {
228  if (data->num_subsample >= data->k)
229  {
230  data->num_subsample = 0;
231  all.trace_message << "oaa is turning off subsampling because your parameter >= K" << std::endl;
232  }
233  else
234  {
235  data->subsample_order = calloc_or_throw<uint32_t>(data->k);
236  for (size_t i = 0; i < data->k; i++) data->subsample_order[i] = (uint32_t)i;
237  for (size_t i = 0; i < data->k; i++)
238  {
239  size_t j = (size_t)(all.get_random_state()->get_and_update_random() * (float)(data->k - i)) + i;
240  uint32_t tmp = data->subsample_order[i];
241  data->subsample_order[i] = data->subsample_order[j];
242  data->subsample_order[j] = tmp;
243  }
244  }
245  }
246 
247  oaa* data_ptr = data.get();
249  auto base = as_singleline(setup_base(options, all));
250  if (probabilities || scores)
251  {
253  if (probabilities)
254  {
255  auto loss_function_type = all.loss->getType();
256  if (loss_function_type != "logistic")
257  all.trace_message << "WARNING: --probabilities should be used only with --loss_function=logistic" << std::endl;
258  // the three boolean template parameters are: is_learn, print_all and scores
259  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, true>,
260  predict_or_learn<false, false, true, true>, all.p, data->k, prediction_type::scalars);
261  all.sd->report_multiclass_log_loss = true;
262  l->set_finish_example(finish_example_scores<true>);
263  }
264  else
265  {
266  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, false>,
267  predict_or_learn<false, false, true, false>, all.p, data->k, prediction_type::scalars);
268  l->set_finish_example(finish_example_scores<false>);
269  }
270  }
271  else if (all.raw_prediction > 0)
272  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, true, false, false>,
273  predict_or_learn<false, true, false, false>, all.p, data->k, prediction_type::multiclass);
274  else
275  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, false, false>,
276  predict_or_learn<false, false, false, false>, all.p, data->k, prediction_type::multiclass);
277 
278  if (data_ptr->num_subsample > 0)
279  {
281  l->set_finish_example(MULTICLASS::finish_example_without_loss<oaa>);
282  }
283 
284  return make_base(*l);
285 }
Definition: oaa.cc:17
bool report_multiclass_log_loss
Definition: global_data.h:166
int raw_prediction
Definition: global_data.h:519
loss_function * loss
Definition: global_data.h:523
void(* delete_prediction)(void *)
Definition: global_data.h:485
void set_learn(void(*u)(T &, L &, E &))
Definition: learner.h:212
namedlabels * ldict
Definition: global_data.h:153
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
shared_data * sd
Definition: global_data.h:375
void delete_scalars(void *v)
Definition: example.h:37
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
virtual std::string getType()=0
uint64_t num_subsample
Definition: oaa.cc:22
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t getK()
Definition: global_data.h:106
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void learn_randomized(oaa &o, LEARNER::single_learner &base, example &ec)
Definition: oaa.cc:33
#define THROW(args)
Definition: vw_exception.h:181

◆ predict_or_learn()

template<bool is_learn, bool print_all, bool scores, bool probabilities>
void predict_or_learn ( oaa o,
LEARNER::single_learner base,
example ec 
)

Definition at line 72 of file oaa.cc.

References add_passthrough_feature, oaa::all, v_array< T >::clear(), correctedExp, f, oaa::k, example::l, MULTICLASS::label_t::label, polylabel::multi, polyprediction::multiclass, LEARNER::learner< T, E >::multipredict(), example::passthrough, oaa::pred, example::pred, vw::print_text, v_array< T >::push_back(), vw::raw_prediction, polyprediction::scalar, polyprediction::scalars, polylabel::simple, example::tag, and LEARNER::learner< T, E >::update().

73 {
74  MULTICLASS::label_t mc_label_data = ec.l.multi;
75  if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
76  std::cout << "label " << mc_label_data.label << " is not in {1," << o.k << "} This won't work right." << std::endl;
77 
78  std::stringstream outputStringStream;
79  uint32_t prediction = 1;
80  v_array<float> scores_array;
81  if (scores)
82  scores_array = ec.pred.scalars;
83 
84  ec.l.simple = {FLT_MAX, 0.f, 0.f};
85  base.multipredict(ec, 0, o.k, o.pred, true);
86  for (uint32_t i = 2; i <= o.k; i++)
87  if (o.pred[i - 1].scalar > o.pred[prediction - 1].scalar)
88  prediction = i;
89 
90  if (ec.passthrough)
91  for (uint32_t i = 1; i <= o.k; i++) add_passthrough_feature(ec, i, o.pred[i - 1].scalar);
92 
93  if (is_learn)
94  {
95  for (uint32_t i = 1; i <= o.k; i++)
96  {
97  ec.l.simple = {(mc_label_data.label == i) ? 1.f : -1.f, 0.f, 0.f};
98  ec.pred.scalar = o.pred[i - 1].scalar;
99  base.update(ec, i - 1);
100  }
101  }
102 
103  if (print_all)
104  {
105  outputStringStream << "1:" << o.pred[0].scalar;
106  for (uint32_t i = 2; i <= o.k; i++) outputStringStream << ' ' << i << ':' << o.pred[i - 1].scalar;
107  o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
108  }
109 
110  if (scores)
111  {
112  scores_array.clear();
113  for (uint32_t i = 0; i < o.k; i++) scores_array.push_back(o.pred[i].scalar);
114  ec.pred.scalars = scores_array;
115 
116  if (probabilities)
117  {
118  float sum_prob = 0;
119  for (uint32_t i = 0; i < o.k; i++)
120  {
121  ec.pred.scalars[i] = 1.f / (1.f + correctedExp(-o.pred[i].scalar));
122  sum_prob += ec.pred.scalars[i];
123  }
124  float inv_sum_prob = 1.f / sum_prob;
125  for (uint32_t i = 0; i < o.k; i++) ec.pred.scalars[i] *= inv_sum_prob;
126  }
127  }
128  else
129  ec.pred.multiclass = prediction;
130 
131  ec.l.multi = mc_label_data;
132 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
#define correctedExp
Definition: correctedMath.h:27
uint32_t multiclass
Definition: example.h:49
float scalar
Definition: example.h:45
label_data simple
Definition: example.h:28
#define add_passthrough_feature(ec, i, x)
Definition: example.h:119
MULTICLASS::label_t multi
Definition: example.h:29
void push_back(const T &new_ele)
Definition: v_array.h:107
polyprediction * pred
Definition: oaa.cc:21
void clear()
Definition: v_array.h:88
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
uint64_t k
Definition: oaa.cc:19
vw * all
Definition: oaa.cc:20
polylabel l
Definition: example.h:57
features * passthrough
Definition: example.h:74
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178
polyprediction pred
Definition: example.h:60
void update(E &ec, size_t i=0)
Definition: learner.h:222
v_array< float > scalars
Definition: example.h:46
float f
Definition: cache.cc:40