Vowpal Wabbit
Classes | Functions
scorer.cc File Reference
#include <cfloat>
#include "correctedMath.h"
#include "reductions.h"
#include "vw_exception.h"

Go to the source code of this file.

Classes

struct  scorer
 

Functions

template<bool is_learn, float(*)(float in) link>
void predict_or_learn (scorer &s, LEARNER::single_learner &base, example &ec)
 
template<float(*)(float in) link>
void multipredict (scorer &, LEARNER::single_learner &base, example &ec, size_t count, size_t, polyprediction *pred, bool finalize_predictions)
 
void update (scorer &s, LEARNER::single_learner &base, example &ec)
 
float logistic (float in)
 
float glf1 (float in)
 
float id (float in)
 
LEARNER::base_learnerscorer_setup (options_i &options, vw &all)
 

Function Documentation

◆ glf1()

float glf1 ( float  in)
inline

Definition at line 49 of file scorer.cc.

References correctedExp, and f.

49 { return 2.f / (1.f + correctedExp(-in)) - 1.f; }
#define correctedExp
Definition: correctedMath.h:27
float f
Definition: cache.cc:40

◆ id()

float id ( float  in)
inline

◆ logistic()

float logistic ( float  in)
inline

Definition at line 43 of file scorer.cc.

References correctedExp.

43 { return 1.f / (1.f + correctedExp(-in)); }
#define correctedExp
Definition: correctedMath.h:27

◆ multipredict()

template<float(*)(float in) link>
void multipredict ( scorer ,
LEARNER::single_learner base,
example ec,
size_t  count,
size_t  ,
polyprediction pred,
bool  finalize_predictions 
)
inline

Definition at line 29 of file scorer.cc.

References c, LEARNER::learner< T, E >::multipredict(), and prediction_type::scalar.

31 {
32  base.multipredict(ec, 0, count, pred, finalize_predictions); // TODO: need to thread step through???
33  for (size_t c = 0; c < count; c++) pred[c].scalar = link(pred[c].scalar);
34 }
void multipredict(E &ec, size_t lo, size_t count, polyprediction *pred, bool finalize_predictions)
Definition: learner.h:178
constexpr uint64_t c
Definition: rand48.cc:12

◆ predict_or_learn()

template<bool is_learn, float(*)(float in) link>
void predict_or_learn ( scorer s,
LEARNER::single_learner base,
example ec 
)

Definition at line 14 of file scorer.cc.

References scorer::all, loss_function::getLoss(), example::l, label_data::label, LEARNER::learner< T, E >::learn(), example::loss, vw::loss, example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, vw::sd, vw::set_minmax, polylabel::simple, and example::weight.

15 {
16  s.all->set_minmax(s.all->sd, ec.l.simple.label);
17  if (is_learn && ec.l.simple.label != FLT_MAX && ec.weight > 0)
18  base.learn(ec);
19  else
20  base.predict(ec);
21 
22  if (ec.weight > 0 && ec.l.simple.label != FLT_MAX)
23  ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
24 
25  ec.pred.scalar = link(ec.pred.scalar);
26 }
loss_function * loss
Definition: global_data.h:523
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
virtual float getLoss(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
vw * all
Definition: scorer.cc:10
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62

◆ scorer_setup()

LEARNER::base_learner* scorer_setup ( options_i options,
vw all 
)

Definition at line 53 of file scorer.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), vw::scorer, LEARNER::learner< T, E >::set_multipredict(), LEARNER::learner< T, E >::set_update(), setup_base(), THROW, and update().

Referenced by parse_reductions().

54 {
55  auto s = scoped_calloc_or_throw<scorer>();
56  std::string link;
57  option_group_definition new_options("scorer options");
58  new_options.add(make_option("link", link)
59  .default_value("identity")
60  .keep()
61  .help("Specify the link function: identity, logistic, glf1 or poisson"));
62  options.add_and_parse(new_options);
63 
64  // This always returns a base_learner.
65 
66  s->all = &all;
67 
68  auto base = as_singleline(setup_base(options, all));
70  void (*multipredict_f)(scorer&, LEARNER::single_learner&, example&, size_t, size_t, polyprediction*, bool) =
71  multipredict<id>;
72 
73  if (link == "identity")
74  l = &init_learner(s, base, predict_or_learn<true, id>, predict_or_learn<false, id>);
75  else if (link == "logistic")
76  {
77  l = &init_learner(s, base, predict_or_learn<true, logistic>, predict_or_learn<false, logistic>);
78  multipredict_f = multipredict<logistic>;
79  }
80  else if (link == "glf1")
81  {
82  l = &init_learner(s, base, predict_or_learn<true, glf1>, predict_or_learn<false, glf1>);
83  multipredict_f = multipredict<glf1>;
84  }
85  else if (link == "poisson")
86  {
87  l = &init_learner(s, base, predict_or_learn<true, expf>, predict_or_learn<false, expf>);
88  multipredict_f = multipredict<expf>;
89  }
90  else
91  THROW("Unknown link function: " << link);
92 
93  l->set_multipredict(multipredict_f);
94  l->set_update(update);
96 
97  return make_base(*all.scorer);
98 }
void set_multipredict(void(*u)(T &, L &, E &, size_t, size_t, polyprediction *, bool))
Definition: learner.h:217
void set_update(void(*u)(T &data, L &base, E &))
Definition: learner.h:231
Definition: scorer.cc:8
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void update(scorer &s, LEARNER::single_learner &base, example &ec)
Definition: scorer.cc:36
LEARNER::single_learner * scorer
Definition: global_data.h:384
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181

◆ update()

void update ( scorer s,
LEARNER::single_learner base,
example ec 
)

Definition at line 36 of file scorer.cc.

References scorer::all, example::l, label_data::label, vw::sd, vw::set_minmax, polylabel::simple, and LEARNER::learner< T, E >::update().

Referenced by scorer_setup().

37 {
38  s.all->set_minmax(s.all->sd, ec.l.simple.label);
39  base.update(ec);
40 }
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
shared_data * sd
Definition: global_data.h:375
polylabel l
Definition: example.h:57
vw * all
Definition: scorer.cc:10
void update(E &ec, size_t i=0)
Definition: learner.h:222