Vowpal Wabbit
mwt.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include "vw.h"
7 #include "reductions.h"
8 #include "gd.h"
9 #include "cb_algs.h"
10 #include "io_buf.h"
11 
12 using namespace LEARNER;
13 using namespace CB_ALGS;
14 using namespace VW::config;
15 
16 namespace MWT
17 {
19 {
20  double cost;
21  uint32_t action;
22  bool seen;
23 };
24 
25 struct mwt
26 {
27  bool namespaces[256]; // the set of namespaces to evaluate.
28  v_array<policy_data> evals; // accrued losses of features.
31  double total;
32  uint32_t num_classes;
33  bool learn;
34 
35  v_array<namespace_index> indices; // excluded namespaces
36  features feature_space[256];
37  vw* all;
38 
39  ~mwt()
40  {
41  evals.delete_v();
42  policies.delete_v();
43  for (auto& i : feature_space) i.delete_v();
44  indices.delete_v();
45  }
46 };
47 
48 inline bool observed_cost(CB::cb_class* cl)
49 {
50  // cost observed for this action if it has non zero probability and cost != FLT_MAX
51  if (cl != nullptr)
52  if (cl->cost != FLT_MAX && cl->probability > .0)
53  return true;
54  return false;
55 }
56 
58 {
59  for (auto& cl : ld.costs)
60  if (observed_cost(&cl))
61  return &cl;
62  return nullptr;
63 }
64 
65 void value_policy(mwt& c, float val, uint64_t index) // estimate the value of a single feature.
66 {
67  if (val < 0 || floor(val) != val)
68  std::cout << "error " << val << " is not a valid action " << std::endl;
69 
70  uint32_t value = (uint32_t)val;
71  uint64_t new_index = (index & c.all->weights.mask()) >> c.all->weights.stride_shift();
72 
73  if (!c.evals[new_index].seen)
74  {
75  c.evals[new_index].seen = true;
76  c.policies.push_back(new_index);
77  }
78 
79  c.evals[new_index].action = value;
80 }
81 
82 template <bool learn, bool exclude, bool is_learn>
84 {
86 
87  if (c.observation != nullptr)
88  {
89  c.total++;
90  // For each nonzero feature in observed namespaces, check it's value.
91  for (unsigned char ns : ec.indices)
92  if (c.namespaces[ns])
93  GD::foreach_feature<mwt, value_policy>(c.all, ec.feature_space[ns], c);
94  for (uint64_t policy : c.policies)
95  {
96  c.evals[policy].cost += get_cost_estimate(c.observation, c.evals[policy].action);
97  c.evals[policy].action = 0;
98  }
99  }
100  if (exclude || learn)
101  {
102  c.indices.clear();
103  uint32_t stride_shift = c.all->weights.stride_shift();
104  uint64_t weight_mask = c.all->weights.mask();
105  for (unsigned char ns : ec.indices)
106  if (c.namespaces[ns])
107  {
108  c.indices.push_back(ns);
109  if (learn)
110  {
111  c.feature_space[ns].clear();
112  for (features::iterator& f : ec.feature_space[ns])
113  {
114  uint64_t new_index = ((f.index() & weight_mask) >> stride_shift) * c.num_classes + (uint64_t)f.value();
115  c.feature_space[ns].push_back(1, new_index << stride_shift);
116  }
117  }
118  std::swap(c.feature_space[ns], ec.feature_space[ns]);
119  }
120  }
121 
122  // modify the predictions to use a vector with a score for each evaluated feature.
123  v_array<float> preds = ec.pred.scalars;
124 
125  if (learn)
126  {
127  if (is_learn)
128  base.learn(ec);
129  else
130  base.predict(ec);
131  }
132 
133  if (exclude || learn)
134  while (!c.indices.empty())
135  {
136  unsigned char ns = c.indices.pop();
137  std::swap(c.feature_space[ns], ec.feature_space[ns]);
138  }
139 
140  // modify the predictions to use a vector with a score for each evaluated feature.
141  preds.clear();
142  if (learn)
143  preds.push_back((float)ec.pred.multiclass);
144  for (uint64_t index : c.policies) preds.push_back((float)c.evals[index].cost / (float)c.total);
145 
146  ec.pred.scalars = preds;
147 }
148 
150 {
151  if (f >= 0)
152  {
153  std::stringstream ss;
154 
155  for (size_t i = 0; i < scalars.size(); i++)
156  {
157  if (i > 0)
158  ss << ' ';
159  ss << scalars[i];
160  }
161  for (size_t i = 0; i < tag.size(); i++)
162  {
163  if (i == 0)
164  ss << ' ';
165  ss << tag[i];
166  }
167  ss << '\n';
168  ssize_t len = ss.str().size();
169  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
170  if (t != len)
171  std::cerr << "write error: " << strerror(errno) << std::endl;
172  }
173 }
174 
175 void finish_example(vw& all, mwt& c, example& ec)
176 {
177  float loss = 0.;
178  if (c.learn)
179  if (c.observation != nullptr)
180  loss = get_cost_estimate(c.observation, (uint32_t)ec.pred.scalars[0]);
181  all.sd->update(ec.test_only, c.observation != nullptr, loss, 1.f, ec.num_features);
182 
183  for (int sink : all.final_prediction_sink) print_scalars(sink, ec.pred.scalars, ec.tag);
184 
185  if (c.learn)
186  {
187  v_array<float> temp = ec.pred.scalars;
188  ec.pred.multiclass = (uint32_t)temp[0];
189  CB::print_update(all, c.observation != nullptr, ec, nullptr, false);
190  ec.pred.scalars = temp;
191  }
192  VW::finish_example(all, ec);
193 }
194 
195 void save_load(mwt& c, io_buf& model_file, bool read, bool text)
196 {
197  if (model_file.files.empty())
198  return;
199 
200  std::stringstream msg;
201 
202  // total
203  msg << "total: " << c.total;
204  bin_text_read_write_fixed_validated(model_file, (char*)&c.total, sizeof(c.total), "", read, msg, text);
205 
206  // policies
207  size_t policies_size = c.policies.size();
208  bin_text_read_write_fixed_validated(model_file, (char*)&policies_size, sizeof(policies_size), "", read, msg, text);
209 
210  if (read)
211  {
212  c.policies.resize(policies_size);
213  c.policies.end() = c.policies.begin() + policies_size;
214  }
215  else
216  {
217  msg << "policies: ";
218  for (feature_index& policy : c.policies) msg << policy << " ";
219  }
220 
222  model_file, (char*)c.policies.begin(), policies_size * sizeof(feature_index), "", read, msg, text);
223 
224  // c.evals is already initialized nicely to the same size as the regressor.
225  for (feature_index& policy : c.policies)
226  {
227  policy_data& pd = c.evals[policy];
228  if (read)
229  msg << "evals: " << policy << ":" << pd.action << ":" << pd.cost << " ";
230  bin_text_read_write_fixed_validated(model_file, (char*)&c.evals[policy], sizeof(policy_data), "", read, msg, text);
231  }
232 }
233 } // namespace MWT
234 using namespace MWT;
235 
237 {
238  auto c = scoped_calloc_or_throw<mwt>();
239  std::string s;
240  bool exclude_eval = false;
241  option_group_definition new_options("Multiworld Testing Options");
242  new_options.add(make_option("multiworld_test", s).keep().help("Evaluate features as a policies"))
243  .add(make_option("learn", c->num_classes).help("Do Contextual Bandit learning on <n> classes."))
244  .add(make_option("exclude_eval", exclude_eval).help("Discard mwt policy features before learning"));
245  options.add_and_parse(new_options);
246 
247  if (!options.was_supplied("multiworld_test"))
248  return nullptr;
249 
250  for (char i : s) c->namespaces[(unsigned char)i] = true;
251  c->all = &all;
252 
253  calloc_reserve(c->evals, all.length());
254  c->evals.end() = c->evals.begin() + all.length();
255 
257  all.p->lp = CB::cb_label;
259 
260  if (c->num_classes > 0)
261  {
262  c->learn = true;
263 
264  if (!options.was_supplied("cb"))
265  {
266  std::stringstream ss;
267  ss << c->num_classes;
268  options.insert("cb", ss.str());
269  }
270  }
271 
273  if (c->learn)
274  if (exclude_eval)
275  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<true, true, true>,
276  predict_or_learn<true, true, false>, 1, prediction_type::scalars);
277  else
278  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<true, false, true>,
279  predict_or_learn<true, false, false>, 1, prediction_type::scalars);
280  else
281  l = &init_learner(c, as_singleline(setup_base(options, all)), predict_or_learn<false, false, true>,
282  predict_or_learn<false, false, false>, 1, prediction_type::scalars);
283 
286  return make_base(*l);
287 }
vw * all
Definition: mwt.cc:37
void resize(size_t length)
Definition: v_array.h:69
v_array< char > tag
Definition: example.h:63
size_t length()
Definition: global_data.h:513
v_array< namespace_index > indices
uint32_t multiclass
Definition: example.h:49
~mwt()
Definition: mwt.cc:39
parameters weights
Definition: global_data.h:537
v_array< namespace_index > indices
Definition: mwt.cc:35
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void(* delete_prediction)(void *)
Definition: global_data.h:485
T pop()
Definition: v_array.h:58
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
uint32_t action
Definition: mwt.cc:21
void push_back(feature_value v, feature_index i)
double cost
Definition: mwt.cc:20
void print_scalars(int f, v_array< float > &scalars, v_array< char > &tag)
Definition: mwt.cc:149
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
CB::cb_class * get_observed_cost(CB::label &ld)
Definition: mwt.cc:57
bool learn
Definition: mwt.cc:33
CB::label cb
Definition: example.h:31
features feature_space[256]
Definition: mwt.cc:36
label_type::label_type_t label_type
Definition: global_data.h:550
v_array< int > final_prediction_sink
Definition: global_data.h:518
the core definition of a set of features.
v_array< cb_class > costs
Definition: cb.h:27
v_array< uint64_t > policies
Definition: mwt.cc:30
size_t bin_text_read_write_fixed_validated(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:335
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
Definition: mwt.cc:25
bool seen
Definition: mwt.cc:22
void value_policy(mwt &c, float val, uint64_t index)
Definition: mwt.cc:65
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
bool observed_cost(CB::cb_class *cl)
Definition: mwt.cc:48
void predict_or_learn(mwt &c, single_learner &base, example &ec)
Definition: mwt.cc:83
parser * p
Definition: global_data.h:377
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
CB::cb_class * observation
Definition: mwt.cc:29
Definition: mwt.cc:16
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
bool namespaces[256]
Definition: mwt.cc:27
float probability
Definition: cb.h:19
base_learner * mwt_setup(options_i &options, vw &all)
Definition: mwt.cc:236
void delete_scalars(void *v)
Definition: example.h:37
v_array< int > files
Definition: io_buf.h:64
void clear()
Definition: v_array.h:88
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
void finish_example(vw &all, mwt &c, example &ec)
Definition: mwt.cc:175
double total
Definition: mwt.cc:31
uint64_t feature_index
Definition: feature_group.h:21
void clear()
void calloc_reserve(v_array< T > &v, size_t length)
Definition: v_array.h:220
Definition: io_buf.h:54
void finish_example(vw &, example &)
Definition: parser.cc:881
T *& end()
Definition: v_array.h:43
virtual void insert(const std::string &key, const std::string &value)=0
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
iterator over values and indicies
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
void save_load(mwt &c, io_buf &model_file, bool read, bool text)
Definition: mwt.cc:195
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: cb.h:25
float cost
Definition: cb.h:17
bool empty() const
Definition: v_array.h:59
uint32_t stride_shift()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< policy_data > evals
Definition: mwt.cc:28
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
v_array< float > scalars
Definition: example.h:46
uint64_t mask()
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
label_parser lp
Definition: parser.h:102
uint32_t num_classes
Definition: mwt.cc:32
bool test_only
Definition: example.h:76