Vowpal Wabbit
Functions
scorer.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerscorer_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ scorer_setup()

LEARNER::base_learner* scorer_setup ( VW::config::options_i options,
vw all 
)

Definition at line 53 of file scorer.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), vw::scorer, LEARNER::learner< T, E >::set_multipredict(), LEARNER::learner< T, E >::set_update(), setup_base(), THROW, and update().

Referenced by parse_reductions().

54 {
55  auto s = scoped_calloc_or_throw<scorer>();
56  std::string link;
57  option_group_definition new_options("scorer options");
58  new_options.add(make_option("link", link)
59  .default_value("identity")
60  .keep()
61  .help("Specify the link function: identity, logistic, glf1 or poisson"));
62  options.add_and_parse(new_options);
63 
64  // This always returns a base_learner.
65 
66  s->all = &all;
67 
68  auto base = as_singleline(setup_base(options, all));
70  void (*multipredict_f)(scorer&, LEARNER::single_learner&, example&, size_t, size_t, polyprediction*, bool) =
71  multipredict<id>;
72 
73  if (link == "identity")
74  l = &init_learner(s, base, predict_or_learn<true, id>, predict_or_learn<false, id>);
75  else if (link == "logistic")
76  {
77  l = &init_learner(s, base, predict_or_learn<true, logistic>, predict_or_learn<false, logistic>);
78  multipredict_f = multipredict<logistic>;
79  }
80  else if (link == "glf1")
81  {
82  l = &init_learner(s, base, predict_or_learn<true, glf1>, predict_or_learn<false, glf1>);
83  multipredict_f = multipredict<glf1>;
84  }
85  else if (link == "poisson")
86  {
87  l = &init_learner(s, base, predict_or_learn<true, expf>, predict_or_learn<false, expf>);
88  multipredict_f = multipredict<expf>;
89  }
90  else
91  THROW("Unknown link function: " << link);
92 
93  l->set_multipredict(multipredict_f);
94  l->set_update(update);
96 
97  return make_base(*all.scorer);
98 }
void set_multipredict(void(*u)(T &, L &, E &, size_t, size_t, polyprediction *, bool))
Definition: learner.h:217
void set_update(void(*u)(T &data, L &base, E &))
Definition: learner.h:231
Definition: scorer.cc:8
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void update(scorer &s, LEARNER::single_learner &base, example &ec)
Definition: scorer.cc:36
LEARNER::single_learner * scorer
Definition: global_data.h:384
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181