Vowpal Wabbit
Classes | Functions
confidence.cc File Reference
#include "reductions.h"
#include "vw.h"
#include "math.h"

Go to the source code of this file.

Classes

struct  confidence
 

Functions

template<bool is_learn, bool is_confidence_after_training>
void predict_or_learn_with_confidence (confidence &, single_learner &base, example &ec)
 
void confidence_print_result (int f, float res, float confidence, v_array< char > tag)
 
void output_and_account_confidence_example (vw &all, example &ec)
 
void return_confidence_example (vw &all, confidence &, example &ec)
 
base_learnerconfidence_setup (options_i &options, vw &all)
 

Function Documentation

◆ confidence_print_result()

void confidence_print_result ( int  f,
float  res,
float  confidence,
v_array< char >  tag 
)

Definition at line 45 of file confidence.cc.

References print_tag(), and io_buf::write_file_or_socket().

Referenced by output_and_account_confidence_example().

46 {
47  if (f >= 0)
48  {
49  std::stringstream ss;
50  ss << std::fixed << res << " " << confidence;
51  if (!print_tag(ss, tag))
52  ss << ' ';
53  ss << '\n';
54  ssize_t len = ss.str().size();
55  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
56  if (t != len)
57  std::cerr << "write error: " << strerror(errno) << std::endl;
58  }
59 }
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
float f
Definition: cache.cc:40

◆ confidence_setup()

base_learner* confidence_setup ( options_i options,
vw all 
)

Definition at line 86 of file confidence.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), return_confidence_example(), LEARNER::learner< T, E >::set_finish_example(), setup_base(), and vw::training.

Referenced by parse_reductions().

87 {
88  bool confidence_arg = false;
89  bool confidence_after_training = false;
90  option_group_definition new_options("Confidence");
91  new_options.add(make_option("confidence", confidence_arg).keep().help("Get confidence for binary predictions"))
92  .add(make_option("confidence_after_training", confidence_after_training).help("Confidence after training"));
93  options.add_and_parse(new_options);
94 
95  if (!confidence_arg)
96  return nullptr;
97 
98  if (!all.training)
99  {
100  std::cout
101  << "Confidence does not work in test mode because learning algorithm state is needed. Use --save_resume when "
102  "saving the model and avoid --test_only"
103  << std::endl;
104  return nullptr;
105  }
106 
107  auto data = scoped_calloc_or_throw<confidence>();
108  data->all = &all;
109 
110  void (*learn_with_confidence_ptr)(confidence&, single_learner&, example&) = nullptr;
111  void (*predict_with_confidence_ptr)(confidence&, single_learner&, example&) = nullptr;
112 
113  if (confidence_after_training)
114  {
115  learn_with_confidence_ptr = predict_or_learn_with_confidence<true, true>;
116  predict_with_confidence_ptr = predict_or_learn_with_confidence<false, true>;
117  }
118  else
119  {
120  learn_with_confidence_ptr = predict_or_learn_with_confidence<true, false>;
121  predict_with_confidence_ptr = predict_or_learn_with_confidence<false, false>;
122  }
123 
124  // Create new learner
126  data, as_singleline(setup_base(options, all)), learn_with_confidence_ptr, predict_with_confidence_ptr);
127 
129 
130  return make_base(l);
131 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
bool training
Definition: global_data.h:488
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void return_confidence_example(vw &all, confidence &, example &ec)
Definition: confidence.cc:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ output_and_account_confidence_example()

void output_and_account_confidence_example ( vw all,
example ec 
)

Definition at line 61 of file confidence.cc.

References example::confidence, confidence_print_result(), f, vw::final_prediction_sink, example::l, label_data::label, example::loss, example::num_features, example::partial_prediction, example::pred, vw::print, CB::print_update(), vw::raw_prediction, polyprediction::scalar, vw::sd, polylabel::simple, v_array< T >::size(), example::tag, example::test_only, shared_data::update(), example::weight, shared_data::weighted_labels, and shared_data::weighted_unlabeled_examples.

Referenced by return_confidence_example().

62 {
63  label_data& ld = ec.l.simple;
64 
65  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
66  if (ld.label != FLT_MAX && !ec.test_only)
67  all.sd->weighted_labels += ld.label * ec.weight;
68  all.sd->weighted_unlabeled_examples += ld.label == FLT_MAX ? ec.weight : 0;
69 
70  all.print(all.raw_prediction, ec.partial_prediction, -1, ec.tag);
71  for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
72  {
73  int f = (int)all.final_prediction_sink[i];
75  }
76 
77  print_update(all, ec);
78 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
float scalar
Definition: example.h:45
double weighted_unlabeled_examples
Definition: global_data.h:143
v_array< int > final_prediction_sink
Definition: global_data.h:518
float confidence
Definition: example.h:72
float partial_prediction
Definition: example.h:68
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
size_t size() const
Definition: v_array.h:68
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void confidence_print_result(int f, float res, float confidence, v_array< char > tag)
Definition: confidence.cc:45
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
double weighted_labels
Definition: global_data.h:144
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76

◆ predict_or_learn_with_confidence()

template<bool is_learn, bool is_confidence_after_training>
void predict_or_learn_with_confidence ( confidence ,
single_learner base,
example ec 
)

Definition at line 15 of file confidence.cc.

References example::confidence, example::l, label_data::label, LEARNER::learner< T, E >::learn(), example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, sensitivity(), LEARNER::learner< T, E >::sensitivity(), and polylabel::simple.

16 {
17  float threshold = 0.f;
18  float sensitivity = 0.f;
19 
20  float existing_label = ec.l.simple.label;
21  if (existing_label == FLT_MAX)
22  {
23  base.predict(ec);
24  float opposite_label = 1.f;
25  if (ec.pred.scalar > 0)
26  opposite_label = -1.f;
27  ec.l.simple.label = opposite_label;
28  }
29 
30  if (!is_confidence_after_training)
31  sensitivity = base.sensitivity(ec);
32 
33  ec.l.simple.label = existing_label;
34  if (is_learn)
35  base.learn(ec);
36  else
37  base.predict(ec);
38 
39  if (is_confidence_after_training)
40  sensitivity = base.sensitivity(ec);
41 
42  ec.confidence = fabsf(ec.pred.scalar - threshold) / sensitivity;
43 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
float confidence
Definition: example.h:72
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
float sensitivity(baseline &data, base_learner &base, example &ec)
Definition: baseline.cc:168
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160

◆ return_confidence_example()

void return_confidence_example ( vw all,
confidence ,
example ec 
)

Definition at line 80 of file confidence.cc.

References VW::finish_example(), and output_and_account_confidence_example().

Referenced by confidence_setup().

81 {
83  VW::finish_example(all, ec);
84 }
void finish_example(vw &, example &)
Definition: parser.cc:881
void output_and_account_confidence_example(vw &all, example &ec)
Definition: confidence.cc:61