Vowpal Wabbit
global_data.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <stdio.h>
7 #include <float.h>
8 #include <errno.h>
9 #include <iostream>
10 #include <sstream>
11 #include <math.h>
12 #include <assert.h>
13 
14 #include "global_data.h"
15 #include "gd.h"
16 #include "vw_exception.h"
17 
19 {
20  float p;
21  float weight;
22 };
23 
24 size_t really_read(int sock, void* in, size_t count)
25 {
26  char* buf = (char*)in;
27  size_t done = 0;
28  int r = 0;
29  while (done < count)
30  {
31  if ((r =
32 #ifdef _WIN32
33  recv(sock, buf, (unsigned int)(count - done), 0)
34 #else
35  read(sock, buf, (unsigned int)(count - done))
36 #endif
37  ) == 0)
38  return 0;
39  else if (r < 0)
40  {
41  THROWERRNO("read(" << sock << "," << count << "-" << done << ")");
42  }
43  else
44  {
45  done += r;
46  buf += r;
47  }
48  }
49  return done;
50 }
51 
52 void get_prediction(int sock, float& res, float& weight)
53 {
55  really_read(sock, &p, sizeof(p));
56  res = p.p;
57  weight = p.weight;
58 }
59 
61 {
62  if (
63 #ifdef _WIN32
64  send(sock, reinterpret_cast<const char*>(&p), sizeof(p), 0)
65 #else
66  write(sock, &p, sizeof(p))
67 #endif
68  < (int)sizeof(p))
69  THROWERRNO("send_prediction write(" << sock << ")");
70 }
71 
72 void binary_print_result(int f, float res, float weight, v_array<char>)
73 {
74  if (f >= 0)
75  {
76  global_prediction ps = {res, weight};
77  send_prediction(f, ps);
78  }
79 }
80 
81 int print_tag(std::stringstream& ss, v_array<char> tag)
82 {
83  if (tag.begin() != tag.end())
84  {
85  ss << ' ';
86  ss.write(tag.begin(), sizeof(char) * tag.size());
87  }
88  return tag.begin() != tag.end();
89 }
90 
91 void print_result(int f, float res, float, v_array<char> tag)
92 {
93  if (f >= 0)
94  {
95  std::stringstream ss;
96  auto saved_precision = ss.precision();
97  if (floorf(res) == res)
98  ss << std::setprecision(0);
99  ss << std::fixed << res << std::setprecision(saved_precision);
100  print_tag(ss, tag);
101  ss << '\n';
102  ssize_t len = ss.str().size();
103  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
104  if (t != len)
105  {
106  std::cerr << "write error: " << strerror(errno) << std::endl;
107  }
108  }
109 }
110 
111 void print_raw_text(int f, std::string s, v_array<char> tag)
112 {
113  if (f < 0)
114  return;
115 
116  std::stringstream ss;
117  ss << s;
118  print_tag(ss, tag);
119  ss << '\n';
120  ssize_t len = ss.str().size();
121  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
122  if (t != len)
123  {
124  std::cerr << "write error: " << strerror(errno) << std::endl;
125  }
126 }
127 
128 void set_mm(shared_data* sd, float label)
129 {
130  sd->min_label = std::min(sd->min_label, label);
131  if (label != FLT_MAX)
132  sd->max_label = std::max(sd->max_label, label);
133 }
134 
135 void noop_mm(shared_data*, float) {}
136 
138 {
139  if (l->is_multiline)
140  THROW("This reduction does not support single-line examples.");
141 
142  if (ec.test_only || !training)
144  else
146 }
147 
149 {
150  if (!l->is_multiline)
151  THROW("This reduction does not support multi-line example.");
152 
153  if (!training)
155  else
157 }
158 
160 {
161  if (l->is_multiline)
162  THROW("This reduction does not support single-line examples.");
163 
165 }
166 
168 {
169  if (!l->is_multiline)
170  THROW("This reduction does not support multi-line example.");
171 
173 }
174 
176 {
177  if (l->is_multiline)
178  THROW("This reduction does not support single-line examples.");
179 
180  LEARNER::as_singleline(l)->finish_example(*this, ec);
181 }
182 
184 {
185  if (!l->is_multiline)
186  THROW("This reduction does not support multi-line example.");
187 
188  LEARNER::as_multiline(l)->finish_example(*this, ec);
189 }
190 
192  std::vector<std::string> grams, std::array<uint32_t, NUM_NAMESPACES>& dest, char* descriptor, bool quiet)
193 {
194  for (size_t i = 0; i < grams.size(); i++)
195  {
196  std::string ngram = grams[i];
197  if (isdigit(ngram[0]))
198  {
199  int n = atoi(ngram.c_str());
200  if (!quiet)
201  std::cerr << "Generating " << n << "-" << descriptor << " for all namespaces." << std::endl;
202  for (size_t j = 0; j < 256; j++) dest[j] = n;
203  }
204  else if (ngram.size() == 1)
205  std::cout << "You must specify the namespace index before the n" << std::endl;
206  else
207  {
208  int n = atoi(ngram.c_str() + 1);
209  dest[(uint32_t)(unsigned char)*ngram.c_str()] = n;
210  if (!quiet)
211  std::cerr << "Generating " << n << "-" << descriptor << " for " << ngram[0] << " namespaces." << std::endl;
212  }
213  }
214 }
215 
216 void compile_limits(std::vector<std::string> limits, std::array<uint32_t, NUM_NAMESPACES>& dest, bool quiet)
217 {
218  for (size_t i = 0; i < limits.size(); i++)
219  {
220  std::string limit = limits[i];
221  if (isdigit(limit[0]))
222  {
223  int n = atoi(limit.c_str());
224  if (!quiet)
225  std::cerr << "limiting to " << n << "features for each namespace." << std::endl;
226  for (size_t j = 0; j < 256; j++) dest[j] = n;
227  }
228  else if (limit.size() == 1)
229  std::cout << "You must specify the namespace index before the n" << std::endl;
230  else
231  {
232  int n = atoi(limit.c_str() + 1);
233  dest[(uint32_t)limit[0]] = n;
234  if (!quiet)
235  std::cerr << "limiting to " << n << " for namespaces " << limit[0] << std::endl;
236  }
237  }
238 }
239 
240 void trace_listener_cerr(void*, const std::string& message)
241 {
242  std::cerr << message;
243  std::cerr.flush();
244 }
245 
247 {
248  int ret = std::stringbuf::sync();
249  if (ret)
250  return ret;
251 
252  parent.trace_listener(parent.trace_context, str());
253  str("");
254  return 0; // success
255 }
256 
257 vw_ostream::vw_ostream() : std::ostream(&buf), buf(*this), trace_context(nullptr)
258 {
260 }
261 
263 {
264  sd = &calloc_or_throw<shared_data>();
265  sd->dump_interval = 1.; // next update progress dump
266  sd->contraction = 1.;
267  sd->first_observed_label = FLT_MAX;
268  sd->is_more_than_two_labels_observed = false;
269  sd->max_label = 0;
270  sd->min_label = 0;
271 
273 
274  l = nullptr;
275  scorer = nullptr;
276  cost_sensitive = nullptr;
277  loss = nullptr;
278  p = nullptr;
279 
280  reg_mode = 0;
281  current_pass = 0;
282 
283  data_filename = "";
284  delete_prediction = nullptr;
285 
286  bfgs = false;
287  no_bias = false;
288  hessian_on = false;
289  active = false;
290  num_bits = 18;
291  default_bits = true;
292  daemon = false;
293  num_children = 10;
294  save_resume = false;
295  preserve_performance_counters = false;
296 
297  random_positive_weights = false;
298 
299  weights.sparse = false;
300 
301  set_minmax = set_mm;
302 
303  power_t = 0.5;
304  eta = 0.5; // default learning rate for normalized adaptive updates, this is switched to 10 by default for the other
305  // updates (see parse_args.cc)
306  numpasses = 1;
307 
308  final_prediction_sink.begin() = final_prediction_sink.end() = final_prediction_sink.end_array = nullptr;
309  raw_prediction = -1;
311  print_text = print_raw_text;
312  lda = 0;
313  random_seed = 0;
314  random_weights = false;
315  normal_weights = false;
316  tnormal_weights = false;
317  per_feature_regularizer_input = "";
318  per_feature_regularizer_output = "";
319  per_feature_regularizer_text = "";
320 
321 #ifdef _WIN32
322  stdout_fileno = _fileno(stdout);
323 #else
324  stdout_fileno = fileno(stdout);
325 #endif
326 
327  searchstr = nullptr;
328 
329  nonormalize = false;
330  l1_lambda = 0.0;
331  l2_lambda = 0.0;
332 
333  eta_decay_rate = 1.0;
334  initial_weight = 0.0;
335  initial_constant = 0.0;
336 
337  all_reduce = nullptr;
338 
339  for (size_t i = 0; i < 256; i++)
340  {
341  ngram[i] = 0;
342  skips[i] = 0;
343  limit[i] = INT_MAX;
344  affix_features[i] = 0;
345  spelling_features[i] = 0;
346  }
347 
348  invariant_updates = true;
349  normalized_idx = 2;
350 
351  add_constant = true;
352  audit = false;
353 
354  pass_length = std::numeric_limits<size_t>::max();
355  passes_complete = 0;
356 
357  save_per_pass = false;
358 
359  stdin_off = false;
360  do_reset_source = false;
361  holdout_set_off = true;
362  holdout_after = 0;
363  check_holdout_every_n_passes = 1;
364  early_terminate = false;
365 
366  max_examples = std::numeric_limits<size_t>::max();
367 
368  hash_inv = false;
369  print_invert = false;
370 
371  // Set by the '--progress <arg>' option and affect sd->dump_interval
372  progress_add = false; // default is multiplicative progress dumps
373  progress_arg = 2.0; // next update progress dump multiplier
374 
375  sd->is_more_than_two_labels_observed = false;
376  sd->first_observed_label = FLT_MAX;
377  sd->second_observed_label = FLT_MAX;
378 
379  sd->report_multiclass_log_loss = false;
380  sd->multiclass_log_loss = 0;
381  sd->holdout_multiclass_log_loss = 0;
382 }
void finish_example(vw &all, E &ec)
Definition: learner.h:302
void learn(example &)
Definition: global_data.cc:137
Definition: scorer.cc:8
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void print_result(int f, float res, float, v_array< char > tag)
Definition: global_data.cc:91
Definition: active.h:6
void get_prediction(int sock, float &res, float &weight)
Definition: global_data.cc:52
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
void trace_listener_cerr(void *, const std::string &message)
Definition: global_data.cc:240
void compile_gram(std::vector< std::string > grams, std::array< uint32_t, NUM_NAMESPACES > &dest, char *descriptor, bool quiet)
Definition: global_data.cc:191
void binary_print_result(int f, float res, float weight, v_array< char >)
Definition: global_data.cc:72
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
Definition: bfgs.cc:62
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
void finish_example(example &)
Definition: global_data.cc:175
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void noop_mm(shared_data *, float)
Definition: global_data.cc:135
void compile_limits(std::vector< std::string > limits, std::array< uint32_t, NUM_NAMESPACES > &dest, bool quiet)
Definition: global_data.cc:216
#define THROWERRNO(args)
Definition: vw_exception.h:167
T *& end()
Definition: v_array.h:43
void predict(example &)
Definition: global_data.cc:159
float weight
std::vector< example * > multi_ex
Definition: example.h:122
float min_label
Definition: global_data.h:150
void send_prediction(int sock, global_prediction p)
Definition: global_data.cc:60
void print_raw_text(int f, std::string s, v_array< char > tag)
Definition: global_data.cc:111
float max_label
Definition: global_data.h:151
void all_reduce(vw &all, T *buffer, const size_t n)
Definition: vw_allreduce.h:13
Definition: print.cc:9
trace_message_t trace_listener
size_t really_read(int sock, void *in, size_t count)
Definition: global_data.cc:24
void learn(E &ec, size_t i=0)
Definition: learner.h:160
vw()
Definition: global_data.cc:262
#define THROW(args)
Definition: vw_exception.h:181
void set_mm(shared_data *sd, float label)
Definition: global_data.cc:128
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
Definition: lda_core.cc:60
bool test_only
Definition: example.h:76