Vowpal Wabbit
scorer.cc
Go to the documentation of this file.
1 #include <cfloat>
2 #include "correctedMath.h"
3 #include "reductions.h"
4 #include "vw_exception.h"
5 
6 using namespace VW::config;
7 
8 struct scorer
9 {
10  vw* all;
11 }; // for set_minmax, loss
12 
13 template <bool is_learn, float (*link)(float in)>
15 {
16  s.all->set_minmax(s.all->sd, ec.l.simple.label);
17  if (is_learn && ec.l.simple.label != FLT_MAX && ec.weight > 0)
18  base.learn(ec);
19  else
20  base.predict(ec);
21 
22  if (ec.weight > 0 && ec.l.simple.label != FLT_MAX)
23  ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
24 
25  ec.pred.scalar = link(ec.pred.scalar);
26 }
27 
28 template <float (*link)(float in)>
29 inline void multipredict(scorer&, LEARNER::single_learner& base, example& ec, size_t count, size_t,
30  polyprediction* pred, bool finalize_predictions)
31 {
32  base.multipredict(ec, 0, count, pred, finalize_predictions); // TODO: need to thread step through???
33  for (size_t c = 0; c < count; c++) pred[c].scalar = link(pred[c].scalar);
34 }
35 
37 {
38  s.all->set_minmax(s.all->sd, ec.l.simple.label);
39  base.update(ec);
40 }
41 
42 // y = f(x) -> [0, 1]
43 inline float logistic(float in) { return 1.f / (1.f + correctedExp(-in)); }
44 
45 // http://en.wikipedia.org/wiki/Generalized_logistic_curve
46 // where the lower & upper asymptotes are -1 & 1 respectively
47 // 'glf1' stands for 'Generalized Logistic Function with [-1,1] range'
48 // y = f(x) -> [-1, 1]
49 inline float glf1(float in) { return 2.f / (1.f + correctedExp(-in)) - 1.f; }
50 
51 inline float id(float in) { return in; }
52 
54 {
55  auto s = scoped_calloc_or_throw<scorer>();
56  std::string link;
57  option_group_definition new_options("scorer options");
58  new_options.add(make_option("link", link)
59  .default_value("identity")
60  .keep()
61  .help("Specify the link function: identity, logistic, glf1 or poisson"));
62  options.add_and_parse(new_options);
63 
64  // This always returns a base_learner.
65 
66  s->all = &all;
67 
68  auto base = as_singleline(setup_base(options, all));
70  void (*multipredict_f)(scorer&, LEARNER::single_learner&, example&, size_t, size_t, polyprediction*, bool) =
71  multipredict<id>;
72 
73  if (link == "identity")
74  l = &init_learner(s, base, predict_or_learn<true, id>, predict_or_learn<false, id>);
75  else if (link == "logistic")
76  {
77  l = &init_learner(s, base, predict_or_learn<true, logistic>, predict_or_learn<false, logistic>);
78  multipredict_f = multipredict<logistic>;
79  }
80  else if (link == "glf1")
81  {
82  l = &init_learner(s, base, predict_or_learn<true, glf1>, predict_or_learn<false, glf1>);
83  multipredict_f = multipredict<glf1>;
84  }
85  else if (link == "poisson")
86  {
87  l = &init_learner(s, base, predict_or_learn<true, expf>, predict_or_learn<false, expf>);
88  multipredict_f = multipredict<expf>;
89  }
90  else
91  THROW("Unknown link function: " << link);
92 
93  l->set_multipredict(multipredict_f);
94  l->set_update(update);
96 
97  return make_base(*all.scorer);
98 }
void set_multipredict(void(*u)(T &, L &, E &, size_t, size_t, polyprediction *, bool))
Definition: learner.h:217
#define correctedExp
Definition: correctedMath.h:27
void set_update(void(*u)(T &data, L &base, E &))
Definition: learner.h:231
loss_function * loss
Definition: global_data.h:523
Definition: scorer.cc:8
void predict(E &ec, size_t i=0)
Definition: learner.h:169
LEARNER::base_learner * scorer_setup(options_i &options, vw &all)
Definition: scorer.cc:53
float scalar
Definition: example.h:45
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
virtual float getLoss(shared_data *, float prediction, float label)=0
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
float id(float in)
Definition: scorer.cc:51
void multipredict(scorer &, LEARNER::single_learner &base, example &ec, size_t count, size_t, polyprediction *pred, bool finalize_predictions)
Definition: scorer.cc:29
void update(scorer &s, LEARNER::single_learner &base, example &ec)
Definition: scorer.cc:36
float glf1(float in)
Definition: scorer.cc:49
LEARNER::single_learner * scorer
Definition: global_data.h:384
float loss
Definition: example.h:70
option_group_definition & add(T &&op)
Definition: options.h:90
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
vw * all
Definition: scorer.cc:10
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict_or_learn(scorer &s, LEARNER::single_learner &base, example &ec)
Definition: scorer.cc:14
polyprediction pred
Definition: example.h:60
void update(E &ec, size_t i=0)
Definition: learner.h:222
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
float logistic(float in)
Definition: scorer.cc:43