Vowpal Wabbit
multiclass.cc
Go to the documentation of this file.
1 #include <cstring>
2 #include <climits>
3 #include "global_data.h"
4 #include "vw.h"
5 #include "vw_exception.h"
6 
7 namespace MULTICLASS
8 {
9 char* bufread_label(label_t* ld, char* c)
10 {
11  memcpy(&ld->label, c, sizeof(ld->label));
12  c += sizeof(ld->label);
13  memcpy(&ld->weight, c, sizeof(ld->weight));
14  c += sizeof(ld->weight);
15  return c;
16 }
17 
18 size_t read_cached_label(shared_data*, void* v, io_buf& cache)
19 {
20  label_t* ld = (label_t*)v;
21  char* c;
22  size_t total = sizeof(ld->label) + sizeof(ld->weight);
23  if (cache.buf_read(c, total) < total)
24  return 0;
25  bufread_label(ld, c);
26 
27  return total;
28 }
29 
30 float weight(void* v)
31 {
32  label_t* ld = (label_t*)v;
33  return (ld->weight > 0) ? ld->weight : 0.f;
34 }
35 
36 char* bufcache_label(label_t* ld, char* c)
37 {
38  memcpy(c, &ld->label, sizeof(ld->label));
39  c += sizeof(ld->label);
40  memcpy(c, &ld->weight, sizeof(ld->weight));
41  c += sizeof(ld->weight);
42  return c;
43 }
44 
45 void cache_label(void* v, io_buf& cache)
46 {
47  char* c;
48  label_t* ld = (label_t*)v;
49  cache.buf_write(c, sizeof(ld->label) + sizeof(ld->weight));
50  bufcache_label(ld, c);
51 }
52 
53 void default_label(void* v)
54 {
55  label_t* ld = (label_t*)v;
56  ld->label = (uint32_t)-1;
57  ld->weight = 1.;
58 }
59 
60 bool test_label(void* v)
61 {
62  label_t* ld = (label_t*)v;
63  return ld->label == (uint32_t)-1;
64 }
65 
66 void delete_label(void*) {}
67 
68 void parse_label(parser*, shared_data* sd, void* v, v_array<substring>& words)
69 {
70  label_t* ld = (label_t*)v;
71 
72  switch (words.size())
73  {
74  case 0:
75  break;
76  case 1:
77  ld->label = sd->ldict ? (uint32_t)sd->ldict->get(words[0]) : int_of_substring(words[0]);
78  ld->weight = 1.0;
79  break;
80  case 2:
81  ld->label = sd->ldict ? (uint32_t)sd->ldict->get(words[0]) : int_of_substring(words[0]);
82  ld->weight = float_of_substring(words[1]);
83  break;
84  default:
85  std::cerr << "malformed example!\n";
86  std::cerr << "words.size() = " << words.size() << std::endl;
87  }
88  if (ld->label == 0)
89  THROW("label 0 is not allowed for multiclass. Valid labels are {1,k}"
90  << (sd->ldict ? "\nthis likely happened because you specified an invalid label with named labels" : ""));
91 }
92 
94  test_label, sizeof(label_t)};
95 
96 void print_label_pred(vw& all, example& ec, uint32_t prediction)
97 {
98  substring ss_label = all.sd->ldict->get(ec.l.multi.label);
99  substring ss_pred = all.sd->ldict->get(prediction);
101  !ss_label.begin ? "unknown" : std::string(ss_label.begin, ss_label.end - ss_label.begin),
102  !ss_pred.begin ? "unknown" : std::string(ss_pred.begin, ss_pred.end - ss_pred.begin), ec.num_features,
103  all.progress_add, all.progress_arg);
104 }
105 
106 void print_probability(vw& all, example& ec, uint32_t prediction)
107 {
108  std::stringstream pred_ss;
109  pred_ss << prediction << "(" << std::setw(2) << std::setprecision(0) << std::fixed
110  << 100 * ec.pred.scalars[prediction - 1] << "%)";
111 
112  std::stringstream label_ss;
113  label_ss << ec.l.multi.label;
114 
115  all.sd->print_update(all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.num_features,
116  all.progress_add, all.progress_arg);
117 }
118 
119 void print_score(vw& all, example& ec, uint32_t prediction)
120 {
121  std::stringstream pred_ss;
122  pred_ss << prediction;
123 
124  std::stringstream label_ss;
125  label_ss << ec.l.multi.label;
126 
127  all.sd->print_update(all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.num_features,
128  all.progress_add, all.progress_arg);
129 }
130 
131 void direct_print_update(vw& all, example& ec, uint32_t prediction)
132 {
133  all.sd->print_update(all.holdout_set_off, all.current_pass, ec.l.multi.label, prediction, ec.num_features,
134  all.progress_add, all.progress_arg);
135 }
136 
137 template <void (*T)(vw&, example&, uint32_t)>
138 void print_update(vw& all, example& ec, uint32_t prediction)
139 {
140  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
141  {
142  if (!all.sd->ldict)
143  T(all, ec, prediction);
144  else
145  print_label_pred(all, ec, ec.pred.multiclass);
146  }
147 }
148 
149 void print_update_with_probability(vw& all, example& ec, uint32_t pred)
150 {
151  print_update<print_probability>(all, ec, pred);
152 }
153 void print_update_with_score(vw& all, example& ec, uint32_t pred) { print_update<print_score>(all, ec, pred); }
154 
155 void finish_example(vw& all, example& ec, bool update_loss)
156 {
157  float loss = 0;
158  if (ec.l.multi.label != (uint32_t)ec.pred.multiclass && ec.l.multi.label != (uint32_t)-1)
159  loss = ec.weight;
160 
161  all.sd->update(ec.test_only, update_loss && (ec.l.multi.label != (uint32_t)-1), loss, ec.weight, ec.num_features);
162 
163  for (int sink : all.final_prediction_sink)
164  if (!all.sd->ldict)
165  all.print(sink, (float)ec.pred.multiclass, 0, ec.tag);
166  else
167  {
168  substring ss_pred = all.sd->ldict->get(ec.pred.multiclass);
169  all.print_text(sink, std::string(ss_pred.begin, ss_pred.end - ss_pred.begin), ec.tag);
170  }
171 
172  MULTICLASS::print_update<direct_print_update>(all, ec, ec.pred.multiclass);
173  VW::finish_example(all, ec);
174 }
175 } // namespace MULTICLASS
int int_of_substring(substring s)
v_array< char > tag
Definition: example.h:63
uint32_t multiclass
Definition: example.h:49
void delete_label(void *)
Definition: multiclass.cc:66
bool test_label(void *v)
Definition: multiclass.cc:60
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
v_array< int > final_prediction_sink
Definition: global_data.h:518
namedlabels * ldict
Definition: global_data.h:153
bool quiet
Definition: global_data.h:487
char * bufcache_label(label_t *ld, char *c)
Definition: multiclass.cc:36
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
void cache_label(void *v, io_buf &cache)
Definition: multiclass.cc:45
void print_update_with_probability(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:149
bool holdout_set_off
Definition: global_data.h:499
void print_update_with_score(vw &all, example &ec, uint32_t pred)
Definition: multiclass.cc:153
bool progress_add
Definition: global_data.h:545
size_t size() const
Definition: v_array.h:68
label_parser mc_label
Definition: multiclass.cc:93
MULTICLASS::label_t multi
Definition: example.h:29
float weight(void *v)
Definition: multiclass.cc:30
char * bufread_label(label_t *ld, char *c)
Definition: multiclass.cc:9
float float_of_substring(substring s)
shared_data * sd
Definition: global_data.h:375
void print_label_pred(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:96
float progress_arg
Definition: global_data.h:546
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
uint64_t current_pass
Definition: global_data.h:396
Definition: io_buf.h:54
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
void print_score(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:119
void buf_write(char *&pointer, size_t n)
Definition: io_buf.cc:94
polylabel l
Definition: example.h:57
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
Definition: multiclass.cc:18
void print_update(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:138
uint64_t get(substring &s)
Definition: global_data.h:108
void print_probability(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:106
Definition: parser.h:38
polyprediction pred
Definition: example.h:60
void finish_example(vw &all, example &ec, bool update_loss)
Definition: multiclass.cc:155
float weight
Definition: example.h:62
void direct_print_update(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:131
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
v_array< float > scalars
Definition: example.h:46
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
void default_label(void *v)
Definition: multiclass.cc:53
void parse_label(parser *, shared_data *sd, void *v, v_array< substring > &words)
Definition: multiclass.cc:68
bool test_only
Definition: example.h:76
size_t buf_read(char *&pointer, size_t n)
Definition: io_buf.cc:12