Vowpal Wabbit
multilabel_oaa.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <sstream>
7 #include <float.h>
8 #include "reductions.h"
9 #include "vw.h"
10 
11 using namespace VW::config;
12 
13 struct multi_oaa
14 {
15  size_t k;
16 };
17 
18 template <bool is_learn>
20 {
23  preds.label_v.clear();
24 
25  ec.l.simple = {FLT_MAX, 1.f, 0.f};
26  uint32_t multilabel_index = 0;
27  for (uint32_t i = 0; i < o.k; i++)
28  {
29  if (is_learn)
30  {
31  ec.l.simple.label = -1.f;
32  if (multilabels.label_v.size() > multilabel_index && multilabels.label_v[multilabel_index] == i)
33  {
34  ec.l.simple.label = 1.f;
35  multilabel_index++;
36  }
37  base.learn(ec, i);
38  }
39  else
40  base.predict(ec, i);
41  if (ec.pred.scalar > 0.)
42  preds.label_v.push_back(i);
43  }
44  if (is_learn && multilabel_index < multilabels.label_v.size())
45  std::cout << "label " << multilabels.label_v[multilabel_index] << " is not in {0," << o.k - 1
46  << "} This won't work right." << std::endl;
47 
48  ec.pred.multilabels = preds;
49  ec.l.multilabels = multilabels;
50 }
51 
53 {
55  VW::finish_example(all, ec);
56 }
57 
59 {
60  auto data = scoped_calloc_or_throw<multi_oaa>();
61  option_group_definition new_options("Multilabel One Against All");
62  new_options.add(make_option("multilabel_oaa", data->k).keep().help("One-against-all multilabel with <k> labels"));
63  options.add_and_parse(new_options);
64 
65  if (!options.was_supplied("multilabel_oaa"))
66  return nullptr;
67 
69  predict_or_learn<true>, predict_or_learn<false>, data->k, prediction_type::multilabels);
74 
75  return make_base(l);
76 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
LEARNER::base_learner * multilabel_oaa_setup(options_i &options, vw &all)
void(* delete_prediction)(void *)
Definition: global_data.h:485
float scalar
Definition: example.h:45
void output_example(vw &all, example &ec)
Definition: multilabel.cc:140
void(* delete_label)(void *)
Definition: label_parser.h:16
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
size_t size() const
Definition: v_array.h:68
void predict_or_learn(multi_oaa &o, LEARNER::single_learner &base, example &ec)
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
void clear()
Definition: v_array.h:88
virtual bool was_supplied(const std::string &key)=0
void finish_example(vw &, example &)
Definition: parser.cc:881
polylabel l
Definition: example.h:57
MULTILABEL::labels multilabels
Definition: example.h:50
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
MULTILABEL::labels multilabels
Definition: example.h:34
label_parser multilabel
Definition: multilabel.cc:118
v_array< uint32_t > label_v
Definition: multilabel.h:16
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
label_parser lp
Definition: parser.h:102