Vowpal Wabbit
confidence.cc
Go to the documentation of this file.
1 #include "reductions.h"
2 #include "vw.h"
3 #include "math.h"
4 
5 using namespace LEARNER;
6 
7 using namespace VW::config;
8 
9 struct confidence
10 {
11  vw* all;
12 };
13 
14 template <bool is_learn, bool is_confidence_after_training>
16 {
17  float threshold = 0.f;
18  float sensitivity = 0.f;
19 
20  float existing_label = ec.l.simple.label;
21  if (existing_label == FLT_MAX)
22  {
23  base.predict(ec);
24  float opposite_label = 1.f;
25  if (ec.pred.scalar > 0)
26  opposite_label = -1.f;
27  ec.l.simple.label = opposite_label;
28  }
29 
30  if (!is_confidence_after_training)
31  sensitivity = base.sensitivity(ec);
32 
33  ec.l.simple.label = existing_label;
34  if (is_learn)
35  base.learn(ec);
36  else
37  base.predict(ec);
38 
39  if (is_confidence_after_training)
40  sensitivity = base.sensitivity(ec);
41 
42  ec.confidence = fabsf(ec.pred.scalar - threshold) / sensitivity;
43 }
44 
45 void confidence_print_result(int f, float res, float confidence, v_array<char> tag)
46 {
47  if (f >= 0)
48  {
49  std::stringstream ss;
50  ss << std::fixed << res << " " << confidence;
51  if (!print_tag(ss, tag))
52  ss << ' ';
53  ss << '\n';
54  ssize_t len = ss.str().size();
55  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
56  if (t != len)
57  std::cerr << "write error: " << strerror(errno) << std::endl;
58  }
59 }
60 
62 {
63  label_data& ld = ec.l.simple;
64 
65  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
66  if (ld.label != FLT_MAX && !ec.test_only)
67  all.sd->weighted_labels += ld.label * ec.weight;
68  all.sd->weighted_unlabeled_examples += ld.label == FLT_MAX ? ec.weight : 0;
69 
70  all.print(all.raw_prediction, ec.partial_prediction, -1, ec.tag);
71  for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
72  {
73  int f = (int)all.final_prediction_sink[i];
75  }
76 
77  print_update(all, ec);
78 }
79 
81 {
83  VW::finish_example(all, ec);
84 }
85 
87 {
88  bool confidence_arg = false;
89  bool confidence_after_training = false;
90  option_group_definition new_options("Confidence");
91  new_options.add(make_option("confidence", confidence_arg).keep().help("Get confidence for binary predictions"))
92  .add(make_option("confidence_after_training", confidence_after_training).help("Confidence after training"));
93  options.add_and_parse(new_options);
94 
95  if (!confidence_arg)
96  return nullptr;
97 
98  if (!all.training)
99  {
100  std::cout
101  << "Confidence does not work in test mode because learning algorithm state is needed. Use --save_resume when "
102  "saving the model and avoid --test_only"
103  << std::endl;
104  return nullptr;
105  }
106 
107  auto data = scoped_calloc_or_throw<confidence>();
108  data->all = &all;
109 
110  void (*learn_with_confidence_ptr)(confidence&, single_learner&, example&) = nullptr;
111  void (*predict_with_confidence_ptr)(confidence&, single_learner&, example&) = nullptr;
112 
113  if (confidence_after_training)
114  {
115  learn_with_confidence_ptr = predict_or_learn_with_confidence<true, true>;
116  predict_with_confidence_ptr = predict_or_learn_with_confidence<false, true>;
117  }
118  else
119  {
120  learn_with_confidence_ptr = predict_or_learn_with_confidence<true, false>;
121  predict_with_confidence_ptr = predict_or_learn_with_confidence<false, false>;
122  }
123 
124  // Create new learner
126  data, as_singleline(setup_base(options, all)), learn_with_confidence_ptr, predict_with_confidence_ptr);
127 
129 
130  return make_base(l);
131 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
double weighted_unlabeled_examples
Definition: global_data.h:143
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
v_array< int > final_prediction_sink
Definition: global_data.h:518
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float confidence
Definition: example.h:72
float partial_prediction
Definition: example.h:68
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
bool training
Definition: global_data.h:488
size_t size() const
Definition: v_array.h:68
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
void confidence_print_result(int f, float res, float confidence, v_array< char > tag)
Definition: confidence.cc:45
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void predict_or_learn_with_confidence(confidence &, single_learner &base, example &ec)
Definition: confidence.cc:15
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
float sensitivity(baseline &data, base_learner &base, example &ec)
Definition: baseline.cc:168
double weighted_labels
Definition: global_data.h:144
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
option_group_definition & add(T &&op)
Definition: options.h:90
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void return_confidence_example(vw &all, confidence &, example &ec)
Definition: confidence.cc:80
void output_and_account_confidence_example(vw &all, example &ec)
Definition: confidence.cc:61
base_learner * confidence_setup(options_i &options, vw &all)
Definition: confidence.cc:86
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76