Vowpal Wabbit
cost_sensitive.cc
Go to the documentation of this file.
1 #include "float.h"
2 #include "gd.h"
3 #include "vw.h"
4 #include "vw_exception.h"
5 #include <cmath>
6 
7 namespace COST_SENSITIVE
8 {
9 void name_value(substring& s, v_array<substring>& name, float& v)
10 {
11  tokenize(':', s, name);
12 
13  switch (name.size())
14  {
15  case 0:
16  case 1:
17  v = 1.;
18  break;
19  case 2:
20  v = float_of_substring(name[1]);
21  if (std::isnan(v))
22  THROW("error NaN value for: " << name[0]);
23  break;
24  default:
25  std::cerr << "example with a wierd name. What is '";
26  std::cerr.write(s.begin, s.end - s.begin);
27  std::cerr << "'?\n";
28  }
29 }
30 
31 char* bufread_label(label* ld, char* c, io_buf& cache)
32 {
33  size_t num = *(size_t*)c;
34  ld->costs.clear();
35  c += sizeof(size_t);
36  size_t total = sizeof(wclass) * num;
37  if (cache.buf_read(c, (int)total) < total)
38  {
39  std::cout << "error in demarshal of cost data" << std::endl;
40  return c;
41  }
42  for (size_t i = 0; i < num; i++)
43  {
44  wclass temp = *(wclass*)c;
45  c += sizeof(wclass);
46  ld->costs.push_back(temp);
47  }
48 
49  return c;
50 }
51 
52 size_t read_cached_label(shared_data*, void* v, io_buf& cache)
53 {
54  label* ld = (label*)v;
55  ld->costs.clear();
56  char* c;
57  size_t total = sizeof(size_t);
58  if (cache.buf_read(c, (int)total) < total)
59  return 0;
60  bufread_label(ld, c, cache);
61 
62  return total;
63 }
64 
65 float weight(void*) { return 1.; }
66 
67 char* bufcache_label(label* ld, char* c)
68 {
69  *(size_t*)c = ld->costs.size();
70  c += sizeof(size_t);
71  for (unsigned int i = 0; i < ld->costs.size(); i++)
72  {
73  *(wclass*)c = ld->costs[i];
74  c += sizeof(wclass);
75  }
76  return c;
77 }
78 
79 void cache_label(void* v, io_buf& cache)
80 {
81  char* c;
82  label* ld = (label*)v;
83  cache.buf_write(c, sizeof(size_t) + sizeof(wclass) * ld->costs.size());
84  bufcache_label(ld, c);
85 }
86 
87 void default_label(void* v)
88 {
89  label* ld = (label*)v;
90  ld->costs.clear();
91 }
92 
93 bool test_label(void* v)
94 {
95  label* ld = (label*)v;
96  if (ld->costs.size() == 0)
97  return true;
98  for (unsigned int i = 0; i < ld->costs.size(); i++)
99  if (FLT_MAX != ld->costs[i].x)
100  return false;
101  return true;
102 }
103 
104 void delete_label(void* v)
105 {
106  label* ld = (label*)v;
107  if (ld)
108  ld->costs.delete_v();
109 }
110 
111 void copy_label(void* dst, void* src)
112 {
113  if (dst && src)
114  {
115  label* ldD = (label*)dst;
116  label* ldS = (label*)src;
117  copy_array(ldD->costs, ldS->costs);
118  }
119 }
120 
121 void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
122 {
123  label* ld = (label*)v;
124  ld->costs.clear();
125 
126  // handle shared and label first
127  if (words.size() == 1)
128  {
129  float fx;
130  name_value(words[0], p->parse_name, fx);
131  bool eq_shared = substring_equal(p->parse_name[0], "***shared***");
132  bool eq_label = substring_equal(p->parse_name[0], "***label***");
133  if (!sd->ldict)
134  {
135  eq_shared |= substring_equal(p->parse_name[0], "shared");
136  eq_label |= substring_equal(p->parse_name[0], "label");
137  }
138  if (eq_shared || eq_label)
139  {
140  if (eq_shared)
141  {
142  if (p->parse_name.size() != 1)
143  std::cerr << "shared feature vectors should not have costs on: " << words[0] << std::endl;
144  else
145  {
146  wclass f = {-FLT_MAX, 0, 0., 0.};
147  ld->costs.push_back(f);
148  }
149  }
150  if (eq_label)
151  {
152  if (p->parse_name.size() != 2)
153  std::cerr << "label feature vectors should have exactly one cost on: " << words[0] << std::endl;
154  else
155  {
156  wclass f = {float_of_substring(p->parse_name[1]), 0, 0., 0.};
157  ld->costs.push_back(f);
158  }
159  }
160  return;
161  }
162  }
163 
164  // otherwise this is a "real" example
165  for (unsigned int i = 0; i < words.size(); i++)
166  {
167  wclass f = {0., 0, 0., 0.};
168  name_value(words[i], p->parse_name, f.x);
169 
170  if (p->parse_name.size() == 0)
171  THROW(" invalid cost: specification -- no names on: " << words[i]);
172 
173  if (p->parse_name.size() == 1 || p->parse_name.size() == 2 || p->parse_name.size() == 3)
174  {
175  f.class_index =
176  sd->ldict ? (uint32_t)sd->ldict->get(p->parse_name[0]) : (uint32_t)hashstring(p->parse_name[0], 0);
177  if (p->parse_name.size() == 1 && f.x >= 0) // test examples are specified just by un-valued class #s
178  f.x = FLT_MAX;
179  }
180  else
181  THROW("malformed cost specification on '" << (p->parse_name[0].begin) << "'");
182 
183  ld->costs.push_back(f);
184  }
185 }
186 
188  test_label, sizeof(label)};
189 
190 void print_update(vw& all, bool is_test, example& ec, multi_ex* ec_seq, bool action_scores, uint32_t prediction)
191 {
192  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
193  {
194  size_t num_current_features = ec.num_features;
195  // for csoaa_ldf we want features from the whole (multiline example),
196  // not only from one line (the first one) represented by ec
197  if (ec_seq != nullptr)
198  {
199  num_current_features = 0;
200  // TODO: including quadratic and cubic.
201  for (auto& ecc : *ec_seq) num_current_features += ecc->num_features;
202  }
203 
204  std::string label_buf;
205  if (is_test)
206  label_buf = " unknown";
207  else
208  label_buf = " known";
209 
210  if (action_scores || all.sd->ldict)
211  {
212  std::ostringstream pred_buf;
213 
214  pred_buf << std::setw(all.sd->col_current_predict) << std::right << std::setfill(' ');
215  if (all.sd->ldict)
216  {
217  if (action_scores)
218  pred_buf << all.sd->ldict->get(ec.pred.a_s[0].action);
219  else
220  pred_buf << all.sd->ldict->get(prediction);
221  }
222  else
223  pred_buf << ec.pred.a_s[0].action;
224  if (action_scores)
225  pred_buf << ".....";
226  all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_current_features,
227  all.progress_add, all.progress_arg);
228  ;
229  }
230  else
231  all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, prediction, num_current_features,
232  all.progress_add, all.progress_arg);
233  }
234 }
235 
236 void output_example(vw& all, example& ec)
237 {
238  label& ld = ec.l.cs;
239 
240  float loss = 0.;
241  if (!test_label(&ld))
242  {
243  // need to compute exact loss
244  size_t pred = (size_t)ec.pred.multiclass;
245 
246  float chosen_loss = FLT_MAX;
247  float min = FLT_MAX;
248  for (auto& cl : ld.costs)
249  {
250  if (cl.class_index == pred)
251  chosen_loss = cl.x;
252  if (cl.x < min)
253  min = cl.x;
254  }
255  if (chosen_loss == FLT_MAX)
256  std::cerr << "warning: csoaa predicted an invalid class. Are all multi-class labels in the {1..k} range?"
257  << std::endl;
258 
259  loss = (chosen_loss - min) * ec.weight;
260  // TODO(alberto): add option somewhere to allow using absolute loss instead?
261  // loss = chosen_loss;
262  }
263 
264  all.sd->update(ec.test_only, !test_label(&ld), loss, ec.weight, ec.num_features);
265 
266  for (int sink : all.final_prediction_sink)
267  if (!all.sd->ldict)
268  all.print(sink, (float)ec.pred.multiclass, 0, ec.tag);
269  else
270  {
271  substring ss_pred = all.sd->ldict->get(ec.pred.multiclass);
272  all.print_text(sink, std::string(ss_pred.begin, ss_pred.end - ss_pred.begin), ec.tag);
273  }
274 
275  if (all.raw_prediction > 0)
276  {
277  std::stringstream outputStringStream;
278  for (unsigned int i = 0; i < ld.costs.size(); i++)
279  {
280  wclass cl = ld.costs[i];
281  if (i > 0)
282  outputStringStream << ' ';
283  outputStringStream << cl.class_index << ':' << cl.partial_prediction;
284  }
285  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
286  }
287 
288  print_update(all, test_label(&ec.l.cs), ec, nullptr, false, ec.pred.multiclass);
289 }
290 
291 void finish_example(vw& all, example& ec)
292 {
293  output_example(all, ec);
294  VW::finish_example(all, ec);
295 }
296 
298 {
300  if (costs.size() == 0)
301  return true;
302  for (size_t j = 0; j < costs.size(); j++)
303  if (costs[j].x != FLT_MAX)
304  return false;
305  return true;
306 }
307 
308 bool ec_is_example_header(example const& ec) // example headers look like "shared"
309 {
311  if (costs.size() != 1)
312  return false;
313  if (costs[0].class_index != 0)
314  return false;
315  if (costs[0].x != -FLT_MAX)
316  return false;
317  return true;
318 }
319 } // namespace COST_SENSITIVE
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
uint32_t multiclass
Definition: example.h:49
ACTION_SCORE::action_scores a_s
Definition: example.h:47
label_parser cs_label
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
void copy_array(v_array< T > &dst, const v_array< T > &src)
Definition: v_array.h:185
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< action_score > action_scores
Definition: action_score.h:10
void cache_label(void *v, io_buf &cache)
namedlabels * ldict
Definition: global_data.h:153
bool quiet
Definition: global_data.h:487
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
bool ec_is_example_header(example const &ec)
void copy_label(void *dst, void *src)
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
void delete_label(void *v)
bool holdout_set_off
Definition: global_data.h:499
v_array< substring > parse_name
Definition: parser.h:100
T *& begin()
Definition: v_array.h:42
bool progress_add
Definition: global_data.h:545
size_t size() const
Definition: v_array.h:68
bool example_is_test(example &ec)
float float_of_substring(substring s)
char * bufcache_label(label *ld, char *c)
COST_SENSITIVE::label cs
Definition: example.h:30
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void finish_example(vw &all, example &ec)
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
float weight(void *)
bool test_label(void *v)
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void default_label(void *v)
uint64_t current_pass
Definition: global_data.h:396
Definition: io_buf.h:54
void finish_example(vw &, example &)
Definition: parser.cc:881
void name_value(substring &s, v_array< substring > &name, float &v)
void parse_label(parser *p, shared_data *sd, void *v, v_array< substring > &words)
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
void buf_write(char *&pointer, size_t n)
Definition: io_buf.cc:94
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
char * bufread_label(label *ld, char *c, io_buf &cache)
VW_STD14_CONSTEXPR uint64_t hashstring(substring s, uint64_t h)
Definition: hashstring.h:18
bool substring_equal(const substring &a, const substring &b)
uint64_t get(substring &s)
Definition: global_data.h:108
Definition: parser.h:38
polyprediction pred
Definition: example.h:60
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores, uint32_t prediction)
static constexpr int col_current_predict
Definition: global_data.h:184
v_array< wclass > costs
float weight
Definition: example.h:62
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
void output_example(vw &all, example &ec)
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76
size_t buf_read(char *&pointer, size_t n)
Definition: io_buf.cc:12