Vowpal Wabbit
Classes | Functions | Variables
MULTICLASS Namespace Reference

Classes

struct  label_t
 

Functions

char * bufread_label (label_t *ld, char *c)
 
size_t read_cached_label (shared_data *, void *v, io_buf &cache)
 
float weight (void *v)
 
char * bufcache_label (label_t *ld, char *c)
 
void cache_label (void *v, io_buf &cache)
 
void default_label (void *v)
 
bool test_label (void *v)
 
void delete_label (void *)
 
void parse_label (parser *, shared_data *sd, void *v, v_array< substring > &words)
 
void print_label_pred (vw &all, example &ec, uint32_t prediction)
 
void print_probability (vw &all, example &ec, uint32_t prediction)
 
void print_score (vw &all, example &ec, uint32_t prediction)
 
void direct_print_update (vw &all, example &ec, uint32_t prediction)
 
template<void(*)(vw &, example &, uint32_t) T>
void print_update (vw &all, example &ec, uint32_t prediction)
 
void print_update_with_probability (vw &all, example &ec, uint32_t pred)
 
void print_update_with_score (vw &all, example &ec, uint32_t pred)
 
void finish_example (vw &all, example &ec, bool update_loss)
 
template<class T >
void finish_example (vw &all, T &, example &ec)
 
template<class T >
void finish_example_without_loss (vw &all, T &, example &ec)
 

Variables

label_parser mc_label
 

Function Documentation

◆ bufcache_label()

char* MULTICLASS::bufcache_label ( label_t ld,
char *  c 
)

Definition at line 36 of file multiclass.cc.

References c, MULTICLASS::label_t::label, and MULTICLASS::label_t::weight.

Referenced by cache_label().

37 {
38  memcpy(c, &ld->label, sizeof(ld->label));
39  c += sizeof(ld->label);
40  memcpy(c, &ld->weight, sizeof(ld->weight));
41  c += sizeof(ld->weight);
42  return c;
43 }
constexpr uint64_t c
Definition: rand48.cc:12

◆ bufread_label()

char* MULTICLASS::bufread_label ( label_t ld,
char *  c 
)

Definition at line 9 of file multiclass.cc.

References c, MULTICLASS::label_t::label, and MULTICLASS::label_t::weight.

Referenced by read_cached_label().

10 {
11  memcpy(&ld->label, c, sizeof(ld->label));
12  c += sizeof(ld->label);
13  memcpy(&ld->weight, c, sizeof(ld->weight));
14  c += sizeof(ld->weight);
15  return c;
16 }
constexpr uint64_t c
Definition: rand48.cc:12

◆ cache_label()

void MULTICLASS::cache_label ( void *  v,
io_buf cache 
)

Definition at line 45 of file multiclass.cc.

References io_buf::buf_write(), bufcache_label(), c, MULTICLASS::label_t::label, and MULTICLASS::label_t::weight.

46 {
47  char* c;
48  label_t* ld = (label_t*)v;
49  cache.buf_write(c, sizeof(ld->label) + sizeof(ld->weight));
50  bufcache_label(ld, c);
51 }
char * bufcache_label(label_t *ld, char *c)
Definition: multiclass.cc:36
void buf_write(char *&pointer, size_t n)
Definition: io_buf.cc:94
constexpr uint64_t c
Definition: rand48.cc:12

◆ default_label()

void MULTICLASS::default_label ( void *  v)

Definition at line 53 of file multiclass.cc.

References MULTICLASS::label_t::label, and MULTICLASS::label_t::weight.

54 {
55  label_t* ld = (label_t*)v;
56  ld->label = (uint32_t)-1;
57  ld->weight = 1.;
58 }

◆ delete_label()

void MULTICLASS::delete_label ( void *  )

Definition at line 66 of file multiclass.cc.

66 {}

◆ direct_print_update()

void MULTICLASS::direct_print_update ( vw all,
example ec,
uint32_t  prediction 
)

Definition at line 131 of file multiclass.cc.

