Vowpal Wabbit
cb_algs.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 
8 #include "vw.h"
9 #include "reductions.h"
10 #include "cb_algs.h"
11 #include "vw_exception.h"
12 #include "gen_cs_example.h"
13 
14 using namespace LEARNER;
15 using namespace VW::config;
16 
17 using namespace CB;
18 using namespace GEN_CS;
19 namespace CB_ALGS
20 {
21 struct cb
22 {
25 
26  ~cb()
27  {
28  cb_cs_ld.costs.delete_v();
30  }
31 };
32 
34 {
35  if (ld.costs.size() <= 1) // this means we specified an example where all actions are possible but only specified the
36  // cost for the observed action
37  return false;
38 
39  // if we specified more than 1 action for this example, i.e. either we have a limited set of possible actions, or all
40  // actions are specified than check if all actions have a specified cost
41  for (auto& cl : ld.costs)
42  if (cl.cost == FLT_MAX)
43  return false;
44 
45  return true;
46 }
47 
48 template <bool is_learn>
49 void predict_or_learn(cb& data, single_learner& base, example& ec)
50 {
51  CB::label ld = ec.l.cb;
52  cb_to_cs& c = data.cbcs;
54  if (c.known_cost != nullptr && (c.known_cost->action < 1 || c.known_cost->action > c.num_actions))
55  std::cerr << "invalid action: " << c.known_cost->action << std::endl;
56 
57  // generate a cost-sensitive example to update classifiers
58  gen_cs_example<is_learn>(c, ec, ld, data.cb_cs_ld);
59 
60  if (c.cb_type != CB_TYPE_DM)
61  {
62  ec.l.cs = data.cb_cs_ld;
63 
64  if (is_learn)
65  base.learn(ec);
66  else
67  base.predict(ec);
68 
69  for (size_t i = 0; i < ld.costs.size(); i++)
70  ld.costs[i].partial_prediction = data.cb_cs_ld.costs[i].partial_prediction;
71  ec.l.cb = ld;
72  }
73 }
74 
75 void predict_eval(cb&, single_learner&, example&) { THROW("can not use a test label for evaluation"); }
76 
77 void learn_eval(cb& data, single_learner&, example& ec)
78 {
79  CB_EVAL::label ld = ec.l.cb_eval;
80 
81  cb_to_cs& c = data.cbcs;
83  gen_cs_example<true>(c, ec, ld.event, data.cb_cs_ld);
84 
85  for (size_t i = 0; i < ld.event.costs.size(); i++)
86  ld.event.costs[i].partial_prediction = data.cb_cs_ld.costs[i].partial_prediction;
87 
88  ec.pred.multiclass = ec.l.cb_eval.action;
89 }
90 
91 void output_example(vw& all, cb& data, example& ec, CB::label& ld)
92 {
93  float loss = 0.;
94 
95  cb_to_cs& c = data.cbcs;
96  if (!CB::cb_label.test_label(&ld))
98 
99  all.sd->update(ec.test_only, !CB::cb_label.test_label(&ld), loss, 1.f, ec.num_features);
100 
101  for (int sink : all.final_prediction_sink) all.print(sink, (float)ec.pred.multiclass, 0, ec.tag);
102 
103  if (all.raw_prediction > 0)
104  {
105  std::stringstream outputStringStream;
106  for (unsigned int i = 0; i < ld.costs.size(); i++)
107  {
108  cb_class cl = ld.costs[i];
109  if (i > 0)
110  outputStringStream << ' ';
111  outputStringStream << cl.action << ':' << cl.partial_prediction;
112  }
113  all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
114  }
115 
116  print_update(all, CB::cb_label.test_label(&ld), ec, nullptr, false);
117 }
118 
119 void finish_example(vw& all, cb& c, example& ec)
120 {
121  output_example(all, c, ec, ec.l.cb);
122  VW::finish_example(all, ec);
123 }
124 
125 void eval_finish_example(vw& all, cb& c, example& ec)
126 {
127  output_example(all, c, ec, ec.l.cb_eval.event);
128  VW::finish_example(all, ec);
129 }
130 } // namespace CB_ALGS
131 using namespace CB_ALGS;
133 {
134  auto data = scoped_calloc_or_throw<cb>();
135  std::string type_string = "dr";
136  bool eval = false;
137 
138  option_group_definition new_options("Contextual Bandit Options");
139  new_options
140  .add(make_option("cb", data->cbcs.num_actions).keep().help("Use contextual bandit learning with <k> costs"))
141  .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dm,dr}"))
142  .add(make_option("eval", eval).help("Evaluate a policy rather than optimizing."));
143  options.add_and_parse(new_options);
144 
145  if (!options.was_supplied("cb"))
146  return nullptr;
147 
148  // Ensure serialization of this option in all cases.
149  if (!options.was_supplied("cb_type"))
150  {
151  options.insert("cb_type", type_string);
152  options.add_and_parse(new_options);
153  }
154 
155  cb_to_cs& c = data->cbcs;
156 
157  size_t problem_multiplier = 2; // default for DR
158  if (type_string.compare("dr") == 0)
159  c.cb_type = CB_TYPE_DR;
160  else if (type_string.compare("dm") == 0)
161  {
162  if (eval)
163  THROW("direct method can not be used for evaluation --- it is biased.");
164  c.cb_type = CB_TYPE_DM;
165  problem_multiplier = 1;
166  }
167  else if (type_string.compare("ips") == 0)
168  {
169  c.cb_type = CB_TYPE_IPS;
170  problem_multiplier = 1;
171  }
172  else
173  {
174  std::cerr << "warning: cb_type must be in {'ips','dm','dr'}; resetting to dr." << std::endl;
175  c.cb_type = CB_TYPE_DR;
176  }
177 
178  if (!options.was_supplied("csoaa"))
179  {
180  std::stringstream ss;
181  ss << data->cbcs.num_actions;
182  options.insert("csoaa", ss.str());
183  }
184 
185  auto base = as_singleline(setup_base(options, all));
186  if (eval)
187  {
188  all.p->lp = CB_EVAL::cb_eval;
190  }
191  else
192  {
193  all.p->lp = CB::cb_label;
195  }
196 
198  if (eval)
199  {
200  l = &init_learner(data, base, learn_eval, predict_eval, problem_multiplier, prediction_type::multiclass);
202  }
203  else
204  {
205  l = &init_learner(
206  data, base, predict_or_learn<true>, predict_or_learn<false>, problem_multiplier, prediction_type::multiclass);
208  }
209  c.scorer = all.scorer;
210 
211  return make_base(*l);
212 }
label_parser cb_eval
Definition: cb.cc:292
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
uint32_t multiclass
Definition: example.h:49
COST_SENSITIVE::label pred_scores
void predict(E &ec, size_t i=0)
Definition: learner.h:169
#define CB_TYPE_IPS
Definition: cb_algs.h:15
void finish_example(vw &all, cb &c, example &ec)
Definition: cb_algs.cc:119
label_parser cs_label
void(* delete_label)(void *)
Definition: label_parser.h:16
bool know_all_cost_example(CB::label &ld)
Definition: cb_algs.cc:33
CB::label cb
Definition: example.h:31
label_type::label_type_t label_type
Definition: global_data.h:550
COST_SENSITIVE::label cb_cs_ld
Definition: cb_algs.cc:24
bool(* test_label)(void *)
Definition: label_parser.h:22
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< cb_class > costs
Definition: cb.h:27
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
CB_EVAL::label cb_eval
Definition: example.h:33
#define CB_TYPE_DM
Definition: cb_algs.h:14
virtual void add_and_parse(const option_group_definition &group)=0
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
void learn_eval(cb &data, single_learner &, example &ec)
Definition: cb_algs.cc:77
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
#define CB_TYPE_DR
Definition: cb_algs.h:13
CB::label event
Definition: cb.h:42
CB::cb_class get_observed_cost(multi_ex &examples)
Definition: cb_adf.cc:99
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
CB::cb_class * known_cost
Definition: cb.cc:15
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
uint32_t action
Definition: cb.h:18
float partial_prediction
Definition: cb.h:21
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
base_learner * cb_algs_setup(options_i &options, vw &all)
Definition: cb_algs.cc:132
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
uint32_t num_actions
void predict_eval(cb &, single_learner &, example &)
Definition: cb_algs.cc:75
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void predict_or_learn(cb &data, single_learner &base, example &ec)
Definition: cb_algs.cc:49
LEARNER::single_learner * scorer
void finish_example(vw &, example &)
Definition: parser.cc:881
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
option_group_definition & add(T &&op)
Definition: options.h:90
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
label_parser cb_label
Definition: cb.cc:167
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: cb.h:25
void eval_finish_example(vw &all, cb &c, example &ec)
Definition: cb_algs.cc:125
void output_example(vw &all, cb &data, example &ec, CB::label &ld)
Definition: cb_algs.cc:91
bool test_label(void *v)
Definition: simple_label.cc:70
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
cb_to_cs cbcs
Definition: cb_algs.cc:23
v_array< wclass > costs
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
label_parser lp
Definition: parser.h:102
bool test_only
Definition: example.h:76