Vowpal Wabbit
Classes | Functions
MWT Namespace Reference

Classes

struct  mwt
 
struct  policy_data
 

Functions

bool observed_cost (CB::cb_class *cl)
 
CB::cb_classget_observed_cost (CB::label &ld)
 
void value_policy (mwt &c, float val, uint64_t index)
 
template<bool learn, bool exclude, bool is_learn>
void predict_or_learn (mwt &c, single_learner &base, example &ec)
 
void print_scalars (int f, v_array< float > &scalars, v_array< char > &tag)
 
void finish_example (vw &all, mwt &c, example &ec)
 
void save_load (mwt &c, io_buf &model_file, bool read, bool text)
 
void delete_scalars (void *v)
 

Function Documentation

◆ delete_scalars()

void MWT::delete_scalars ( void *  v)
inline

Definition at line 37 of file example.h.

References v_array< T >::delete_v().

Referenced by lda_setup(), mwt_setup(), and oaa_setup().

38 {
39  v_array<float>* preds = (v_array<float>*)v;
40  preds->delete_v();
41 }
void delete_v()
Definition: v_array.h:98

◆ finish_example()

void MWT::finish_example ( vw all,
mwt c,
example ec 
)

Definition at line 175 of file mwt.cc.

References vw::final_prediction_sink, VW::finish_example(), CB_ALGS::get_cost_estimate(), MWT::mwt::learn, loss(), polyprediction::multiclass, example::num_features, MWT::mwt::observation, example::pred, print_scalars(), CB::print_update(), polyprediction::scalars, vw::sd, example::tag, example::test_only, and shared_data::update().

Referenced by mwt_setup().

176 {
177  float loss = 0.;
178  if (c.learn)
179  if (c.observation != nullptr)
180  loss = get_cost_estimate(c.observation, (uint32_t)ec.pred.scalars[0]);
181  all.sd->update(ec.test_only, c.observation != nullptr, loss, 1.f, ec.num_features);
182 
183  for (int sink : all.final_prediction_sink) print_scalars(sink, ec.pred.scalars, ec.tag);
184 
185  if (c.learn)
186  {
187  v_array<float> temp = ec.pred.scalars;
188  ec.pred.multiclass = (uint32_t)temp[0];
189  CB::print_update(all, c.observation != nullptr, ec, nullptr, false);
190  ec.pred.scalars = temp;
191  }
192  VW::finish_example(all, ec);
193 }
v_array< char > tag
Definition: example.h:63
uint32_t multiclass
Definition: example.h:49
void print_scalars(int f, v_array< float > &scalars, v_array< char > &tag)
Definition: mwt.cc:149
bool learn
Definition: mwt.cc:33
v_array< int > final_prediction_sink
Definition: global_data.h:518
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
CB::cb_class * observation
Definition: mwt.cc:29
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
polyprediction pred
Definition: example.h:60
v_array< float > scalars
Definition: example.h:46
bool test_only
Definition: example.h:76

◆ get_observed_cost()

CB::cb_class* MWT::get_observed_cost ( CB::label ld)

Definition at line 57 of file mwt.cc.

References CB::label::costs, and observed_cost().

Referenced by predict_or_learn().

58 {
59  for (auto& cl : ld.costs)
60  if (observed_cost(&cl))
61  return &cl;
62  return nullptr;
63 }
v_array< cb_class > costs
Definition: cb.h:27
bool observed_cost(CB::cb_class *cl)
Definition: mwt.cc:48

◆ observed_cost()

bool MWT::observed_cost ( CB::cb_class cl)
inline

Definition at line 48 of file mwt.cc.

References CB::cb_class::cost, and CB::cb_class::probability.

Referenced by get_observed_cost().

49 {
50  // cost observed for this action if it has non zero probability and cost != FLT_MAX
51  if (cl != nullptr)
52  if (cl->cost != FLT_MAX && cl->probability > .0)
53  return true;
54  return false;
55 }
float probability
Definition: cb.h:19
float cost
Definition: cb.h:17

◆ predict_or_learn()

template<bool learn, bool exclude, bool is_learn>
void MWT::predict_or_learn ( mwt c,
single_learner base,
example ec 
)

Definition at line 83 of file mwt.cc.