References vw::current_pass, vw::holdout_set_off, example::l, MULTICLASS::label_t::label, polylabel::multi, example::num_features, shared_data::print_update(), vw::progress_add, vw::progress_arg, and vw::sd.

132 {
133  all.sd->print_update(all.holdout_set_off, all.current_pass, ec.l.multi.label, prediction, ec.num_features,
134  all.progress_add, all.progress_arg);
135 }
bool holdout_set_off
Definition: global_data.h:499
bool progress_add
Definition: global_data.h:545
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
size_t num_features
Definition: example.h:67
uint64_t current_pass
Definition: global_data.h:396
polylabel l
Definition: example.h:57

◆ finish_example() [1/2]

template<class T >
void MULTICLASS::finish_example ( vw all,
T &  ,
example ec 
)

Definition at line 28 of file multiclass.h.

References finish_example().

29 {
30  finish_example(all, ec, true);
31 }
void finish_example(vw &all, T &, example &ec)
Definition: multiclass.h:28

◆ finish_example() [2/2]

void MULTICLASS::finish_example ( vw all,
example ec,
bool  update_loss 
)

Definition at line 155 of file multiclass.cc.

References substring::begin, substring::end, vw::final_prediction_sink, VW::finish_example(), namedlabels::get(), example::l, MULTICLASS::label_t::label, shared_data::ldict, loss(), polylabel::multi, polyprediction::multiclass, example::num_features, example::pred, vw::print, vw::print_text, vw::sd, example::tag, example::test_only, shared_data::update(), and example::weight.

Referenced by finish_example(), and finish_example_without_loss().

156 {
157  float loss = 0;
158  if (ec.l.multi.label != (uint32_t)ec.pred.multiclass && ec.l.multi.label != (uint32_t)-1)
159  loss = ec.weight;
160 
161  all.sd->update(ec.test_only, update_loss && (ec.l.multi.label != (uint32_t)-1), loss, ec.weight, ec.num_features);
162 
163  for (int sink : all.final_prediction_sink)
164  if (!all.sd->ldict)
165  all.print(sink, (float)ec.pred.multiclass, 0, ec.tag);
166  else
167  {
168  substring ss_pred = all.sd->ldict->get(ec.pred.multiclass);
169  all.print_text(sink, std::string(ss_pred.begin, ss_pred.end - ss_pred.begin), ec.tag);
170  }
171 
172  MULTICLASS::print_update<direct_print_update>(all, ec, ec.pred.multiclass);
173  VW::finish_example(all, ec);
174 }
v_array< char > tag
Definition: example.h:63
uint32_t multiclass
Definition: example.h:49
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
v_array< int > final_prediction_sink
Definition: global_data.h:518
namedlabels * ldict
Definition: global_data.h:153
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
polylabel l
Definition: example.h:57
uint64_t get(substring &s)
Definition: global_data.h:108
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
bool test_only
Definition: example.h:76

◆ finish_example_without_loss()

template<class T >
void MULTICLASS::finish_example_without_loss ( vw all,
T &  ,
example ec 
)

Definition at line 34 of file multiclass.h.

References finish_example().

35 {
36  finish_example(all, ec, false);
37 }
void finish_example(vw &all, T &, example &ec)
Definition: multiclass.h:28

◆ parse_label()

void MULTICLASS::parse_label ( parser ,
shared_data sd,
void *  v,
v_array< substring > &  words 
)

Definition at line 68 of file multiclass.cc.

References float_of_substring(), namedlabels::get(), int_of_substring(), MULTICLASS::label_t::label, shared_data::ldict, v_array< T >::size(), THROW, and MULTICLASS::label_t::weight.

