Vowpal Wabbit
classweight.cc
Go to the documentation of this file.
1 #include <unordered_map>
2 #include "reductions.h"
3 
4 using namespace VW::config;
5 
6 namespace CLASSWEIGHTS
7 {
9 {
10  std::unordered_map<uint32_t, float> weights;
11 
12  void load_string(std::string const& source)
13  {
14  std::stringstream ss(source);
15  std::string item;
16  while (std::getline(ss, item, ','))
17  {
18  std::stringstream inner_ss(item);
19  std::string klass;
20  std::string weight;
21  std::getline(inner_ss, klass, ':');
22  std::getline(inner_ss, weight, ':');
23 
24  if (!klass.size() || !weight.size())
25  {
26  THROW("error: while parsing --classweight " << item);
27  }
28 
29  int klass_int = std::stoi(klass);
30  float weight_double = std::stof(weight);
31 
32  weights[klass_int] = weight_double;
33  }
34  }
35 
36  float get_class_weight(uint32_t klass)
37  {
38  auto got = weights.find(klass);
39  if (got == weights.end())
40  return 1.0f;
41  else
42  return got->second;
43  }
44 };
45 
46 template <bool is_learn, int pred_type>
48 {
49  switch (pred_type)
50  {
52  ec.weight *= cweights.get_class_weight((uint32_t)ec.l.simple.label);
53  break;
55  ec.weight *= cweights.get_class_weight(ec.l.multi.label);
56  break;
57  default:
58  // suppress the warning
59  break;
60  }
61 
62  if (is_learn)
63  base.learn(ec);
64  else
65  base.predict(ec);
66 }
67 } // namespace CLASSWEIGHTS
68 
69 using namespace CLASSWEIGHTS;
70 
72 {
73  std::vector<std::string> classweight_array;
74  auto cweights = scoped_calloc_or_throw<classweights>();
75  option_group_definition new_options("importance weight classes");
76  new_options.add(make_option("classweight", classweight_array).help("importance weight multiplier for class"));
77  options.add_and_parse(new_options);
78 
79  if (!options.was_supplied("classweight"))
80  return nullptr;
81 
82  for (auto& s : classweight_array) cweights->load_string(s);
83 
84  if (!all.quiet)
85  all.trace_message << "parsed " << cweights->weights.size() << " class weights" << std::endl;
86 
87  LEARNER::single_learner* base = as_singleline(setup_base(options, all));
88 
91  ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::scalar>,
92  predict_or_learn<false, prediction_type::scalar>);
93  else if (base->pred_type == prediction_type::multiclass)
94  ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::multiclass>,
95  predict_or_learn<false, prediction_type::multiclass>);
96  else
97  THROW("--classweight not implemented for this type of prediction");
98  return make_base(*ret);
99 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
LEARNER::base_learner * classweight_setup(options_i &options, vw &all)
Definition: classweight.cc:71
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
void load_string(std::string const &source)
Definition: classweight.cc:12
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
std::unordered_map< uint32_t, float > weights
Definition: classweight.cc:10
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
MULTICLASS::label_t multi
Definition: example.h:29
float get_class_weight(uint32_t klass)
Definition: classweight.cc:36
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
float weight
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
prediction_type::prediction_type_t pred_type
Definition: learner.h:149
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void learn(E &ec, size_t i=0)
Definition: learner.h:160
static void predict_or_learn(classweights &cweights, LEARNER::single_learner &base, example &ec)
Definition: classweight.cc:47
float weight
Definition: example.h:62
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40