Vowpal Wabbit
best_constant.cc
Go to the documentation of this file.
1 #include "best_constant.h"
2 
3 bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss)
4 {
5  if (all.sd->first_observed_label == FLT_MAX || // no non-test labels observed or function was never called
6  (all.loss == nullptr) || (all.sd == nullptr))
7  return false;
8 
9  float label1 = all.sd->first_observed_label; // observed labels might be inside [sd->Min_label, sd->Max_label], so
10  // can't use Min/Max
11  float label2 = (all.sd->second_observed_label == FLT_MAX)
12  ? 0
13  : all.sd->second_observed_label; // if only one label observed, second might be 0
14  if (label1 > label2)
15  {
16  float tmp = label1;
17  label1 = label2;
18  label2 = tmp;
19  } // as don't use min/max - make sure label1 < label2
20 
21  float label1_cnt;
22  float label2_cnt;
23 
24  if (label1 != label2)
25  {
26  label1_cnt = (float)(all.sd->weighted_labels - label2 * all.sd->weighted_labeled_examples) / (label1 - label2);
27  label2_cnt = (float)all.sd->weighted_labeled_examples - label1_cnt;
28  }
29  else
30  return false;
31 
32  if ((label1_cnt + label2_cnt) <= 0.)
33  return false;
34 
35  auto funcName = all.loss->getType();
36  if (funcName.compare("squared") == 0 || funcName.compare("Huber") == 0 || funcName.compare("classic") == 0)
37  best_constant = (float)all.sd->weighted_labels / (float)(all.sd->weighted_labeled_examples);
39  {
40  // loss functions below don't have generic formuas for constant yet.
41  return false;
42  }
43  else if (funcName.compare("hinge") == 0)
44  {
45  best_constant = label2_cnt <= label1_cnt ? -1.f : 1.f;
46  }
47  else if (funcName.compare("logistic") == 0)
48  {
49  label1 = -1.; // override {-50, 50} to get proper loss
50  label2 = 1.;
51 
52  if (label1_cnt <= 0)
53  best_constant = 1.;
54  else if (label2_cnt <= 0)
55  best_constant = -1.;
56  else
57  best_constant = log(label2_cnt / label1_cnt);
58  }
59  else if (funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0)
60  {
61  float tau = 0.5;
62 
63  if (all.options->was_supplied("quantile_tau"))
64  tau = all.options->get_typed_option<float>("quantile_tau").value();
65 
66  float q = tau * (label1_cnt + label2_cnt);
67  if (q < label2_cnt)
68  best_constant = label2;
69  else
70  best_constant = label1;
71  }
72  else
73  return false;
74 
76  {
77  best_constant_loss = (label1_cnt > 0) ? all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt : 0.0f;
78  best_constant_loss += (label2_cnt > 0) ? all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt : 0.0f;
79  best_constant_loss /= label1_cnt + label2_cnt;
80  }
81  else
82  best_constant_loss = FLT_MIN;
83 
84  return true;
85 }
loss_function * loss
Definition: global_data.h:523
VW::config::options_i * options
Definition: global_data.h:428
bool is_more_than_two_labels_observed
Definition: global_data.h:170
float first_observed_label
Definition: global_data.h:171
bool get_best_constant(vw &all, float &best_constant, float &best_constant_loss)
Definition: best_constant.cc:3
virtual float getLoss(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
virtual bool was_supplied(const std::string &key)=0
virtual std::string getType()=0
double weighted_labels
Definition: global_data.h:144
double weighted_labeled_examples
Definition: global_data.h:141
float second_observed_label
Definition: global_data.h:172