Vowpal Wabbit
Functions
oaa.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learneroaa_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ oaa_setup()

LEARNER::base_learner* oaa_setup ( VW::config::options_i options,
vw all 
)

Definition at line 203 of file oaa.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), vw::delete_prediction, delete_scalars(), vw::get_random_state(), namedlabels::getK(), loss_function::getType(), LEARNER::init_multiclass_learner(), shared_data::ldict, learn_randomized(), vw::loss, LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, oaa::num_subsample, vw::p, vw::raw_prediction, shared_data::report_multiclass_log_loss, prediction_type::scalars, vw::sd, LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_learn(), setup_base(), THROW, vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

204 {
205  auto data = scoped_calloc_or_throw<oaa>();
206  bool probabilities = false;
207  bool scores = false;
208  option_group_definition new_options("One Against All Options");
209  new_options.add(make_option("oaa", data->k).keep().help("One-against-all multiclass with <k> labels"))
210  .add(make_option("oaa_subsample", data->num_subsample)
211  .help("subsample this number of negative examples when learning"))
212  .add(make_option("probabilities", probabilities).help("predict probabilites of all classes"))
213  .add(make_option("scores", scores).help("output raw scores per class"));
214  options.add_and_parse(new_options);
215 
216  if (!options.was_supplied("oaa"))
217  return nullptr;
218 
219  if (all.sd->ldict && (data->k != all.sd->ldict->getK()))
220  THROW("error: you have " << all.sd->ldict->getK() << " named labels; use that as the argument to oaa")
221 
222  data->all = &all;
223  data->pred = calloc_or_throw<polyprediction>(data->k);
224  data->subsample_order = nullptr;
225  data->subsample_id = 0;
226  if (data->num_subsample > 0)
227  {
228  if (data->num_subsample >= data->k)
229  {
230  data->num_subsample = 0;
231  all.trace_message << "oaa is turning off subsampling because your parameter >= K" << std::endl;
232  }
233  else
234  {
235  data->subsample_order = calloc_or_throw<uint32_t>(data->k);
236  for (size_t i = 0; i < data->k; i++) data->subsample_order[i] = (uint32_t)i;
237  for (size_t i = 0; i < data->k; i++)
238  {
239  size_t j = (size_t)(all.get_random_state()->get_and_update_random() * (float)(data->k - i)) + i;
240  uint32_t tmp = data->subsample_order[i];
241  data->subsample_order[i] = data->subsample_order[j];
242  data->subsample_order[j] = tmp;
243  }
244  }
245  }
246 
247  oaa* data_ptr = data.get();
249  auto base = as_singleline(setup_base(options, all));
250  if (probabilities || scores)
251  {
253  if (probabilities)
254  {
255  auto loss_function_type = all.loss->getType();
256  if (loss_function_type != "logistic")
257  all.trace_message << "WARNING: --probabilities should be used only with --loss_function=logistic" << std::endl;
258  // the three boolean template parameters are: is_learn, print_all and scores
259  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, true>,
260  predict_or_learn<false, false, true, true>, all.p, data->k, prediction_type::scalars);
261  all.sd->report_multiclass_log_loss = true;
262  l->set_finish_example(finish_example_scores<true>);
263  }
264  else
265  {
266  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, true, false>,
267  predict_or_learn<false, false, true, false>, all.p, data->k, prediction_type::scalars);
268  l->set_finish_example(finish_example_scores<false>);
269  }
270  }
271  else if (all.raw_prediction > 0)
272  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, true, false, false>,
273  predict_or_learn<false, true, false, false>, all.p, data->k, prediction_type::multiclass);
274  else
275  l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, false, false>,
276  predict_or_learn<false, false, false, false>, all.p, data->k, prediction_type::multiclass);
277 
278  if (data_ptr->num_subsample > 0)
279  {
281  l->set_finish_example(MULTICLASS::finish_example_without_loss<oaa>);
282  }
283 
284  return make_base(*l);
285 }
Definition: oaa.cc:17
bool report_multiclass_log_loss
Definition: global_data.h:166
int raw_prediction
Definition: global_data.h:519
loss_function * loss
Definition: global_data.h:523
void(* delete_prediction)(void *)
Definition: global_data.h:485
void set_learn(void(*u)(T &, L &, E &))
Definition: learner.h:212
namedlabels * ldict
Definition: global_data.h:153
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
shared_data * sd
Definition: global_data.h:375
void delete_scalars(void *v)
Definition: example.h:37
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
virtual std::string getType()=0
uint64_t num_subsample
Definition: oaa.cc:22
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
Definition: learner.h:437
uint32_t getK()
Definition: global_data.h:106
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void learn_randomized(oaa &o, LEARNER::single_learner &base, example &ec)
Definition: oaa.cc:33
#define THROW(args)
Definition: vw_exception.h:181