References MWT::mwt::all, c, polylabel::cb, v_array< T >::clear(), features::clear(), v_array< T >::empty(), MWT::mwt::evals, f, MWT::mwt::feature_space, example_predict::feature_space, CB_ALGS::get_cost_estimate(), get_observed_cost(), MWT::mwt::indices, example_predict::indices, example::l, LEARNER::learner< T, E >::learn(), learn(), parameters::mask(), polyprediction::multiclass, MWT::mwt::namespaces, MWT::mwt::num_classes, MWT::mwt::observation, MWT::mwt::policies, v_array< T >::pop(), example::pred, LEARNER::learner< T, E >::predict(), v_array< T >::push_back(), features::push_back(), polyprediction::scalars, stride_shift(), parameters::stride_shift(), MWT::mwt::total, and vw::weights.

84 {
86 
87  if (c.observation != nullptr)
88  {
89  c.total++;
90  // For each nonzero feature in observed namespaces, check it's value.
91  for (unsigned char ns : ec.indices)
92  if (c.namespaces[ns])
93  GD::foreach_feature<mwt, value_policy>(c.all, ec.feature_space[ns], c);
94  for (uint64_t policy : c.policies)
95  {
96  c.evals[policy].cost += get_cost_estimate(c.observation, c.evals[policy].action);
97  c.evals[policy].action = 0;
98  }
99  }
100  if (exclude || learn)
101  {
102  c.indices.clear();
103  uint32_t stride_shift = c.all->weights.stride_shift();
104  uint64_t weight_mask = c.all->weights.mask();
105  for (unsigned char ns : ec.indices)
106  if (c.namespaces[ns])
107  {
108  c.indices.push_back(ns);
109  if (learn)
110  {
111  c.feature_space[ns].clear();
112  for (features::iterator& f : ec.feature_space[ns])
113  {
114  uint64_t new_index = ((f.index() & weight_mask) >> stride_shift) * c.num_classes + (uint64_t)f.value();
115  c.feature_space[ns].push_back(1, new_index << stride_shift);
116  }
117  }
118  std::swap(c.feature_space[ns], ec.feature_space[ns]);
119  }
120  }
121 
122  // modify the predictions to use a vector with a score for each evaluated feature.
123  v_array<float> preds = ec.pred.scalars;
124 
125  if (learn)
126  {
127  if (is_learn)
128  base.learn(ec);
129  else
130  base.predict(ec);
131  }
132 
133  if (exclude || learn)
134  while (!c.indices.empty())
135  {
136  unsigned char ns = c.indices.pop();
137  std::swap(c.feature_space[ns], ec.feature_space[ns]);
138  }
139 
140  // modify the predictions to use a vector with a score for each evaluated feature.
141  preds.clear();
142  if (learn)
143  preds.push_back((float)ec.pred.multiclass);
144  for (uint64_t index : c.policies) preds.push_back((float)c.evals[index].cost / (float)c.total);
145 
146  ec.pred.scalars = preds;
147 }
vw * all
Definition: mwt.cc:37
v_array< namespace_index > indices
uint32_t multiclass
Definition: example.h:49
parameters weights
Definition: global_data.h:537
v_array< namespace_index > indices
Definition: mwt.cc:35
void predict(E &ec, size_t i=0)
Definition: learner.h:169
T pop()
Definition: v_array.h:58
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
CB::cb_class * get_observed_cost(CB::label &ld)
Definition: mwt.cc:57
CB::label cb
Definition: example.h:31
features feature_space[256]
Definition: mwt.cc:36
v_array< uint64_t > policies
Definition: mwt.cc:30
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
std::array< features, NUM_NAMESPACES > feature_space
CB::cb_class * observation
Definition: mwt.cc:29
void push_back(const T &new_ele)
Definition: v_array.h:107
bool namespaces[256]
Definition: mwt.cc:27
void clear()
Definition: v_array.h:88
double total
Definition: mwt.cc:31
void clear()
iterator over values and indicies
polylabel l
Definition: example.h:57
bool empty() const
Definition: v_array.h:59
uint32_t stride_shift()
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< policy_data > evals
Definition: mwt.cc:28
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
v_array< float > scalars
Definition: example.h:46
uint64_t mask()
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40
uint32_t num_classes
Definition: mwt.cc:32

◆ print_scalars()

void MWT::print_scalars ( int  f,
v_array< float > &  scalars,
v_array< char > &  tag 
)

Definition at line 149 of file mwt.cc.

References v_array< T >::size(), and io_buf::write_file_or_socket().