69 {
70  label_t* ld = (label_t*)v;
71 
72  switch (words.size())
73  {
74  case 0:
75  break;
76  case 1:
77  ld->label = sd->ldict ? (uint32_t)sd->ldict->get(words[0]) : int_of_substring(words[0]);
78  ld->weight = 1.0;
79  break;
80  case 2:
81  ld->label = sd->ldict ? (uint32_t)sd->ldict->get(words[0]) : int_of_substring(words[0]);
82  ld->weight = float_of_substring(words[1]);
83  break;
84  default:
85  std::cerr << "malformed example!\n";
86  std::cerr << "words.size() = " << words.size() << std::endl;
87  }
88  if (ld->label == 0)
89  THROW("label 0 is not allowed for multiclass. Valid labels are {1,k}"
90  << (sd->ldict ? "\nthis likely happened because you specified an invalid label with named labels" : ""));
91 }
int int_of_substring(substring s)
namedlabels * ldict
Definition: global_data.h:153
size_t size() const
Definition: v_array.h:68
float float_of_substring(substring s)
uint64_t get(substring &s)
Definition: global_data.h:108
#define THROW(args)
Definition: vw_exception.h:181

◆ print_label_pred()

void MULTICLASS::print_label_pred ( vw all,
example ec,
uint32_t  prediction 
)

Definition at line 96 of file multiclass.cc.

References substring::begin, vw::current_pass, substring::end, namedlabels::get(), vw::holdout_set_off, example::l, MULTICLASS::label_t::label, shared_data::ldict, polylabel::multi, example::num_features, shared_data::print_update(), vw::progress_add, vw::progress_arg, and vw::sd.

Referenced by print_update().

97 {
98  substring ss_label = all.sd->ldict->get(ec.l.multi.label);
99  substring ss_pred = all.sd->ldict->get(prediction);
101  !ss_label.begin ? "unknown" : std::string(ss_label.begin, ss_label.end - ss_label.begin),
102  !ss_pred.begin ? "unknown" : std::string(ss_pred.begin, ss_pred.end - ss_pred.begin), ec.num_features,
103  all.progress_add, all.progress_arg);
104 }
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
namedlabels * ldict
Definition: global_data.h:153
bool holdout_set_off
Definition: global_data.h:499
bool progress_add
Definition: global_data.h:545
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
size_t num_features
Definition: example.h:67
uint64_t current_pass
Definition: global_data.h:396
polylabel l
Definition: example.h:57
uint64_t get(substring &s)
Definition: global_data.h:108

◆ print_probability()

void MULTICLASS::print_probability ( vw all,
example ec,
uint32_t  prediction 
)

Definition at line 106 of file multiclass.cc.

References vw::current_pass, vw::holdout_set_off, example::l, MULTICLASS::label_t::label, polylabel::multi, example::num_features, example::pred, shared_data::print_update(), vw::progress_add, vw::progress_arg, polyprediction::scalars, and vw::sd.

107 {
108  std::stringstream pred_ss;
109  pred_ss << prediction << "(" << std::setw(2) << std::setprecision(0) << std::fixed
110  << 100 * ec.pred.scalars[prediction - 1] << "%)";
111 
112  std::stringstream label_ss;
113  label_ss << ec.l.multi.label;
114 
115  all.sd->print_update(all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.num_features,
116  all.progress_add, all.progress_arg);
117 }
bool holdout_set_off
Definition: global_data.h:499
bool progress_add
Definition: global_data.h:545
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
size_t num_features
Definition: example.h:67
uint64_t current_pass
Definition: global_data.h:396
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
v_array< float > scalars
Definition: example.h:46

◆ print_score()

void MULTICLASS::print_score ( vw all,
example ec,
uint32_t  prediction 
)

Definition at line 119 of file multiclass.cc.

References vw::current_pass, vw::holdout_set_off, example::l, MULTICLASS::label_t::label, polylabel::multi, example::num_features, shared_data::print_update(), vw::progress_add, vw::progress_arg, and vw::sd.

120 {
121  std::stringstream pred_ss;
122  pred_ss << prediction;
123 
124  std::stringstream label_ss;
125  label_ss << ec.l.multi.label;
126 
127  all.sd->print_update(all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.num_features,
128  all.progress_add, all.progress_arg);
129 }
bool holdout_set_off
Definition: global_data.h:499
bool progress_add
Definition: global_data.h:545
MULTICLASS::label_t multi
Definition: example.h:29
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
void print_update(bool holdout_set_off, size_t current_pass, float label, float prediction, size_t num_features, bool progress_add, float progress_arg)
Definition: global_data.h:225
size_t num_features
Definition: example.h:67
uint64_t current_pass
Definition: global_data.h:396
polylabel l
Definition: example.h:57

