Vowpal Wabbit
parse_regressor.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <fstream>
7 #include <iostream>
8 
9 #include "crossplat_compat.h"
10 
11 #ifndef _WIN32
12 #include <unistd.h>
13 #endif
14 
15 #include <stdlib.h>
16 #include <stdint.h>
17 #include <math.h>
18 #include <cmath>
19 #include <algorithm>
20 #include <stdarg.h>
21 #include <numeric>
22 #include "rand48.h"
23 #include "global_data.h"
24 #include "vw_exception.h"
25 #include "vw_validate.h"
26 #include "vw_versions.h"
27 
29 
30 template <class T>
32 {
33  public:
34  static void func(weight& w, float& initial, uint64_t /* index */) { w = initial; }
35 };
36 
37 template <class T>
39 {
40  public:
41  static void func(weight& w, uint64_t index) { w = (float)(0.1 * merand48(index)); }
42 };
43 
44 template <class T>
46 {
47  public:
48  static void func(weight& w, uint64_t index) { w = (float)(merand48(index) - 0.5); }
49 };
50 // box-muller polar implementation
51 template <class T>
53 {
54  public:
55  static void func(weight& w, uint64_t index)
56  {
57  static float x1 = 0.0;
58  static float x2 = 0.0;
59  static float temp = 0.0;
60  do
61  {
62  x1 = 2.0f * merand48(index) - 1.0f;
63  x2 = 2.0f * merand48(index) - 1.0f;
64  temp = x1 * x1 + x2 * x2;
65  } while ((temp >= 1.0) || (temp == 0.0));
66  temp = sqrtf((-2.0f * logf(temp)) / temp);
67  w = x1 * temp;
68  }
69 };
70 // re-scaling to re-picking values outside the truncating boundary.
71 // note:- boundary is twice the standard deviation.
72 template <class T>
73 void truncate(vw& all, T& weights)
74 {
75  static double sd = calculate_sd(all, weights);
76  std::for_each(weights.begin(), weights.end(), [](float& v) {
77  if (std::fabs(v) > sd * 2)
78  {
79  v = (float)std::remainder(static_cast<double>(v), sd * 2);
80  }
81  });
82 }
83 
84 template <class T>
85 double calculate_sd(vw& /* all */, T& weights)
86 {
87  static int my_size = 0;
88  std::for_each(weights.begin(), weights.end(), [](float /* v */) { my_size += 1; });
89  double sum = std::accumulate(weights.begin(), weights.end(), 0.0);
90  double mean = sum / my_size;
91  std::vector<double> diff(my_size);
92  std::transform(weights.begin(), weights.end(), diff.begin(), [mean](double x) { return x - mean; });
93  double sq_sum = inner_product(diff.begin(), diff.end(), diff.begin(), 0.0);
94  return std::sqrt(sq_sum / my_size);
95 }
96 template <class T>
97 void initialize_regressor(vw& all, T& weights)
98 {
99  // Regressor is already initialized.
100 
101  if (weights.not_null())
102  return;
103  size_t length = ((size_t)1) << all.num_bits;
104  try
105  {
106  uint32_t ss = weights.stride_shift();
107  weights.~T(); // dealloc so that we can realloc, now with a known size
108  new (&weights) T(length, ss);
109  }
110  catch (const VW::vw_exception&)
111  {
112  THROW(" Failed to allocate weight array with " << all.num_bits << " bits: try decreasing -b <bits>");
113  }
114  if (weights.mask() == 0)
115  {
116  THROW(" Failed to allocate weight array with " << all.num_bits << " bits: try decreasing -b <bits>");
117  }
118  else if (all.initial_weight != 0.)
119  weights.template set_default<float, set_initial_wrapper<T> >(all.initial_weight);
120  else if (all.random_positive_weights)
121  weights.template set_default<random_positive_wrapper<T> >();
122  else if (all.random_weights)
123  weights.template set_default<random_weights_wrapper<T> >();
124  else if (all.normal_weights)
125  {
126  weights.template set_default<polar_normal_weights_wrapper<T> >();
127  }
128  else if (all.tnormal_weights)
129  {
130  weights.template set_default<polar_normal_weights_wrapper<T> >();
131  truncate(all, weights);
132  }
133 }
134 
136 {
137  if (all.weights.sparse)
139  else
141 }
142 
143 constexpr size_t default_buf_size = 512;
144 
145 bool resize_buf_if_needed(char*& __dest, size_t& __dest_size, const size_t __n)
146 {
147  char* new_dest;
148  if (__dest_size < __n)
149  {
150  if ((new_dest = (char*)realloc(__dest, __n)) == NULL)
151  THROW("Can't realloc enough memory.")
152  else
153  {
154  __dest = new_dest;
155  __dest_size = __n;
156  return true;
157  }
158  }
159  return false;
160 }
161 
162 inline void safe_memcpy(char*& __dest, size_t& __dest_size, const void* __src, size_t __n)
163 {
164  resize_buf_if_needed(__dest, __dest_size, __n);
165  memcpy(__dest, __src, __n);
166 }
167 
168 // file_options will be written to when reading
170  vw& all, io_buf& model_file, bool read, bool text, std::string& file_options, VW::config::options_i& options)
171 {
172  char* buff2 = (char*)malloc(default_buf_size);
173  size_t buf2_size = default_buf_size;
174 
175  try
176  {
177  if (model_file.files.size() > 0)
178  {
179  size_t bytes_read_write = 0;
180 
181  size_t v_length = (uint32_t)VW::version.to_string().length() + 1;
182  std::stringstream msg;
183  msg << "Version " << VW::version.to_string() << "\n";
184  memcpy(buff2, VW::version.to_string().c_str(), std::min(v_length, buf2_size));
185  if (read)
186  {
187  v_length = buf2_size;
188  buff2[std::min(v_length, default_buf_size) - 1] = '\0';
189  }
190  bytes_read_write += bin_text_read_write(model_file, buff2, v_length, "", read, msg, text);
191  all.model_file_ver = buff2; // stored in all to check save_resume fix in gd
193 
195  model_file.verify_hash(true);
196 
198  {
199  v_length = all.id.length() + 1;
200 
201  msg << "Id " << all.id << "\n";
202  memcpy(buff2, all.id.c_str(), std::min(v_length, default_buf_size));
203  if (read)
204  v_length = default_buf_size;
205  bytes_read_write += bin_text_read_write(model_file, buff2, v_length, "", read, msg, text);
206  all.id = buff2;
207 
208  if (read && !options.was_supplied("id") && !all.id.empty())
209  {
210  file_options += " --id";
211  file_options += " " + all.id;
212  }
213  }
214 
215  char model = 'm';
216 
217  bytes_read_write +=
218  bin_text_read_write_fixed_validated(model_file, &model, 1, "file is not a model file", read, msg, text);
219 
220  msg << "Min label:" << all.sd->min_label << "\n";
221  bytes_read_write += bin_text_read_write_fixed_validated(
222  model_file, (char*)&all.sd->min_label, sizeof(all.sd->min_label), "", read, msg, text);
223 
224  msg << "Max label:" << all.sd->max_label << "\n";
225  bytes_read_write += bin_text_read_write_fixed_validated(
226  model_file, (char*)&all.sd->max_label, sizeof(all.sd->max_label), "", read, msg, text);
227 
228  msg << "bits:" << all.num_bits << "\n";
229  uint32_t local_num_bits = all.num_bits;
230  bytes_read_write += bin_text_read_write_fixed_validated(
231  model_file, (char*)&local_num_bits, sizeof(local_num_bits), "", read, msg, text);
232 
233  if (read && !options.was_supplied("bit_precision"))
234  {
235  file_options += " --bit_precision";
236  std::stringstream temp;
237  temp << local_num_bits;
238  file_options += " " + temp.str();
239  }
240 
241  VW::validate_default_bits(all, local_num_bits);
242 
243  all.default_bits = false;
244  all.num_bits = local_num_bits;
245 
247 
249  {
250  // -q, --cubic and --interactions are not saved in vw::file_options
251  uint32_t pair_len = (uint32_t)all.pairs.size();
252 
253  msg << pair_len << " pairs: ";
254  bytes_read_write +=
255  bin_text_read_write_fixed_validated(model_file, (char*)&pair_len, sizeof(pair_len), "", read, msg, text);
256 
257  // TODO: validate pairs?
258  for (size_t i = 0; i < pair_len; i++)
259  {
260  char pair[3] = {0, 0, 0};
261 
262  if (!read)
263  {
264  memcpy(pair, all.pairs[i].c_str(), 2);
265  msg << all.pairs[i] << " ";
266  }
267 
268  bytes_read_write += bin_text_read_write_fixed_validated(model_file, pair, 2, "", read, msg, text);
269  if (read)
270  {
271  std::string temp(pair);
272  if (count(all.pairs.begin(), all.pairs.end(), temp) == 0)
273  all.pairs.push_back(temp);
274  }
275  }
276 
277  msg << "\n";
278  bytes_read_write += bin_text_read_write_fixed_validated(model_file, nullptr, 0, "", read, msg, text);
279 
280  uint32_t triple_len = (uint32_t)all.triples.size();
281 
282  msg << triple_len << " triples: ";
283  bytes_read_write += bin_text_read_write_fixed_validated(
284  model_file, (char*)&triple_len, sizeof(triple_len), "", read, msg, text);
285 
286  // TODO: validate triples?
287  for (size_t i = 0; i < triple_len; i++)
288  {
289  char triple[4] = {0, 0, 0, 0};
290 
291  if (!read)
292  {
293  msg << all.triples[i] << " ";
294  memcpy(triple, all.triples[i].c_str(), 3);
295  }
296  bytes_read_write += bin_text_read_write_fixed_validated(model_file, triple, 3, "", read, msg, text);
297  if (read)
298  {
299  std::string temp(triple);
300  if (count(all.triples.begin(), all.triples.end(), temp) == 0)
301  all.triples.push_back(temp);
302  }
303  }
304 
305  msg << "\n";
306  bytes_read_write += bin_text_read_write_fixed_validated(model_file, nullptr, 0, "", read, msg, text);
307 
308  if (all.model_file_ver >=
309  VERSION_FILE_WITH_INTERACTIONS) // && < VERSION_FILE_WITH_INTERACTIONS_IN_FO (previous if)
310  {
311  // the only version that saves interacions among pairs and triples
312  uint32_t len = (uint32_t)all.interactions.size();
313 
314  msg << len << " interactions: ";
315  bytes_read_write +=
316  bin_text_read_write_fixed_validated(model_file, (char*)&len, sizeof(len), "", read, msg, text);
317 
318  for (size_t i = 0; i < len; i++)
319  {
320  uint32_t inter_len = 0;
321  if (!read)
322  {
323  inter_len = (uint32_t)all.interactions[i].size();
324  msg << "len: " << inter_len << " ";
325  }
326  bytes_read_write += bin_text_read_write_fixed_validated(
327  model_file, (char*)&inter_len, sizeof(inter_len), "", read, msg, text);
328  if (!read)
329  {
330  memcpy(buff2, all.interactions[i].c_str(), inter_len);
331 
332  msg << "interaction: ";
333  msg.write(all.interactions[i].c_str(), inter_len);
334  }
335 
336  bytes_read_write += bin_text_read_write_fixed_validated(model_file, buff2, inter_len, "", read, msg, text);
337 
338  if (read)
339  {
340  std::string temp(buff2, inter_len);
341  all.interactions.push_back(temp);
342  }
343  }
344 
345  msg << "\n";
346  bytes_read_write += bin_text_read_write_fixed_validated(model_file, nullptr, 0, "", read, msg, text);
347  }
348  else // < VERSION_FILE_WITH_INTERACTIONS
349  {
350  // pairs and triples may be restored but not reflected in interactions
351  all.interactions.insert(std::end(all.interactions), std::begin(all.pairs), std::end(all.pairs));
352  all.interactions.insert(std::end(all.interactions), std::begin(all.triples), std::end(all.triples));
353  }
354  }
355 
357  {
358  // to fix compatibility that was broken in 7.9
359  uint32_t rank = 0;
360  msg << "rank:" << rank << "\n";
361  bytes_read_write +=
362  bin_text_read_write_fixed_validated(model_file, (char*)&rank, sizeof(rank), "", read, msg, text);
363  if (rank != 0)
364  {
365  if (!options.was_supplied("rank"))
366  {
367  file_options += " --rank";
368  std::stringstream temp;
369  temp << rank;
370  file_options += " " + temp.str();
371  }
372  else
373  all.trace_message << "WARNING: this model file contains 'rank: " << rank
374  << "' value but it will be ignored as another value specified via the command line."
375  << std::endl;
376  }
377  }
378 
379  msg << "lda:" << all.lda << "\n";
380  bytes_read_write +=
381  bin_text_read_write_fixed_validated(model_file, (char*)&all.lda, sizeof(all.lda), "", read, msg, text);
382 
383  // TODO: validate ngram_len?
384  uint32_t ngram_len = (uint32_t)all.ngram_strings.size();
385  msg << ngram_len << " ngram:";
386  bytes_read_write +=
387  bin_text_read_write_fixed_validated(model_file, (char*)&ngram_len, sizeof(ngram_len), "", read, msg, text);
388  for (size_t i = 0; i < ngram_len; i++)
389  {
390  // have '\0' at the end for sure
391  char ngram[4] = {0, 0, 0, 0};
392  if (!read)
393  {
394  msg << all.ngram_strings[i] << " ";
395  memcpy(ngram, all.ngram_strings[i].c_str(), std::min(static_cast<size_t>(3), all.ngram_strings[i].size()));
396  }
397  bytes_read_write += bin_text_read_write_fixed_validated(model_file, ngram, 3, "", read, msg, text);
398  if (read)
399  {
400  std::string temp(ngram);
401  all.ngram_strings.push_back(temp);
402 
403  file_options += " --ngram";
404  file_options += " " + temp;
405  }
406  }
407 
408  msg << "\n";
409  bytes_read_write += bin_text_read_write_fixed_validated(model_file, nullptr, 0, "", read, msg, text);
410 
411  // TODO: validate skips?
412  uint32_t skip_len = (uint32_t)all.skip_strings.size();
413  msg << skip_len << " skip:";
414  bytes_read_write +=
415  bin_text_read_write_fixed_validated(model_file, (char*)&skip_len, sizeof(skip_len), "", read, msg, text);
416 
417  for (size_t i = 0; i < skip_len; i++)
418  {
419  char skip[4] = {0, 0, 0, 0};
420  if (!read)
421  {
422  msg << all.skip_strings[i] << " ";
423  memcpy(skip, all.skip_strings[i].c_str(), std::min(static_cast<size_t>(3), all.skip_strings[i].size()));
424  }
425 
426  bytes_read_write += bin_text_read_write_fixed_validated(model_file, skip, 3, "", read, msg, text);
427  if (read)
428  {
429  std::string temp(skip);
430  all.skip_strings.push_back(temp);
431 
432  file_options += " --skips";
433  file_options += " " + temp;
434  }
435  }
436  msg << "\n";
437  bytes_read_write += bin_text_read_write_fixed_validated(model_file, nullptr, 0, "", read, msg, text);
438 
439  if (read)
440  {
441  uint32_t len;
442  size_t ret = model_file.bin_read_fixed((char*)&len, sizeof(len), "");
443  if (len > 104857600 /*sanity check: 100 Mb*/ || ret < sizeof(uint32_t))
444  THROW("bad model format!");
445  resize_buf_if_needed(buff2, buf2_size, len);
446  bytes_read_write += model_file.bin_read_fixed(buff2, len, "") + ret;
447 
448  // Write out file options to caller.
449  if (len > 0)
450  {
451  // There is a potential bug here if len is read out to be zero (e.g. corrupted file). If we naively
452  // append buff2 into file_options it might contain old information and thus be invalid. Before, what
453  // probably happened is boost::program_options did the right thing, but now we have to construct the
454  // input to it where we do not know whether a particular option key can have multiple values or not.
455  //
456  // In some cases we end up with a std::string like: "--bit_precision 18 <something_not_an_int>", which will
457  // cause a "bad program options value" exception, rather than the true "file is corrupted" issue. Only
458  // pushing the contents of buff2 into file_options when it is valid will prevent this false error.
459  file_options = file_options + " " + buff2;
460  }
461  }
462  else
463  {
465  for (auto const& option : options.get_all_options())
466  {
467  if (option->m_keep && options.was_supplied(option->m_name))
468  {
469  serializer.add(*option);
470  }
471  }
472 
473  auto serialized_keep_options = serializer.str();
474 
475  // We need to save our current PRG state
476  if (all.save_resume && all.get_random_state()->get_current_state() != 0)
477  {
478  serialized_keep_options += " --random_seed";
479  serialized_keep_options += " " + std::to_string(all.get_random_state()->get_current_state());
480  }
481 
482  msg << "options:" << serialized_keep_options << "\n";
483 
484  uint32_t len = (uint32_t)serialized_keep_options.length();
485  if (len > 0)
486  safe_memcpy(buff2, buf2_size, serialized_keep_options.c_str(), len + 1);
487  *(buff2 + len) = 0;
488  bytes_read_write += bin_text_read_write(model_file, buff2, len + 1, // len+1 to write a \0
489  "", read, msg, text);
490  }
491 
492  // Read/write checksum if required by version
494  {
495  uint32_t check_sum = (all.model_file_ver >= VERSION_FILE_WITH_HEADER_CHAINED_HASH)
496  ? model_file.hash()
497  : (uint32_t)uniform_hash(model_file.space.begin(), bytes_read_write, 0);
498 
499  uint32_t check_sum_saved = check_sum;
500 
501  msg << "Checksum: " << check_sum << "\n";
502  bin_text_read_write(model_file, (char*)&check_sum, sizeof(check_sum), "", read, msg, text);
503 
504  if (check_sum_saved != check_sum)
505  THROW("Checksum is inconsistent, file is possibly corrupted.");
506  }
507 
509  {
510  model_file.verify_hash(false);
511  }
512  }
513  }
514  catch (...)
515  {
516  free(buff2);
517  throw;
518  }
519 
520  free(buff2);
521 }
522 
523 void dump_regressor(vw& all, io_buf& buf, bool as_text)
524 {
525  std::string unused;
526  save_load_header(all, buf, false, as_text, unused, *all.options);
527  if (all.l != nullptr)
528  all.l->save_load(buf, false, as_text);
529 
530  buf.flush(); // close_file() should do this for me ...
531  buf.close_file();
532 }
533 
534 void dump_regressor(vw& all, std::string reg_name, bool as_text)
535 {
536  if (reg_name == std::string(""))
537  return;
538  std::string start_name = reg_name + std::string(".writing");
539  io_buf io_temp;
540 
541  io_temp.open_file(start_name.c_str(), all.stdin_off, io_buf::WRITE);
542 
543  dump_regressor(all, io_temp, as_text);
544 
545  remove(reg_name.c_str());
546 
547  if (0 != rename(start_name.c_str(), reg_name.c_str()))
548  THROW("WARN: dump_regressor(vw& all, std::string reg_name, bool as_text): cannot rename: "
549  << start_name.c_str() << " to " << reg_name.c_str());
550 }
551 
552 void save_predictor(vw& all, std::string reg_name, size_t current_pass)
553 {
554  std::stringstream filename;
555  filename << reg_name;
556  if (all.save_per_pass)
557  filename << "." << current_pass;
558  dump_regressor(all, filename.str(), false);
559 }
560 
561 void finalize_regressor(vw& all, std::string reg_name)
562 {
563  if (!all.early_terminate)
564  {
565  if (all.per_feature_regularizer_output.length() > 0)
567  else
568  dump_regressor(all, reg_name, false);
569  if (all.per_feature_regularizer_text.length() > 0)
571  else
572  {
573  dump_regressor(all, all.text_regressor_name, true);
574  all.print_invert = true;
575  dump_regressor(all, all.inv_hash_regressor_name, true);
576  all.print_invert = false;
577  }
578  }
579 }
580 
581 void read_regressor_file(vw& all, std::vector<std::string> all_intial, io_buf& io_temp)
582 {
583  if (all_intial.size() > 0)
584  {
585  io_temp.open_file(all_intial[0].c_str(), all.stdin_off, io_buf::READ);
586  if (!all.quiet)
587  {
588  // all.trace_message << "initial_regressor = " << regs[0] << std::endl;
589  if (all_intial.size() > 1)
590  {
591  all.trace_message << "warning: ignoring remaining " << (all_intial.size() - 1) << " initial regressors"
592  << std::endl;
593  }
594  }
595  }
596 }
597 
598 void parse_mask_regressor_args(vw& all, std::string feature_mask, std::vector<std::string> initial_regressors)
599 {
600  // TODO does this extra check need to be used? I think it is duplicated but there may be some logic I am missing.
601  std::string file_options;
602  if (!feature_mask.empty())
603  {
604  if (initial_regressors.size() > 0)
605  {
606  if (feature_mask == initial_regressors[0]) //-i and -mask are from same file, just generate mask
607  {
608  return;
609  }
610  }
611 
612  // all other cases, including from different file, or -i does not exist, need to read in the mask file
613  io_buf io_temp_mask;
614  io_temp_mask.open_file(feature_mask.c_str(), false, io_buf::READ);
615  save_load_header(all, io_temp_mask, true, false, file_options, *all.options);
616  all.l->save_load(io_temp_mask, true, false);
617  io_temp_mask.close_file();
618 
619  // Deal with the over-written header from initial regressor
620  if (initial_regressors.size() > 0)
621  {
622  // Load original header again.
623  io_buf io_temp;
624  io_temp.open_file(initial_regressors[0].c_str(), false, io_buf::READ);
625  save_load_header(all, io_temp, true, false, file_options, *all.options);
626  io_temp.close_file();
627 
628  // Re-zero the weights, in case weights of initial regressor use different indices
629  all.weights.set_zero(0);
630  }
631  else
632  {
633  // If no initial regressor, just clear out the options loaded from the header.
634  // TODO clear file options
635  // all.opts_n_args.file_options.str("");
636  }
637  }
638 }
639 
640 namespace VW
641 {
642 void save_predictor(vw& all, std::string reg_name) { dump_regressor(all, reg_name, false); }
643 
644 void save_predictor(vw& all, io_buf& buf) { dump_regressor(all, buf, false); }
645 } // namespace VW
std::vector< std::string > skip_strings
Definition: global_data.h:470
void save_load(io_buf &io, const bool read, const bool text)
Definition: learner.h:251
#define VERSION_FILE_WITH_HEADER_ID
Definition: vw_versions.h:19
parameters weights
Definition: global_data.h:537
#define VERSION_FILE_WITH_INTERACTIONS_IN_FO
Definition: vw_versions.h:14
void accumulate(vw &all, parameters &weights, size_t offset)
Definition: accumulate.cc:20
bool tnormal_weights
Definition: global_data.h:495
static void func(weight &w, float &initial, uint64_t)
void initialize_regressor(vw &all, T &weights)
std::vector< std::string > pairs
Definition: global_data.h:459
void truncate(vw &all, T &weights)
VW::config::options_i * options
Definition: global_data.h:428
static constexpr int WRITE
Definition: io_buf.h:72
virtual bool close_file()
Definition: io_buf.h:204
std::vector< std::string > ngram_strings
Definition: global_data.h:469
bool random_positive_weights
Definition: global_data.h:493
#define VERSION_FILE_WITH_RANK_IN_HEADER
Definition: vw_versions.h:12
float initial_weight
Definition: global_data.h:409
std::string inv_hash_regressor_name
Definition: global_data.h:511
#define VERSION_FILE_WITH_INTERACTIONS
Definition: vw_versions.h:13
size_t bin_text_read_write_fixed_validated(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:335
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
size_t bin_text_read_write(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:304
bool quiet
Definition: global_data.h:487
void finalize_regressor(vw &all, std::string reg_name)
float merand48(uint64_t &initial)
Definition: rand48.cc:16
#define VERSION_FILE_WITH_HEADER_CHAINED_HASH
Definition: vw_versions.h:17
void validate_version(vw &all)
Definition: vw_validate.cc:12
uint32_t num_bits
Definition: global_data.h:398
T *& begin()
Definition: v_array.h:42
const version_struct version(PACKAGE_VERSION)
size_t size() const
Definition: v_array.h:68
static constexpr int READ
Definition: io_buf.h:71
uint32_t hash()
Definition: io_buf.h:83
void save_predictor(vw &all, std::string reg_name, size_t current_pass)
virtual std::vector< std::shared_ptr< base_option > > get_all_options()=0
uint32_t lda
Definition: global_data.h:508
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
virtual int open_file(const char *name, bool stdin_off)
Definition: io_buf.h:90
size_t bin_read_fixed(char *data, size_t len, const char *read_message)
Definition: io_buf.h:230
std::string id
Definition: global_data.h:417
virtual void flush()
Definition: io_buf.h:194
shared_data * sd
Definition: global_data.h:375
void safe_memcpy(char *&__dest, size_t &__dest_size, const void *__src, size_t __n)
VW::version_struct model_file_ver
Definition: global_data.h:419
v_array< int > files
Definition: io_buf.h:64
static void func(weight &w, uint64_t index)
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
std::string per_feature_regularizer_output
Definition: global_data.h:441
void save_load_header(vw &all, io_buf &model_file, bool read, bool text, std::string &file_options, VW::config::options_i &options)
#define VERSION_FILE_WITH_HEADER_HASH
Definition: vw_versions.h:16
bool random_weights
Definition: global_data.h:492
std::string to_string() const
Definition: version.cc:97
dense_parameters dense_weights
void verify_hash(bool verify)
Definition: io_buf.h:74
virtual void add(base_option &option) override
void parse_mask_regressor_args(vw &all, std::string feature_mask, std::vector< std::string > initial_regressors)
bool print_invert
Definition: global_data.h:542
bool default_bits
Definition: global_data.h:399
void set_zero(size_t offset)
Definition: io_buf.h:54
bool resize_buf_if_needed(char *&__dest, size_t &__dest_size, const size_t __n)
v_array< char > space
Definition: io_buf.h:62
double calculate_sd(vw &, T &weights)
std::string per_feature_regularizer_text
Definition: global_data.h:442
std::vector< std::string > triples
Definition: global_data.h:461
float weight
bool save_per_pass
Definition: global_data.h:408
float min_label
Definition: global_data.h:150
sparse_parameters sparse_weights
Definition: autolink.cc:11
void read_regressor_file(vw &all, std::vector< std::string > all_intial, io_buf &io_temp)
std::vector< std::string > interactions
Definition: global_data.h:457
LEARNER::base_learner * l
Definition: global_data.h:383
bool save_resume
Definition: global_data.h:415
float max_label
Definition: global_data.h:151
constexpr size_t default_buf_size
bool stdin_off
Definition: global_data.h:527
void dump_regressor(vw &all, io_buf &buf, bool as_text)
bool early_terminate
Definition: global_data.h:500
void validate_num_bits(vw &all)
Definition: vw_validate.cc:32
static void func(weight &w, uint64_t index)
#define THROW(args)
Definition: vw_exception.h:181
bool normal_weights
Definition: global_data.h:494
float f
Definition: cache.cc:40
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
std::string text_regressor_name
Definition: global_data.h:510
void validate_default_bits(vw &all, uint32_t local_num_bits)
Definition: vw_validate.cc:26
static void func(weight &w, uint64_t index)