Vowpal Wabbit
model_parser.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <memory>
4 #include <cctype>
5 #include <string>
6 
7 #include "vw_slim_return_codes.h"
8 #include "hash.h"
9 
10 // #define MODEL_PARSER_DEBUG
11 
12 #ifdef MODEL_PARSER_DEBUG
13 #include <iostream>
14 #include <iomanip>
15 #include <fstream>
16 #endif
17 
18 namespace vw_slim
19 {
21 {
22  const char* _model_begin;
23  const char* _model;
24  const char* _model_end;
25  uint32_t _checksum;
26 
27  public:
28  model_parser(const char* model, size_t length);
29 
30  int read(const char* field_name, size_t field_length, const char** ret);
31 
32  int skip(size_t bytes);
33 
34  const char* position();
35 
36  uint32_t checksum();
37 
38  template <bool compute_checksum>
39  int read_string(const char* field_name, std::string& s)
40  {
41  uint32_t str_len;
42  RETURN_ON_FAIL((read<uint32_t, compute_checksum>("string.len", str_len)));
43 #ifdef MODEL_PARSER_DEBUG
44  {
45  std::fstream log("vwslim-debug.log", std::fstream::app);
46  log << std::setw(18) << field_name << " length " << str_len << std::endl;
47  }
48 #endif
49 
50  // 0 length strings are not valid, need to contain at least \0
51  if (str_len == 0)
53 
54  const char* data;
55  RETURN_ON_FAIL(read(field_name, str_len, &data));
56 
57  s = std::string(data, str_len - 1);
58 #ifdef MODEL_PARSER_DEBUG
59  {
60  std::fstream log("vwslim-debug.log", std::fstream::app);
61  log << std::setw(18) << field_name << " '" << s << '\'' << std::endl;
62  }
63 #endif
64 
65  // calculate checksum
66  if (compute_checksum && str_len > 0)
67  _checksum = (uint32_t)uniform_hash(data, str_len, _checksum);
68 
69  return S_VW_PREDICT_OK;
70  }
71 
72  template <typename T, bool compute_checksum>
73  int read(const char* field_name, T& val)
74  {
75 #ifdef MODEL_PARSER_DEBUG
76  std::fstream log("vwslim-debug.log", std::fstream::app);
77  log << std::setw(18) << field_name << " 0x" << std::hex << std::setw(8) << (uint64_t)_model << "-" << std::hex
78  << std::setw(8) << (uint64_t)_model_end << " " << std::setw(4) << (_model - _model_begin)
79  << " field: " << std::setw(8) << (uint64_t)&val << std::dec;
80 #endif
81 
82  const char* data;
83  RETURN_ON_FAIL(read(field_name, sizeof(T), &data));
84 
85  // avoid alignment issues for 32/64bit types on e.g. Android/ARM
86  memcpy(&val, data, sizeof(T));
87 
88  if (compute_checksum)
89  _checksum = (uint32_t)uniform_hash(&val, sizeof(T), _checksum);
90 
91 #ifdef MODEL_PARSER_DEBUG
92  log << " '" << val << '\'' << std::endl;
93 #endif
94 
95  return S_VW_PREDICT_OK;
96  }
97 
98  // default overload without checksum hashing
99  template <typename T>
100  int read(const char* field_name, T& val)
101  {
102  return read<T, true>(field_name, val);
103  }
104 
105  template <typename T, typename W>
106  int read_weights(std::unique_ptr<W>& weights, uint64_t weight_length)
107  {
108  // weights are excluded from checksum calculation
109  while (_model < _model_end)
110  {
111  T idx;
112  RETURN_ON_FAIL((read<T, false>("gd.weight.index", idx)));
113  if (idx > weight_length)
115 
116  float& w = (*weights)[idx];
117  RETURN_ON_FAIL((read<float, false>("gd.weight.value", w)));
118 
119 #ifdef MODEL_PARSER_DEBUG
120  std::cout << "weight. idx: " << idx << ":" << (*weights)[idx] << std::endl;
121 #endif
122  }
123 
124  return S_VW_PREDICT_OK;
125  }
126 
127  template <typename W>
128  int read_weights(std::unique_ptr<W>& weights, uint32_t num_bits, uint32_t stride_shift)
129  {
130  uint64_t weight_length = (uint64_t)1 << num_bits;
131 
132  weights = std::unique_ptr<W>(new W(weight_length));
133  weights->stride_shift(stride_shift);
134 
135  if (num_bits < 31)
136  {
137  RETURN_ON_FAIL((read_weights<uint32_t, W>(weights, weight_length)));
138  }
139  else
140  {
141  RETURN_ON_FAIL((read_weights<uint64_t, W>(weights, weight_length)));
142  }
143 
144  return S_VW_PREDICT_OK;
145  }
146 };
147 } // namespace vw_slim
#define RETURN_ON_FAIL(stmt)
int read_weights(std::unique_ptr< W > &weights, uint64_t weight_length)
Definition: model_parser.h:106
#define S_VW_PREDICT_OK
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
const char * _model
Definition: model_parser.h:23
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
const char * _model_begin
Definition: model_parser.h:22
int read(const char *field_name, size_t field_length, const char **ret)
Definition: model_parser.cc:19
#define E_VW_PREDICT_ERR_INVALID_MODEL
int skip(size_t bytes)
Definition: model_parser.cc:37
int read(const char *field_name, T &val)
Definition: model_parser.h:73
model_parser(const char *model, size_t length)
Definition: model_parser.cc:5
int read(const char *field_name, T &val)
Definition: model_parser.h:100
int read_weights(std::unique_ptr< W > &weights, uint32_t num_bits, uint32_t stride_shift)
Definition: model_parser.h:128
const char * position()
Definition: model_parser.cc:15
int read_string(const char *field_name, std::string &s)
Definition: model_parser.h:39
const char * _model_end
Definition: model_parser.h:24
#define E_VW_PREDICT_ERR_WEIGHT_INDEX_OUT_OF_RANGE