Vowpal Wabbit
Classes | Functions
bs.cc File Reference
#include <cfloat>
#include <cmath>
#include <cerrno>
#include <sstream>
#include <numeric>
#include <vector>
#include <memory>
#include "reductions.h"
#include "rand48.h"
#include "vw.h"
#include "bs.h"
#include "vw_exception.h"

Go to the source code of this file.

Classes

struct  bs
 

Functions

void bs_predict_mean (vw &all, example &ec, std::vector< double > &pred_vec)
 
void bs_predict_vote (example &ec, std::vector< double > &pred_vec)
 
void print_result (int f, float res, v_array< char > tag, float lb, float ub)
 
void output_example (vw &all, bs &d, example &ec)
 
template<bool is_learn>
void predict_or_learn (bs &d, single_learner &base, example &ec)
 
void finish_example (vw &all, bs &d, example &ec)
 
base_learnerbs_setup (options_i &options, vw &all)
 

Function Documentation

◆ bs_predict_mean()

void bs_predict_mean ( vw all,
example ec,
std::vector< double > &  pred_vec 
)

Definition at line 36 of file bs.cc.

References accumulate(), loss_function::getLoss(), example::l, label_data::label, example::loss, vw::loss, example::pred, polyprediction::scalar, vw::sd, polylabel::simple, and example::weight.

Referenced by predict_or_learn().

37 {
38  ec.pred.scalar = (float)accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0) / pred_vec.size();
39  if (ec.weight > 0 && ec.l.simple.label != FLT_MAX)
40  ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
41 }
loss_function * loss
Definition: global_data.h:523
void accumulate(vw &all, parameters &weights, size_t offset)
Definition: accumulate.cc:20
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
virtual float getLoss(shared_data *, float prediction, float label)=0
shared_data * sd
Definition: global_data.h:375
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62

◆ bs_predict_vote()

void bs_predict_vote ( example ec,
std::vector< double > &  pred_vec 
)

Definition at line 43 of file bs.cc.

References f, example::l, label_data::label, example::loss, example::pred, polyprediction::scalar, polylabel::simple, and example::weight.

Referenced by predict_or_learn().

44 {
45  // majority vote in linear time
46  unsigned int counter = 0;
47  int current_label = 1, init_label = 1;
48  // float sum_labels = 0; // uncomment for: "avg on votes" and getLoss()
49  bool majority_found = false;
50  bool multivote_detected = false; // distinct(votes)>2: used to skip part of the algorithm
51  auto pred_vec_sz = pred_vec.size();
52  int* pred_vec_int = new int[pred_vec_sz];
53 
54  for (unsigned int i = 0; i < pred_vec_sz; i++)
55  {
56  pred_vec_int[i] = (int)floor(
57  pred_vec[i] + 0.5); // could be added: link(), min_label/max_label, cutoff between true/false for binary
58 
59  if (!multivote_detected) // distinct(votes)>2 detection bloc
60  {
61  if (i == 0)
62  {
63  init_label = pred_vec_int[i];
64  current_label = pred_vec_int[i];
65  }
66  else if (init_label != current_label && pred_vec_int[i] != current_label && pred_vec_int[i] != init_label)
67  multivote_detected = true; // more than 2 distinct votes detected
68  }
69 
70  if (counter == 0)
71  {
72  counter = 1;
73  current_label = pred_vec_int[i];
74  }
75  else
76  {
77  if (pred_vec_int[i] == current_label)
78  counter++;
79  else
80  {
81  counter--;
82  }
83  }
84  }
85 
86  if (counter > 0 && multivote_detected) // remove this condition for: "avg on votes" and getLoss()
87  {
88  counter = 0;
89  for (unsigned int i = 0; i < pred_vec.size(); i++)
90  if (pred_vec_int[i] == current_label)
91  {
92  counter++;
93  // sum_labels += pred_vec[i]; // uncomment for: "avg on votes" and getLoss()
94  }
95  if (counter * 2 > pred_vec.size())
96  majority_found = true;
97  }
98 
99  if (multivote_detected && !majority_found) // then find most frequent element - if tie: smallest tie label
100  {
101  std::sort(pred_vec_int, pred_vec_int + pred_vec.size());
102  int tmp_label = pred_vec_int[0];
103  counter = 1;
104  for (unsigned int i = 1, temp_count = 1; i < pred_vec.size(); i++)
105  {
106  if (tmp_label == pred_vec_int[i])
107  temp_count++;
108  else
109  {
110  if (temp_count > counter)
111  {
112  current_label = tmp_label;
113  counter = temp_count;
114  }
115  tmp_label = pred_vec_int[i];
116  temp_count = 1;
117  }
118  }
119  /* uncomment for: "avg on votes" and getLoss()
120  sum_labels = 0;
121  for(unsigned int i=0; i<pred_vec.size(); i++)
122  if(pred_vec_int[i] == current_label)
123  sum_labels += pred_vec[i]; */
124  }
125  // TODO: unique_ptr would also handle exception case
126  delete[] pred_vec_int;
127 
128  // ld.prediction = sum_labels/(float)counter; //replace line below for: "avg on votes" and getLoss()
129  ec.pred.scalar = (float)current_label;
130 
131  // ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ec.weight; //replace line below for: "avg on votes"
132  // and getLoss()
133  ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.weight;
134 }
float scalar
Definition: example.h:45
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
float f
Definition: cache.cc:40

