Vowpal Wabbit
multilabel.cc
Go to the documentation of this file.
1 #include "float.h"
2 #include "gd.h"
3 #include "vw.h"
4 
5 namespace MULTILABEL
6 {
7 char* bufread_label(labels* ld, char* c, io_buf& cache)
8 {
9  size_t num = *(size_t*)c;
10  ld->label_v.clear();
11  c += sizeof(size_t);
12  size_t total = sizeof(uint32_t) * num;
13  if (cache.buf_read(c, (int)total) < total)
14  {
15  std::cout << "error in demarshal of cost data" << std::endl;
16  return c;
17  }
18  for (size_t i = 0; i < num; i++)
19  {
20  uint32_t temp = *(uint32_t*)c;
21  c += sizeof(uint32_t);
22  ld->label_v.push_back(temp);
23  }
24 
25  return c;
26 }
27 
28 size_t read_cached_label(shared_data*, void* v, io_buf& cache)
29 {
30  labels* ld = (labels*)v;
31  ld->label_v.clear();
32  char* c;
33  size_t total = sizeof(size_t);
34  if (cache.buf_read(c, (int)total) < total)
35  return 0;
36  bufread_label(ld, c, cache);
37 
38  return total;
39 }
40 
41 float weight(void*) { return 1.; }
42 
43 char* bufcache_label(labels* ld, char* c)
44 {
45  *(size_t*)c = ld->label_v.size();
46  c += sizeof(size_t);
47  for (unsigned int i = 0; i < ld->label_v.size(); i++)
48  {
49  *(uint32_t*)c = ld->label_v[i];
50  c += sizeof(uint32_t);
51  }
52  return c;
53 }
54 
55 void cache_label(void* v, io_buf& cache)
56 {
57  char* c;
58  labels* ld = (labels*)v;
59  cache.buf_write(c, sizeof(size_t) + sizeof(uint32_t) * ld->label_v.size());
60  bufcache_label(ld, c);
61 }
62 
63 void default_label(void* v)
64 {
65  labels* ld = (labels*)v;
66  ld->label_v.clear();
67 }
68 
69 bool test_label(void* v)
70 {
71  labels* ld = (labels*)v;
72  return ld->label_v.size() == 0;
73 }
74 
75 void delete_label(void* v)
76 {
77  labels* ld = (labels*)v;
78  if (ld)
79  ld->label_v.delete_v();
80 }
81 
82 void copy_label(void* dst, void* src)
83 {
84  if (dst && src)
85  {
86  labels* ldD = (labels*)dst;
87  labels* ldS = (labels*)src;
88  copy_array(ldD->label_v, ldS->label_v);
89  }
90 }
91 
92 void parse_label(parser* p, shared_data*, void* v, v_array<substring>& words)
93 {
94  labels* ld = (labels*)v;
95 
96  ld->label_v.clear();
97  switch (words.size())
98  {
99  case 0:
100  break;
101  case 1:
102  tokenize(',', words[0], p->parse_name);
103 
104  for (size_t i = 0; i < p->parse_name.size(); i++)
105  {
106  *(p->parse_name[i].end) = '\0';
107  uint32_t n = atoi(p->parse_name[i].begin);
108  ld->label_v.push_back(n);
109  }
110  break;
111  default:
112  std::cerr << "example with an odd label, what is ";
113  for (size_t i = 0; i < words.size(); i++) std::cerr << words[i].begin << " ";
114  std::cerr << std::endl;
115  }
116 }
117 
119  test_label, sizeof(labels)};
120 
121 void print_update(vw& all, bool is_test, example& ec)
122 {
123  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
124  {
125  std::stringstream label_string;
126  if (is_test)
127  label_string << " unknown";
128  else
129  for (size_t i = 0; i < ec.l.multilabels.label_v.size(); i++) label_string << " " << ec.l.multilabels.label_v[i];
130 
131  std::stringstream pred_string;
132  for (size_t i = 0; i < ec.pred.multilabels.label_v.size(); i++)
133  pred_string << " " << ec.pred.multilabels.label_v[i];
134 
135  all.sd->print_update(all.holdout_set_off, all.current_pass, label_string.str(), pred_string.str(), ec.num_features,
136  all.progress_add, all.progress_arg);
137  }
138 }
139 
140 void output_example(vw& all, example& ec)
141 {
142  labels& ld = ec.l.multilabels;
143 
144  float loss = 0.;
145  if (!test_label(&ld))
146  {
147  // need to compute exact loss
148  labels preds = ec.pred.multilabels;
149  labels given = ec.l.multilabels;
150 
151  uint32_t preds_index = 0;
152  uint32_t given_index = 0;
153 
154  while (preds_index < preds.label_v.size() && given_index < given.label_v.size())
155  {
156  if (preds.label_v[preds_index] < given.label_v[given_index])
157  {
158  preds_index++;
159  loss++;
160  }
161  else if (preds.label_v[preds_index] > given.label_v[given_index])
162  {
163  given_index++;
164  loss++;
165  }
166  else
167  {
168  preds_index++;
169  given_index++;
170  }
171  }
172  loss += given.label_v.size() - given_index;
173  loss += preds.label_v.size() - preds_index;
174  }
175 
176  all.sd->update(ec.test_only, !test_label(&ld), loss, 1.f, ec.num_features);
177 
178  for (int sink : all.final_prediction_sink)
179  if (sink >= 0)
180  {
181  std::stringstream ss;
182 
183  for (size_t i = 0; i < ec.pred.multilabels.label_v.size(); i++)
184  {
185  if (i > 0)
186  ss << ',';
187  ss << ec.pred.multilabels.label_v[i];
188  }
189  ss << ' ';
190  all.print_text(sink, ss.str(), ec.tag);
191  }
192 
193  print_update(all, test_label(&ec.l.multilabels), ec);
194 }
195 } // namespace MULTILABEL
v_array< char > tag
Definition: example.h:63
void print_update(vw &all, bool is_test, example &ec)
Definition: multilabel.cc:121
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
Definition: multilabel.cc:28
void output_example(vw &all, example &ec)
Definition: multilabel.cc:140
void copy_array(v_array< T > &dst, const v_array< T > &src)
Definition: v_array.h:185
v_array< int > final_prediction_sink
Definition: global_data.h:518
bool test_label(void *v)
Definition: multilabel.cc:69
bool quiet
Definition: global_data.h:487
void copy_label(void *dst, void *src)
Definition: multilabel.cc:82
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
char * bufread_label(labels *ld, char *c, io_buf &cache)
Definition: multilabel.cc:7
bool holdout_set_off
Definition: global_data.h:499
v_array< substring > parse_name
Definition: parser.h:100
T *& begin()
Definition: v_array.h:42
bool progress_add
Definition: global_data.h:545
size_t size() const
Definition: v_array.h:68
void parse_label(parser *p, shared_data *, void *v, v_array< substring > &words)
Definition: multilabel.cc:92
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void clear()
Definition: v_array.h:88
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
void default_label(void *v)
Definition: multilabel.cc:63
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
uint64_t current_pass
Definition: global_data.h:396
Definition: io_buf.h:54
T *& end()
Definition: v_array.h:43
void cache_label(void *v, io_buf &cache)
Definition: multilabel.cc:55
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
void buf_write(char *&pointer, size_t n)
Definition: io_buf.cc:94
polylabel l
Definition: example.h:57
MULTILABEL::labels multilabels
Definition: example.h:50
MULTILABEL::labels multilabels
Definition: example.h:34
label_parser multilabel
Definition: multilabel.cc:118
v_array< uint32_t > label_v
Definition: multilabel.h:16
char * bufcache_label(labels *ld, char *c)
Definition: multilabel.cc:43
Definition: parser.h:38
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
float weight(void *)
Definition: multilabel.cc:41
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
void delete_label(void *v)
Definition: multilabel.cc:75
bool test_only
Definition: example.h:76
size_t buf_read(char *&pointer, size_t n)
Definition: io_buf.cc:12