Referenced by finish_example(), and return_example().

150 {
151  if (f >= 0)
152  {
153  std::stringstream ss;
154 
155  for (size_t i = 0; i < scalars.size(); i++)
156  {
157  if (i > 0)
158  ss << ' ';
159  ss << scalars[i];
160  }
161  for (size_t i = 0; i < tag.size(); i++)
162  {
163  if (i == 0)
164  ss << ' ';
165  ss << tag[i];
166  }
167  ss << '\n';
168  ssize_t len = ss.str().size();
169  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
170  if (t != len)
171  std::cerr << "write error: " << strerror(errno) << std::endl;
172  }
173 }
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
size_t size() const
Definition: v_array.h:68
float f
Definition: cache.cc:40

◆ save_load()

void MWT::save_load ( mwt c,
io_buf model_file,
bool  read,
bool  text 
)

Definition at line 195 of file mwt.cc.

References MWT::policy_data::action, v_array< T >::begin(), bin_text_read_write_fixed_validated(), MWT::policy_data::cost, v_array< T >::empty(), v_array< T >::end(), MWT::mwt::evals, io_buf::files, MWT::mwt::policies, v_array< T >::resize(), v_array< T >::size(), and MWT::mwt::total.

Referenced by mwt_setup().

196 {
197  if (model_file.files.empty())
198  return;
199 
200  std::stringstream msg;
201 
202  // total
203  msg << "total: " << c.total;
204  bin_text_read_write_fixed_validated(model_file, (char*)&c.total, sizeof(c.total), "", read, msg, text);
205 
206  // policies
207  size_t policies_size = c.policies.size();
208  bin_text_read_write_fixed_validated(model_file, (char*)&policies_size, sizeof(policies_size), "", read, msg, text);
209 
210  if (read)
211  {
212  c.policies.resize(policies_size);
213  c.policies.end() = c.policies.begin() + policies_size;
214  }
215  else
216  {
217  msg << "policies: ";
218  for (feature_index& policy : c.policies) msg << policy << " ";
219  }
220 
222  model_file, (char*)c.policies.begin(), policies_size * sizeof(feature_index), "", read, msg, text);
223 
224  // c.evals is already initialized nicely to the same size as the regressor.
225  for (feature_index& policy : c.policies)
226  {
227  policy_data& pd = c.evals[policy];
228  if (read)
229  msg << "evals: " << policy << ":" << pd.action << ":" << pd.cost << " ";
230  bin_text_read_write_fixed_validated(model_file, (char*)&c.evals[policy], sizeof(policy_data), "", read, msg, text);
231  }
232 }
void resize(size_t length)
Definition: v_array.h:69
uint32_t action
Definition: mwt.cc:21
double cost
Definition: mwt.cc:20
v_array< uint64_t > policies
Definition: mwt.cc:30
size_t bin_text_read_write_fixed_validated(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
Definition: io_buf.h:335
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
v_array< int > files
Definition: io_buf.h:64
double total
Definition: mwt.cc:31
uint64_t feature_index
Definition: feature_group.h:21
T *& end()
Definition: v_array.h:43
bool empty() const
Definition: v_array.h:59
v_array< policy_data > evals
Definition: mwt.cc:28

◆ value_policy()

void MWT::value_policy ( mwt c,
float  val,
uint64_t  index 
)

Definition at line 65 of file mwt.cc.

References MWT::mwt::all, MWT::mwt::evals, parameters::mask(), MWT::mwt::policies, v_array< T >::push_back(), parameters::stride_shift(), and vw::weights.

66 {
67  if (val < 0 || floor(val) != val)
68  std::cout << "error " << val << " is not a valid action " << std::endl;
69 
70  uint32_t value = (uint32_t)val;
71  uint64_t new_index = (index & c.all->weights.mask()) >> c.all->weights.stride_shift();
72 
73  if (!c.evals[new_index].seen)
74  {
75  c.evals[new_index].seen = true;
76  c.policies.push_back(new_index);
77  }
78 
79  c.evals[new_index].action = value;
80 }
vw * all
Definition: mwt.cc:37
parameters weights
Definition: global_data.h:537
v_array< uint64_t > policies
Definition: mwt.cc:30
void push_back(const T &new_ele)
Definition: v_array.h:107
uint32_t stride_shift()
v_array< policy_data > evals
Definition: mwt.cc:28
uint64_t mask()