Vowpal Wabbit
cb.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <cfloat>
7 
8 #include "example.h"
9 #include "parse_primitives.h"
10 #include "vw.h"
11 #include "vw_exception.h"
12 
13 using namespace LEARNER;
14 
15 namespace CB
16 {
17 char* bufread_label(CB::label* ld, char* c, io_buf& cache)
18 {
19  size_t num = *(size_t*)c;
20  ld->costs.clear();
21  c += sizeof(size_t);
22  size_t total = sizeof(cb_class) * num + sizeof(ld->weight);
23  if (cache.buf_read(c, total) < total)
24  {
25  std::cout << "error in demarshal of cost data" << std::endl;
26  return c;
27  }
28  for (size_t i = 0; i < num; i++)
29  {
30  cb_class temp = *(cb_class*)c;
31  c += sizeof(cb_class);
32  ld->costs.push_back(temp);
33  }
34  memcpy(&ld->weight, c, sizeof(ld->weight));
35  c += sizeof(ld->weight);
36  return c;
37 }
38 
39 size_t read_cached_label(shared_data*, void* v, io_buf& cache)
40 {
41  CB::label* ld = (CB::label*)v;
42  ld->costs.clear();
43  char* c;
44  size_t total = sizeof(size_t);
45  if (cache.buf_read(c, total) < total)
46  return 0;
47  bufread_label(ld, c, cache);
48 
49  return total;
50 }
51 
52 float weight(void* v)
53 {
54  CB::label* ld = (CB::label*)v;
55  return ld->weight;
56 }
57 
58 char* bufcache_label(CB::label* ld, char* c)
59 {
60  *(size_t*)c = ld->costs.size();
61  c += sizeof(size_t);
62  for (auto const& cost : ld->costs)
63  {
64  *(cb_class*)c = cost;
65  c += sizeof(cb_class);
66  }
67  memcpy(c, &ld->weight, sizeof(ld->weight));
68  c += sizeof(ld->weight);
69  return c;
70 }
71 
72 void cache_label(void* v, io_buf& cache)
73 {
74  char* c;
75  CB::label* ld = (CB::label*)v;
76  cache.buf_write(c, sizeof(size_t) + sizeof(cb_class) * ld->costs.size() + sizeof(ld->weight));
77  bufcache_label(ld, c);
78 }
79 
80 void default_label(void* v)
81 {
82  CB::label* ld = (CB::label*)v;
83  ld->costs.clear();
84  ld->weight = 1;
85 }
86 
87 bool test_label(void* v)
88 {
89  CB::label* ld = (CB::label*)v;
90  if (ld->costs.empty())
91  return true;
92  for (auto const& cost : ld->costs)
93  if (FLT_MAX != cost.cost && cost.probability > 0.)
94  return false;
95  return true;
96 }
97 
98 void delete_label(void* v)
99 {
100  CB::label* ld = (CB::label*)v;
101  ld->costs.delete_v();
102 }
103 
104 void copy_label(void* dst, void* src)
105 {
106  CB::label* ldD = (CB::label*)dst;
107  CB::label* ldS = (CB::label*)src;
108  copy_array(ldD->costs, ldS->costs);
109  ldD->weight = ldS->weight;
110 }
111 
113 {
114  CB::label* ld = (CB::label*)v;
115  ld->costs.clear();
116  ld->weight = 1.0;
117 
118  for (auto const& word : words)
119  {
120  cb_class f;
121  tokenize(':', word, p->parse_name);
122 
123  if (p->parse_name.empty() || p->parse_name.size() > 3)
124  THROW("malformed cost specification: " << p->parse_name);
125 
126  f.partial_prediction = 0.;
127  f.action = (uint32_t)hashstring(p->parse_name[0], 0);
128  f.cost = FLT_MAX;
129 
130  if (p->parse_name.size() > 1)
132 
133  if (std::isnan(f.cost))
134  THROW("error NaN cost (" << p->parse_name[1] << " for action: " << p->parse_name[0]);
135 
136  f.probability = .0;
137  if (p->parse_name.size() > 2)
139 
140  if (std::isnan(f.probability))
141  THROW("error NaN probability (" << p->parse_name[2] << " for action: " << p->parse_name[0]);
142 
143  if (f.probability > 1.0)
144  {
145  std::cerr << "invalid probability > 1 specified for an action, resetting to 1." << std::endl;
146  f.probability = 1.0;
147  }
148  if (f.probability < 0.0)
149  {
150  std::cerr << "invalid probability < 0 specified for an action, resetting to 0." << std::endl;
151  f.probability = .0;
152  }
153  if (substring_equal(p->parse_name[0], "shared"))
154  {
155  if (p->parse_name.size() == 1)
156  {
157  f.probability = -1.f;
158  }
159  else
160  std::cerr << "shared feature vectors should not have costs" << std::endl;
161  }
162 
163  ld->costs.push_back(f);
164  }
165 }
166 
168  test_label, sizeof(label)};
169 
170 bool ec_is_example_header(example const& ec) // example headers just have "shared"
171 {
172  v_array<CB::cb_class> costs = ec.l.cb.costs;
173  if (costs.size() != 1)
174  return false;
175  if (costs[0].probability == -1.f)
176  return true;
177  return false;
178 }
179 
180 void print_update(vw& all, bool is_test, example& ec, multi_ex* ec_seq, bool action_scores)
181 {
182  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
183  {
184  size_t num_features = ec.num_features;
185 
186  size_t pred = ec.pred.multiclass;
187  if (ec_seq != nullptr)
188  {
189  num_features = 0;
190  // TODO: code duplication csoaa.cc LabelDict::ec_is_example_header
191  for (size_t i = 0; i < (*ec_seq).size(); i++)
192  if (!CB::ec_is_example_header(*(*ec_seq)[i]))
193  num_features += (*ec_seq)[i]->num_features;
194  }
195  std::string label_buf;
196  if (is_test)
197  label_buf = " unknown";
198  else
199  label_buf = " known";
200 
201  if (action_scores)
202  {
203  std::ostringstream pred_buf;
204  pred_buf << std::setw(shared_data::col_current_predict) << std::right << std::setfill(' ');
205  if (!ec.pred.a_s.empty())
206  pred_buf << ec.pred.a_s[0].action << ":" << ec.pred.a_s[0].score << "...";
207  else
208  pred_buf << "no action";
209  all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_features,
210  all.progress_add, all.progress_arg);
211  }
212  else
213  all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, (uint32_t)pred, num_features,
214  all.progress_add, all.progress_arg);
215  }
216 }
217 } // namespace CB
218 
219 namespace CB_EVAL
220 {
221 float weight(void* v)
222 {
223  CB_EVAL::label* ld = (CB_EVAL::label*)v;
224  return ld->event.weight;
225 }
226 
227 size_t read_cached_label(shared_data* sd, void* v, io_buf& cache)
228 {
229  CB_EVAL::label* ld = (CB_EVAL::label*)v;
230  char* c;
231  size_t total = sizeof(uint32_t);
232  if (cache.buf_read(c, total) < total)
233  return 0;
234  ld->action = *(uint32_t*)c;
235 
236  return total + CB::read_cached_label(sd, &(ld->event), cache);
237 }
238 
239 void cache_label(void* v, io_buf& cache)
240 {
241  char* c;
242  CB_EVAL::label* ld = (CB_EVAL::label*)v;
243  cache.buf_write(c, sizeof(uint32_t));
244  *(uint32_t*)c = ld->action;
245 
246  CB::cache_label(&(ld->event), cache);
247 }
248 
249 void default_label(void* v)
250 {
251  CB_EVAL::label* ld = (CB_EVAL::label*)v;
252  CB::default_label(&(ld->event));
253  ld->action = 0;
254 }
255 
256 bool test_label(void* v)
257 {
258  CB_EVAL::label* ld = (CB_EVAL::label*)v;
259  return CB::test_label(&ld->event);
260 }
261 
262 void delete_label(void* v)
263 {
264  CB_EVAL::label* ld = (CB_EVAL::label*)v;
265  CB::delete_label(&(ld->event));
266 }
267 
268 void copy_label(void* dst, void* src)
269 {
270  CB_EVAL::label* ldD = (CB_EVAL::label*)dst;
271  CB_EVAL::label* ldS = (CB_EVAL::label*)src;
272  CB::copy_label(&(ldD->event), &(ldS)->event);
273  ldD->action = ldS->action;
274 }
275 
276 void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
277 {
278  CB_EVAL::label* ld = (CB_EVAL::label*)v;
279 
280  if (words.size() < 2)
281  THROW("Evaluation can not happen without an action and an exploration");
282 
283  ld->action = (uint32_t)hashstring(words[0], 0);
284 
285  words.begin()++;
286 
287  CB::parse_label(p, sd, &(ld->event), words);
288 
289  words.begin()--;
290 }
291 
293  test_label, sizeof(CB_EVAL::label)};
294 } // namespace CB_EVAL
label_parser cb_eval
Definition: cb.cc:292
void copy_label(void *dst, void *src)
Definition: cb.cc:104
void parse_label(parser *p, shared_data *sd, void *v, v_array< substring > &words)
Definition: cb.cc:276
uint32_t multiclass
Definition: example.h:49
ACTION_SCORE::action_scores a_s
Definition: example.h:47
void copy_label(void *dst, void *src)
Definition: cb.cc:268
char * bufcache_label(CB::label *ld, char *c)
Definition: cb.cc:58
void cache_label(void *v, io_buf &cache)
Definition: cb.cc:72
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
void parse_label(parser *p, shared_data *, void *v, v_array< substring > &words)
Definition: cb.cc:112
float weight
Definition: cb.h:28
void copy_array(v_array< T > &dst, const v_array< T > &src)
Definition: v_array.h:185
CB::label cb
Definition: example.h:31
v_array< action_score > action_scores
Definition: action_score.h:10
v_array< cb_class > costs
Definition: cb.h:27
void delete_label(void *v)
Definition: cb.cc:98
bool quiet
Definition: global_data.h:487
void cache_label(void *v, io_buf &cache)
Definition: cb.cc:239
bool holdout_set_off
Definition: global_data.h:499
v_array< substring > parse_name
Definition: parser.h:100
T *& begin()
Definition: v_array.h:42
bool progress_add
Definition: global_data.h:545
uint32_t action
Definition: cb.h:41
size_t size() const
Definition: v_array.h:68
CB::label event
Definition: cb.h:42
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
float float_of_substring(substring s)
void default_label(void *v)
Definition: cb.cc:80
Definition: cb.cc:15
uint32_t action
Definition: cb.h:18
float partial_prediction
Definition: cb.h:21
shared_data * sd
Definition: global_data.h:375
float probability
Definition: cb.h:19
float progress_arg
Definition: global_data.h:546
size_t read_cached_label(shared_data *sd, void *v, io_buf &cache)
Definition: cb.cc:227
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
uint64_t current_pass
Definition: global_data.h:396
void default_label(void *v)
Definition: cb.cc:249
Definition: io_buf.h:54
void buf_write(char *&pointer, size_t n)
Definition: io_buf.cc:94
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
float weight(void *v)
Definition: cb.cc:221
Definition: cb.h:25
VW_STD14_CONSTEXPR uint64_t hashstring(substring s, uint64_t h)
Definition: hashstring.h:18
float cost
Definition: cb.h:17
bool empty() const
Definition: v_array.h:59
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
Definition: cb.cc:39
bool substring_equal(const substring &a, const substring &b)
bool test_label(void *v)
Definition: simple_label.cc:70
Definition: parser.h:38
Definition: cb.cc:219
polyprediction pred
Definition: example.h:60
static constexpr int col_current_predict
Definition: global_data.h:184
void delete_label(void *v)
Definition: cb.cc:262
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
char * bufread_label(CB::label *ld, char *c, io_buf &cache)
Definition: cb.cc:17
size_t buf_read(char *&pointer, size_t n)
Definition: io_buf.cc:12
bool test_label(void *v)
Definition: cb.cc:87