◆ bs_setup()

base_learner* bs_setup ( options_i options,
vw all 
)

Definition at line 231 of file bs.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), BS_TYPE_MEAN, BS_TYPE_VOTE, finish_example(), vw::get_random_state(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), LEARNER::learner< T, E >::set_finish_example(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

232 {
233  auto data = scoped_calloc_or_throw<bs>();
234  std::string type_string("mean");
235  option_group_definition new_options("Bootstrap");
236  new_options.add(make_option("bootstrap", data->B).keep().help("k-way bootstrap by online importance resampling"))
237  .add(make_option("bs_type", type_string).keep().help("prediction type {mean,vote}"));
238  options.add_and_parse(new_options);
239 
240  if (!options.was_supplied("bootstrap"))
241  return nullptr;
242 
243  data->ub = FLT_MAX;
244  data->lb = -FLT_MAX;
245 
246  if (options.was_supplied("bs_type"))
247  {
248  if (type_string == "mean")
249  data->bs_type = BS_TYPE_MEAN;
250  else if (type_string == "vote")
251  data->bs_type = BS_TYPE_VOTE;
252  else
253  {
254  std::cerr << "warning: bs_type must be in {'mean','vote'}; resetting to mean." << std::endl;
255  data->bs_type = BS_TYPE_MEAN;
256  }
257  }
258  else // by default use mean
259  data->bs_type = BS_TYPE_MEAN;
260 
261  data->pred_vec = new std::vector<double>();
262  data->pred_vec->reserve(data->B);
263  data->all = &all;
264  data->_random_state = all.get_random_state();
265 
267  data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>, data->B);
269 
270  return make_base(l);
271 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
#define BS_TYPE_VOTE
Definition: bs.h:9
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void finish_example(vw &all, bs &d, example &ec)
Definition: bs.cc:225
#define BS_TYPE_MEAN
Definition: bs.h:8
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ finish_example()

void finish_example ( vw all,
bs d,
example ec 
)

Definition at line 225 of file bs.cc.

References VW::finish_example(), and output_example().

Referenced by bs_setup().

226 {
227  output_example(all, d, ec);
228  VW::finish_example(all, ec);
229 }
void output_example(vw &all, bs &d, example &ec)
Definition: bs.cc:151
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ output_example()

void output_example ( vw all,
bs d,
example ec 
)

Definition at line 151 of file bs.cc.

References v_array< T >::empty(), vw::final_prediction_sink, example::l, label_data::label, bs::lb, example::loss, example::num_features, example::pred, bs::pred_vec, print_result(), CB::print_update(), polyprediction::scalar, vw::sd, polylabel::simple, example::tag, example::test_only, bs::ub, shared_data::update(), example::weight, and shared_data::weighted_labels.

Referenced by finish_example().

152 {
153  label_data& ld = ec.l.simple;
154 
155  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
156  if (ld.label != FLT_MAX && !ec.test_only)
157  all.sd->weighted_labels += ((double)ld.label) * ec.weight;
158 
159  if (!all.final_prediction_sink.empty()) // get confidence interval only when printing out predictions
160  {
161  d.lb = FLT_MAX;
162  d.ub = -FLT_MAX;
163  for (double v : *d.pred_vec)
164  {
165  if (v > d.ub)
166  d.ub = (float)v;
167  if (v < d.lb)
168  d.lb = (float)v;
169  }
170  }
171 
172  for (int sink : all.final_prediction_sink) print_result(sink, ec.pred.scalar, ec.tag, d.lb, d.ub);
173 
174  print_update(all, ec);
175 }
float ub
Definition: bs.cc:28
v_array< char > tag
Definition: example.h:63
float scalar
Definition: example.h:45
v_array< int > final_prediction_sink
Definition: global_data.h:518
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
double weighted_labels
Definition: global_data.h:144
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
void print_result(int f, float res, v_array< char > tag, float lb, float ub)
Definition: bs.cc:136
bool empty() const
Definition: v_array.h:59
float lb
Definition: bs.cc:27
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
std::vector< double > * pred_vec
Definition: bs.cc:29
bool test_only
Definition: example.h:76

◆ predict_or_learn()

template<bool is_learn>
void predict_or_learn ( bs d,
single_learner base,
example ec 
)

Definition at line 178 of file bs.cc.

References bs::_random_state, bs::all, bs::B, bs_predict_mean(), bs_predict_vote(), bs::bs_type, BS_TYPE_MEAN, BS_TYPE_VOTE, LEARNER::learner< T, E >::learn(), example::partial_prediction, example::pred, bs::pred_vec, LEARNER::learner< T, E >::predict(), vw::print_text, vw::raw_prediction, polyprediction::scalar, example::tag, THROW, example::weight, and BS::weight_gen().

179 {
180  vw& all = *d.all;
181  bool shouldOutput = all.raw_prediction > 0;
182 
183  float weight_temp = ec.weight;
184 
185  std::stringstream outputStringStream;
186  d.pred_vec->clear();
187 
188  for (size_t i = 1; i <= d.B; i++)
189  {
190  ec.weight = weight_temp * (float)BS::weight_gen(d._random_state);
191 
192  if (is_learn)
193  base.learn(ec, i - 1);
194  else
195  base.predict(ec, i - 1);
196 
197  d.pred_vec->push_back(ec.pred.scalar);
198 
199  if (shouldOutput)
200  {
201  if (i > 1)
202  outputStringStream << ' ';
203  outputStringStream << i << ':' << ec.partial_prediction;
204  }
205  }
206 
207  ec.weight = weight_temp;
208 
209  switch (d.bs_type)
210  {
211  case BS_TYPE_MEAN:
212  bs_predict_mean(all, ec, *d.pred_vec);
213  break;
214  case BS_TYPE_VOTE:
215  bs_predict_vote(ec, *d.pred_vec);
216  break;
217  default:
218  THROW("Unknown bs_type specified: " << d.bs_type);
219  }
220 
221  if (shouldOutput)
222  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
223 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint32_t weight_gen(std::shared_ptr< rand_state > &state)
Definition: bs.h:17
size_t bs_type
Definition: bs.cc:26
float scalar
Definition: example.h:45
vw * all
Definition: bs.cc:30
float partial_prediction
Definition: example.h:68
std::shared_ptr< rand_state > _random_state
Definition: bs.cc:31
uint32_t B
Definition: bs.cc:25
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
#define BS_TYPE_VOTE
Definition: bs.h:9
#define BS_TYPE_MEAN
Definition: bs.h:8
void bs_predict_vote(example &ec, std::vector< double > &pred_vec)
Definition: bs.cc:43
void bs_predict_mean(vw &all, example &ec, std::vector< double > &pred_vec)
Definition: bs.cc:36
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
#define THROW(args)
Definition: vw_exception.h:181
std::vector< double > * pred_vec
Definition: bs.cc:29

◆ print_result()

void print_result ( int  f,
float  res,
v_array< char >  tag,
float  lb,
float  ub 
)

Definition at line 136 of file bs.cc.

References print_tag(), and io_buf::write_file_or_socket().

Referenced by enable_sources(), mf_print_audit_features(), output_example(), GD::print_audit_features(), and reset_source().

137 {
138  if (f >= 0)
139  {
140  std::stringstream ss;
141  ss << std::fixed << res;
142  print_tag(ss, tag);
143  ss << std::fixed << ' ' << lb << ' ' << ub << '\n';
144  ssize_t len = ss.str().size();
145  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
146  if (t != len)
147  std::cerr << "write error: " << strerror(errno) << std::endl;
148  }
149 }
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
float f
Definition: cache.cc:40