Vowpal Wabbit
search.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5 */
6 #pragma once
7 #include "global_data.h"
8 
9 #define cdbg std::clog
10 #undef cdbg
11 #define cdbg \
12  if (1) \
13  { \
14  } \
15  else \
16  std::clog
17 // comment the previous two lines if you want loads of debug output :)
18 
19 typedef uint32_t action;
20 typedef uint32_t ptag;
21 
22 namespace Search
23 {
24 struct search_private;
25 struct search_task;
26 
28 
29 struct search;
30 
31 class BaseTask
32 {
33  public:
34  BaseTask(search* _sch, multi_ex& _ec) : sch(_sch), ec(_ec)
35  {
36  _foreach_action = nullptr;
37  _post_prediction = nullptr;
39  _with_output_string = nullptr;
40  _final_run = false;
41  }
42  inline BaseTask& foreach_action(void (*f)(search&, size_t, float, action, bool, float))
43  {
45  return *this;
46  }
47  inline BaseTask& post_prediction(void (*f)(search&, size_t, action, float))
48  {
50  return *this;
51  }
52  inline BaseTask& maybe_override_prediction(bool (*f)(search&, size_t, action&, float&))
53  {
55  return *this;
56  }
57  inline BaseTask& with_output_string(void (*f)(search&, std::stringstream&))
58  {
60  return *this;
61  }
62  inline BaseTask& final_run()
63  {
64  _final_run = true;
65  return *this;
66  }
67 
68  void Run();
69 
70  // data
73  bool _final_run;
74  void (*_foreach_action)(search&, size_t, float, action, bool, float);
75  void (*_post_prediction)(search&, size_t, action, float);
76  bool (*_maybe_override_prediction)(search&, size_t, action&, float&);
77  void (*_with_output_string)(search&, std::stringstream&);
78 };
79 
80 struct search
81 { // INTERFACE
82  // for managing task-specific data that you want on the heap:
83  template <class T>
84  void set_task_data(T* data)
85  {
86  task_data = data;
87  }
88  template <class T>
90  {
91  return (T*)task_data;
92  }
93 
94  // for managing metatask-specific data
95  template <class T>
96  void set_metatask_data(T* data)
97  {
98  metatask_data = data;
99  }
100  template <class T>
102  {
103  return (T*)metatask_data;
104  }
105 
106  // for setting programmatic options during initialization
107  // this should be an or ("|") of AUTO_CONDITION_FEATURES, etc.
108  void set_options(uint32_t opts);
109 
110  // change the default label parser, but you _must_ tell me how
111  // to detect test examples!
112  void set_label_parser(label_parser& lp, bool (*is_test)(polylabel&));
113 
114  // for explicitly declaring a loss incrementally
115  void loss(float incr_loss);
116 
117  // make a prediction on an example. returns the predicted action.
118  // arguments:
119  // ec the example (features) on which to make a prediction
120  // my_tag a tag for this prediction, so that you can explicitly
121  // state, for future predictions, which ones depend
122  // explicitely or implicitly on this prediction
123  // oracle_actions an array of actions that the oracle would take
124  // nullptr => the oracle doesn't know (is random!)
125  // oracle_actions_cnt the length of the previous array, or 0 if it's nullptr
126  // condition_on an array of previous (or future) predictions on which
127  // this prediction depends. the semantics of conditioning
128  // is that IF the predictions for all the tags in
129  // condition_on were the same, then the prediction for
130  // _this_ example will also be the same. i.e., same
131  // features, etc. (also assuming same policy). if
132  // AUTO_CONDITION_FEATURES is on, then we will automatically
133  // add features to ec based on what you're conditioning on.
134  // nullptr => independent prediction
135  // condition_on_names a std::string containing the list of names of features you're
136  // conditioning on. used explicitly for auditing, implicitly
137  // for keeping tags separated. also, strlen(condition_on_names)
138  // tells us how long condition_on is
139  // allowed_actions an array of actions that are allowed at this step, or
140  // nullptr if everything is allowed
141  // allowed_actions_cnt the length of allowed_actions (0 if allowed_actions is null)
142  // allowed_actions_cost if you can precompute the cost-under-rollout-by-ref for each
143  // allowed action, and the underlying algorithm can use this
144  // (i.e., rollout=none or rollout=mix_per_roll and we're on
145  // a rollout-by-ref), then fill this in and rollouts will be
146  // avoided. note: if you provide allowed_actions_cost,
147  // then oracle_actions will be ignored (might as well pass
148  // nullptr). if allowed_actions
149  // is a nullptr, then allowed_actions_cost should be a vector
150  // of length equal to the total number of actions ("A"); otherwise
151  // it should be of length allowed_actions_cnt. only valid
152  // if ACTION_COSTS is specified as an option.
153  // learner_id the id for the underlying learner to use (via set_num_learners)
154  action predict(example& ec, ptag my_tag, const action* oracle_actions, size_t oracle_actions_cnt = 1,
155  const ptag* condition_on = nullptr,
156  const char* condition_on_names = nullptr // strlen(condition_on_names) should == |condition_on|
157  ,
158  const action* allowed_actions = nullptr, size_t allowed_actions_cnt = 0,
159  const float* allowed_actions_cost = nullptr, size_t learner_id = 0, float weight = 0.);
160 
161  // make an LDF prediction on a list of examples. arguments are identical to predict(...)
162  // with the following exceptions:
163  // * ecs/ec_cnt replace ec. ecs is the list of examples the make up a single
164  // LDF example, and ec_cnt is its length
165  // * there are no more "allowed_actions" because that is implicit in the LDF
166  // example structure. additionally, allowed_actions_cost should be stored
167  // in the label structure for ecs (if ACTION_COSTS is set as an option)
168  action predictLDF(example* ecs, size_t ec_cnt, ptag my_tag, const action* oracle_actions,
169  size_t oracle_actions_cnt = 1, const ptag* condition_on = nullptr, const char* condition_on_names = nullptr,
170  size_t learner_id = 0, float weight = 0.);
171 
172  // some times during training, a call to "predict" doesn't
173  // actually use the example you pass (*), and for efficiency you
174  // might want to forgo the construction of examples in those
175  // cases. if a call to predictNeedsExample() returns true, then
176  // then any subsequent call to predict should be sure to include
177  // correctly processed examples. if it returns false, you can pass
178  // anything to the next call to predict.
179  //
180  // (*) the slight exception is for predictLDF. in this case, we
181  // always need to provide some examples so that we know which
182  // actions are possible. in LDF mode, if predictNeedsExample()
183  // returns false, then it's okay to just provide the labels in
184  // your subsequent call to predictLDF(), and skip the feature
185  // values.
186  bool predictNeedsExample();
187 
188  // get the value specified by --search_history_length
189  uint32_t get_history_length();
190 
191  // check if the user declared ldf mode
192  bool is_ldf();
193 
194  // where you should write output
195  std::stringstream& output();
196 
197  // set the number of learners
198  void set_num_learners(size_t num_learners);
199 
200  // get the action sequence from the test run (only run if test_only or -t or...)
201  void get_test_action_sequence(std::vector<action>&);
202 
203  // get feature index mask
204  uint64_t get_mask();
205 
206  // get stride_shift
207  size_t get_stride_shift();
208 
209  // pretty print a label
210  std::string pretty_label(action a);
211 
212  // for meta-tasks:
213  BaseTask base_task(multi_ex& ec) { return BaseTask(this, ec); }
214 
215  // internal data that you don't get to see!
217  void* task_data; // your task data!
218  void* metatask_data; // your metatask data!
219  const char* task_name;
220  const char* metatask_name;
221 
222  vw& get_vw_pointer_unsafe(); // although you should rarely need this, some times you need a poiter to the vw data
223  // structure :(
224  void set_force_oracle(bool force); // if the library wants to force search to use the oracle, set this to true
225  search();
226  ~search();
227 };
228 
229 // for defining new tasks, you must fill out a search_task
231 { // required
232  const char* task_name;
233  void (*run)(search&, multi_ex&);
234 
235  // optional
237  void (*finish)(search&);
238  void (*run_setup)(search&, multi_ex&);
240 };
241 
243 { // required
244  const char* metatask_name;
245  void (*run)(search&, multi_ex&);
246 
247  // optional
249  void (*finish)(search&);
250  void (*run_setup)(search&, multi_ex&);
252 };
253 
254 // to make calls to "predict" (and "predictLDF") cleaner when you
255 // want to use crazy combinations of arguments
257 {
258  public:
259  predictor(search& sch, ptag my_tag);
260  ~predictor();
261 
262  // tell the predictor what to use as input. a single example input
263  // means non-LDF mode; an array of inputs means LDF mode
264  predictor& set_input(example& input_example);
265  predictor& set_input(example* input_example, size_t input_length); // if you're lucky and have an array of examples
266 
267  // the following is mostly to make life manageable for the Python interface
268  void set_input_length(size_t input_length); // declare that we have an input_length-long LDF example
269  void set_input_at(size_t posn, example& input_example); // set the corresponding input (*after* set_input_length)
270 
271  // different ways of adding to the list of oracle actions. you can
272  // either add_ or set_; setting erases previous actions. these
273  // functions attempt to allocate as little memory as possible, so if
274  // you pass a v_array or an action*, unless you later add something
275  // else, we'll just store a pointer to your memory. this means that
276  // you probably shouldn't change the data there, or free that pointer,
277  // between calling add/set_oracle and calling predict()
278  predictor& erase_oracles();
279 
280  predictor& reset();
281 
282  predictor& add_oracle(action a);
283  predictor& add_oracle(action* a, size_t action_count);
284  predictor& add_oracle(v_array<action>& a);
285 
286  predictor& set_oracle(action a);
287  predictor& set_oracle(action* a, size_t action_count);
288  predictor& set_oracle(v_array<action>& a);
289 
290  predictor& set_weight(float w);
291 
292  // same as add/set_oracle but for allowed actions
293  predictor& erase_alloweds();
294 
295  predictor& add_allowed(action a);
296  predictor& add_allowed(action* a, size_t action_count);
297  predictor& add_allowed(v_array<action>& a);
298 
299  predictor& set_allowed(action a);
300  predictor& set_allowed(action* a, size_t action_count);
301  predictor& set_allowed(v_array<action>& a);
302 
303  // set/add allowed but with per-actions costs specified
304  predictor& add_allowed(action a, float cost);
305  predictor& add_allowed(action* a, float* costs, size_t action_count);
306  predictor& add_allowed(v_array<std::pair<action, float> >& a);
307  predictor& add_allowed(std::vector<std::pair<action, float> >& a);
308 
309  predictor& set_allowed(action a, float cost);
310  predictor& set_allowed(action* a, float* costs, size_t action_count);
311  predictor& set_allowed(v_array<std::pair<action, float> >& a);
312  predictor& set_allowed(std::vector<std::pair<action, float> >& a);
313 
314  // add a tag to condition on with a name, or set the conditioning
315  // variables (i.e., erase previous ones)
316  predictor& add_condition(ptag tag, char name);
317  predictor& set_condition(ptag tag, char name);
318  predictor& add_condition_range(
319  ptag hi, ptag count, char name0); // add (hi,name0), (hi-1,name0+1), ..., (h-count,name0+count)
320  predictor& set_condition_range(
321  ptag hi, ptag count, char name0); // set (hi,name0), (hi-1,name0+1), ..., (h-count,name0+count)
322 
323  // set learner id
324  predictor& set_learner_id(size_t id);
325 
326  // change the current tag
327  predictor& set_tag(ptag tag);
328 
329  // make a prediction
330  action predict();
331 
332  private:
333  bool is_ldf;
336  size_t ec_cnt;
338  float weight;
340  bool oracle_is_pointer; // if we're pointing to your memory TRUE; if it's our own memory FALSE
344  bool allowed_is_pointer; // if we're pointing to your memory TRUE; if it's our own memory FALSE
346  bool allowed_cost_is_pointer; // if we're pointing to your memory TRUE; if it's our own memory FALSE
347  size_t learner_id;
349 
350  template <class T>
351  void make_new_pointer(v_array<T>& A, size_t new_size);
352  template <class T>
353  predictor& add_to(v_array<T>& A, bool& A_is_ptr, T a, bool clear_first);
354  template <class T>
355  predictor& add_to(v_array<T>& A, bool& A_is_ptr, T* a, size_t count, bool clear_first);
356  void free_ec();
357 
358  // prevent the user from doing something stupid :) ... ugh needed to turn this off for python :(
359  // predictor(const predictor&P);
360  // predictor&operator=(const predictor&P);
361 };
362 
363 // some helper functions you might find helpful
364 /*template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool
365 default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string) { if
366 (vm.count(opt_name)) { ret = vm[opt_name].as<T>(); *all.args_n_opts.file_options << " --" << opt_name << " " << ret;
367  }
368  else if (strlen(required_error_string)>0)
369  { std::cerr << required_error_string << std::endl;
370  if (! vm.count("help"))
371  THROW(required_error_string);
372  }
373  }*/
374 
375 // void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const
376 // char* mismatch_error_string);
377 bool string_equal(std::string a, std::string b);
378 bool float_equal(float a, float b);
379 bool uint32_equal(uint32_t a, uint32_t b);
380 bool size_equal(size_t a, size_t b);
381 
382 // our interface within VW
384 } // namespace Search
size_t ec_cnt
Definition: search.h:336
search * sch
Definition: search.h:71
T * get_metatask_data()
Definition: search.h:101
v_array< char > condition_on_names
Definition: search.h:342
v_array< float > allowed_actions_cost
Definition: search.h:345
BaseTask(search *_sch, multi_ex &_ec)
Definition: search.h:34
Definition: search.cc:33
void run_takedown(Search::search &sch, multi_ex &)
size_t learner_id
Definition: search.h:347
base_learner * setup(options_i &options, vw &all)
Definition: search.cc:2671
search_private * priv
Definition: search.h:216
v_array< ptag > condition_on_tags
Definition: search.h:341
uint32_t action
Definition: search.h:19
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
const char * metatask_name
Definition: search.h:244
bool allowed_is_pointer
Definition: search.h:344
bool _final_run
Definition: search.h:73
bool float_equal(float a, float b)
bool uint32_equal(uint32_t a, uint32_t b)
void set_metatask_data(T *data)
Definition: search.h:96
T * get_task_data()
Definition: search.h:89
uint32_t ACTION_COSTS
Definition: search.cc:50
BaseTask & maybe_override_prediction(bool(*f)(search &, size_t, action &, float &))
Definition: search.h:52
void(* _foreach_action)(search &, size_t, float, action, bool, float)
Definition: search.h:74
BaseTask & post_prediction(void(*f)(search &, size_t, action, float))
Definition: search.h:47
multi_ex & ec
Definition: search.h:72
bool allowed_cost_is_pointer
Definition: search.h:346
void(* _post_prediction)(search &, size_t, action, float)
Definition: search.h:75
void run_setup(Search::search &sch, multi_ex &)
uint32_t NO_CACHING
Definition: search.cc:49
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
void(* _with_output_string)(search &, std::stringstream &)
Definition: search.h:77
example * ec
Definition: search.h:335
BaseTask & foreach_action(void(*f)(search &, size_t, float, action, bool, float))
Definition: search.h:42
const char * task_name
Definition: search.h:232
BaseTask & with_output_string(void(*f)(search &, std::stringstream &))
Definition: search.h:57
uint32_t IS_LDF
Definition: search.cc:49
search & sch
Definition: search.h:348
void set_weight(vw &all, uint32_t index, uint32_t offset, float value)
Definition: vw.h:182
vw * initialize(options_i &options, io_buf *model, bool skipModelLoad, trace_message_t trace_listener, void *trace_context)
Definition: parse_args.cc:1654
v_array< action > oracle_actions
Definition: search.h:339
void * task_data
Definition: search.h:217
float weight
std::vector< example * > multi_ex
Definition: example.h:122
uint32_t AUTO_HAMMING_LOSS
Definition: search.cc:49
constexpr uint64_t a
Definition: rand48.cc:11
void set_task_data(T *data)
Definition: search.h:84
BaseTask base_task(multi_ex &ec)
Definition: search.h:213
bool oracle_is_pointer
Definition: search.h:340
void finish(audit_regressor_data &dat)
void * metatask_data
Definition: search.h:218
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
bool size_equal(size_t a, size_t b)
uint32_t ptag
Definition: search.h:20
uint32_t EXAMPLES_DONT_CHANGE
Definition: search.cc:49
bool(* _maybe_override_prediction)(search &, size_t, action &, float &)
Definition: search.h:76
v_array< action > allowed_actions
Definition: search.h:343
const char * metatask_name
Definition: search.h:220
float f
Definition: cache.cc:40
const char * task_name
Definition: search.h:219
bool string_equal(std::string a, std::string b)
BaseTask & final_run()
Definition: search.h:62
void run(Search::search &sch, multi_ex &ec)