◆ print_update()

template<void(*)(vw &, example &, uint32_t) T>
void MULTICLASS::print_update ( vw all,
example ec,
uint32_t  prediction 
)

Definition at line 138 of file multiclass.cc.

References vw::bfgs, shared_data::dump_interval, shared_data::ldict, polyprediction::multiclass, example::pred, print_label_pred(), vw::quiet, vw::sd, and shared_data::weighted_examples().

139 {
140  if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs)
141  {
142  if (!all.sd->ldict)
143  T(all, ec, prediction);
144  else
145  print_label_pred(all, ec, ec.pred.multiclass);
146  }
147 }
uint32_t multiclass
Definition: example.h:49
namedlabels * ldict
Definition: global_data.h:153
bool quiet
Definition: global_data.h:487
shared_data * sd
Definition: global_data.h:375
void print_label_pred(vw &all, example &ec, uint32_t prediction)
Definition: multiclass.cc:96
bool bfgs
Definition: global_data.h:412
polyprediction pred
Definition: example.h:60
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147

◆ print_update_with_probability()

void MULTICLASS::print_update_with_probability ( vw all,
example ec,
uint32_t  pred 
)

Definition at line 149 of file multiclass.cc.

Referenced by finish_example_scores().

150 {
151  print_update<print_probability>(all, ec, pred);
152 }

◆ print_update_with_score()

void MULTICLASS::print_update_with_score ( vw all,
example ec,
uint32_t  pred 
)

Definition at line 153 of file multiclass.cc.

Referenced by finish_example_scores().

153 { print_update<print_score>(all, ec, pred); }

◆ read_cached_label()

size_t MULTICLASS::read_cached_label ( shared_data ,
void *  v,
io_buf cache 
)

Definition at line 18 of file multiclass.cc.

References io_buf::buf_read(), bufread_label(), c, MULTICLASS::label_t::label, and MULTICLASS::label_t::weight.

19 {
20  label_t* ld = (label_t*)v;
21  char* c;
22  size_t total = sizeof(ld->label) + sizeof(ld->weight);
23  if (cache.buf_read(c, total) < total)
24  return 0;
25  bufread_label(ld, c);
26 
27  return total;
28 }
char * bufread_label(label_t *ld, char *c)
Definition: multiclass.cc:9
constexpr uint64_t c
Definition: rand48.cc:12
size_t buf_read(char *&pointer, size_t n)
Definition: io_buf.cc:12

◆ test_label()

bool MULTICLASS::test_label ( void *  v)

Definition at line 60 of file multiclass.cc.

References MULTICLASS::label_t::label.

61 {
62  label_t* ld = (label_t*)v;
63  return ld->label == (uint32_t)-1;
64 }

◆ weight()

float MULTICLASS::weight ( void *  v)

Definition at line 30 of file multiclass.cc.

References MULTICLASS::label_t::weight.

31 {
32  label_t* ld = (label_t*)v;
33  return (ld->weight > 0) ? ld->weight : 0.f;
34 }

Variable Documentation

◆ mc_label

label_parser MULTICLASS::mc_label
Initial value:
test_label, sizeof(label_t)}
void delete_label(void *)
Definition: multiclass.cc:66
void cache_label(void *v, io_buf &cache)
Definition: multiclass.cc:45
float weight(void *v)
Definition: multiclass.cc:30
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
Definition: multiclass.cc:18
bool test_label(void *v)
Definition: simple_label.cc:70
void default_label(void *v)
Definition: multiclass.cc:53
void parse_label(parser *, shared_data *sd, void *v, v_array< substring > &words)
Definition: multiclass.cc:68

Definition at line 93 of file multiclass.cc.

Referenced by add_to_vali(), LEARNER::init_multiclass_learner(), Search::mc_label_is_test(), Search::setup(), Search::train_single_example(), Search::search::~search(), and warm_cb::~warm_cb().