Vowpal Wabbit
marginal.cc
Go to the documentation of this file.
1 #include <unordered_map>
2 #include "reductions.h"
3 #include "correctedMath.h"
4 
5 using namespace VW::config;
6 
7 namespace MARGINAL
8 {
9 struct expert
10 {
11  float regret;
12  float abs_regret;
13  float weight;
14 };
15 
16 typedef std::pair<double, double> marginal;
17 typedef std::pair<expert, expert> expert_pair;
18 
19 struct data
20 {
23  float decay;
26  bool id_features[256];
27  features temp[256]; // temporary storage when reducing.
28  std::unordered_map<uint64_t, marginal> marginals;
29 
30  // bookkeeping variables for experts
31  bool compete;
32  float feature_pred; // the prediction computed from using all the features
33  float average_pred; // the prediction of the expert
34  float net_weight; // normalizer for expert weights
35  float net_feature_weight; // the net weight on the feature-based expert
36  float alg_loss; // temporary storage for the loss of the current marginal-based predictor
37  std::unordered_map<uint64_t, expert_pair>
38  expert_state; // pair of weights on marginal and feature based predictors, one per marginal feature
39 
40  vw* all;
41 
43  {
44  for (size_t i = 0; i < 256; i++) temp[i].delete_v();
45  }
46 };
47 
48 float get_adanormalhedge_weights(float R, float C)
49 {
50  float Rpos = R > 0 ? R : 0.f;
51  if (C == 0. || Rpos == 0.)
52  return 0;
53  return 2 * Rpos * correctedExp(Rpos * Rpos / (3 * C)) / (3 * C);
54 }
55 
56 template <bool is_learn>
57 void make_marginal(data& sm, example& ec)
58 {
59  uint64_t mask = sm.all->weights.mask();
60  float label = ec.l.simple.label;
61  vw& all = *sm.all;
62  sm.alg_loss = 0.;
63  sm.net_weight = 0.;
64  sm.net_feature_weight = 0.;
65  sm.average_pred = 0.;
66 
67  for (example::iterator i = ec.begin(); i != ec.end(); ++i)
68  {
69  namespace_index n = i.index();
70  if (sm.id_features[n])
71  {
72  std::swap(sm.temp[n], *i);
73  features& f = *i;
74  f.clear();
75  for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
76  {
77  float first_value = j.value();
78  uint64_t first_index = j.index() & mask;
79  if (++j == sm.temp[n].end())
80  {
81  std::cout << "warning: id feature namespace has " << sm.temp[n].size()
82  << " features. Should be a multiple of 2" << std::endl;
83  break;
84  }
85  float second_value = j.value();
86  uint64_t second_index = j.index() & mask;
87  if (first_value != 1. || second_value != 1.)
88  {
89  std::cout << "warning: bad id features, must have value 1." << std::endl;
90  continue;
91  }
92  uint64_t key = second_index + ec.ft_offset;
93  if (sm.marginals.find(key) == sm.marginals.end()) // need to initialize things.
94  {
95  sm.marginals.insert(std::make_pair(key, std::make_pair(sm.initial_numerator, sm.initial_denominator)));
96  if (sm.compete)
97  {
98  expert e = {0, 0, 1.};
99  sm.expert_state.insert(std::make_pair(key, std::make_pair(e, e)));
100  }
101  }
102  float marginal_pred = (float)(sm.marginals[key].first / sm.marginals[key].second);
103  f.push_back(marginal_pred, first_index);
104  if (!sm.temp[n].space_names.empty())
105  f.space_names.push_back(sm.temp[n].space_names[2 * (f.size() - 1)]);
106 
107  if (sm.compete) // compute the prediction from the marginals using the weights
108  {
109  float weight = sm.expert_state[key].first.weight;
110  sm.average_pred += weight * marginal_pred;
111  sm.net_weight += weight;
112  sm.net_feature_weight += sm.expert_state[key].second.weight;
113  if (is_learn)
114  sm.alg_loss += weight * all.loss->getLoss(all.sd, marginal_pred, label);
115  }
116  }
117  }
118  }
119 }
120 
121 void undo_marginal(data& sm, example& ec)
122 {
123  for (example::iterator i = ec.begin(); i != ec.end(); ++i)
124  {
125  namespace_index n = i.index();
126  if (sm.id_features[n])
127  std::swap(sm.temp[n], *i);
128  }
129 }
130 
131 template <bool is_learn>
133 {
134  vw& all = *sm.all;
135  // add in the feature-based expert and normalize,
136  float label = ec.l.simple.label;
137 
138  if (sm.net_weight + sm.net_feature_weight > 0.)
140  else
141  {
142  sm.net_feature_weight = 1.;
143  sm.average_pred = sm.feature_pred;
144  }
145  float inv_weight = 1.0f / (sm.net_weight + sm.net_feature_weight);
146  sm.average_pred *= inv_weight;
147  ec.pred.scalar = sm.average_pred;
149 
150  if (is_learn)
151  {
152  sm.alg_loss += sm.net_feature_weight * all.loss->getLoss(all.sd, sm.feature_pred, label);
153  sm.alg_loss *= inv_weight;
154  }
155 }
156 
158 {
159  vw& all = *sm.all;
160  uint64_t mask = sm.all->weights.mask();
161  float label = ec.l.simple.label;
162  float weight = ec.weight;
163  if (sm.unweighted_marginals)
164  weight = 1.;
165 
166  for (example::iterator i = ec.begin(); i != ec.end(); ++i)
167  {
168  namespace_index n = i.index();
169  if (sm.id_features[n])
170  for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
171  {
172  if (++j == sm.temp[n].end())
173  break;
174 
175  uint64_t second_index = j.index() & mask;
176  uint64_t key = second_index + ec.ft_offset;
177  marginal& m = sm.marginals[key];
178 
179  if (sm.compete) // now update weights, before updating marginals
180  {
181  expert_pair& e = sm.expert_state[key];
182  float regret1 = sm.alg_loss - all.loss->getLoss(all.sd, (float)(m.first / m.second), label);
183  float regret2 = sm.alg_loss - all.loss->getLoss(all.sd, sm.feature_pred, label);
184 
185  e.first.regret += regret1 * weight;
186  e.first.abs_regret += regret1 * regret1 * weight; // fabs(regret1);
187  e.first.weight = get_adanormalhedge_weights(e.first.regret, e.first.abs_regret);
188  e.second.regret += regret2 * weight;
189  e.second.abs_regret += regret2 * regret2 * weight; // fabs(regret2);
190  e.second.weight = get_adanormalhedge_weights(e.second.regret, e.second.abs_regret);
191  }
192 
193  m.first = m.first * (1. - sm.decay) + ec.l.simple.label * weight;
194  m.second = m.second * (1. - sm.decay) + weight;
195  }
196  }
197 }
198 
199 template <bool is_learn>
201 {
202  make_marginal<is_learn>(sm, ec);
203  if (is_learn)
204  if (sm.update_before_learn)
205  {
206  base.predict(ec);
207  float pred = ec.pred.scalar;
208  if (sm.compete)
209  {
210  sm.feature_pred = pred;
211  compute_expert_loss<is_learn>(sm, ec);
212  }
213  undo_marginal(sm, ec);
214  update_marginal(sm, ec); // update features before learning.
215  make_marginal<is_learn>(sm, ec);
216  base.learn(ec);
217  ec.pred.scalar = pred;
218  }
219  else
220  {
221  base.learn(ec);
222  if (sm.compete)
223  {
224  sm.feature_pred = ec.pred.scalar;
225  compute_expert_loss<is_learn>(sm, ec);
226  }
227  update_marginal(sm, ec);
228  }
229  else
230  {
231  base.predict(ec);
232  float pred = ec.pred.scalar;
233  if (sm.compete)
234  {
235  sm.feature_pred = pred;
236  compute_expert_loss<is_learn>(sm, ec);
237  }
238  }
239 
240  // undo marginalization
241  undo_marginal(sm, ec);
242 }
243 
244 void save_load(data& sm, io_buf& io, bool read, bool text)
245 {
246  uint64_t stride_shift = sm.all->weights.stride_shift();
247 
248  if (io.files.size() == 0)
249  return;
250  std::stringstream msg;
251  uint64_t total_size;
252  if (!read)
253  {
254  total_size = (uint64_t)sm.marginals.size();
255  msg << "marginals size = " << total_size << "\n";
256  }
257  bin_text_read_write_fixed_validated(io, (char*)&total_size, sizeof(total_size), "", read, msg, text);
258 
259  auto iter = sm.marginals.begin();
260  for (size_t i = 0; i < total_size; ++i)
261  {
262  uint64_t index;
263  if (!read)
264  {
265  index = iter->first >> stride_shift;
266  msg << index << ":";
267  }
268  bin_text_read_write_fixed(io, (char*)&index, sizeof(index), "", read, msg, text);
269  double numerator;
270  if (!read)
271  {
272  numerator = iter->second.first;
273  msg << numerator << ":";
274  }
275  bin_text_read_write_fixed(io, (char*)&numerator, sizeof(numerator), "", read, msg, text);
276  double denominator;
277  if (!read)
278  {
279  denominator = iter->second.second;
280  msg << denominator << "\n";
281  }
282  bin_text_read_write_fixed(io, (char*)&denominator, sizeof(denominator), "", read, msg, text);
283  if (read)
284  sm.marginals.insert(std::make_pair(index << stride_shift, std::make_pair(numerator, denominator)));
285  else
286  ++iter;
287  }
288 
289  if (sm.compete)
290  {
291  if (!read)
292  {
293  total_size = (uint64_t)sm.expert_state.size();
294  msg << "expert_state size = " << total_size << "\n";
295  }
296  bin_text_read_write_fixed_validated(io, (char*)&total_size, sizeof(total_size), "", read, msg, text);
297 
298  auto exp_iter = sm.expert_state.begin();
299  for (size_t i = 0; i < total_size; ++i)
300  {
301  uint64_t index;
302  if (!read)
303  {
304  index = exp_iter->first >> stride_shift;
305  msg << index << ":";
306  }
307  bin_text_read_write_fixed(io, (char*)&index, sizeof(index), "", read, msg, text);
308  float r1, c1, w1, r2, c2, w2;
309  if (!read)
310  {
311  r1 = exp_iter->second.first.regret;
312  c1 = exp_iter->second.first.abs_regret;
313  w1 = exp_iter->second.first.weight;
314  r2 = exp_iter->second.second.regret;
315  c2 = exp_iter->second.second.abs_regret;
316  w2 = exp_iter->second.second.weight;
317  msg << r1 << ":";
318  }
319  bin_text_read_write_fixed(io, (char*)&r1, sizeof(r1), "", read, msg, text);
320  if (!read)
321  msg << c1 << ":";
322  bin_text_read_write_fixed(io, (char*)&c1, sizeof(c1), "", read, msg, text);
323  if (!read)
324  msg << w1 << ":";
325  bin_text_read_write_fixed(io, (char*)&w1, sizeof(w1), "", read, msg, text);
326  if (!read)
327  msg << r2 << ":";
328  bin_text_read_write_fixed(io, (char*)&r2, sizeof(r2), "", read, msg, text);
329  if (!read)
330  msg << c2 << ":";
331  bin_text_read_write_fixed(io, (char*)&c2, sizeof(c2), "", read, msg, text);
332  if (!read)
333  msg << w2 << ":";
334  bin_text_read_write_fixed(io, (char*)&w2, sizeof(w2), "", read, msg, text);
335 
336  if (read)
337  {
338  expert e1 = {r1, c1, w1};
339  expert e2 = {r2, c2, w2};
340  sm.expert_state.insert(std::make_pair(index << stride_shift, std::make_pair(e1, e2)));
341  }
342  else
343  ++exp_iter;
344  }
345  }
346 }
347 } // namespace MARGINAL
348 
349 using namespace MARGINAL;
350 
352 {
353  free_ptr<MARGINAL::data> d = scoped_calloc_or_throw<MARGINAL::data>();
354  std::string marginal;
355 
356  option_group_definition marginal_options("VW options");
357  marginal_options.add(make_option("marginal", marginal).keep().help("substitute marginal label estimates for ids"));
358  marginal_options.add(
359  make_option("initial_denominator", d->initial_denominator).default_value(1.f).help("initial denominator"));
360  marginal_options.add(
361  make_option("initial_numerator", d->initial_numerator).default_value(0.5f).help("initial numerator"));
362  marginal_options.add(make_option("compete", d->compete).help("enable competition with marginal features"));
363  marginal_options.add(
364  make_option("update_before_learn", d->update_before_learn).help("update marginal values before learning"));
365  marginal_options.add(make_option("unweighted_marginals", d->unweighted_marginals)
366  .help("ignore importance weights when computing marginals"));
367  marginal_options.add(
368  make_option("decay", d->decay).default_value(0.f).help("decay multiplier per event (1e-3 for example)"));
369  options.add_and_parse(marginal_options);
370 
371  if (!options.was_supplied("marginal"))
372  {
373  return nullptr;
374  }
375 
376  d->all = &all;
377 
378  for (size_t u = 0; u < 256; u++)
379  if (marginal.find((char)u) != std::string::npos)
380  d->id_features[u] = true;
381 
383  init_learner(d, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
384  ret.set_save_load(save_load);
385 
386  return make_base(ret);
387 }
void save_load(data &sm, io_buf &io, bool read, bool text)
Definition: marginal.cc:244
#define correctedExp
Definition: correctedMath.h:27
parameters weights
Definition: global_data.h:537
loss_function * loss
Definition: global_data.h:523
iterator begin()
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
float net_feature_weight
Definition: marginal.cc:35
float scalar
Definition: example.h:45
bool unweighted_marginals
Definition: marginal.cc:25
float initial_denominator
Definition: marginal.cc:22
std::unordered_map< uint64_t, marginal > marginals
Definition: marginal.cc:28
the core definition of a set of features.
size_t bin_text_read_write_fixed_validated(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:335
std::unordered_map< uint64_t, expert_pair > expert_state
Definition: marginal.cc:38
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float partial_prediction
Definition: example.h:68
void undo_marginal(data &sm, example &ec)
Definition: marginal.cc:121
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
iterator end()
bool update_before_learn
Definition: marginal.cc:24
size_t size() const
Definition: v_array.h:68
std::pair< double, double > marginal
Definition: marginal.cc:16
bool id_features[256]
Definition: marginal.cc:26
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
size_t size() const
float get_adanormalhedge_weights(float R, float C)
Definition: marginal.cc:48
virtual float getLoss(shared_data *, float prediction, float label)=0
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
void make_marginal(data &sm, example &ec)
Definition: marginal.cc:57
float net_weight
Definition: marginal.cc:34
v_array< int > files
Definition: io_buf.h:64
virtual bool was_supplied(const std::string &key)=0
unsigned char namespace_index
LEARNER::base_learner * marginal_setup(options_i &options, vw &all)
Definition: marginal.cc:351
void compute_expert_loss(data &sm, example &ec)
Definition: marginal.cc:132
void clear()
float average_pred
Definition: marginal.cc:33
Definition: io_buf.h:54
void predict_or_learn(data &sm, LEARNER::single_learner &base, example &ec)
Definition: marginal.cc:200
features temp[256]
Definition: marginal.cc:27
float alg_loss
Definition: marginal.cc:36
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
iterator over values and indicies
v_array< audit_strings_ptr > space_names
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
bool empty() const
Definition: v_array.h:59
float feature_pred
Definition: marginal.cc:32
uint32_t stride_shift()
iterator begin()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
float initial_numerator
Definition: marginal.cc:21
void learn(E &ec, size_t i=0)
Definition: learner.h:160
void update_marginal(data &sm, example &ec)
Definition: marginal.cc:157
float weight
Definition: example.h:62
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:326
float abs_regret
Definition: marginal.cc:12
uint64_t mask()
float f
Definition: cache.cc:40
std::pair< expert, expert > expert_pair
Definition: marginal.cc:17