Vowpal Wabbit
search.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 #include <string.h>
8 #include <math.h>
9 #include <memory>
10 #include "vw.h"
11 #include "rand48.h"
12 #include "reductions.h"
13 #include "gd.h" // for GD::foreach_feature
14 #include "search_sequencetask.h"
15 #include "search_multiclasstask.h"
16 #include "search_dep_parser.h"
18 #include "search_hooktask.h"
19 #include "search_graph.h"
20 #include "search_meta.h"
21 #include "csoaa.h"
22 #include "active.h"
23 #include "label_dictionary.h"
24 #include "vw_exception.h"
25 
26 using namespace LEARNER;
27 using namespace VW::config;
28 namespace CS = COST_SENSITIVE;
29 namespace MC = MULTICLASS;
30 
31 using std::endl;
32 
33 namespace Search
34 {
37  &EntityRelationTask::task, &HookTask::task, &GraphTask::task, nullptr}; // must nullptr terminate!
38 
40  &DebugMT::metatask, &SelectiveBranchingMT::metatask, nullptr}; // must nullptr terminate!
41 
42 constexpr bool PRINT_UPDATE_EVERY_EXAMPLE = false;
43 constexpr bool PRINT_UPDATE_EVERY_PASS = false;
44 constexpr bool PRINT_CLOCK_TIME = false;
45 
46 std::string neighbor_feature_space("neighbor");
47 std::string condition_feature_space("search_condition");
48 
52 {
58 };
60 {
66 };
67 
68 // a data structure to hold conditioning information
69 struct prediction
70 {
71  ptag me; // the id of the current prediction (the one being memoized)
72  size_t cnt; // how many variables are we conditioning on?
73  ptag* tags; // which variables are they?
74  action* acts; // and which actions were taken at each?
75  uint32_t hash; // a hash of the above
76 };
77 
78 // parameters for auto-conditioning
80 {
81  size_t max_bias_ngram_length; // add a "bias" feature for each ngram up to and including this length. eg., if it's 1,
82  // then you get a single feature for each conditional
83  size_t max_quad_ngram_length; // add bias *times* input features for each ngram up to and including this length
84  float feature_value; // how much weight should the conditional features get?
85  bool use_passthrough_repr; // should we ask lower-level reductions for their internal state?
86 };
87 
89 {
90  action a; // the action
91  float s; // the predicted cost of this action
92  // v_array<feature> repr;
93  scored_action(action _a = (action)-1, float _s = 0) : a(_a), s(_s) {} // , repr(v_init<feature>()) {}
94  // scored_action(action _a, float _s, v_array<feature>& _repr) : a(_a), s(_s), repr(_repr) {}
95  // scored_action() { a = (action)-1; s = 0.; }
96 };
97 std::ostream& operator<<(std::ostream& os, const scored_action& x)
98 {
99  os << x.a << ':' << x.s;
100  return os;
101 }
102 
104 {
107  action_repr(action _a, features* _repr) : a(_a)
108  {
109  if (_repr != nullptr)
110  {
111  repr = new features();
112  repr->deep_copy_from(*_repr);
113  }
114  else
115  repr = nullptr;
116  }
117  action_repr(action _a) : a(_a), repr(nullptr) {}
118 };
119 
121 {
122  float min_cost;
124  bool is_opt;
125  float cost;
126  action_cache(float _min_cost, action _k, bool _is_opt, float _cost)
127  : min_cost(_min_cost), k(_k), is_opt(_is_opt), cost(_cost)
128  {
129  }
130 };
131 std::ostream& operator<<(std::ostream& os, const action_cache& x)
132 {
133  os << x.k << ':' << x.cost;
134  if (x.is_opt)
135  os << '*';
136  return os;
137 }
138 
140 {
141  vw* all;
142  std::shared_ptr<rand_state> _random_state;
143 
144  uint64_t offset;
145  bool auto_condition_features; // do you want us to automatically add conditioning features?
146  bool auto_hamming_loss; // if you're just optimizing hamming loss, we can do it for you!
147  bool examples_dont_change; // set to true if you don't do any internal example munging
148  bool is_ldf; // user declared ldf
149  bool use_action_costs; // task promises to define per-action rollout-by-ref costs
150 
151  v_array<int32_t> neighbor_features; // ugly encoding of neighbor feature requirements
152  auto_condition_settings acset; // settings for auto-conditioning
153  size_t history_length; // value of --search_history_length, used by some tasks, default 1
154 
155  size_t A; // total number of actions, [1..A]; 0 means ldf
156  size_t num_learners; // total number of learners;
157  bool cb_learner; // do contextual bandit learning on action (was "! rollout_all_actions" which was confusing)
158  SearchState state; // current state of learning
159  size_t learn_learner_id; // we allow user to use different learners for different states
160  int mix_per_roll_policy; // for MIX_PER_ROLL, we need to choose a policy to use; this is where it's stored (-2 means
161  // "not selected yet")
162  bool no_caching; // turn off caching
163  size_t rollout_num_steps; // how many calls of "loss" before we stop really predicting on rollouts and switch to
164  // oracle (0 means "infinite")
165  bool linear_ordering; // insist that examples are generated in linear order (rather that the default hoopla
166  // permutation)
167  bool (*label_is_test)(polylabel&); // tell me if the label data from an example is test
168 
169  size_t t; // current search step
170  size_t T; // length of root trajectory
171  v_array<example> learn_ec_copy; // copy of example(s) at learn_t
172  example* learn_ec_ref; // reference to example at learn_t, when there's no example munging
173  size_t learn_ec_ref_cnt; // how many are there (for LDF mode only; otherwise 1)
174  v_array<ptag> learn_condition_on; // a copy of the tags used for conditioning at the training position
176  v_array<char> learn_condition_on_names; // the names of the actions
177  v_array<action> learn_allowed_actions; // which actions were allowed at training time?
178  v_array<action_repr> ptag_to_action; // tag to action mapping for conditioning
179  std::vector<action> test_action_sequence; // if test-mode was run, what was the corresponding action sequence; it's a
180  // vector cuz we might expose it to the library
181  action learn_oracle_action; // store an oracle action for debugging purposes
183 
185 
186  size_t loss_declared_cnt; // how many times did run declare any loss (implicitly or explicitly)?
187  v_array<scored_action> train_trajectory; // the training trajectory
188  size_t learn_t; // what time step are we learning on?
189  size_t learn_a_idx; // what action index are we trying?
190  bool done_with_all_actions; // set to true when there are no more learn_a_idx to go
191 
192  float test_loss; // loss incurred when run INIT_TEST
193  float learn_loss; // loss incurred when run LEARN
194  float train_loss; // loss incurred when run INIT_TRAIN
195 
196  bool hit_new_pass; // have we hit a new pass?
197  bool force_oracle; // insist on using the oracle to make predictions
198  float perturb_oracle; // with this probability, choose a random action instead of oracle action
199 
200  size_t num_calls_to_run, num_calls_to_run_previous, save_every_k_runs;
201 
202  // if we're printing to stderr we need to remember if we've printed the header yet
203  // (i.e., we do this if we're driving)
205 
206  // various strings for different search states
208  std::stringstream* pred_string;
209  std::stringstream* truth_string;
210  std::stringstream* bad_string_stream;
211 
212  // parameters controlling interpolation
213  float beta; // interpolation rate
214  float alpha; // parameter used to adapt beta for dagger (see above comment), should be in (0,1)
215 
218  float subsample_timesteps; // train at every time step or just a (random) subset?
219  bool xv; // train three separate policies -- two for providing examples to the other and a third training on the
220  // union (which will be used at test time -- TODO)
221 
222  bool allow_current_policy; // should the current policy be used for training? true for dagger
223  bool adaptive_beta; // used to implement dagger-like algorithms. if true, beta = 1-(1-alpha)^n after n updates, and
224  // policy is mixed with oracle as \pi' = (1-beta)\pi^* + beta \pi
225  size_t passes_per_policy; // if we're not in dagger-mode, then we need to know how many passes to train a policy
226 
227  uint32_t current_policy; // what policy are we training right now?
228 
229  // various statistics for reporting
230  size_t num_features;
238 
240 
241  // for foreach_feature temporary storage for conditioning
244  std::stringstream dat_new_feature_audit_ss;
248 
249  // to reduce memory allocation
250  std::string rawOutputString;
251  std::stringstream* rawOutputStringStream;
262 
265 
267 
268  search_task* task; // your task!
269  search_metatask* metatask; // your (optional) metatask
271  size_t meta_t; // the metatask has it's own notion of time. meta_t+t, during a single run, is the way to think about
272  // the "real" decision step but this really only matters for caching purposes
274  memo_foreach_action; // when foreach_action is on, we need to cache TRAIN trajectory actions for LEARN
275 };
276 
277 void free_key(unsigned char* mem, scored_action) { free(mem); } // sa.repr.delete_v(); }
279 {
280  priv.cache_hash_map.iter(free_key);
281  priv.cache_hash_map.clear();
282 }
283 
285 {
286  for (size_t i = 0; i < priv.memo_foreach_action.size(); i++)
287  if (priv.memo_foreach_action[i])
288  {
289  priv.memo_foreach_action[i]->delete_v();
290  delete priv.memo_foreach_action[i];
291  }
292  priv.memo_foreach_action.clear();
293 }
294 
295 search::search() { priv = &calloc_or_throw<search_private>(); }
296 
297 search::~search()
298 {
299  if (this->priv && this->priv->all)
300  {
301  search_private& priv = *this->priv;
302  clear_cache_hash_map(priv);
303 
304  priv._random_state.~shared_ptr<rand_state>();
305  delete priv.truth_string;
306  delete priv.pred_string;
307  delete priv.bad_string_stream;
308  priv.cache_hash_map.~v_hashmap<unsigned char*, scored_action>();
309  priv.rawOutputString.~basic_string();
310  priv.test_action_sequence.~vector<action>();
311  priv.dat_new_feature_audit_ss.~basic_stringstream();
313  priv.timesteps.delete_v();
314  if (priv.cb_learner)
315  priv.learn_losses.cb.costs.delete_v();
316  else
317  priv.learn_losses.cs.costs.delete_v();
318  if (priv.cb_learner)
319  priv.gte_label.cb.costs.delete_v();
320  else
321  priv.gte_label.cs.costs.delete_v();
322 
323  priv.condition_on_actions.delete_v();
325  priv.ldf_test_label.costs.delete_v();
326  priv.last_action_repr.delete_v();
328  for (size_t i = 0; i < priv.active_known.size(); i++) priv.active_known[i].delete_v();
329  priv.active_known.delete_v();
330 
331  if (priv.cb_learner)
332  priv.allowed_actions_cache->cb.costs.delete_v();
333  else
334  priv.allowed_actions_cache->cs.costs.delete_v();
335 
336  priv.train_trajectory.delete_v();
337  for (Search::action_repr& ar : priv.ptag_to_action)
338  {
339  if (ar.repr != nullptr)
340  {
341  ar.repr->delete_v();
342  delete ar.repr;
343  cdbg << "delete_v" << endl;
344  }
345  }
346  priv.ptag_to_action.delete_v();
348  priv.memo_foreach_action.delete_v();
349 
350  // destroy copied examples if we needed them
351  if (!priv.examples_dont_change)
352  {
355  priv.learn_ec_copy.delete_v();
356  }
359  priv.learn_condition_on_act.delete_v();
360 
361  free(priv.allowed_actions_cache);
362  delete priv.rawOutputStringStream;
363  }
364  free(this->priv);
365 }
366 
367 std::string audit_feature_space("conditional");
368 uint64_t conditional_constant = 8290743;
369 
371 {
372  return (priv.state == INIT_TRAIN) && (priv.metatask) && (priv.metaoverride); // &&
373  // (priv.metaoverride->_foreach_action || priv.metaoverride->_post_prediction);
374 }
375 
376 int random_policy(search_private& priv, bool allow_current, bool allow_optimal, bool advance_prng = true)
377 {
378  if (priv.beta >= 1)
379  {
380  if (allow_current)
381  return (int)priv.current_policy;
382  if (priv.current_policy > 0)
383  return (((int)priv.current_policy) - 1);
384  if (allow_optimal)
385  return -1;
386  std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << endl;
387  return (int)priv.current_policy;
388  }
389 
390  int num_valid_policies = (int)priv.current_policy + allow_optimal + allow_current;
391  int pid = -1;
392 
393  if (num_valid_policies == 0)
394  {
395  std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << endl;
396  return (int)priv.current_policy;
397  }
398  else if (num_valid_policies == 1)
399  pid = 0;
400  else if (num_valid_policies == 2)
401  pid = (advance_prng ? priv._random_state->get_and_update_random() : priv._random_state->get_random()) >= priv.beta;
402  else
403  {
404  // SPEEDUP this up in the case that beta is small!
405  float r = (advance_prng ? priv._random_state->get_and_update_random() : priv._random_state->get_random());
406  pid = 0;
407 
408  if (r > priv.beta)
409  {
410  r -= priv.beta;
411  while ((r > 0) && (pid < num_valid_policies - 1))
412  {
413  pid++;
414  r -= priv.beta * powf(1.f - priv.beta, (float)pid);
415  }
416  }
417  }
418  // figure out which policy pid refers to
419  if (allow_optimal && (pid == num_valid_policies - 1))
420  return -1; // this is the optimal policy
421 
422  pid = (int)priv.current_policy - pid;
423  if (!allow_current)
424  pid--;
425 
426  return pid;
427 }
428 
429 // for two-fold cross validation, we double the number of learners
430 // and send examples to one or the other depending on the xor of
431 // (is_training) and (example_id % 2)
432 int select_learner(search_private& priv, int policy, size_t learner_id, bool is_training, bool is_local)
433 {
434  if (policy < 0)
435  return policy; // optimal policy
436  else
437  {
438  if (priv.xv)
439  {
440  learner_id *= 3;
441  if (!is_local)
442  learner_id += 1 + (size_t)(is_training ^ (priv.all->sd->example_number % 2 == 1));
443  }
444  int p = (int)(policy * priv.num_learners + learner_id);
445  return p;
446  }
447 }
448 
449 bool should_print_update(vw& all, bool hit_new_pass = false)
450 {
451  // uncomment to print out final loss after all examples processed
452  // commented for now so that outputs matches make test
453 
454  if (PRINT_UPDATE_EVERY_EXAMPLE)
455  return true;
456  if (PRINT_UPDATE_EVERY_PASS && hit_new_pass)
457  return true;
458  return (all.sd->weighted_examples() >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
459 }
460 
462 {
463  // basically do should_print_update but check me and the next
464  // example because of off-by-ones
465 
466  if (PRINT_UPDATE_EVERY_EXAMPLE)
467  return true;
468  if (PRINT_UPDATE_EVERY_PASS)
469  return true; // SPEEDUP: make this better
470  return (all.sd->weighted_examples() + 1. >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
471 }
472 
473 bool must_run_test(vw& all, multi_ex& ec, bool is_test_ex)
474 {
475  return (all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this
476  might_print_update(all) || // if we have to print and update to stderr
477  (all.raw_prediction > 0) || // we need raw predictions
478  ((!all.vw_is_main) && (is_test_ex)) || // library needs predictions
479  // or:
480  // it's not quiet AND
481  // current_pass == 0
482  // OR holdout is off
483  // OR it's a test example
484  ((!all.quiet || !all.vw_is_main) && // had to disable this because of library mode!
485  (!is_test_ex) &&
486  (all.holdout_set_off || // no holdout
487  ec[0]->test_only || (all.current_pass == 0) // we need error rates for progressive cost
488  ));
489 }
490 
491 float safediv(float a, float b)
492 {
493  if (b == 0.f)
494  return 0.f;
495  else
496  return (a / b);
497 }
498 
499 void to_short_string(std::string in, size_t max_len, char* out)
500 {
501  for (size_t i = 0; i < max_len; i++)
502  out[i] = ((i >= in.length()) || (in[i] == '\n') || (in[i] == '\t')) ? ' ' : in[i];
503 
504  if (in.length() > max_len)
505  {
506  out[max_len - 2] = '.';
507  out[max_len - 1] = '.';
508  }
509  out[max_len] = 0;
510 }
511 
512 std::string number_to_natural(size_t big)
513 {
514  std::stringstream ss;
515  if (big > 9999999999)
516  ss << big / 1000000000 << "g";
517  else if (big > 9999999)
518  ss << big / 1000000 << "m";
519  else if (big > 9999)
520  ss << big / 1000 << "k";
521  else
522  ss << big;
523 
524  return ss.str();
525 }
526 
528 {
529  vw& all = *priv.all;
530  if (!priv.printed_output_header && !all.quiet)
531  {
532  const char* header_fmt = "%-10s %-10s %8s%24s %22s %5s %5s %7s %7s %7s %-8s\n";
533  fprintf(stderr, header_fmt, "average", "since", "instance", "current true", "current predicted", "cur", "cur",
534  "predic", "cache", "examples", "");
535  if (priv.active_csoaa)
536  fprintf(stderr, header_fmt, "loss", "last", "counter", "output prefix", "output prefix", "pass", "pol", "made",
537  "hits", "gener", "#run");
538  else
539  fprintf(stderr, header_fmt, "loss", "last", "counter", "output prefix", "output prefix", "pass", "pol", "made",
540  "hits", "gener", "beta");
541  std::cerr.precision(5);
542  priv.printed_output_header = true;
543  }
544 
545  if (!should_print_update(all, priv.hit_new_pass))
546  return;
547 
548  char true_label[21];
549  char pred_label[21];
550  to_short_string(priv.truth_string->str(), 20, true_label);
551  to_short_string(priv.pred_string->str(), 20, pred_label);
552 
553  float avg_loss = 0.;
554  float avg_loss_since = 0.;
555  bool use_heldout_loss = (!all.holdout_set_off && all.current_pass >= 1) && (all.sd->weighted_holdout_examples > 0);
556  if (use_heldout_loss)
557  {
558  avg_loss = safediv((float)all.sd->holdout_sum_loss, (float)all.sd->weighted_holdout_examples);
559  avg_loss_since = safediv(
561 
564  }
565  else
566  {
567  avg_loss = safediv((float)all.sd->sum_loss, (float)all.sd->weighted_labeled_examples);
568  avg_loss_since = safediv((float)all.sd->sum_loss_since_last_dump,
570  }
571 
572  auto const& inst_cntr = number_to_natural((size_t)all.sd->example_number);
573  auto const& total_pred = number_to_natural(priv.total_predictions_made);
574  auto const& total_cach = number_to_natural(priv.total_cache_hits);
575  auto const& total_exge = number_to_natural(priv.total_examples_generated);
576 
577  fprintf(stderr, "%-10.6f %-10.6f %8s [%s] [%s] %5d %5d %7s %7s %7s %-8f", avg_loss, avg_loss_since,
578  inst_cntr.c_str(), true_label, pred_label, (int)priv.read_example_last_pass, (int)priv.current_policy,
579  total_pred.c_str(), total_cach.c_str(), total_exge.c_str(),
580  priv.active_csoaa ? priv.num_calls_to_run : priv.beta);
581 
582  if (PRINT_CLOCK_TIME)
583  {
584  size_t num_sec = (size_t)(((float)(clock() - priv.start_clock_time)) / CLOCKS_PER_SEC);
585  std::cerr << " " << num_sec << "sec";
586  }
587 
588  if (use_heldout_loss)
589  fprintf(stderr, " h");
590 
591  fprintf(stderr, "\n");
592  fflush(stderr);
594 }
595 
596 void add_new_feature(search_private& priv, float val, uint64_t idx)
597 {
598  uint64_t mask = priv.all->weights.mask();
599  size_t ss = priv.all->weights.stride_shift();
600 
601  uint64_t idx2 = ((idx & mask) >> ss) & mask;
603  fs.push_back(val * priv.dat_new_feature_value, ((priv.dat_new_feature_idx + idx2) << ss));
604  cdbg << "adding: " << fs.indicies.last() << ':' << fs.values.last() << endl;
605  if (priv.all->audit)
606  {
607  std::stringstream temp;
608  temp << "fid=" << ((idx & mask) >> ss) << "_" << priv.dat_new_feature_audit_ss.str();
610  }
611 }
612 
613 void del_features_in_top_namespace(search_private& /* priv */, example& ec, size_t ns)
614 {
615  if ((ec.indices.size() == 0) || (ec.indices.last() != ns))
616  {
617  return;
618  // if (ec.indices.size() == 0)
619  //{ THROW("internal error (bug): expecting top namespace to be '" << ns << "' but it was empty"); }
620  // else
621  //{ THROW("internal error (bug): expecting top namespace to be '" << ns << "' but it was " <<
622  //(size_t)ec.indices.last()); }
623  }
624  features& fs = ec.feature_space[ns];
625  ec.indices.decr();
626  ec.num_features -= fs.size();
628  fs.clear();
629 }
630 
632 {
633  if (priv.neighbor_features.size() == 0)
634  return;
635 
636  uint32_t stride_shift = priv.all->weights.stride_shift();
637  for (size_t n = 0; n < ec_seq.size(); n++) // iterate over every example in the sequence
638  {
639  example& me = *ec_seq[n];
640  for (size_t n_id = 0; n_id < priv.neighbor_features.size(); n_id++)
641  {
642  int32_t offset = priv.neighbor_features[n_id] >> 24;
643  size_t ns = priv.neighbor_features[n_id] & 0xFF;
644 
645  priv.dat_new_feature_ec = &me;
646  priv.dat_new_feature_value = 1.;
647  priv.dat_new_feature_idx = priv.neighbor_features[n_id] * 13748127;
649  if (priv.all->audit)
650  {
652  priv.dat_new_feature_audit_ss.str("");
653  priv.dat_new_feature_audit_ss << '@' << ((offset > 0) ? '+' : '-') << (char)(abs(offset) + '0');
654  if (ns != ' ')
655  priv.dat_new_feature_audit_ss << (char)ns;
656  }
657 
658  // std::cerr << "n=" << n << " offset=" << offset << endl;
659  if ((offset < 0) && (n < (uint64_t)(-offset))) // add <s> feature
660  add_new_feature(priv, 1., (uint64_t)925871901 << stride_shift);
661  else if (n + offset >= ec_seq.size()) // add </s> feature
662  add_new_feature(priv, 1., (uint64_t)3824917 << stride_shift);
663  else // this is actually a neighbor
664  {
665  example& other = *ec_seq[n + offset];
666  GD::foreach_feature<search_private, add_new_feature>(priv.all, other.feature_space[ns], priv, me.ft_offset);
667  }
668  }
669 
671  size_t sz = fs.size();
672  if ((sz > 0) && (fs.sum_feat_sq > 0.))
673  {
675  me.total_sum_feat_sq += fs.sum_feat_sq;
676  me.num_features += sz;
677  }
678  else
679  fs.clear();
680  }
681 }
682 
684 {
685  if (priv.neighbor_features.size() == 0)
686  return;
687  for (size_t n = 0; n < ec_seq.size(); n++) del_features_in_top_namespace(priv, *ec_seq[n], neighbor_namespace);
688 }
689 
691 {
692  // NOTE: make sure do NOT reset priv.learn_a_idx
693  priv.t = 0;
694  priv.meta_t = 0;
695  priv.loss_declared_cnt = 0;
696  priv.done_with_all_actions = false;
697  priv.test_loss = 0.;
698  priv.learn_loss = 0.;
699  priv.train_loss = 0.;
700  priv.num_features = 0;
701  priv.should_produce_string = false;
702  priv.mix_per_roll_policy = -2;
703  priv.force_setup_ec_ref = false;
704  if (priv.adaptive_beta)
705  {
706  float x = -log1pf(-priv.alpha) * (float)priv.total_examples_generated;
707  static constexpr float log_of_2 = (float)0.6931471805599453;
708  priv.beta = (x <= log_of_2) ? -expm1f(-x) : (1 - expf(-x)); // numerical stability
709  // float priv_beta = 1.f - powf(1.f - priv.alpha, (float)priv.total_examples_generated);
710  // assert( fabs(priv_beta - priv.beta) < 1e-2 );
711  if (priv.beta > 1)
712  priv.beta = 1;
713  }
714  for (Search::action_repr& ar : priv.ptag_to_action)
715  {
716  if (ar.repr != nullptr)
717  {
718  ar.repr->delete_v();
719  delete ar.repr;
720  }
721  }
722  priv.ptag_to_action.clear();
723 
724  if (!priv.cb_learner) // was: if rollout_all_actions
725  {
726  priv._random_state->set_random_state((uint32_t)(priv.read_example_last_id * 147483 + 4831921) * 2147483647);
727  }
728 }
729 
731 {
732  priv.loss_declared_cnt++;
733  switch (priv.state)
734  {
735  case INIT_TEST:
736  priv.test_loss += loss;
737  break;
738  case INIT_TRAIN:
739  priv.train_loss += loss;
740  break;
741  case LEARN:
742  if ((priv.rollout_num_steps == 0) || (priv.loss_declared_cnt <= priv.rollout_num_steps))
743  {
744  priv.learn_loss += loss;
745  cdbg << "priv.learn_loss += " << loss << " (now = " << priv.learn_loss << ")" << endl;
746  }
747  break;
748  default:
749  break; // get rid of the warning about missing cases (danger!)
750  }
751 }
752 
753 template <class T>
754 void cdbg_print_array(std::string str, v_array<T>& A)
755 {
756  cdbg << str << " = [";
757  for (size_t i = 0; i < A.size(); i++) cdbg << " " << A[i];
758  cdbg << " ]" << endl;
759 }
760 template <class T>
761 void cerr_print_array(std::string str, v_array<T>& A)
762 {
763  std::cerr << str << " = [";
764  for (size_t i = 0; i < A.size(); i++) std::cerr << " " << A[i];
765  std::cerr << " ]" << endl;
766 }
767 
768 size_t random(std::shared_ptr<rand_state>& rs, size_t max)
769 {
770  return (size_t)(rs->get_and_update_random() * (float)max);
771 }
772 template <class T>
773 bool array_contains(T target, const T* A, size_t n)
774 {
775  if (A == nullptr)
776  return false;
777  for (size_t i = 0; i < n; i++)
778  if (A[i] == target)
779  return true;
780  return false;
781 }
782 
783 // priv.learn_condition_on_act or priv.condition_on_actions
784 void add_example_conditioning(search_private& priv, example& ec, size_t condition_on_cnt,
785  const char* condition_on_names, action_repr* condition_on_actions)
786 {
787  if (condition_on_cnt == 0)
788  return;
789 
790  uint64_t extra_offset = 0;
791  if (priv.is_ldf)
792  if (ec.l.cs.costs.size() > 0)
793  extra_offset = 3849017 * ec.l.cs.costs[0].class_index;
794 
795  size_t I = condition_on_cnt;
796  size_t N = std::max(priv.acset.max_bias_ngram_length, priv.acset.max_quad_ngram_length);
797  for (size_t i = 0; i < I; i++) // position in conditioning
798  {
799  uint64_t fid = 71933 + 8491087 * extra_offset;
800  if (priv.all->audit)
801  {
802  priv.dat_new_feature_audit_ss.str("");
803  priv.dat_new_feature_audit_ss.clear();
805  }
806 
807  for (size_t n = 0; n < N; n++) // length of ngram
808  {
809  if (i + n >= I)
810  break; // no more ngrams
811  // we're going to add features for the ngram condition_on_actions[i .. i+N]
812  uint64_t name = condition_on_names[i + n];
813  fid = fid * 328901 + 71933 * ((condition_on_actions[i + n].a + 349101) * (name + 38490137));
814 
815  priv.dat_new_feature_ec = &ec;
819 
820  if (priv.all->audit)
821  {
822  if (n > 0)
823  priv.dat_new_feature_audit_ss << ',';
824  if ((33 <= name) && (name <= 126))
825  priv.dat_new_feature_audit_ss << name;
826  else
827  priv.dat_new_feature_audit_ss << '#' << (int)name;
828  priv.dat_new_feature_audit_ss << '=' << condition_on_actions[i + n].a;
829  }
830 
831  // add the single bias feature
832  if (n < priv.acset.max_bias_ngram_length)
833  add_new_feature(priv, 1., (uint64_t)4398201 << priv.all->weights.stride_shift());
834  // add the quadratic features
835  if (n < priv.acset.max_quad_ngram_length)
836  GD::foreach_feature<search_private, uint64_t, add_new_feature>(*priv.all, ec, priv);
837  }
838  }
839 
840  if (priv.acset.use_passthrough_repr)
841  {
842  cdbg << "BEGIN adding passthrough features" << endl;
843  for (size_t i = 0; i < I; i++)
844  {
845  if (condition_on_actions[i].repr == nullptr)
846  continue;
847  features& fs = *(condition_on_actions[i].repr);
848  char name = condition_on_names[i];
849  for (size_t k = 0; k < fs.size(); k++)
850  if ((fs.values[k] > 1e-10) || (fs.values[k] < -1e-10))
851  {
852  uint64_t fid = 84913 + 48371803 * (extra_offset + 8392817 * name) + 840137 * (4891 + fs.indicies[k]);
853  if (priv.all->audit)
854  {
855  priv.dat_new_feature_audit_ss.str("");
856  priv.dat_new_feature_audit_ss.clear();
857  priv.dat_new_feature_audit_ss << "passthrough_repr_" << i << '_' << k;
858  }
859 
860  priv.dat_new_feature_ec = &ec;
861  priv.dat_new_feature_idx = fid;
863  priv.dat_new_feature_value = fs.values[k];
864  add_new_feature(priv, 1., (uint64_t)4398201 << priv.all->weights.stride_shift());
865  }
866  }
867  cdbg << "END adding passthrough features" << endl;
868  }
869 
871  if ((con_fs.size() > 0) && (con_fs.sum_feat_sq > 0.))
872  {
874  ec.total_sum_feat_sq += con_fs.sum_feat_sq;
875  ec.num_features += con_fs.size();
876  }
877  else
878  con_fs.clear();
879 }
880 
882 {
883  if ((ec.indices.size() > 0) && (ec.indices.last() == conditioning_namespace))
885 }
886 
887 inline size_t cs_get_costs_size(bool isCB, polylabel& ld) { return isCB ? ld.cb.costs.size() : ld.cs.costs.size(); }
888 
889 inline uint32_t cs_get_cost_index(bool isCB, polylabel& ld, size_t k)
890 {
891  return isCB ? ld.cb.costs[k].action : ld.cs.costs[k].class_index;
892 }
893 
894 inline float cs_get_cost_partial_prediction(bool isCB, polylabel& ld, size_t k)
895 {
896  return isCB ? ld.cb.costs[k].partial_prediction : ld.cs.costs[k].partial_prediction;
897 }
898 
899 inline void cs_set_cost_loss(bool isCB, polylabel& ld, size_t k, float val)
900 {
901  if (isCB)
902  ld.cb.costs[k].cost = val;
903  else
904  ld.cs.costs[k].x = val;
905 }
906 
907 inline void cs_costs_erase(bool isCB, polylabel& ld)
908 {
909  if (isCB)
910  ld.cb.costs.clear();
911  else
912  ld.cs.costs.clear();
913 }
914 
915 inline void cs_costs_resize(bool isCB, polylabel& ld, size_t new_size)
916 {
917  if (isCB)
918  ld.cb.costs.resize(new_size);
919  else
920  ld.cs.costs.resize(new_size);
921 }
922 
923 inline void cs_cost_push_back(bool isCB, polylabel& ld, uint32_t index, float value)
924 {
925  if (isCB)
926  {
927  CB::cb_class cost = {value, index, 0., 0.};
928  ld.cb.costs.push_back(cost);
929  }
930  else
931  {
932  CS::wclass cost = {value, index, 0., 0.};
933  ld.cs.costs.push_back(cost);
934  }
935 }
936 
937 polylabel& allowed_actions_to_ld(search_private& priv, size_t ec_cnt, const action* allowed_actions,
938  size_t allowed_actions_cnt, const float* allowed_actions_cost)
939 {
940  bool isCB = priv.cb_learner;
941  polylabel& ld = *priv.allowed_actions_cache;
942  uint32_t num_costs = (uint32_t)cs_get_costs_size(isCB, ld);
943 
944  if (priv.is_ldf) // LDF version easier
945  {
946  if (num_costs > ec_cnt)
947  cs_costs_resize(isCB, ld, ec_cnt);
948  else if (num_costs < ec_cnt)
949  for (action k = num_costs; k < ec_cnt; k++) cs_cost_push_back(isCB, ld, k, FLT_MAX);
950  }
951  else if (priv.use_action_costs)
952  {
953  // TODO: Weight
954  if (allowed_actions == nullptr)
955  {
956  if (cs_get_costs_size(isCB, ld) != priv.A)
957  {
958  cs_costs_erase(isCB, ld);
959  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, ld, k + 1, 0.);
960  }
961  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, ld, k, allowed_actions_cost[k]);
962  }
963  else // manually specified actions
964  {
965  cs_costs_erase(isCB, ld);
966  for (action k = 0; k < allowed_actions_cnt; k++)
967  cs_cost_push_back(isCB, ld, allowed_actions[k], allowed_actions_cost[k]);
968  }
969  }
970  else // non-LDF version, no action costs
971  {
972  if ((allowed_actions == nullptr) || (allowed_actions_cnt == 0)) // any action is allowed
973  {
974  if (num_costs != priv.A) // if there are already A-many actions, they must be the right ones, unless the user did
975  // something stupid like putting duplicate allowed_actions...
976  {
977  cs_costs_erase(isCB, ld);
978  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, ld, k + 1, FLT_MAX); //+1 because MC is 1-based
979  }
980  }
981  else // we need to peek at allowed_actions
982  {
983  cs_costs_erase(isCB, ld);
984  for (size_t i = 0; i < allowed_actions_cnt; i++) cs_cost_push_back(isCB, ld, allowed_actions[i], FLT_MAX);
985  }
986  }
987 
988  return ld;
989 }
990 
991 void allowed_actions_to_label(search_private& priv, size_t ec_cnt, const action* allowed_actions,
992  size_t allowed_actions_cnt, const float* allowed_actions_cost, const action* oracle_actions,
993  size_t oracle_actions_cnt, polylabel& lab)
994 {
995  bool isCB = priv.cb_learner;
996  if (priv.is_ldf) // LDF version easier
997  {
998  cs_costs_erase(isCB, lab);
999  for (action k = 0; k < ec_cnt; k++)
1000  cs_cost_push_back(isCB, lab, k, array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f);
1001  // std::cerr << "lab = ["; for (size_t i=0; i<lab.cs.costs.size(); i++) cdbg << ' ' << lab.cs.costs[i].class_index
1002  // << ':'
1003  // << lab.cs.costs[i].x; cdbg << " ]" << endl;
1004  }
1005  else if (priv.use_action_costs)
1006  {
1007  // TODO: Weight
1008  if (allowed_actions == nullptr)
1009  {
1010  if (cs_get_costs_size(isCB, lab) != priv.A)
1011  {
1012  cs_costs_erase(isCB, lab);
1013  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, lab, k + 1, 0.);
1014  }
1015  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, lab, k, allowed_actions_cost[k]);
1016  }
1017  else // manually specified actions
1018  {
1019  cs_costs_erase(isCB, lab);
1020  for (action k = 0; k < allowed_actions_cnt; k++)
1021  cs_cost_push_back(isCB, lab, allowed_actions[k], allowed_actions_cost[k]);
1022  }
1023  }
1024  else // non-LDF, no action costs
1025  {
1026  if ((allowed_actions == nullptr) || (allowed_actions_cnt == 0)) // any action is allowed
1027  {
1028  bool set_to_one = false;
1029  if (cs_get_costs_size(isCB, lab) != priv.A)
1030  {
1031  cs_costs_erase(isCB, lab);
1032  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, lab, k + 1, 1.);
1033  set_to_one = true;
1034  }
1035  // std::cerr << "lab = ["; for (size_t i=0; i<lab.cs.costs.size(); i++) cdbg << ' ' << lab.cs.costs[i].class_index
1036  // <<
1037  // ':' << lab.cs.costs[i].x; cdbg << " ]" << endl;
1038  if (oracle_actions_cnt <= 1) // common case to speed up
1039  {
1040  if (!set_to_one)
1041  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, lab, k, 1.);
1042  if (oracle_actions_cnt == 1)
1043  cs_set_cost_loss(isCB, lab, oracle_actions[0] - 1, 0.);
1044  }
1045  else
1046  {
1047  for (action k = 0; k < priv.A; k++)
1048  cs_set_cost_loss(isCB, lab, k, array_contains<action>(k + 1, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f);
1049  }
1050  }
1051  else // only some actions are allowed
1052  {
1053  cs_costs_erase(isCB, lab);
1054  float w = 1.; // array_contains<action>(3, oracle_actions, oracle_actions_cnt) ? 5.f : 1.f;
1055  for (size_t i = 0; i < allowed_actions_cnt; i++)
1056  {
1057  action k = allowed_actions[i];
1059  isCB, lab, k, (array_contains<action>(k, oracle_actions, oracle_actions_cnt)) ? 0.f : w); // 1.f );
1060  }
1061  }
1062  }
1063 }
1064 
1065 template <class T>
1066 void ensure_size(v_array<T>& A, size_t sz)
1067 {
1068  if ((size_t)(A.end_array - A.begin()) < sz)
1069  A.resize(sz * 2 + 1);
1070  A.end() = A.begin() + sz;
1071 }
1072 
1073 template <class T>
1074 void push_at(v_array<T>& v, T item, size_t pos)
1075 {
1076  if (v.size() > pos)
1077  v.begin()[pos] = item;
1078  else
1079  {
1080  if (v.end_array > v.begin() + pos)
1081  {
1082  // there's enough memory, just not enough filler
1083  memset(v.end(), 0, sizeof(T) * (pos - v.size()));
1084  v.begin()[pos] = item;
1085  v.end() = v.begin() + pos + 1;
1086  }
1087  else
1088  {
1089  // there's not enough memory
1090  v.resize(2 * pos + 3);
1091  v.begin()[pos] = item;
1092  v.end() = v.begin() + pos + 1;
1093  }
1094  }
1095 }
1096 
1097 action choose_oracle_action(search_private& priv, size_t ec_cnt, const action* oracle_actions,
1098  size_t oracle_actions_cnt, const action* allowed_actions, size_t allowed_actions_cnt,
1099  const float* allowed_actions_cost)
1100 {
1101  action a = (action)-1;
1102  if (priv.use_action_costs)
1103  {
1104  size_t K = (allowed_actions == nullptr) ? priv.A : allowed_actions_cnt;
1105  cdbg << "costs = [";
1106  for (size_t k = 0; k < K; k++) cdbg << ' ' << allowed_actions_cost[k];
1107  cdbg << " ]" << endl;
1108  float min_cost = FLT_MAX;
1109  for (size_t k = 0; k < K; k++) min_cost = std::min(min_cost, allowed_actions_cost[k]);
1110  cdbg << "min_cost = " << min_cost;
1111  if (min_cost < FLT_MAX)
1112  {
1113  size_t count = 0;
1114  for (size_t k = 0; k < K; k++)
1115  if (allowed_actions_cost[k] <= min_cost)
1116  {
1117  cdbg << ", hit @ " << k;
1118  count++;
1119  if ((count == 1) || (priv._random_state->get_and_update_random() < 1. / (float)count))
1120  {
1121  a = (allowed_actions == nullptr) ? (uint32_t)(k + 1) : allowed_actions[k];
1122  cdbg << "***";
1123  }
1124  }
1125  }
1126  cdbg << endl;
1127  }
1128 
1129  if (a == (action)-1)
1130  {
1131  if ((priv.perturb_oracle > 0.) && (priv.state == INIT_TRAIN) &&
1132  (priv._random_state->get_and_update_random() < priv.perturb_oracle))
1133  oracle_actions_cnt = 0;
1134  a = (oracle_actions_cnt > 0)
1135  ? oracle_actions[random(priv._random_state, oracle_actions_cnt)]
1136  : (allowed_actions_cnt > 0) ? allowed_actions[random(priv._random_state, allowed_actions_cnt)]
1137  : priv.is_ldf ? (action)random(priv._random_state, ec_cnt)
1138  : (action)(1 + random(priv._random_state, priv.A));
1139  }
1140  cdbg << "choose_oracle_action from oracle_actions = [";
1141  for (size_t i = 0; i < oracle_actions_cnt; i++) cdbg << " " << oracle_actions[i];
1142  cdbg << " ], ret=" << a << endl;
1143  if (need_memo_foreach_action(priv) && (priv.state == INIT_TRAIN))
1144  {
1145  v_array<action_cache>* this_cache = new v_array<action_cache>();
1146  *this_cache = v_init<action_cache>();
1147  // TODO we don't really need to construct this polylabel
1148  polylabel l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1149  size_t K = cs_get_costs_size(priv.cb_learner, l);
1150  for (size_t k = 0; k < K; k++)
1151  {
1152  action cl = cs_get_cost_index(priv.cb_learner, l, k);
1153  float cost = array_contains(cl, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f;
1154  this_cache->push_back(action_cache(0., cl, cl == a, cost));
1155  }
1156  assert(priv.memo_foreach_action.size() == priv.meta_t + priv.t - 1);
1157  priv.memo_foreach_action.push_back(this_cache);
1158  cdbg << "memo_foreach_action[" << priv.meta_t + priv.t - 1 << "] = " << this_cache << " from oracle" << endl;
1159  }
1160  return a;
1161 }
1162 
1163 action single_prediction_notLDF(search_private& priv, example& ec, int policy, const action* allowed_actions,
1164  size_t allowed_actions_cnt, const float* allowed_actions_cost, float& a_cost,
1165  action override_action) // if override_action != -1, then we return it as the action and a_cost is set to the
1166  // appropriate cost for that action
1167 {
1168  vw& all = *priv.all;
1169  polylabel old_label = ec.l;
1170  bool need_partial_predictions = need_memo_foreach_action(priv) ||
1171  (priv.metaoverride && priv.metaoverride->_foreach_action) || (override_action != (action)-1) || priv.active_csoaa;
1172  if ((allowed_actions_cnt > 0) || need_partial_predictions)
1173  ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1174  else
1175  ec.l.cs = priv.empty_cs_label;
1176 
1177  cdbg << "allowed_actions_cnt=" << allowed_actions_cnt << ", ec.l = [";
1178  for (size_t i = 0; i < ec.l.cs.costs.size(); i++)
1179  cdbg << ' ' << ec.l.cs.costs[i].class_index << ':' << ec.l.cs.costs[i].x;
1180  cdbg << " ]" << endl;
1181 
1182  as_singleline(priv.base_learner)->predict(ec, policy);
1183 
1184  uint32_t act = ec.pred.multiclass;
1185  cdbg << "a=" << act << " from";
1186  if (allowed_actions)
1187  {
1188  for (size_t ii = 0; ii < allowed_actions_cnt; ii++) cdbg << ' ' << allowed_actions[ii];
1189  }
1190  cdbg << endl;
1191  a_cost = ec.partial_prediction;
1192  cdbg << "a_cost = " << a_cost << endl;
1193 
1194  if (override_action != (action)-1)
1195  act = override_action;
1196 
1197  if (need_partial_predictions)
1198  {
1199  size_t K = cs_get_costs_size(priv.cb_learner, ec.l);
1200  float min_cost = FLT_MAX;
1201  for (size_t k = 0; k < K; k++)
1202  {
1203  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1204  if (cost < min_cost)
1205  min_cost = cost;
1206  }
1207  v_array<action_cache>* this_cache = nullptr;
1208  if (need_memo_foreach_action(priv) && (override_action == (action)-1))
1209  {
1210  this_cache = new v_array<action_cache>();
1211  *this_cache = v_init<action_cache>();
1212  }
1213  for (size_t k = 0; k < K; k++)
1214  {
1215  action cl = cs_get_cost_index(priv.cb_learner, ec.l, k);
1216  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1217  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1218  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, priv.t - 1, min_cost, cl, cl == act, cost);
1219  if (override_action == cl)
1220  a_cost = cost;
1221  if (this_cache)
1222  this_cache->push_back(action_cache(min_cost, cl, cl == act, cost));
1223  }
1224  if (this_cache)
1225  {
1226  assert(priv.memo_foreach_action.size() == priv.meta_t + priv.t - 1);
1227  priv.memo_foreach_action.push_back(this_cache);
1228  cdbg << "memo_foreach_action[" << priv.meta_t + priv.t - 1 << "] = " << this_cache << endl;
1229  }
1230  }
1231 
1232  if ((priv.state == INIT_TRAIN) && (priv.subsample_timesteps <= -1)) // active learning
1233  {
1234  size_t K = cs_get_costs_size(priv.cb_learner, ec.l);
1235  float min_cost = FLT_MAX, min_cost2 = FLT_MAX;
1236  for (size_t k = 0; k < K; k++)
1237  {
1238  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1239  if (cost < min_cost)
1240  {
1241  min_cost2 = min_cost;
1242  min_cost = cost;
1243  }
1244  else if (cost < min_cost2)
1245  {
1246  min_cost2 = cost;
1247  }
1248  }
1249  if (min_cost2 < FLT_MAX)
1250  priv.active_uncertainty.push_back(std::make_pair(min_cost2 - min_cost, priv.t + priv.meta_t));
1251  }
1252  if ((priv.state == INIT_TRAIN) && priv.active_csoaa)
1253  {
1254  if (priv.cb_learner)
1255  THROW("cannot use active_csoaa with cb learning");
1256  size_t cur_t = priv.t + priv.meta_t - 1;
1257  while (priv.active_known.size() <= cur_t)
1258  {
1259  priv.active_known.push_back(v_array<std::pair<CS::wclass&, bool>>());
1260  priv.active_known[priv.active_known.size() - 1] = v_init<std::pair<CS::wclass&, bool>>();
1261  cdbg << "active_known length now " << priv.active_known.size() << endl;
1262  }
1263  priv.active_known[cur_t].clear();
1264  assert(ec.l.cs.costs.size() > 0);
1265  for (size_t k = 0; k < ec.l.cs.costs.size(); k++)
1266  {
1267  /* priv.active_known[cur_t].push_back( ec.l.cs.costs[k].pred_is_certain
1268  ? ec.l.cs.costs[k].partial_prediction
1269  : FLT_MAX );
1270  cdbg << "active_known[" << cur_t << "][" << (priv.active_known[cur_t].size() -
1271  1) << "] = certain=" << ec.l.cs.costs[k].pred_is_certain << ", cost=" << ec.l.cs.costs[k].partial_prediction <<
1272  "}" << endl; */
1273  CS::wclass& wc = ec.l.cs.costs[k];
1274  // Get query_needed from pred
1275  bool query_needed = v_array_contains(ec.pred.multilabels.label_v, wc.class_index);
1276  std::pair<CS::wclass&, bool> p = {wc, query_needed};
1277  // Push into active_known[cur_t] with wc
1278  priv.active_known[cur_t].push_back(p);
1279  // cdbg << "active_known[" << cur_t << "][" << (priv.active_known[cur_t].size() - 1) << "] = " << wc.class_index
1280  // << ':' << wc.x << " pp=" << wc.partial_prediction << " query_needed=" << wc.query_needed << " max_pred=" <<
1281  // wc.max_pred << " min_pred=" << wc.min_pred << " is_range_overlapped=" << wc.is_range_overlapped << "
1282  // is_range_large=" << wc.is_range_large << endl;
1283  // query_needed=" << ec.l.cs.costs[k].query_needed << ", cost=" << ec.l.cs.costs[k].partial_prediction << "}" <<
1284  // endl;
1285  }
1286  }
1287 
1288  // generate raw predictions if necessary
1289  if ((priv.state == INIT_TEST) && (all.raw_prediction > 0))
1290  {
1291  priv.rawOutputStringStream->str("");
1292  for (size_t k = 0; k < cs_get_costs_size(priv.cb_learner, ec.l); k++)
1293  {
1294  if (k > 0)
1295  (*priv.rawOutputStringStream) << ' ';
1296  (*priv.rawOutputStringStream) << cs_get_cost_index(priv.cb_learner, ec.l, k) << ':'
1298  }
1299  all.print_text(all.raw_prediction, priv.rawOutputStringStream->str(), ec.tag);
1300  }
1301 
1302  ec.l = old_label;
1303 
1304  priv.total_predictions_made++;
1305  priv.num_features += ec.num_features;
1306 
1307  return act;
1308 }
1309 
1310 action single_prediction_LDF(search_private& priv, example* ecs, size_t ec_cnt, int policy, float& a_cost,
1311  action override_action) // if override_action != -1, then we return it as the action and a_cost is set to the
1312  // appropriate cost for that action
1313 {
1314  bool need_partial_predictions = need_memo_foreach_action(priv) ||
1315  (priv.metaoverride && priv.metaoverride->_foreach_action) || (override_action != (action)-1);
1316 
1318  CS::wclass wc = {0., 1, 0., 0.};
1319  priv.ldf_test_label.costs.push_back(wc);
1320 
1321  // keep track of best (aka chosen) action
1322  float best_prediction = 0.;
1323  action best_action = 0;
1324 
1325  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(ecs[0])) ? 1 : 0;
1326 
1327  v_array<action_cache>* this_cache = nullptr;
1328  if (need_partial_predictions)
1329  {
1330  this_cache = new v_array<action_cache>();
1331  *this_cache = v_init<action_cache>();
1332  }
1333 
1334  for (action a = (uint32_t)start_K; a < ec_cnt; a++)
1335  {
1336  cdbg << "== single_prediction_LDF a=" << a << "==" << endl;
1337  if (start_K > 0)
1339 
1340  polylabel old_label = ecs[a].l;
1341  ecs[a].l.cs = priv.ldf_test_label;
1342 
1343  multi_ex tmp;
1344  uint64_t old_offset = ecs[a].ft_offset;
1345  ecs[a].ft_offset = priv.offset;
1346  tmp.push_back(&ecs[a]);
1347  as_multiline(priv.base_learner)->predict(tmp, policy);
1348 
1349  ecs[a].ft_offset = old_offset;
1350  cdbg << "partial_prediction[" << a << "] = " << ecs[a].partial_prediction << endl;
1351 
1352  if (override_action != (action)-1)
1353  {
1354  if (a == override_action)
1355  a_cost = ecs[a].partial_prediction;
1356  }
1357  else if ((a == start_K) || (ecs[a].partial_prediction < best_prediction))
1358  {
1359  best_prediction = ecs[a].partial_prediction;
1360  best_action = a;
1361  a_cost = best_prediction;
1362  }
1363  if (this_cache)
1364  this_cache->push_back(action_cache(0., a, false, ecs[a].partial_prediction));
1365 
1366  priv.num_features += ecs[a].num_features;
1367  ecs[a].l = old_label;
1368  if (start_K > 0)
1370  }
1371  if (override_action != (action)-1)
1372  best_action = override_action;
1373  else
1374  a_cost = best_prediction;
1375 
1376  if (this_cache)
1377  {
1378  for (size_t i = 0; i < this_cache->size(); i++)
1379  {
1380  action_cache& ac = (*this_cache)[i];
1381  ac.min_cost = a_cost;
1382  ac.is_opt = (ac.k == best_action);
1383  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1384  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, priv.t - 1, ac.min_cost, ac.k, ac.is_opt, ac.cost);
1385  }
1386  if (need_memo_foreach_action(priv) && (override_action == (action)-1))
1387  priv.memo_foreach_action.push_back(this_cache);
1388  else
1389  {
1390  this_cache->delete_v();
1391  delete this_cache;
1392  }
1393  }
1394 
1395  // TODO: generate raw predictions if necessary
1396 
1397  priv.total_predictions_made++;
1398  return best_action;
1399 }
1400 
1401 int choose_policy(search_private& priv, bool advance_prng = true)
1402 {
1403  RollMethod method = (priv.state == INIT_TEST) ? POLICY
1404  : (priv.state == LEARN)
1405  ? priv.rollout_method
1406  : (priv.state == INIT_TRAIN) ? priv.rollin_method : NO_ROLLOUT; // this should never happen
1407  switch (method)
1408  {
1409  case POLICY:
1410  return random_policy(priv, priv.allow_current_policy || (priv.state == INIT_TEST), false, advance_prng);
1411 
1412  case ORACLE:
1413  return -1;
1414 
1415  case MIX_PER_STATE:
1416  return random_policy(priv, priv.allow_current_policy, true, advance_prng);
1417 
1418  case MIX_PER_ROLL:
1419  if (priv.mix_per_roll_policy == -2) // then we have to choose one!
1420  priv.mix_per_roll_policy = random_policy(priv, priv.allow_current_policy, true, advance_prng);
1421  return priv.mix_per_roll_policy;
1422 
1423  case NO_ROLLOUT:
1424  default:
1425  THROW("internal error (bug): trying to rollin or rollout with NO_ROLLOUT");
1426  }
1427 }
1428 
1429 bool cached_item_equivalent(unsigned char* const& A, unsigned char* const& B)
1430 {
1431  size_t sz_A = *A;
1432  size_t sz_B = *B;
1433  if (sz_A != sz_B)
1434  return false;
1435  return memcmp(A, B, sz_A) == 0;
1436 }
1437 // returns true if found and do_store is false. if do_store is true, always returns true.
1438 bool cached_action_store_or_find(search_private& priv, ptag mytag, const ptag* condition_on,
1439  const char* condition_on_names, action_repr* condition_on_actions, size_t condition_on_cnt, int policy,
1440  size_t learner_id, action& a, bool do_store, float& a_cost)
1441 {
1442  if (priv.no_caching)
1443  return do_store;
1444  if (mytag == 0)
1445  return do_store; // don't attempt to cache when tag is zero
1446 
1447  size_t sz = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) +
1448  condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char));
1449  if (sz % 4 != 0)
1450  sz += 4 - (sz % 4); // make sure sz aligns to 4 so that uniform_hash does the right thing
1451 
1452  unsigned char* item = calloc_or_throw<unsigned char>(sz);
1453  unsigned char* here = item;
1454  *here = (unsigned char)sz;
1455  here += sizeof(size_t);
1456  *here = mytag;
1457  here += sizeof(ptag);
1458  *here = policy;
1459  here += sizeof(int);
1460  *here = (unsigned char)learner_id;
1461  here += sizeof(size_t);
1462  *here = (unsigned char)condition_on_cnt;
1463  here += sizeof(size_t);
1464  for (size_t i = 0; i < condition_on_cnt; i++)
1465  {
1466  *here = condition_on[i];
1467  here += sizeof(ptag);
1468  *here = condition_on_actions[i].a;
1469  here += sizeof(action);
1470  *here = condition_on_names[i];
1471  here += sizeof(char); // SPEEDUP: should we align this at 4?
1472  }
1473  uint64_t hash = uniform_hash(item, sz, 3419);
1474 
1475  if (do_store)
1476  {
1477  priv.cache_hash_map.put(item, hash, scored_action(a, a_cost));
1478  return true;
1479  }
1480  else // its a find
1481  {
1482  scored_action sa = priv.cache_hash_map.get(item, hash);
1483  a = sa.a;
1484  a_cost = sa.s;
1485  free(item);
1486  return a != (action)-1;
1487  }
1488 }
1489 
1490 void generate_training_example(search_private& priv, polylabel& losses, float weight, bool add_conditioning = true,
1491  float min_loss = FLT_MAX) // min_loss = FLT_MAX means "please compute it for me as the actual min"; any other value
1492  // means to use this
1493 {
1494  // should we really subtract out min-loss?
1495  // float min_loss = FLT_MAX;
1496  if (priv.cb_learner)
1497  {
1498  if (min_loss == FLT_MAX)
1499  for (size_t i = 0; i < losses.cb.costs.size(); i++) min_loss = std::min(min_loss, losses.cb.costs[i].cost);
1500  for (size_t i = 0; i < losses.cb.costs.size(); i++) losses.cb.costs[i].cost = losses.cb.costs[i].cost - min_loss;
1501  }
1502  else
1503  {
1504  if (min_loss == FLT_MAX)
1505  for (size_t i = 0; i < losses.cs.costs.size(); i++) min_loss = std::min(min_loss, losses.cs.costs[i].x);
1506  for (size_t i = 0; i < losses.cs.costs.size(); i++)
1507  losses.cs.costs[i].x = (losses.cs.costs[i].x - min_loss) * weight;
1508  }
1509  // std::cerr << "losses = ["; for (size_t i=0; i<losses.cs.costs.size(); i++) std::cerr << ' ' <<
1510  // losses.cs.costs[i].class_index
1511  // << ':' << losses.cs.costs[i].x; std::cerr << " ]" << endl;
1512 
1513  if (!priv.is_ldf) // not LDF
1514  {
1515  // since we're not LDF, it should be the case that ec_ref_cnt == 1
1516  // and learn_ec_ref[0] is a pointer to a single example
1517  assert(priv.learn_ec_ref_cnt == 1);
1518  assert(priv.learn_ec_ref != nullptr);
1519 
1520  example& ec = priv.learn_ec_ref[0];
1521  polylabel old_label = ec.l;
1522  ec.l = losses; // labels;
1523  if (add_conditioning)
1525  priv.learn_condition_on_act.begin());
1526  for (size_t is_local = 0; is_local <= (size_t)priv.xv; is_local++)
1527  {
1528  int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id, true, is_local > 0);
1529  ec.in_use = true;
1530  cdbg << "BEGIN base_learner->learn(ec, " << learner << ")" << endl;
1531  as_singleline(priv.base_learner)->learn(ec, learner);
1532  cdbg << "END base_learner->learn(ec, " << learner << ")" << endl;
1533  }
1534  if (add_conditioning)
1535  del_example_conditioning(priv, ec);
1536  ec.l = old_label;
1537  priv.total_examples_generated++;
1538  }
1539  else // is LDF
1540  {
1541  assert(cs_get_costs_size(priv.cb_learner, losses) == priv.learn_ec_ref_cnt);
1542  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(priv.learn_ec_ref[0])) ? 1 : 0;
1543 
1544  // TODO: weight
1545  if (add_conditioning)
1546  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1547  {
1548  example& ec = priv.learn_ec_ref[a];
1550  priv.learn_condition_on_act.begin());
1551  }
1552 
1553  for (size_t is_local = 0; is_local <= (size_t)priv.xv; is_local++)
1554  {
1555  int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id, true, is_local > 0);
1556 
1557  // create an example collection for
1558 
1559  multi_ex tmp;
1560  uint64_t tmp_offset = 0;
1561  if (priv.learn_ec_ref_cnt > start_K)
1562  tmp_offset = priv.learn_ec_ref[start_K].ft_offset;
1563  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1564  {
1565  example& ec = priv.learn_ec_ref[a];
1566  CS::label& lab = ec.l.cs;
1567  if (lab.costs.size() == 0)
1568  {
1569  CS::wclass wc = {0., a - (uint32_t)start_K, 0., 0.};
1570  lab.costs.push_back(wc);
1571  }
1572  lab.costs[0].x = losses.cs.costs[a - start_K].x;
1573  ec.in_use = true;
1574  // store the offset to restore it later
1575  ec.ft_offset = priv.offset;
1576  // create the example collection used to learn
1577  tmp.push_back(&ec);
1578  cdbg << "generate_training_example called learn on action a=" << a << ", costs.size=" << lab.costs.size()
1579  << " ec=" << &ec << endl;
1580  priv.total_examples_generated++;
1581  }
1582 
1583  // learn with the multiline example
1584  as_multiline(priv.base_learner)->learn(tmp, learner);
1585 
1586  // restore the offsets in examples
1587  int i = 0;
1588  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++, i++)
1589  priv.learn_ec_ref[a].ft_offset = tmp_offset;
1590  }
1591 
1592  if (add_conditioning)
1593  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1594  {
1595  example& ec = priv.learn_ec_ref[a];
1596  del_example_conditioning(priv, ec);
1597  }
1598  }
1599 }
1600 
1602 {
1603  // this is basically copied from the logic of search_predict()
1604  switch (priv.state)
1605  {
1606  case INITIALIZE:
1607  return false;
1608  case GET_TRUTH_STRING:
1609  return false;
1610  case INIT_TEST:
1611  return true;
1612  case INIT_TRAIN:
1613  // TODO: do we need to do something here for metatasks?
1614  // if (priv.beam && (priv.t < priv.beam_actions.size()))
1615  // return false;
1616  if (priv.rollout_method == NO_ROLLOUT)
1617  return true;
1618  break;
1619  case LEARN:
1620  if (priv.t + priv.meta_t < priv.learn_t)
1621  return false; // TODO: in meta search mode with foreach feature we'll need it even here
1622  if (priv.t + priv.meta_t == priv.learn_t)
1623  return true; // SPEEDUP: we really only need it on the last learn_a, but this is hard to know...
1624  // t > priv.learn_t
1625  if ((priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps))
1626  return false; // skipping
1627  break;
1628  }
1629 
1630  int pol = choose_policy(priv, false); // choose a policy but don't advance prng
1631  return (pol != -1);
1632 }
1633 
1634 void foreach_action_from_cache(search_private& priv, size_t t, action override_a = (action)-1)
1635 {
1636  cdbg << "foreach_action_from_cache: t=" << t << ", memo_foreach_action.size()=" << priv.memo_foreach_action.size()
1637  << ", override_a=" << override_a << endl;
1638  assert(t < priv.memo_foreach_action.size());
1639  v_array<action_cache>* cached = priv.memo_foreach_action[t];
1640  if (!cached)
1641  return; // the only way this can happen is if the metatask overrode this action
1642  cdbg << "memo_foreach_action size = " << cached->size() << endl;
1643  for (size_t id = 0; id < cached->size(); id++)
1644  {
1645  action_cache& ac = (*cached)[id];
1646  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, t - priv.meta_t, ac.min_cost, ac.k,
1647  (override_a == (action)-1) ? ac.is_opt : (ac.k == override_a), ac.cost);
1648  }
1649 }
1650 
1651 // note: ec_cnt should be 1 if we are not LDF
1652 action search_predict(search_private& priv, example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions,
1653  size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, const action* allowed_actions,
1654  size_t allowed_actions_cnt, const float* allowed_actions_cost, size_t learner_id, float& a_cost, float /* weight */)
1655 {
1656  size_t condition_on_cnt = condition_on_names ? strlen(condition_on_names) : 0;
1657  size_t t = priv.t + priv.meta_t;
1658  priv.t++;
1659 
1660  // make sure parameters come in pairs correctly
1661  assert((oracle_actions == nullptr) == (oracle_actions_cnt == 0));
1662  assert((condition_on == nullptr) == (condition_on_names == nullptr));
1663  assert(((allowed_actions == nullptr) && (allowed_actions_cost == nullptr)) == (allowed_actions_cnt == 0));
1664  assert(priv.use_action_costs == (allowed_actions_cost != nullptr));
1665  if (allowed_actions_cost != nullptr)
1666  assert(oracle_actions == nullptr);
1667 
1668  // if we're just after the string, choose an oracle action
1669  if ((priv.state == GET_TRUTH_STRING) || priv.force_oracle)
1670  {
1672  priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1673  // if (priv.metaoverride && priv.metaoverride->_post_prediction)
1674  // priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t-priv.meta_t, a, 0.);
1675  a_cost = 0.;
1676  return a;
1677  }
1678 
1679  // if we're in LEARN mode and before learn_t, return the train action
1680  if ((priv.state == LEARN) && (t < priv.learn_t))
1681  {
1682  assert(t < priv.train_trajectory.size());
1683  action a = priv.train_trajectory[t].a;
1684  a_cost = priv.train_trajectory[t].s;
1685  cdbg << "LEARN " << t << " < priv.learn_t ==> a=" << a << ", a_cost=" << a_cost << endl;
1686  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1687  foreach_action_from_cache(priv, t);
1688  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1689  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1690  return a;
1691  }
1692 
1693  // for LDF, # of valid actions is ec_cnt; otherwise it's either allowed_actions_cnt or A
1694  size_t valid_action_cnt = priv.is_ldf ? ec_cnt : (allowed_actions_cnt > 0) ? allowed_actions_cnt : priv.A;
1695 
1696  // if we're in LEARN mode and _at_ learn_t, then:
1697  // - choose the next action
1698  // - decide if we're done
1699  // - if we are, then copy/mark the example ref
1700  if ((priv.state == LEARN) && (t == priv.learn_t))
1701  {
1702  action a = (action)priv.learn_a_idx;
1703  priv.loss_declared_cnt = 0;
1704 
1705  cdbg << "LEARN " << t << " = priv.learn_t ==> a=" << a << ", learn_a_idx=" << priv.learn_a_idx
1706  << " valid_action_cnt=" << valid_action_cnt << endl;
1707  priv.learn_a_idx++;
1708 
1709  // check to see if we're done with available actions
1710  if (priv.learn_a_idx >= valid_action_cnt)
1711  {
1712  priv.done_with_all_actions = true;
1713  priv.learn_learner_id = learner_id;
1714 
1715  // set reference or copy example(s)
1716  if (oracle_actions_cnt > 0)
1717  priv.learn_oracle_action = oracle_actions[0];
1718  priv.learn_ec_ref_cnt = ec_cnt;
1719  if (priv.examples_dont_change)
1720  priv.learn_ec_ref = ecs;
1721  else
1722  {
1723  size_t label_size = priv.is_ldf ? sizeof(CS::label) : sizeof(MC::label_t);
1724  void (*label_copy_fn)(void*, void*) = priv.is_ldf ? CS::cs_label.copy_label : nullptr;
1725 
1726  ensure_size(priv.learn_ec_copy, ec_cnt);
1727  for (size_t i = 0; i < ec_cnt; i++)
1728  VW::copy_example_data(priv.all->audit, priv.learn_ec_copy.begin() + i, ecs + i, label_size, label_copy_fn);
1729 
1730  priv.learn_ec_ref = priv.learn_ec_copy.begin();
1731  }
1732 
1733  // copy conditioning stuff and allowed actions
1734  if (priv.auto_condition_features)
1735  {
1736  ensure_size(priv.learn_condition_on, condition_on_cnt);
1737  ensure_size(priv.learn_condition_on_act, condition_on_cnt);
1738 
1739  priv.learn_condition_on.end() =
1740  priv.learn_condition_on.begin() + condition_on_cnt; // allow .size() to be used in lieu of _cnt
1741 
1742  memcpy(priv.learn_condition_on.begin(), condition_on, condition_on_cnt * sizeof(ptag));
1743 
1744  for (size_t i = 0; i < condition_on_cnt; i++)
1746  action_repr(((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size()))
1747  ? priv.ptag_to_action[condition_on[i]]
1748  : 0),
1749  i);
1750 
1751  if (condition_on_names == nullptr)
1752  {
1754  priv.learn_condition_on_names[0] = 0;
1755  }
1756  else
1757  {
1758  ensure_size(priv.learn_condition_on_names, strlen(condition_on_names) + 1);
1759  strcpy(priv.learn_condition_on_names.begin(), condition_on_names);
1760  }
1761  }
1762 
1763  if (allowed_actions && (allowed_actions_cnt > 0))
1764  {
1765  ensure_size(priv.learn_allowed_actions, allowed_actions_cnt);
1766  memcpy(priv.learn_allowed_actions.begin(), allowed_actions, allowed_actions_cnt * sizeof(action));
1767  cdbg_print_array("in LEARN, learn_allowed_actions", priv.learn_allowed_actions);
1768  }
1769  }
1770 
1771  assert((allowed_actions_cnt == 0) || (a < allowed_actions_cnt));
1772 
1773  a_cost = 0.;
1774  action a_name = (allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[a] : priv.is_ldf ? a : (a + 1);
1775  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1776  {
1777  foreach_action_from_cache(priv, t, a_name);
1778  if (priv.memo_foreach_action[t])
1779  {
1780  cdbg << "@ memo_foreach_action: t=" << t << ", a=" << a << ", cost=" << (*priv.memo_foreach_action[t])[a].cost
1781  << endl;
1782  a_cost = (*priv.memo_foreach_action[t])[a].cost;
1783  }
1784  }
1785 
1786  a = a_name;
1787 
1788  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1789  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1790  return a;
1791  }
1792 
1793  if ((priv.state == LEARN) && (t > priv.learn_t) && (priv.rollout_num_steps > 0) &&
1794  (priv.loss_declared_cnt >= priv.rollout_num_steps))
1795  {
1796  cdbg << "... skipping" << endl;
1797  action a = priv.is_ldf ? 0 : ((allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[0] : 1);
1798  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1799  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, 0.);
1800  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1801  foreach_action_from_cache(priv, t);
1802  a_cost = 0.;
1803  return a;
1804  }
1805 
1806  if ((priv.state == INIT_TRAIN) || (priv.state == INIT_TEST) || ((priv.state == LEARN) && (t > priv.learn_t)))
1807  {
1808  // we actually need to run the policy
1809 
1810  int policy = choose_policy(priv);
1811  action a = 0;
1812 
1813  cdbg << "executing policy " << policy << endl;
1814 
1815  bool gte_here = (priv.state == INIT_TRAIN) && (priv.rollout_method == NO_ROLLOUT) &&
1816  ((oracle_actions_cnt > 0) || (priv.use_action_costs));
1817  a_cost = 0.;
1818  bool skip = false;
1819 
1821  (priv.state != LEARN)) // if LEARN and t>learn_t,then we cannot allow overrides!
1822  {
1823  skip = priv.metaoverride->_maybe_override_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1824  cdbg << "maybe_override_prediction --> " << skip << ", a=" << a << ", a_cost=" << a_cost << endl;
1825  if (skip && need_memo_foreach_action(priv))
1826  priv.memo_foreach_action.push_back(nullptr);
1827  }
1828 
1829  if ((!skip) && (policy == -1))
1830  a = choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt,
1831  allowed_actions_cost); // TODO: we probably want to actually get costs for oracle actions???
1832 
1833  bool need_fea = (policy == -1) && priv.metaoverride && priv.metaoverride->_foreach_action;
1834 
1835  if ((policy >= 0) || gte_here || need_fea) // the last case is we need to do foreach action
1836  {
1837  int learner = select_learner(priv, policy, learner_id, false, priv.state != INIT_TEST);
1838 
1839  ensure_size(priv.condition_on_actions, condition_on_cnt);
1840  for (size_t i = 0; i < condition_on_cnt; i++)
1841  priv.condition_on_actions[i] = ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size()))
1842  ? priv.ptag_to_action[condition_on[i]]
1843  : 0;
1844 
1845  bool not_test = priv.all->training && !ecs[0].test_only;
1846 
1847  if ((!skip) && (!need_fea) && not_test &&
1848  cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin(),
1849  condition_on_cnt, policy, learner_id, a, false, a_cost))
1850  // if this succeeded, 'a' has the right action
1851  priv.total_cache_hits++;
1852  else // we need to predict, and then cache, and maybe run foreach_action
1853  {
1854  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(ecs[0])) ? 1 : 0;
1855  priv.last_action_repr.clear();
1856  if (priv.auto_condition_features)
1857  for (size_t n = start_K; n < ec_cnt; n++)
1859  priv, ecs[n], condition_on_cnt, condition_on_names, priv.condition_on_actions.begin());
1860 
1861  if (((!skip) && (policy >= 0)) || need_fea) // only make a prediction if we're going to use the output
1862  {
1864  {
1865  if (priv.is_ldf)
1866  {
1867  THROW("search cannot use state representations in ldf mode");
1868  }
1869  if (ecs[0].passthrough)
1870  {
1871  THROW("search cannot passthrough");
1872  }
1873  ecs[0].passthrough = &priv.last_action_repr;
1874  }
1875  a = priv.is_ldf ? single_prediction_LDF(priv, ecs, ec_cnt, learner, a_cost, need_fea ? a : (action)-1)
1876  : single_prediction_notLDF(priv, *ecs, learner, allowed_actions, allowed_actions_cnt,
1877  allowed_actions_cost, a_cost, need_fea ? a : (action)-1);
1878 
1879  cdbg << "passthrough = [";
1880  for (size_t kk = 0; kk < priv.last_action_repr.size(); kk++)
1881  cdbg << ' ' << priv.last_action_repr.indicies[kk] << ':' << priv.last_action_repr.values[kk];
1882  cdbg << " ]" << endl;
1883 
1884  ecs[0].passthrough = nullptr;
1885  }
1886 
1887  if (need_fea)
1888  {
1889  // TODO this
1890  }
1891 
1892  if (gte_here)
1893  {
1894  cdbg << "INIT_TRAIN, NO_ROLLOUT, at least one oracle_actions, a=" << a << endl;
1895  // we can generate a training example _NOW_ because we're not doing rollouts
1896  // allowed_actions_to_losses(priv, ec_cnt, allowed_actions, allowed_actions_cnt, oracle_actions,
1897  // oracle_actions_cnt, losses);
1898  allowed_actions_to_label(priv, ec_cnt, allowed_actions, allowed_actions_cnt, allowed_actions_cost,
1899  oracle_actions, oracle_actions_cnt, priv.gte_label);
1900  cdbg << "priv.gte_label = [";
1901  for (size_t i = 0; i < priv.gte_label.cs.costs.size(); i++)
1902  cdbg << ' ' << priv.gte_label.cs.costs[i].class_index << ':' << priv.gte_label.cs.costs[i].x;
1903  cdbg << " ]" << endl;
1904 
1905  priv.learn_ec_ref = ecs;
1906  priv.learn_ec_ref_cnt = ec_cnt;
1907  if (allowed_actions)
1908  {
1909  ensure_size(priv.learn_allowed_actions, allowed_actions_cnt); // TODO: do we really need this?
1910  memcpy(priv.learn_allowed_actions.begin(), allowed_actions, allowed_actions_cnt * sizeof(action));
1911  }
1912  size_t old_learner_id = priv.learn_learner_id;
1913  priv.learn_learner_id = learner_id;
1915  priv, priv.gte_label, 1., false); // this is false because the conditioning has already been added!
1916  priv.learn_learner_id = old_learner_id;
1917  }
1918 
1919  if (priv.auto_condition_features)
1920  for (size_t n = start_K; n < ec_cnt; n++) del_example_conditioning(priv, ecs[n]);
1921 
1922  if (not_test && (!skip))
1923  cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin(),
1924  condition_on_cnt, policy, learner_id, a, true, a_cost);
1925  }
1926  }
1927 
1928  if (priv.state == INIT_TRAIN)
1929  priv.train_trajectory.push_back(scored_action(a, a_cost)); // note the action for future reference
1930 
1931  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1932  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1933 
1934  return a;
1935  }
1936 
1937  THROW("error: predict called in unknown state");
1938 }
1939 
1940 inline bool cmp_size_t(const size_t a, const size_t b) { return a < b; }
1941 inline bool cmp_size_t_pair(const std::pair<size_t, size_t>& a, const std::pair<size_t, size_t>& b)
1942 {
1943  return ((a.first == b.first) && (a.second < b.second)) || (a.first < b.first);
1944 }
1945 
1946 inline size_t absdiff(size_t a, size_t b) { return (a < b) ? (b - a) : (a - b); }
1947 
1948 void hoopla_permute(size_t* B, size_t* end)
1949 {
1950  // from Curtis IPL 2004, "Darts and hoopla board design"
1951  // first sort
1952  size_t N = end - B;
1953  std::sort(B, end, cmp_size_t);
1954  // make some temporary space
1955  size_t* A = calloc_or_throw<size_t>((N + 1) * 2);
1956  A[N] = B[0]; // arbitrarily choose the maximum in the middle
1957  A[N + 1] = B[N - 1]; // so the maximum goes next to it
1958  size_t lo = N, hi = N + 1; // which parts of A have we filled in? [lo,hi]
1959  size_t i = 0, j = N - 1; // which parts of B have we already covered? [0,i] and [j,N-1]
1960  while (i + 1 < j)
1961  {
1962  // there are four options depending on where things get placed
1963  size_t d1 = absdiff(A[lo], B[i + 1]); // put B[i+1] at the bottom
1964  size_t d2 = absdiff(A[lo], B[j - 1]); // put B[j-1] at the bottom
1965  size_t d3 = absdiff(A[hi], B[i + 1]); // put B[i+1] at the top
1966  size_t d4 = absdiff(A[hi], B[j - 1]); // put B[j-1] at the top
1967  size_t mx = std::max(std::max(d1, d2), std::max(d3, d4));
1968  if (d1 >= mx)
1969  A[--lo] = B[++i];
1970  else if (d2 >= mx)
1971  A[--lo] = B[--j];
1972  else if (d3 >= mx)
1973  A[++hi] = B[++i];
1974  else
1975  A[++hi] = B[--j];
1976  }
1977  // copy it back to B
1978  memcpy(B, A + lo, N * sizeof(size_t));
1979  // clean up
1980  free(A);
1981 }
1982 
1984 {
1985  timesteps.clear();
1986 
1987  // if there's active learning, we need to
1988  if (priv.subsample_timesteps <= -1)
1989  {
1990  for (size_t i = 0; i < priv.active_uncertainty.size(); i++)
1991  if (priv._random_state->get_and_update_random() > priv.active_uncertainty[i].first)
1992  timesteps.push_back(priv.active_uncertainty[i].second - 1);
1993  /*
1994  float k = (float)priv.total_examples_generated;
1995  priv.ec_seq[t]->revert_weight = priv.all->loss->getRevertingWeight(priv.all->sd, priv.ec_seq[t].pred.scalar,
1996  priv.all->eta / powf(k, priv.all->power_t)); float importance = query_decision(active_str, *priv.ec_seq[t], k); if
1997  (importance > 0.) timesteps.push_back(pair<size_t,size_t>(0,t));
1998  */
1999  }
2000  // if there's no subsampling to do, just return [0,T)
2001  else if (priv.subsample_timesteps <= 0)
2002  for (size_t t = 0; t < priv.T; t++)
2003  {
2004  uint32_t count = 99;
2005  if (priv.active_csoaa && (t < priv.active_known.size()))
2006  {
2007  count = 0;
2008  for (std::pair<CS::wclass&, bool> wcq : priv.active_known[t])
2009  if (wcq.second)
2010  {
2011  count++;
2012  if (count > 1)
2013  break;
2014  }
2015  }
2016  if (count > 1)
2017  timesteps.push_back(t);
2018  }
2019 
2020  // if subsample in (0,1) then pick steps with that probability, but ensuring there's at least one!
2021  else if (priv.subsample_timesteps < 1)
2022  {
2023  for (size_t t = 0; t < priv.T; t++)
2024  if (priv._random_state->get_and_update_random() <= priv.subsample_timesteps)
2025  timesteps.push_back(t);
2026 
2027  if (timesteps.size() == 0) // ensure at least one
2028  timesteps.push_back((size_t)(priv._random_state->get_and_update_random() * priv.T));
2029  }
2030 
2031  // finally, if subsample >= 1, then pick (int) that many uniformly at random without replacement; could use an LFSR
2032  // but why? :P
2033  else
2034  {
2035  while ((timesteps.size() < (size_t)priv.subsample_timesteps) && (timesteps.size() < priv.T))
2036  {
2037  size_t t = (size_t)(priv._random_state->get_and_update_random() * (float)priv.T);
2038  if (!v_array_contains(timesteps, t))
2039  timesteps.push_back(t);
2040  }
2041  std::sort(timesteps.begin(), timesteps.end(), cmp_size_t);
2042  }
2043 
2044  if (!priv.linear_ordering)
2045  hoopla_permute(timesteps.begin(), timesteps.end());
2046 }
2047 
2049 {
2051  std::string str;
2052  float total_cost;
2053  final_item(v_array<scored_action>* p, std::string s, float ic) : prefix(p), str(s), total_cost(ic) {}
2054 };
2055 
2057 {
2058  p->prefix->delete_v();
2059  delete p->prefix;
2060  delete p;
2061 }
2062 
2063 void BaseTask::Run()
2064 {
2065  search_private& priv = *sch->priv;
2066  // make sure output is correct
2067  bool old_should_produce_string = priv.should_produce_string;
2068  if (!_final_run && !_with_output_string)
2069  priv.should_produce_string = false;
2070  // if this isn't a final run, it shouldn't count for loss
2071  float old_test_loss = priv.test_loss;
2072  // float old_learn_loss = priv.learn_loss;
2073  priv.learn_loss *= 0.5;
2074  float old_train_loss = priv.train_loss;
2075 
2076  if (priv.should_produce_string)
2077  priv.pred_string->str("");
2078 
2079  priv.t = 0;
2080  priv.metaoverride = this;
2081  priv.task->run(*sch, ec);
2082  priv.metaoverride = nullptr;
2083  priv.meta_t += priv.t;
2084 
2085  // restore
2086  if (_with_output_string && old_should_produce_string)
2087  _with_output_string(*sch, *priv.pred_string);
2088 
2089  priv.should_produce_string = old_should_produce_string;
2090  if (!_final_run)
2091  {
2092  priv.test_loss = old_test_loss;
2093  // priv.learn_loss = old_learn_loss;
2094  priv.train_loss = old_train_loss;
2095  }
2096 }
2097 
2098 void run_task(search& sch, multi_ex& ec)
2099 {
2100  search_private& priv = *sch.priv;
2101  priv.num_calls_to_run++;
2102  if (priv.metatask && (priv.state != GET_TRUTH_STRING))
2103  priv.metatask->run(sch, ec);
2104  else
2105  priv.task->run(sch, ec);
2106 }
2107 
2109  COST_SENSITIVE::label& losses, v_array<std::pair<CS::wclass&, bool>>& known, size_t t, float multiplier)
2110 {
2111  float threshold = multiplier / std::sqrt((float)t);
2112  cdbg << "verify_active_csoaa, losses = [";
2113  for (COST_SENSITIVE::wclass& wc : losses.costs) cdbg << " " << wc.class_index << ":" << wc.x;
2114  cdbg << " ]" << endl;
2115  // cdbg_print_array("verify_active_csoaa, known", known);
2116  size_t i = 0;
2117  for (COST_SENSITIVE::wclass& wc : losses.costs)
2118  {
2119  if (!known[i].second)
2120  {
2121  float err = pow(known[i].first.partial_prediction - wc.x, 2);
2122  if (err > threshold)
2123  {
2124  std::cerr << "verify_active_csoaa failed: truth " << wc.class_index << ":" << wc.x << ", known[" << i
2125  << "]=" << known[i].first.partial_prediction << ", error=" << err << " vs threshold " << threshold
2126  << endl;
2127  }
2128  }
2129  i++;
2130  }
2131 }
2132 
2134 {
2135  size_t t = priv.learn_t;
2136  if (!priv.active_csoaa)
2137  return;
2138  if (priv.active_csoaa_verify > 0.)
2139  return;
2140  if (t >= priv.active_known.size())
2141  return;
2142  cdbg << "advance_from_known_actions t=" << t << " active_known.size()=" << priv.active_known.size()
2143  << " learn_a_idx=" << priv.learn_a_idx << endl;
2144  // cdbg_print_array(" active_known[t]", priv.active_known[t]);
2145  if (priv.learn_a_idx >= priv.active_known[t].size())
2146  {
2147  cdbg << "advance_from_known_actions setting done_with_all_actions=true (active_known[t].size()="
2148  << priv.active_known[t].size() << ")" << endl;
2149  priv.done_with_all_actions = true;
2150  return;
2151  }
2152  // if (priv.active_known[t][priv.learn_a_idx] >= FLT_MAX) return;
2153  if (priv.active_known[t][priv.learn_a_idx].second)
2154  return;
2155  // return;
2156  // wow, we actually found something we were confident about!
2157  /*
2158  cs_cost_push_back(priv.cb_learner,
2159  priv.learn_losses,
2160  priv.is_ldf ? (uint32_t)(priv.learn_a_idx - 1) : (uint32_t)priv.learn_a_idx,
2161  priv.active_known[t][priv.learn_a_idx],
2162  true);
2163  */
2164  priv.learn_losses.cs.costs.push_back(priv.active_known[t][priv.learn_a_idx].first);
2165  cdbg << " --> adding " << priv.learn_a_idx << ":" << priv.active_known[t][priv.learn_a_idx].first.x << endl;
2166  priv.learn_a_idx++;
2168 }
2169 
2170 template <bool is_learn>
2171 void train_single_example(search& sch, bool is_test_ex, bool is_holdout_ex, multi_ex& ec_seq)
2172 {
2173  search_private& priv = *sch.priv;
2174  vw& all = *priv.all;
2175  bool ran_test = false; // we must keep track so that even if we skip test, we still update # of examples seen
2176 
2177  // if (! priv.no_caching)
2178  clear_cache_hash_map(priv);
2179 
2180  cdbg << "is_test_ex=" << is_test_ex << " vw_is_main=" << all.vw_is_main << endl;
2181  cdbg << "must_run_test = " << must_run_test(all, ec_seq, is_test_ex) << endl;
2182  // do an initial test pass to compute output (and loss)
2183  if (must_run_test(all, ec_seq, is_test_ex))
2184  {
2185  cdbg << "======================================== INIT TEST (" << priv.current_policy << ","
2186  << priv.read_example_last_pass << ") ========================================" << endl;
2187 
2188  ran_test = true;
2189 
2190  // do the prediction
2191  reset_search_structure(priv);
2192  priv.state = INIT_TEST;
2193  priv.should_produce_string =
2194  might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0);
2195  priv.pred_string->str("");
2196  priv.test_action_sequence.clear();
2197  run_task(sch, ec_seq);
2198 
2199  // accumulate loss
2200  if (!is_test_ex)
2201  all.sd->update(ec_seq[0]->test_only, !is_test_ex, priv.test_loss, 1.f, priv.num_features);
2202 
2203  // generate output
2204  for (int sink : all.final_prediction_sink) all.print_text((int)sink, priv.pred_string->str(), ec_seq[0]->tag);
2205 
2206  if (all.raw_prediction > 0)
2207  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
2208  }
2209 
2210  // if we're not training, then we're done!
2211  if ((!is_learn) || is_test_ex || is_holdout_ex || ec_seq[0]->test_only || (!priv.all->training))
2212  return;
2213 
2214  // SPEEDUP: if the oracle was never called, we can skip this!
2215 
2216  // do a pass over the data allowing oracle
2217  cdbg << "======================================== INIT TRAIN (" << priv.current_policy << ","
2218  << priv.read_example_last_pass << ") ========================================" << endl;
2219  // std::cerr << "training" << endl;
2220 
2221  clear_cache_hash_map(priv);
2222  reset_search_structure(priv);
2224  priv.state = INIT_TRAIN;
2225  priv.active_uncertainty.clear();
2226  priv.train_trajectory.clear(); // this is where we'll store the training sequence
2227  run_task(sch, ec_seq);
2228 
2229  if (!ran_test) // was && !priv.ec_seq[0]->test_only) { but we know it's not test_only
2230  all.sd->update(ec_seq[0]->test_only, true, priv.test_loss, 1.f, priv.num_features);
2231 
2232  // if there's nothing to train on, we're done!
2233  if ((priv.loss_declared_cnt == 0) || (priv.t + priv.meta_t == 0) ||
2234  (priv.rollout_method == NO_ROLLOUT)) // TODO: make sure NO_ROLLOUT works with beam!
2235  {
2236  return;
2237  }
2238 
2239  // otherwise, we have some learn'in to do!
2240  cdbg << "======================================== LEARN (" << priv.current_policy << ","
2241  << priv.read_example_last_pass << ") ========================================" << endl;
2242  priv.T = priv.metatask ? priv.meta_t : priv.t;
2243  get_training_timesteps(priv, priv.timesteps);
2244  cdbg << "train_trajectory.size() = " << priv.train_trajectory.size() << ":\t";
2245  cdbg_print_array<scored_action>("", priv.train_trajectory);
2246  // cdbg << "memo_foreach_action = " << priv.memo_foreach_action << endl;
2247  for (size_t i = 0; i < priv.memo_foreach_action.size(); i++)
2248  {
2249  cdbg << "memo_foreach_action[" << i << "] = ";
2250  if (priv.memo_foreach_action[i])
2251  cdbg << *priv.memo_foreach_action[i];
2252  else
2253  cdbg << "null";
2254  cdbg << endl;
2255  }
2256 
2257  if (priv.cb_learner)
2258  priv.learn_losses.cb.costs.clear();
2259  else
2260  priv.learn_losses.cs.costs.clear();
2261 
2262  for (size_t tid = 0; tid < priv.timesteps.size(); tid++)
2263  {
2264  cdbg << "timestep = " << priv.timesteps[tid] << " [" << tid << "/" << priv.timesteps.size() << "]" << endl;
2265 
2266  if (priv.metatask && !priv.memo_foreach_action[tid])
2267  {
2268  cdbg << "skipping because it looks like this was overridden by metatask" << endl;
2269  continue;
2270  }
2271 
2272  priv.learn_ec_ref = nullptr;
2273  priv.learn_ec_ref_cnt = 0;
2274 
2275  reset_search_structure(priv); // TODO remove this?
2276  bool skipped_all_actions = true;
2277  priv.learn_a_idx = 0;
2278  priv.done_with_all_actions = false;
2279  // for each action, roll out to get a loss
2280  while (!priv.done_with_all_actions)
2281  {
2282  priv.learn_t = priv.timesteps[tid];
2284  if (priv.done_with_all_actions)
2285  break;
2286 
2287  skipped_all_actions = false;
2288  reset_search_structure(priv);
2289 
2290  priv.state = LEARN;
2291  priv.learn_t = priv.timesteps[tid];
2292  cdbg << "-------------------------------------------------------------------------------------" << endl;
2293  cdbg << "learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl;
2294  // cdbg_print_array("priv.active_known[learn_t]", priv.active_known[priv.learn_t]);
2295  run_task(sch, ec_seq);
2296  // cerr_print_array("in GENER, learn_allowed_actions", priv.learn_allowed_actions);
2297  float this_loss = priv.learn_loss;
2299  priv.is_ldf ? (uint32_t)(priv.learn_a_idx - 1) : (uint32_t)priv.learn_a_idx, this_loss);
2300  // (priv.learn_allowed_actions.size() > 0) ?
2301  // priv.learn_allowed_actions[priv.learn_a_idx-1] : priv.is_ldf ? (priv.learn_a_idx-1) :
2302  // (priv.learn_a_idx),
2303  // priv.learn_loss);
2304  }
2305  if (priv.active_csoaa_verify > 0.)
2307  priv.learn_losses.cs, priv.active_known[priv.learn_t], ec_seq[0]->example_counter, priv.active_csoaa_verify);
2308 
2309  if (skipped_all_actions)
2310  {
2311  reset_search_structure(priv);
2312  priv.state = LEARN;
2313  priv.learn_t = priv.timesteps[tid];
2314  priv.force_setup_ec_ref = true;
2315  cdbg << "<<<<<" << endl;
2316  cdbg << "skipped all actions; learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl;
2317  run_task(sch, ec_seq); // TODO: i guess we can break out of this early
2318  cdbg << ">>>>>" << endl;
2319  }
2320  else
2321  cdbg << "didn't skip all actions" << endl;
2322 
2323  // now we can make a training example
2324  if (priv.learn_allowed_actions.size() > 0)
2325  {
2326  for (size_t i = 0; i < priv.learn_allowed_actions.size(); i++)
2327  {
2328  priv.learn_losses.cs.costs[i].class_index = priv.learn_allowed_actions[i];
2329  }
2330  }
2331  // float min_loss = 0.;
2332  // if (priv.metatask)
2333  // for (size_t aid=0; aid<priv.memo_foreach_action[tid]->size(); aid++)
2334  // min_loss = std::min(min_loss, priv.memo_foreach_action[tid]->get(aid).cost);
2335  cdbg << "priv.learn_losses = [";
2336  for (auto& wc : priv.learn_losses.cs.costs) cdbg << " " << wc.class_index << ":" << wc.x;
2337  cdbg << " ]" << endl;
2338  cdbg << "gte" << endl;
2339  generate_training_example(priv, priv.learn_losses, 1., true); // , min_loss); // TODO: weight
2340  if (!priv.examples_dont_change)
2341  for (size_t n = 0; n < priv.learn_ec_copy.size(); n++)
2342  {
2343  if (sch.priv->is_ldf)
2344  CS::cs_label.delete_label(&priv.learn_ec_copy[n].l.cs);
2345  else
2346  MC::mc_label.delete_label(&priv.learn_ec_copy[n].l.multi);
2347  }
2348  if (priv.cb_learner)
2349  priv.learn_losses.cb.costs.clear();
2350  else
2351  priv.learn_losses.cs.costs.clear();
2352  }
2353 
2354  if (priv.active_csoaa && (priv.save_every_k_runs > 1))
2355  {
2356  size_t prev_num = priv.num_calls_to_run_previous / priv.save_every_k_runs;
2357  size_t this_num = priv.num_calls_to_run / priv.save_every_k_runs;
2358  if (this_num > prev_num)
2359  save_predictor(all, all.final_regressor_name, this_num);
2361  }
2362 }
2363 
2365 {
2366  if (priv.auto_condition_features)
2367  {
2368  // turn off auto-condition if it's irrelevant
2369  if ((priv.history_length == 0) || (priv.acset.feature_value == 0.f))
2370  {
2371  std::cerr << "warning: turning off AUTO_CONDITION_FEATURES because settings make it useless" << endl;
2372  priv.auto_condition_features = false;
2373  }
2374  }
2375 }
2376 
2377 template <bool is_learn>
2379 {
2380  if (ec_seq.size() == 0)
2381  return; // nothing to do :)
2382 
2383  bool is_test_ex = false;
2384  bool is_holdout_ex = false;
2385 
2386  search_private& priv = *sch.priv;
2387  priv.offset = ec_seq[0]->ft_offset;
2388  priv.base_learner = &base;
2389 
2390  adjust_auto_condition(priv);
2391  priv.read_example_last_id = ec_seq[ec_seq.size() - 1]->example_counter;
2392 
2393  // hit_new_pass true would have already triggered a printout
2394  // finish_example(multi_ex). so we can reset hit_new_pass here
2395  priv.hit_new_pass = false;
2396 
2397  for (size_t i = 0; i < ec_seq.size(); i++)
2398  {
2399  is_test_ex |= priv.label_is_test(ec_seq[i]->l);
2400  is_holdout_ex |= ec_seq[i]->test_only;
2401  if (is_test_ex && is_holdout_ex)
2402  break;
2403  }
2404 
2405  if (priv.task->run_setup)
2406  priv.task->run_setup(sch, ec_seq);
2407 
2408  // if we're going to have to print to the screen, generate the "truth" std::string
2409  cdbg << "======================================== GET TRUTH STRING (" << priv.current_policy << ","
2410  << priv.read_example_last_pass << ") ========================================" << endl;
2411  if (might_print_update(*priv.all))
2412  {
2413  if (is_test_ex)
2414  priv.truth_string->str("**test**");
2415  else
2416  {
2418  priv.state = GET_TRUTH_STRING;
2419  priv.should_produce_string = true;
2420  priv.truth_string->str("");
2421  run_task(sch, ec_seq);
2422  }
2423  }
2424 
2425  add_neighbor_features(priv, ec_seq);
2426  train_single_example<is_learn>(sch, is_test_ex, is_holdout_ex, ec_seq);
2427  del_neighbor_features(priv, ec_seq);
2428 
2429  if (priv.task->run_takedown)
2430  priv.task->run_takedown(sch, ec_seq);
2431 }
2432 
2433 void end_pass(search& sch)
2434 {
2435  search_private& priv = *sch.priv;
2436  vw* all = priv.all;
2437  priv.hit_new_pass = true;
2438  priv.read_example_last_pass++;
2439  priv.passes_since_new_policy++;
2440 
2441  if (priv.passes_since_new_policy >= priv.passes_per_policy)
2442  {
2443  priv.passes_since_new_policy = 0;
2444  if (all->training)
2445  priv.current_policy++;
2446  if (priv.current_policy > priv.total_number_of_policies)
2447  {
2448  std::cerr << "internal error (bug): too many policies; not advancing" << endl;
2450  }
2451  // reset search_trained_nb_policies in options_from_file so it is saved to regressor file later
2452  // TODO work out a better system to update state that will be saved in the model.
2453  all->options->replace("search_trained_nb_policies", std::to_string(priv.current_policy));
2454  all->options->get_typed_option<uint32_t>("search_trained_nb_policies").value(priv.current_policy);
2455  }
2456 }
2457 
2458 void finish_multiline_example(vw& all, search& sch, multi_ex& ec_seq)
2459 {
2460  print_update(*sch.priv);
2461  VW::finish_example(all, ec_seq);
2462 }
2463 
2465 {
2466  search_private& priv = *sch.priv;
2467  vw* all = priv.all;
2468 
2469  if (all->training)
2470  {
2471  // TODO work out a better system to update state that will be saved in the model.
2472  // Dig out option and change it in case we already loaded a predictor which had a value stored for
2473  // --search_trained_nb_policies
2474  auto val = (priv.passes_since_new_policy == 0) ? priv.current_policy : (priv.current_policy + 1);
2475  all->options->replace("search_trained_nb_policies", std::to_string(val));
2476  all->options->get_typed_option<uint32_t>("search_trained_nb_policies").value(val);
2477  // Dig out option and change it in case we already loaded a predictor which had a value stored for
2478  // --search_total_nb_policies
2479  all->options->replace("search_total_nb_policies", std::to_string(priv.total_number_of_policies));
2480  all->options->get_typed_option<uint32_t>("search_total_nb_policies").value(priv.total_number_of_policies);
2481  }
2482 }
2483 
2485 
2486 void search_initialize(vw* all, search& sch)
2487 {
2488  search_private& priv = *sch.priv; // priv is zero initialized by default
2489  priv.all = all;
2490  priv._random_state = all->get_random_state();
2491 
2492  priv.active_csoaa = false;
2494 
2495  priv.num_learners = 1;
2496  priv.state = INITIALIZE;
2497  priv.mix_per_roll_policy = -2;
2498 
2499  priv.pred_string = new std::stringstream();
2500  priv.truth_string = new std::stringstream();
2501  priv.bad_string_stream = new std::stringstream();
2502  priv.bad_string_stream->clear(priv.bad_string_stream->badbit);
2503 
2505  priv.rollin_method = MIX_PER_ROLL;
2506 
2507  priv.allow_current_policy = true;
2508  priv.adaptive_beta = true;
2509 
2510  priv.total_number_of_policies = 1;
2511 
2512  priv.acset.max_bias_ngram_length = 1;
2513 
2514  priv.acset.feature_value = 1.;
2515 
2516  scored_action sa((action)-1, 0.);
2518  priv.cache_hash_map.set_default_value(sa);
2519  priv.cache_hash_map.set_equivalent(cached_item_equivalent);
2520 
2521  sch.task_data = nullptr;
2522 
2523  priv.active_uncertainty = v_init<std::pair<float, size_t>>();
2524  priv.active_known = v_init<v_array<std::pair<CS::wclass&, bool>>>();
2525 
2527 
2528  new (&priv.rawOutputString) std::string();
2529  priv.rawOutputStringStream = new std::stringstream(priv.rawOutputString);
2530  new (&priv.test_action_sequence) std::vector<action>();
2531  new (&priv.dat_new_feature_audit_ss) std::stringstream();
2532 }
2533 
2534 void ensure_param(float& v, float lo, float hi, float def, const char* str)
2535 {
2536  if ((v < lo) || (v > hi))
2537  {
2538  std::cerr << str << endl;
2539  v = def;
2540  }
2541 }
2542 
2544 {
2545  option_group_definition new_options("Search Auto-conditioning Options");
2546  new_options.add(make_option("search_max_bias_ngram_length", acset.max_bias_ngram_length)
2547  .keep()
2548  .default_value(1)
2549  .help("add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 "
2550  "(default), then you get a single feature for each conditional"));
2551  new_options.add(make_option("search_max_quad_ngram_length", acset.max_quad_ngram_length)
2552  .keep()
2553  .default_value(0)
2554  .help("add bias *times* input features for each ngram up to and including this length (def: 0)"));
2555  new_options.add(make_option("search_condition_feature_value", acset.feature_value)
2556  .keep()
2557  .default_value(1.f)
2558  .help("how much weight should the conditional features get? (def: 1.)"));
2559  new_options.add(make_option("search_use_passthrough_repr", acset.use_passthrough_repr)
2560  .keep()
2561  .help("should we use lower-level reduction _internal state_ as additional features? (def: no)"));
2562  all.options->add_and_parse(new_options);
2563 }
2564 
2566 {
2567  search_private& priv = *sch.priv;
2568  cdbg << "search_finish" << endl;
2569 
2570  if (priv.active_csoaa)
2571  std::cerr << "search calls to run = " << priv.num_calls_to_run << endl;
2572 
2573  if (priv.task->finish)
2574  priv.task->finish(sch);
2575  if (priv.metatask && priv.metatask->finish)
2576  priv.metatask->finish(sch);
2577 }
2578 
2580 {
2581  FILE* f = fopen(filename, "r");
2582  if (f == nullptr)
2583  THROW("error: could not read file " << filename << " (" << strerror(errno)
2584  << "); assuming all transitions are valid");
2585 
2586  bool* bg = calloc_or_throw<bool>(((size_t)(A + 1)) * (A + 1));
2587  int rd, from, to, count = 0;
2588  while ((rd = fscanf(f, "%d:%d", &from, &to)) > 0)
2589  {
2590  if ((from < 0) || (from > (int)A))
2591  {
2592  std::cerr << "warning: ignoring transition from " << from << " because it's out of the range [0," << A << "]"
2593  << endl;
2594  }
2595  if ((to < 0) || (to > (int)A))
2596  {
2597  std::cerr << "warning: ignoring transition to " << to << " because it's out of the range [0," << A << "]" << endl;
2598  }
2599  bg[from * (A + 1) + to] = true;
2600  count++;
2601  }
2602  fclose(f);
2603 
2604  v_array<CS::label> allowed = v_init<CS::label>();
2605 
2606  for (size_t from = 0; from < A; from++)
2607  {
2608  v_array<CS::wclass> costs = v_init<CS::wclass>();
2609 
2610  for (size_t to = 0; to < A; to++)
2611  if (bg[from * (A + 1) + to])
2612  {
2613  CS::wclass c = {FLT_MAX, (action)to, 0., 0.};
2614  costs.push_back(c);
2615  }
2616 
2617  CS::label ld = {costs};
2618  allowed.push_back(ld);
2619  }
2620  free(bg);
2621 
2622  std::cerr << "read " << count << " allowed transitions from " << filename << endl;
2623 
2624  return allowed;
2625 }
2626 
2627 void parse_neighbor_features(std::string& nf_string, search& sch)
2628 {
2629  search_private& priv = *sch.priv;
2630  priv.neighbor_features.clear();
2631  size_t len = nf_string.length();
2632  if (len == 0)
2633  return;
2634 
2635  char* cstr = new char[len + 1];
2636  strcpy(cstr, nf_string.c_str());
2637 
2638  char* p = strtok(cstr, ",");
2639  std::vector<substring> cmd;
2640  while (p != 0)
2641  {
2642  cmd.clear();
2643  substring me = {p, p + strlen(p)};
2644  tokenize(':', me, cmd, true);
2645 
2646  int32_t posn = 0;
2647  char ns = ' ';
2648  if (cmd.size() == 1)
2649  {
2650  posn = int_of_substring(cmd[0]);
2651  ns = ' ';
2652  }
2653  else if (cmd.size() == 2)
2654  {
2655  posn = int_of_substring(cmd[0]);
2656  ns = (cmd[1].end > cmd[1].begin) ? cmd[1].begin[0] : ' ';
2657  }
2658  else
2659  {
2660  std::cerr << "warning: ignoring malformed neighbor specification: '" << p << "'" << endl;
2661  }
2662  int32_t enc = (posn << 24) | (ns & 0xFF);
2663  priv.neighbor_features.push_back(enc);
2664 
2665  p = strtok(nullptr, ",");
2666  }
2667 
2668  delete[] cstr;
2669 }
2670 
2671 base_learner* setup(options_i& options, vw& all)
2672 {
2673  free_ptr<search> sch = scoped_calloc_or_throw<search>();
2674  search_private& priv = *sch->priv;
2675  std::string task_string;
2676  std::string metatask_string;
2677  std::string interpolation_string = "data";
2678  std::string neighbor_features_string;
2679  std::string rollout_string = "mix_per_state";
2680  std::string rollin_string = "mix_per_state";
2681 
2682  uint32_t search_trained_nb_policies;
2683  std::string search_allowed_transitions;
2684 
2685  priv.A = 1;
2686  option_group_definition new_options("Search options");
2687  new_options.add(
2688  make_option("search", priv.A).keep().help("Use learning to search, argument=maximum action id or 0 for LDF"));
2689  new_options.add(make_option("search_task", task_string)
2690  .keep()
2691  .help("the search task (use \"--search_task list\" to get a list of available tasks)"));
2692  new_options.add(
2693  make_option("search_metatask", metatask_string)
2694  .keep()
2695  .help("the search metatask (use \"--search_metatask list\" to get a list of available metatasks)"));
2696  new_options.add(make_option("search_interpolation", interpolation_string)
2697  .keep()
2698  .help("at what level should interpolation happen? [*data|policy]"));
2699  new_options.add(
2700  make_option("search_rollout", rollout_string)
2701  .help("how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]"));
2702  new_options.add(make_option("search_rollin", rollin_string)
2703  .help("how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]"));
2704  new_options.add(make_option("search_passes_per_policy", priv.passes_per_policy)
2705  .default_value(1)
2706  .help("number of passes per policy (only valid for search_interpolation=policy)"));
2707  new_options.add(make_option("search_beta", priv.beta)
2708  .default_value(0.5f)
2709  .help("interpolation rate for policies (only valid for search_interpolation=policy)"));
2710  new_options.add(make_option("search_alpha", priv.alpha)
2711  .default_value(1e-10f)
2712  .help("annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data)"));
2713  new_options.add(make_option("search_total_nb_policies", priv.total_number_of_policies)
2714  .help("if we are going to train the policies through multiple separate calls to vw, we need to "
2715  "specify this parameter and tell vw how many policies are eventually going to be trained"));
2716  new_options.add(make_option("search_trained_nb_policies", search_trained_nb_policies)
2717  .help("the number of trained policies in a file"));
2718  new_options.add(make_option("search_allowed_transitions", search_allowed_transitions)
2719  .help("read file of allowed transitions [def: all transitions are allowed]"));
2720  new_options.add(make_option("search_subsample_time", priv.subsample_timesteps)
2721  .help("instead of training at all timesteps, use a subset. if value in (0,1), train on a random "
2722  "v%. if v>=1, train on precisely v steps per example, if v<=-1, use active learning"));
2723  new_options.add(
2724  make_option("search_neighbor_features", neighbor_features_string)
2725  .keep()
2726  .help("copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line "
2727  "namespace a and next next line from namespace _unnamed_, where ',' separates them"));
2728  new_options.add(make_option("search_rollout_num_steps", priv.rollout_num_steps)
2729  .help("how many calls of \"loss\" before we stop really predicting on rollouts and switch to "
2730  "oracle (default means \"infinite\")"));
2731  new_options.add(make_option("search_history_length", priv.history_length)
2732  .keep()
2733  .default_value(1)
2734  .help("some tasks allow you to specify how much history their depend on; specify that here"));
2735  new_options.add(make_option("search_no_caching", priv.no_caching)
2736  .help("turn off the built-in caching ability (makes things slower, but technically more safe)"));
2737  new_options.add(
2738  make_option("search_xv", priv.xv).help("train two separate policies, alternating prediction/learning"));
2739  new_options.add(make_option("search_perturb_oracle", priv.perturb_oracle)
2740  .default_value(0.f)
2741  .help("perturb the oracle on rollin with this probability"));
2742  new_options.add(make_option("search_linear_ordering", priv.linear_ordering)
2743  .help("insist on generating examples in linear order (def: hoopla permutation)"));
2744  new_options.add(make_option("search_active_verify", priv.active_csoaa_verify)
2745  .help("verify that active learning is doing the right thing (arg = multiplier, should be = "
2746  "cost_range * range_c)"));
2747  new_options.add(make_option("search_save_every_k_runs", priv.save_every_k_runs).help("save model every k runs"));
2748  options.add_and_parse(new_options);
2749 
2750  if (!options.was_supplied("search_task"))
2751  return nullptr;
2752 
2753  search_initialize(&all, *sch.get());
2754 
2755  parse_neighbor_features(neighbor_features_string, *sch.get());
2756 
2757  if (interpolation_string.compare("data") == 0) // run as dagger
2758  {
2759  priv.adaptive_beta = true;
2760  priv.allow_current_policy = true;
2761  priv.passes_per_policy = all.numpasses;
2762  if (priv.current_policy > 1)
2763  priv.current_policy = 1;
2764  }
2765  else if (interpolation_string.compare("policy") == 0)
2766  ;
2767  else
2768  THROW("error: --search_interpolation must be 'data' or 'policy'");
2769 
2770  if ((rollout_string.compare("policy") == 0) || (rollout_string.compare("learn") == 0))
2771  priv.rollout_method = POLICY;
2772  else if ((rollout_string.compare("oracle") == 0) || (rollout_string.compare("ref") == 0))
2773  priv.rollout_method = ORACLE;
2774  else if ((rollout_string.compare("mix_per_state") == 0))
2775  priv.rollout_method = MIX_PER_STATE;
2776  else if ((rollout_string.compare("mix_per_roll") == 0) || (rollout_string.compare("mix") == 0))
2777  priv.rollout_method = MIX_PER_ROLL;
2778  else if ((rollout_string.compare("none") == 0))
2779  {
2780  priv.rollout_method = NO_ROLLOUT;
2781  priv.no_caching = true;
2782  }
2783  else
2784  THROW("error: --search_rollout must be 'learn', 'ref', 'mix', 'mix_per_state' or 'none'");
2785 
2786  if ((rollin_string.compare("policy") == 0) || (rollin_string.compare("learn") == 0))
2787  priv.rollin_method = POLICY;
2788  else if ((rollin_string.compare("oracle") == 0) || (rollin_string.compare("ref") == 0))
2789  priv.rollin_method = ORACLE;
2790  else if ((rollin_string.compare("mix_per_state") == 0))
2791  priv.rollin_method = MIX_PER_STATE;
2792  else if ((rollin_string.compare("mix_per_roll") == 0) || (rollin_string.compare("mix") == 0))
2793  priv.rollin_method = MIX_PER_ROLL;
2794  else
2795  THROW("error: --search_rollin must be 'learn', 'ref', 'mix' or 'mix_per_state'");
2796 
2797  // check if the base learner is contextual bandit, in which case, we dont rollout all actions.
2798  priv.allowed_actions_cache = &calloc_or_throw<polylabel>();
2799  if (options.was_supplied("cb"))
2800  {
2801  priv.cb_learner = true;
2802  CB::cb_label.default_label(priv.allowed_actions_cache);
2803  priv.learn_losses.cb.costs = v_init<CB::cb_class>();
2804  priv.gte_label.cb.costs = v_init<CB::cb_class>();
2805  }
2806  else
2807  {
2808  priv.cb_learner = false;
2809  CS::cs_label.default_label(priv.allowed_actions_cache);
2810  priv.learn_losses.cs.costs = v_init<CS::wclass>();
2811  priv.gte_label.cs.costs = v_init<CS::wclass>();
2812  }
2813 
2814  ensure_param(priv.beta, 0.0, 1.0, 0.5, "warning: search_beta must be in (0,1); resetting to 0.5");
2815  ensure_param(priv.alpha, 0.0, 1.0, 1e-10f, "warning: search_alpha must be in (0,1); resetting to 1e-10");
2816 
2817  priv.num_calls_to_run = 0;
2818 
2819  // compute total number of policies we will have at end of training
2820  // we add current_policy for cases where we start from an initial set of policies loaded through -i option
2821  uint32_t tmp_number_of_policies = priv.current_policy;
2822  if (all.training)
2823  tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)priv.passes_per_policy));
2824 
2825  // the user might have specified the number of policies that will eventually be trained through multiple vw calls,
2826  // so only set total_number_of_policies to computed value if it is larger
2827  cdbg << "current_policy=" << priv.current_policy << " tmp_number_of_policies=" << tmp_number_of_policies
2828  << " total_number_of_policies=" << priv.total_number_of_policies << endl;
2829  if (tmp_number_of_policies > priv.total_number_of_policies)
2830  {
2831  priv.total_number_of_policies = tmp_number_of_policies;
2832  if (priv.current_policy >
2833  0) // we loaded a file but total number of policies didn't match what is needed for training
2834  std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause "
2835  "bad performance."
2836  << endl;
2837  }
2838 
2839  // current policy currently points to a new policy we would train
2840  // if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy
2841  // so that we only use those loaded when testing (as run_prediction is called with allow_current to true)
2842  if (!all.training && priv.current_policy > 0)
2843  priv.current_policy--;
2844 
2845  all.options->replace("search_trained_nb_policies", std::to_string(priv.current_policy));
2846  all.options->get_typed_option<uint32_t>("search_trained_nb_policies").value(priv.current_policy);
2847 
2848  all.options->replace("search_total_nb_policies", std::to_string(priv.total_number_of_policies));
2849  all.options->get_typed_option<uint32_t>("search_total_nb_policies").value(priv.total_number_of_policies);
2850 
2851  cdbg << "search current_policy = " << priv.current_policy
2852  << " total_number_of_policies = " << priv.total_number_of_policies << endl;
2853 
2854  if (task_string.compare("list") == 0)
2855  {
2856  std::cerr << endl << "available search tasks:" << endl;
2857  for (search_task** mytask = all_tasks; *mytask != nullptr; mytask++)
2858  std::cerr << " " << (*mytask)->task_name << endl;
2859  std::cerr << endl;
2860  exit(0);
2861  }
2862  if (metatask_string.compare("list") == 0)
2863  {
2864  std::cerr << endl << "available search metatasks:" << endl;
2865  for (search_metatask** mytask = all_metatasks; *mytask != nullptr; mytask++)
2866  std::cerr << " " << (*mytask)->metatask_name << endl;
2867  std::cerr << endl;
2868  exit(0);
2869  }
2870  for (search_task** mytask = all_tasks; *mytask != nullptr; mytask++)
2871  if (task_string.compare((*mytask)->task_name) == 0)
2872  {
2873  priv.task = *mytask;
2874  sch->task_name = (*mytask)->task_name;
2875  break;
2876  }
2877  if (priv.task == nullptr)
2878  {
2879  if (!options.was_supplied("help"))
2880  THROW("fail: unknown task for --search_task '" << task_string << "'; use --search_task list to get a list");
2881  }
2882  priv.metatask = nullptr;
2883  for (search_metatask** mytask = all_metatasks; *mytask != nullptr; mytask++)
2884  if (metatask_string.compare((*mytask)->metatask_name) == 0)
2885  {
2886  priv.metatask = *mytask;
2887  sch->metatask_name = (*mytask)->metatask_name;
2888  break;
2889  }
2890  all.p->emptylines_separate_examples = true;
2891 
2892  if (!options.was_supplied("csoaa") && !options.was_supplied("cs_active") && !options.was_supplied("csoaa_ldf") &&
2893  !options.was_supplied("wap_ldf") && !options.was_supplied("cb"))
2894  {
2895  options.insert("csoaa", std::to_string(priv.A));
2896  }
2897 
2898  priv.active_csoaa = options.was_supplied("cs_active");
2899  priv.active_csoaa_verify = -1.;
2900  if (options.was_supplied("search_active_verify"))
2901  if (!priv.active_csoaa)
2902  THROW("cannot use --search_active_verify without using --cs_active");
2903 
2904  cdbg << "active_csoaa = " << priv.active_csoaa << ", active_csoaa_verify = " << priv.active_csoaa_verify << endl;
2905 
2906  base_learner* base = setup_base(*all.options, all);
2907 
2908  // default to OAA labels unless the task wants to override this (which they can do in initialize)
2909  all.p->lp = MC::mc_label;
2910  all.label_type = label_type::mc;
2911  if (priv.task && priv.task->initialize)
2912  priv.task->initialize(*sch.get(), priv.A, options);
2913  if (priv.metatask && priv.metatask->initialize)
2914  priv.metatask->initialize(*sch.get(), priv.A, options);
2915  priv.meta_t = 0;
2916 
2917  if (options.was_supplied("search_allowed_transitions"))
2918  read_allowed_transitions((action)priv.A, search_allowed_transitions.c_str());
2919 
2920  // set up auto-history (used to only do this if AUTO_CONDITION_FEATURES was on, but that doesn't work for hooktask)
2921  handle_condition_options(all, priv.acset);
2922 
2923  if (!priv.allow_current_policy) // if we're not dagger
2924  all.check_holdout_every_n_passes = priv.passes_per_policy;
2925 
2926  all.searchstr = sch.get();
2927 
2928  priv.start_clock_time = clock();
2929 
2930  if (priv.xv)
2931  priv.num_learners *= 3;
2932 
2933  cdbg << "num_learners = " << priv.num_learners << endl;
2934 
2935  learner<search, multi_ex>& l = init_learner(sch, make_base(*base), do_actual_learning<true>,
2936  do_actual_learning<false>, priv.total_number_of_policies * priv.num_learners);
2941  return make_base(l);
2942 }
2943 
2944 float action_hamming_loss(action a, const action* A, size_t sz)
2945 {
2946  if (sz == 0)
2947  return 0.; // latent variables have zero loss
2948  for (size_t i = 0; i < sz; i++)
2949  if (a == A[i])
2950  return 0.;
2951  return 1.;
2952 }
2953 
2954 float action_cost_loss(action a, const action* act, const float* costs, size_t sz)
2955 {
2956  if (act == nullptr)
2957  return costs[a - 1];
2958  for (size_t i = 0; i < sz; i++)
2959  if (act[i] == a)
2960  return costs[i];
2961  THROW("action_cost_loss got action that wasn't allowed: " << a);
2962 }
2963 
2964 // the interface:
2965 bool search::is_ldf() { return priv->is_ldf; }
2966 
2967 action search::predict(example& ec, ptag mytag, const action* oracle_actions, size_t oracle_actions_cnt,
2968  const ptag* condition_on, const char* condition_on_names, const action* allowed_actions, size_t allowed_actions_cnt,
2969  const float* allowed_actions_cost, size_t learner_id, float weight)
2970 {
2971  float a_cost = 0.;
2972  action a = search_predict(*priv, &ec, 1, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names,
2973  allowed_actions, allowed_actions_cnt, allowed_actions_cost, learner_id, a_cost, weight);
2974  if (priv->state == INIT_TEST)
2975  priv->test_action_sequence.push_back(a);
2976  if (mytag != 0)
2977  {
2978  if (mytag < priv->ptag_to_action.size())
2979  {
2980  cdbg << "delete_v at " << mytag << endl;
2981  if (priv->ptag_to_action[mytag].repr != nullptr)
2982  {
2983  priv->ptag_to_action[mytag].repr->delete_v();
2984  delete priv->ptag_to_action[mytag].repr;
2985  }
2986  }
2987  if (priv->acset.use_passthrough_repr)
2988  {
2989  assert((mytag >= priv->ptag_to_action.size()) || (priv->ptag_to_action[mytag].repr == nullptr));
2990  push_at(priv->ptag_to_action, action_repr(a, &(priv->last_action_repr)), mytag);
2991  }
2992  else
2993  push_at(priv->ptag_to_action, action_repr(a, (features*)nullptr), mytag);
2994  cdbg << "push_at " << mytag << endl;
2995  }
2996  if (priv->auto_hamming_loss)
2997  loss(priv->use_action_costs ? action_cost_loss(a, allowed_actions, allowed_actions_cost, allowed_actions_cnt)
2998  : action_hamming_loss(a, oracle_actions, oracle_actions_cnt));
2999  cdbg << "predict returning " << a << endl;
3000  return a;
3001 }
3002 
3003 action search::predictLDF(example* ecs, size_t ec_cnt, ptag mytag, const action* oracle_actions,
3004  size_t oracle_actions_cnt, const ptag* condition_on, const char* condition_on_names, size_t learner_id,
3005  float weight)
3006 {
3007  float a_cost = 0.;
3008  // TODO: action costs for ldf
3009  action a = search_predict(*priv, ecs, ec_cnt, mytag, oracle_actions, oracle_actions_cnt, condition_on,
3010  condition_on_names, nullptr, 0, nullptr, learner_id, a_cost, weight);
3011  if (priv->state == INIT_TEST)
3012  priv->test_action_sequence.push_back(a);
3013 
3014  // If there is a shared example (example header), then action "1" is at index 1, but otherwise
3015  // action "1" is at index 0. Map action to its appropriate index. In particular, this fixes an
3016  // issue where the predicted action is the last, and there is no example header, causing an index
3017  // beyond the end of the array (usually resulting in a segfault at some point.)
3018  size_t action_index = a - COST_SENSITIVE::ec_is_example_header(ecs[0]) ? 0 : 1;
3019 
3020  if ((mytag != 0) && ecs[action_index].l.cs.costs.size() > 0)
3021  {
3022  if (mytag < priv->ptag_to_action.size())
3023  {
3024  cdbg << "delete_v at " << mytag << endl;
3025  if (priv->ptag_to_action[mytag].repr != nullptr)
3026  {
3027  priv->ptag_to_action[mytag].repr->delete_v();
3028  delete priv->ptag_to_action[mytag].repr;
3029  }
3030  }
3031  push_at(priv->ptag_to_action, action_repr(ecs[a].l.cs.costs[0].class_index, &(priv->last_action_repr)), mytag);
3032  }
3033  if (priv->auto_hamming_loss)
3034  loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt)); // TODO: action costs
3035  cdbg << "predict returning " << a << endl;
3036  return a;
3037 }
3038 
3039 void search::loss(float loss) { search_declare_loss(*this->priv, loss); }
3040 
3041 bool search::predictNeedsExample() { return search_predictNeedsExample(*this->priv); }
3042 
3043 std::stringstream& search::output()
3044 {
3045  if (!this->priv->should_produce_string)
3046  return *(this->priv->bad_string_stream);
3047  else if (this->priv->state == GET_TRUTH_STRING)
3048  return *(this->priv->truth_string);
3049  else
3050  return *(this->priv->pred_string);
3051 }
3052 
3053 void search::set_options(uint32_t opts)
3054 {
3055  if (this->priv->all->vw_is_main && (this->priv->state != INITIALIZE))
3056  std::cerr << "warning: task should not set options except in initialize function!" << endl;
3057  if ((opts & AUTO_CONDITION_FEATURES) != 0)
3058  this->priv->auto_condition_features = true;
3059  if ((opts & AUTO_HAMMING_LOSS) != 0)
3060  this->priv->auto_hamming_loss = true;
3061  if ((opts & EXAMPLES_DONT_CHANGE) != 0)
3062  this->priv->examples_dont_change = true;
3063  if ((opts & IS_LDF) != 0)
3064  this->priv->is_ldf = true;
3065  if ((opts & NO_CACHING) != 0)
3066  this->priv->no_caching = true;
3067  if ((opts & ACTION_COSTS) != 0)
3068  this->priv->use_action_costs = true;
3069 
3070  if (this->priv->is_ldf && this->priv->use_action_costs)
3071  THROW("using LDF and actions costs is not yet implemented; turn off action costs"); // TODO fix
3072 
3073  if (this->priv->use_action_costs && (this->priv->rollout_method != NO_ROLLOUT))
3074  std::cerr
3075  << "warning: task is designed to use rollout costs, but this only works when --search_rollout none is specified"
3076  << endl;
3077 }
3078 
3079 void search::set_label_parser(label_parser& lp, bool (*is_test)(polylabel&))
3080 {
3081  if (this->priv->all->vw_is_main && (this->priv->state != INITIALIZE))
3082  std::cerr << "warning: task should not set label parser except in initialize function!" << endl;
3083  this->priv->all->p->lp = lp;
3084  this->priv->all->p->lp.test_label = (bool (*)(void*))is_test;
3085  this->priv->label_is_test = is_test;
3086 }
3087 
3088 void search::get_test_action_sequence(std::vector<action>& V)
3089 {
3090  V.clear();
3091  for (size_t i = 0; i < this->priv->test_action_sequence.size(); i++) V.push_back(this->priv->test_action_sequence[i]);
3092 }
3093 
3094 void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; }
3095 
3096 uint64_t search::get_mask() { return this->priv->all->weights.mask(); }
3097 size_t search::get_stride_shift() { return this->priv->all->weights.stride_shift(); }
3098 uint32_t search::get_history_length() { return (uint32_t)this->priv->history_length; }
3099 
3100 std::string search::pretty_label(action a)
3101 {
3102  if (this->priv->all->sd->ldict)
3103  {
3104  substring ss = this->priv->all->sd->ldict->get(a);
3105  return std::string(ss.begin, ss.end - ss.begin);
3106  }
3107  else
3108  {
3109  std::ostringstream os;
3110  os << a;
3111  return os.str();
3112  }
3113 }
3114 
3115 vw& search::get_vw_pointer_unsafe() { return *this->priv->all; }
3116 void search::set_force_oracle(bool force) { this->priv->force_oracle = force; }
3117 
3118 // predictor implementation
3119 predictor::predictor(search& sch, ptag my_tag)
3120  : is_ldf(false)
3121  , my_tag(my_tag)
3122  , ec(nullptr)
3123  , ec_cnt(0)
3124  , ec_alloced(false)
3125  , weight(1.)
3126  , oracle_is_pointer(false)
3127  , allowed_is_pointer(false)
3128  , allowed_cost_is_pointer(false)
3129  , learner_id(0)
3130  , sch(sch)
3131 {
3132  oracle_actions = v_init<action>();
3133  condition_on_tags = v_init<ptag>();
3134  condition_on_names = v_init<char>();
3135  allowed_actions = v_init<action>();
3136  allowed_actions_cost = v_init<float>();
3137 }
3138 
3140 {
3141  if (ec_alloced)
3142  {
3143  if (is_ldf)
3144  for (size_t i = 0; i < ec_cnt; i++) VW::dealloc_example(CS::cs_label.delete_label, ec[i]);
3145  else
3146  VW::dealloc_example(nullptr, *ec);
3147  free(ec);
3148  }
3149 }
3150 
3152 {
3153  if (!oracle_is_pointer)
3155  if (!allowed_is_pointer)
3159  free_ec();
3162 }
3164 {
3165  this->erase_oracles();
3166  this->erase_alloweds();
3169  free_ec();
3170  return *this;
3171 }
3172 
3174 {
3175  free_ec();
3176  is_ldf = false;
3177  ec = &input_example;
3178  ec_cnt = 1;
3179  ec_alloced = false;
3180  return *this;
3181 }
3182 
3183 predictor& predictor::set_input(example* input_example, size_t input_length)
3184 {
3185  free_ec();
3186  is_ldf = true;
3187  ec = input_example;
3188  ec_cnt = input_length;
3189  ec_alloced = false;
3190  return *this;
3191 }
3192 
3193 void predictor::set_input_length(size_t input_length)
3194 {
3195  is_ldf = true;
3196  if (ec_alloced)
3197  {
3198  example* temp = (example*)realloc(ec, input_length * sizeof(example));
3199  if (temp != nullptr)
3200  ec = temp;
3201  else
3202  THROW("realloc failed in search.cc");
3203  }
3204  else
3205  ec = calloc_or_throw<example>(input_length);
3206  ec_cnt = input_length;
3207  ec_alloced = true;
3208 }
3209 void predictor::set_input_at(size_t posn, example& ex)
3210 {
3211  if (!ec_alloced)
3212  THROW("call to set_input_at without previous call to set_input_length");
3213 
3214  if (posn >= ec_cnt)
3215  THROW("call to set_input_at with too large a position: posn (" << posn << ") >= ec_cnt(" << ec_cnt << ")");
3216 
3218  false, ec + posn, &ex, CS::cs_label.label_size, CS::cs_label.copy_label); // TODO: the false is "audit"
3219 }
3220 
3221 template <class T>
3222 void predictor::make_new_pointer(v_array<T>& A, size_t new_size)
3223 {
3224  size_t old_size = A.size();
3225  T* old_pointer = A.begin();
3226  A.begin() = calloc_or_throw<T>(new_size);
3227  A.end() = A.begin() + new_size;
3228  A.end_array = A.end();
3229  memcpy(A.begin(), old_pointer, old_size * sizeof(T));
3230 }
3231 
3232 template <class T>
3233 predictor& predictor::add_to(v_array<T>& A, bool& A_is_ptr, T a, bool clear_first)
3234 {
3235  if (A_is_ptr) // we need to make our own memory
3236  {
3237  if (clear_first)
3238  A.end() = A.begin();
3239  size_t new_size = clear_first ? 1 : (A.size() + 1);
3240  make_new_pointer<T>(A, new_size);
3241  A_is_ptr = false;
3242  A[new_size - 1] = a;
3243  }
3244  else // we've already allocated our own memory
3245  {
3246  if (clear_first)
3247  A.clear();
3248  A.push_back(a);
3249  }
3250  return *this;
3251 }
3252 
3253 template <class T>
3254 predictor& predictor::add_to(v_array<T>& A, bool& A_is_ptr, T* a, size_t count, bool clear_first)
3255 {
3256  size_t old_size = A.size();
3257  if (old_size > 0)
3258  {
3259  if (A_is_ptr) // we need to make our own memory
3260  {
3261  if (clear_first)
3262  {
3263  A.end() = A.begin();
3264  old_size = 0;
3265  }
3266  size_t new_size = old_size + count;
3267  make_new_pointer<T>(A, new_size);
3268  A_is_ptr = false;
3269  if (a != nullptr)
3270  memcpy(A.begin() + old_size, a, count * sizeof(T));
3271  }
3272  else // we already have our own memory
3273  {
3274  if (clear_first)
3275  A.clear();
3276  if (a != nullptr)
3277  push_many<T>(A, a, count);
3278  }
3279  }
3280  else // old_size == 0, clear_first is irrelevant
3281  {
3282  if (!A_is_ptr)
3283  A.delete_v(); // avoid memory leak
3284 
3285  A.begin() = a;
3286  if (a != nullptr) // a is not nullptr
3287  A.end() = a + count;
3288  else
3289  A.end() = a;
3290  A.end_array = A.end();
3291  A_is_ptr = true;
3292  }
3293  return *this;
3294 }
3295 
3297 {
3298  if (oracle_is_pointer)
3300  else
3302  return *this;
3303 }
3305 predictor& predictor::add_oracle(action* a, size_t action_count)
3306 {
3307  return add_to(oracle_actions, oracle_is_pointer, a, action_count, false);
3308 }
3310 {
3311  return add_to(oracle_actions, oracle_is_pointer, a.begin(), a.size(), false);
3312 }
3313 
3315 predictor& predictor::set_oracle(action* a, size_t action_count)
3316 {
3317  return add_to(oracle_actions, oracle_is_pointer, a, action_count, true);
3318 }
3320 {
3321  return add_to(oracle_actions, oracle_is_pointer, a.begin(), a.size(), true);
3322 }
3323 
3325 {
3326  weight = w;
3327  return *this;
3328 }
3329 
3331 {
3332  if (allowed_is_pointer)
3334  else
3338  else
3340  return *this;
3341 }
3343 predictor& predictor::add_allowed(action* a, size_t action_count)
3344 {
3345  return add_to(allowed_actions, allowed_is_pointer, a, action_count, false);
3346 }
3348 {
3349  return add_to(allowed_actions, allowed_is_pointer, a.begin(), a.size(), false);
3350 }
3351 
3353 predictor& predictor::set_allowed(action* a, size_t action_count)
3354 {
3355  return add_to(allowed_actions, allowed_is_pointer, a, action_count, true);
3356 }
3358 {
3359  return add_to(allowed_actions, allowed_is_pointer, a.begin(), a.size(), true);
3360 }
3361 
3363 {
3365  return add_to(allowed_actions, allowed_is_pointer, a, false);
3366 }
3367 
3368 predictor& predictor::add_allowed(action* a, float* costs, size_t action_count)
3369 {
3370  add_to(allowed_actions_cost, allowed_cost_is_pointer, costs, action_count, false);
3371  return add_to(allowed_actions, allowed_is_pointer, a, action_count, false);
3372 }
3373 predictor& predictor::add_allowed(v_array<std::pair<action, float>>& a)
3374 {
3375  for (size_t i = 0; i < a.size(); i++)
3376  {
3377  add_to(allowed_actions, allowed_is_pointer, a[i].first, false);
3379  }
3380  return *this;
3381 }
3382 predictor& predictor::add_allowed(std::vector<std::pair<action, float>>& a)
3383 {
3384  for (size_t i = 0; i < a.size(); i++)
3385  {
3386  add_to(allowed_actions, allowed_is_pointer, a[i].first, false);
3388  }
3389  return *this;
3390 }
3391 
3393 {
3395  return add_to(allowed_actions, allowed_is_pointer, a, true);
3396 }
3397 
3398 predictor& predictor::set_allowed(action* a, float* costs, size_t action_count)
3399 {
3400  add_to(allowed_actions_cost, allowed_cost_is_pointer, costs, action_count, true);
3401  return add_to(allowed_actions, allowed_is_pointer, a, action_count, true);
3402 }
3403 predictor& predictor::set_allowed(v_array<std::pair<action, float>>& a)
3404 {
3405  erase_alloweds();
3406  return add_allowed(a);
3407 }
3408 predictor& predictor::set_allowed(std::vector<std::pair<action, float>>& a)
3409 {
3410  erase_alloweds();
3411  return add_allowed(a);
3412 }
3413 
3415 {
3418  return *this;
3419 }
3421 {
3424  return add_condition(tag, name);
3425 }
3426 
3428 {
3429  if (count == 0)
3430  return *this;
3431  for (ptag i = 0; i < count; i++)
3432  {
3433  if (i > hi)
3434  break;
3435  char name = name0 + i;
3436  condition_on_tags.push_back(hi - i);
3438  }
3439  return *this;
3440 }
3442 {
3445  return add_condition_range(hi, count, name0);
3446 }
3447 
3449 {
3450  learner_id = id;
3451  return *this;
3452 }
3453 
3455 {
3456  my_tag = tag;
3457  return *this;
3458 }
3459 
3461 {
3462  const action* orA = oracle_actions.size() == 0 ? nullptr : oracle_actions.begin();
3463  const ptag* cOn = condition_on_names.size() == 0 ? nullptr : condition_on_tags.begin();
3464  const char* cNa = nullptr;
3465  if (condition_on_names.size() > 0)
3466  {
3467  condition_on_names.push_back((char)0); // null terminate
3468  cNa = condition_on_names.begin();
3469  }
3470  const action* alA = (allowed_actions.size() == 0) ? nullptr : allowed_actions.begin();
3471  const float* alAcosts = (allowed_actions_cost.size() == 0) ? nullptr : allowed_actions_cost.begin();
3472  size_t numAlA = std::max(allowed_actions.size(), allowed_actions_cost.size());
3473  action p = is_ldf
3474  ? sch.predictLDF(ec, ec_cnt, my_tag, orA, oracle_actions.size(), cOn, cNa, learner_id, weight)
3475  : sch.predict(*ec, my_tag, orA, oracle_actions.size(), cOn, cNa, alA, numAlA, alAcosts, learner_id, weight);
3476 
3477  if (condition_on_names.size() > 0)
3478  condition_on_names.pop(); // un-null-terminate
3479  return p;
3480 }
3481 } // namespace Search
3482 
3483 // TODO: valgrind --leak-check=full ./vw --search 2 -k -c --passes 1 --search_task sequence -d test_beam --holdout_off
3484 // --search_rollin policy --search_metatask selective_branching 2>&1 | less
void add_example_conditioning(search_private &priv, example &ec, size_t condition_on_cnt, const char *condition_on_names, action_repr *condition_on_actions)
Definition: search.cc:784
void hoopla_permute(size_t *B, size_t *end)
Definition: search.cc:1948
void free_final_item(final_item *p)
Definition: search.cc:2056
int int_of_substring(substring s)
double sum_loss
Definition: global_data.h:145
predictor & set_oracle(action a)
Definition: search.cc:3314
size_t ec_cnt
Definition: search.h:336
void resize(size_t length)
Definition: v_array.h:69
v_array< char > tag
Definition: example.h:63
void copy_label(void *dst, void *src)
Definition: cb.cc:104
constexpr unsigned char conditioning_namespace
Definition: constant.h:29
void to_short_string(std::string in, size_t max_len, char *out)
Definition: search.cc:499
v_array< CS::label > read_allowed_transitions(action A, const char *filename)
Definition: search.cc:2579
#define cdbg
Definition: search.h:11
bool cmp_size_t_pair(const std::pair< size_t, size_t > &a, const std::pair< size_t, size_t > &b)
Definition: search.cc:1941
int raw_prediction
Definition: global_data.h:519
v_array< namespace_index > indices
RollMethod rollin_method
Definition: search.cc:217
uint32_t multiclass
Definition: example.h:49
void * searchstr
Definition: global_data.h:430
size_t loss_declared_cnt
Definition: search.cc:186
parameters weights
Definition: global_data.h:537
void(* copy_label)(void *, void *)
Definition: label_parser.h:18
auto_condition_settings acset
Definition: search.cc:152
Search::search_task task
void predict(E &ec, size_t i=0)
Definition: learner.h:169
final_item(v_array< scored_action > *p, std::string s, float ic)
Definition: search.cc:2053
search * sch
Definition: search.h:71
std::stringstream dat_new_feature_audit_ss
Definition: search.cc:244
bool must_run_test(vw &all, multi_ex &ec, bool is_test_ex)
Definition: search.cc:473
vw * setup(options_i &options)
Definition: main.cc:27
void deep_copy_from(const features &src)
action * acts
Definition: search.cc:74
v_array< char > condition_on_names
Definition: search.h:342
T pop()
Definition: v_array.h:58
size_t read_example_last_pass
Definition: search.cc:234
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
std::ostream & operator<<(std::ostream &os, const action_cache &x)
Definition: search.cc:131
std::string audit_feature_space("conditional")
void train_single_example(search &sch, bool is_test_ex, bool is_holdout_ex, multi_ex &ec_seq)
Definition: search.cc:2171
void push_back(feature_value v, feature_index i)
bool mc_label_is_test(polylabel &lab)
Definition: search.cc:2484
Search::search_task task
VW::config::options_i * options
Definition: global_data.h:428
size_t rollout_num_steps
Definition: search.cc:163
std::stringstream * pred_string
Definition: search.cc:208
search_metatask * metatask
Definition: search.cc:269
label_parser cs_label
example * dat_new_feature_ec
Definition: search.cc:243
v_array< float > allowed_actions_cost
Definition: search.h:345
void(* delete_label)(void *)
Definition: label_parser.h:16
CS::label ldf_test_label
Definition: search.cc:252
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
std::shared_ptr< audit_strings > audit_strings_ptr
Definition: feature_group.h:23
float action_cost_loss(action a, const action *act, const float *costs, size_t sz)
Definition: search.cc:2954
void search_initialize(vw *all, search &sch)
Definition: search.cc:2486
void push_at(v_array< T > &v, T item, size_t pos)
Definition: search.cc:1074
bool might_print_update(vw &all)
Definition: search.cc:461
char * end
Definition: hashstring.h:10
Definition: search.cc:33
char * begin
Definition: hashstring.h:9
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
search_metatask * all_metatasks[]
Definition: search.cc:39
constexpr int quadratic_constant
Definition: constant.h:7
bool cached_action_store_or_find(search_private &priv, ptag mytag, const ptag *condition_on, const char *condition_on_names, action_repr *condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store, float &a_cost)
Definition: search.cc:1438
CB::label cb
Definition: example.h:31
size_t learner_id
Definition: search.h:347
v_array< feature_index > indicies
predictor & add_condition(ptag tag, char name)
Definition: search.cc:3414
void adjust_auto_condition(search_private &priv)
Definition: search.cc:2364
v_array< v_array< std::pair< CS::wclass &, bool > > > active_known
Definition: search.cc:258
void(* finish)(search &)
Definition: search.h:237
void(* default_label)(void *)
Definition: label_parser.h:12
SearchState
Definition: search.cc:51
SearchState state
Definition: search.cc:158
virtual void replace(const std::string &key, const std::string &value)=0
label_type::label_type_t label_type
Definition: global_data.h:550
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
void set_input_at(size_t posn, example &input_example)
Definition: search.cc:3209
bool(* test_label)(void *)
Definition: label_parser.h:22
void set_input_length(size_t input_length)
Definition: search.cc:3193
constexpr unsigned char neighbor_namespace
Definition: constant.h:25
void del_neighbor_features(search_private &priv, multi_ex &ec_seq)
Definition: search.cc:683
void del_example_namespaces_from_example(example &target, example &source)
v_array< int > final_prediction_sink
Definition: global_data.h:518
v_array< scored_action > * prefix
Definition: search.cc:2050
void(* run_setup)(search &, multi_ex &)
Definition: search.h:238
void delete_v()
void run_task(search &sch, multi_ex &ec)
Definition: search.cc:2098
the core definition of a set of features.
v_array< cb_class > costs
Definition: cb.h:27
v_hashmap< unsigned char *, scored_action > cache_hash_map
Definition: search.cc:239
namedlabels * ldict
Definition: global_data.h:153
void generate_training_example(search_private &priv, polylabel &losses, float weight, bool add_conditioning=true, float min_loss=FLT_MAX)
Definition: search.cc:1490
action_repr(action _a)
Definition: search.cc:117
void parse_neighbor_features(std::string &nf_string, search &sch)
Definition: search.cc:2627
Search::search_task task
v_array< ptag > learn_condition_on
Definition: search.cc:174
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
search_private * priv
Definition: search.h:216
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
v_array< ptag > condition_on_tags
Definition: search.h:341
void delete_label(void *v)
Definition: cb.cc:98
polylabel * allowed_actions_cache
Definition: search.cc:184
uint32_t action
Definition: search.h:19
float partial_prediction
Definition: example.h:68
bool quiet
Definition: global_data.h:487
v_array< feature_value > values
int random_policy(search_private &priv, bool allow_current, bool allow_optimal, bool advance_prng=true)
Definition: search.cc:376
search_task * all_tasks[]
Definition: search.cc:35
float safediv(float a, float b)
Definition: search.cc:491
virtual void add_and_parse(const option_group_definition &group)=0
void(* run)(search &, multi_ex &)
Definition: search.h:245
action single_prediction_notLDF(search_private &priv, example &ec, int policy, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, float &a_cost, action override_action)
Definition: search.cc:1163
predictor & set_weight(float w)
Definition: search.cc:3324
action predict()
Definition: search.cc:3460
polylabel gte_label
Definition: search.cc:256
bool ec_is_example_header(example const &ec)
polylabel learn_losses
Definition: search.cc:255
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
void do_actual_learning(search &sch, base_learner &base, multi_ex &ec_seq)
Definition: search.cc:2378
Search::search_task task
void handle_condition_options(vw &all, auto_condition_settings &acset)
Definition: search.cc:2543
bool allowed_is_pointer
Definition: search.h:344
constexpr bool PRINT_UPDATE_EVERY_EXAMPLE
Definition: search.cc:42
bool holdout_set_off
Definition: global_data.h:499
size_t check_holdout_every_n_passes
Definition: global_data.h:503
void(* run)(search &, multi_ex &)
Definition: search.h:233
void advance_from_known_actions(search_private &priv)
Definition: search.cc:2133
Search::search_task task
bool auto_condition_features
Definition: search.cc:145
void(* run_takedown)(search &, multi_ex &)
Definition: search.h:239
uint64_t dat_new_feature_idx
Definition: search.cc:242
int select_learner(search_private &priv, int policy, size_t learner_id, bool is_training, bool is_local)
Definition: search.cc:432
T *& begin()
Definition: v_array.h:42
bool progress_add
Definition: global_data.h:545
v_array< size_t > timesteps
Definition: search.cc:254
std::string rawOutputString
Definition: search.cc:250
bool training
Definition: global_data.h:488
size_t size() const
Definition: v_array.h:68
uint32_t ACTION_COSTS
Definition: search.cc:50
void(* _foreach_action)(search &, size_t, float, action, bool, float)
Definition: search.h:74
LEARNER::base_learner * base_learner
Definition: search.cc:263
void cs_cost_push_back(bool isCB, polylabel &ld, uint32_t index, float value)
Definition: search.cc:923
RollMethod rollout_method
Definition: search.cc:216
void save_predictor(vw &all, std::string reg_name, size_t current_pass)
double sum_loss_since_last_dump
Definition: global_data.h:146
label_parser mc_label
Definition: multiclass.cc:93
bool allowed_cost_is_pointer
Definition: search.h:346
void(* _post_prediction)(search &, size_t, action, float)
Definition: search.h:75
parser * p
Definition: global_data.h:377
uint32_t NO_CACHING
Definition: search.cc:49
v_array< action_repr > ptag_to_action
Definition: search.cc:178
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
std::array< features, NUM_NAMESPACES > feature_space
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
size_t total_examples_generated
Definition: search.cc:235
void clear_cache_hash_map(search_private &priv)
Definition: search.cc:278
predictor & add_oracle(action a)
Definition: search.cc:3304
MULTICLASS::label_t multi
Definition: example.h:29
size_t size() const
predictor & add_condition_range(ptag hi, ptag count, char name0)
Definition: search.cc:3427
example * ec
Definition: search.h:335
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
CS::label empty_cs_label
Definition: search.cc:266
void clear_memo_foreach_action(search_private &priv)
Definition: search.cc:284
float dat_new_feature_value
Definition: search.cc:247
void reset_search_structure(search_private &priv)
Definition: search.cc:690
std::string str
Definition: search.cc:2051
bool(* label_is_test)(polylabel &)
Definition: search.cc:167
size_t passes_since_new_policy
Definition: search.cc:233
Search::search_metatask metatask
Definition: search_meta.cc:18
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
size_t num_calls_to_run_previous
Definition: search.cc:200
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
predictor & erase_alloweds()
Definition: search.cc:3330
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
void end_pass(example &ec, vw &all)
Definition: learner.cc:44
T powf(T, T)
Definition: lda_core.cc:428
size_t random(std::shared_ptr< rand_state > &rs, size_t max)
Definition: search.cc:768
action search_predict(search_private &priv, example *ecs, size_t ec_cnt, ptag mytag, const action *oracle_actions, size_t oracle_actions_cnt, const ptag *condition_on, const char *condition_on_names, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, size_t learner_id, float &a_cost, float)
Definition: search.cc:1652
void print_update(search_private &priv)
Definition: search.cc:527
float id(float in)
Definition: scorer.cc:51
std::string neighbor_feature_space("neighbor")
v_array< v_array< action_cache > * > memo_foreach_action
Definition: search.cc:274
float progress_arg
Definition: global_data.h:546
bool cmp_size_t(const size_t a, const size_t b)
Definition: search.cc:1940
size_t total_predictions_made
Definition: search.cc:236
bool vw_is_main
Definition: global_data.h:421
void clear()
Definition: v_array.h:88
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void cs_costs_erase(bool isCB, polylabel &ld)
Definition: search.cc:907
void add_neighbor_features(search_private &priv, multi_ex &ec_seq)
Definition: search.cc:631
double old_weighted_labeled_examples
Definition: global_data.h:142
void foreach_action_from_cache(search_private &priv, size_t t, action override_a=(action) -1)
Definition: search.cc:1634
bool bfgs
Definition: global_data.h:412
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
double weighted_holdout_examples
Definition: global_data.h:156
std::string * dat_new_feature_feature_space
Definition: search.cc:246
void search_declare_loss(search_private &priv, float loss)
Definition: search.cc:730
double holdout_sum_loss
Definition: global_data.h:159
size_t absdiff(size_t a, size_t b)
Definition: search.cc:1946
uint32_t IS_LDF
Definition: search.cc:49
bool need_memo_foreach_action(search_private &priv)
Definition: search.cc:370
action predictLDF(example *ecs, size_t ec_cnt, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, size_t learner_id=0, float weight=0.)
Definition: search.cc:3003
predictor & reset()
Definition: search.cc:3163
std::vector< action > test_action_sequence
Definition: search.cc:179
size_t read_example_last_id
Definition: search.cc:232
std::stringstream * truth_string
Definition: search.cc:209
search & sch
Definition: search.h:348
v_array< action_repr > condition_on_actions
Definition: search.cc:253
float action_hamming_loss(action a, const action *A, size_t sz)
Definition: search.cc:2944
void clear()
predictor & add_allowed(action a)
Definition: search.cc:3342
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
uint64_t current_pass
Definition: global_data.h:396
predictor & add_to(v_array< T > &A, bool &A_is_ptr, T a, bool clear_first)
Definition: search.cc:3233
predictor & set_learner_id(size_t id)
Definition: search.cc:3448
bool search_predictNeedsExample(search_private &priv)
Definition: search.cc:1601
bool cached_item_equivalent(unsigned char *const &A, unsigned char *const &B)
Definition: search.cc:1429
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
Definition: search.cc:2967
v_array< action > oracle_actions
Definition: search.h:339
std::string number_to_natural(size_t big)
Definition: search.cc:512
predictor & set_allowed(action a)
Definition: search.cc:3352
std::shared_ptr< rand_state > _random_state
Definition: search.cc:142
search_task * task
Definition: search.cc:268
v_array< action > learn_allowed_actions
Definition: search.cc:177
void * task_data
Definition: search.h:217
void finish_example(vw &, example &)
Definition: parser.cc:881
bool should_print_update(vw &all, bool hit_new_pass=false)
Definition: search.cc:449
void del_features_in_top_namespace(search_private &, example &ec, size_t ns)
Definition: search.cc:613
T *& end()
Definition: v_array.h:43
virtual void insert(const std::string &key, const std::string &value)=0
size_t numpasses
Definition: global_data.h:451
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
scored_action(action _a=(action) -1, float _s=0)
Definition: search.cc:93
void cdbg_print_array(std::string str, v_array< T > &A)
Definition: search.cc:754
RollMethod
Definition: search.cc:59
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
Search::search_task task
uint64_t example_number
Definition: global_data.h:137
void add_new_feature(search_private &priv, float val, uint64_t idx)
Definition: search.cc:596
void search_finish(search &sch)
Definition: search.cc:2565
std::vector< example * > multi_ex
Definition: example.h:122
void end_examples(search &sch)
Definition: search.cc:2464
v_array< audit_strings_ptr > space_names
label_parser cb_label
Definition: cb.cc:167
v_array< example > learn_ec_copy
Definition: search.cc:171
action learn_oracle_action
Definition: search.cc:181
void cs_set_cost_loss(bool isCB, polylabel &ld, size_t k, float val)
Definition: search.cc:899
int choose_policy(search_private &priv, bool advance_prng=true)
Definition: search.cc:1401
uint32_t AUTO_HAMMING_LOSS
Definition: search.cc:49
void(* finish)(search &)
Definition: search.h:249
bool array_contains(T target, const T *A, size_t n)
Definition: search.cc:773
v_array< std::pair< float, size_t > > active_uncertainty
Definition: search.cc:257
void make_new_pointer(v_array< T > &A, size_t new_size)
Definition: search.cc:3222
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
MULTILABEL::labels multilabels
Definition: example.h:50
void ensure_size(v_array< T > &A, size_t sz)
Definition: search.cc:1066
size_t passes_per_policy
Definition: search.cc:225
predictor & set_condition_range(ptag hi, ptag count, char name0)
Definition: search.cc:3441
Search::search_task task
Definition: search_graph.cc:63
void del_example_conditioning(search_private &priv, example &ec)
Definition: search.cc:881
bool in_use
Definition: example.h:79
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float total_sum_feat_sq
Definition: example.h:71
features * passthrough
Definition: example.h:74
float sum_feat_sq
action_repr(action _a, features *_repr)
Definition: search.cc:107
v_array< uint32_t > label_v
Definition: multilabel.h:16
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
constexpr bool PRINT_CLOCK_TIME
Definition: search.cc:44
void set_finish(void(*f)(T &))
Definition: learner.h:265
example * learn_ec_ref
Definition: search.cc:172
void verify_active_csoaa(COST_SENSITIVE::label &losses, v_array< std::pair< CS::wclass &, bool >> &known, size_t t, float multiplier)
Definition: search.cc:2108
std::stringstream * bad_string_stream
Definition: search.cc:210
void allowed_actions_to_label(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, const action *oracle_actions, size_t oracle_actions_cnt, polylabel &lab)
Definition: search.cc:991
predictor & set_tag(ptag tag)
Definition: search.cc:3454
std::string condition_feature_space("search_condition")
void ensure_param(float &v, float lo, float hi, float def, const char *str)
Definition: search.cc:2534
action choose_oracle_action(search_private &priv, size_t ec_cnt, const action *oracle_actions, size_t oracle_actions_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
Definition: search.cc:1097
uint32_t stride_shift()
uint64_t get(substring &s)
Definition: global_data.h:108
features last_action_repr
Definition: search.cc:182
std::stringstream * rawOutputStringStream
Definition: search.cc:251
v_array< scored_action > train_trajectory
Definition: search.cc:187
void cs_costs_resize(bool isCB, polylabel &ld, size_t new_size)
Definition: search.cc:915
double weighted_labeled_examples
Definition: global_data.h:141
bool oracle_is_pointer
Definition: search.h:340
double weighted_holdout_examples_since_last_dump
Definition: global_data.h:157
uint32_t hash
Definition: search.cc:75
bool audit
Definition: global_data.h:486
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
Search::search_task task
uint32_t current_policy
Definition: search.cc:227
Search::search_metatask metatask
Definition: search_meta.cc:50
action_cache(float _min_cost, action _k, bool _is_opt, float _cost)
Definition: search.cc:126
T last() const
Definition: v_array.h:57
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
predictor & set_condition(ptag tag, char name)
Definition: search.cc:3420
polyprediction pred
Definition: example.h:60
void cerr_print_array(std::string str, v_array< T > &A)
Definition: search.cc:761
double holdout_sum_loss_since_last_dump
Definition: global_data.h:158
features * repr
Definition: search.cc:106
uint64_t conditional_constant
Definition: search.cc:368
void delete_v()
Definition: v_array.h:98
uint32_t ptag
Definition: search.h:20
constexpr bool PRINT_UPDATE_EVERY_PASS
Definition: search.cc:43
void learn(E &ec, size_t i=0)
Definition: learner.h:160
uint32_t EXAMPLES_DONT_CHANGE
Definition: search.cc:49
bool(* _maybe_override_prediction)(search &, size_t, action &, float &)
Definition: search.h:76
v_array< action > allowed_actions
Definition: search.h:343
float cs_get_cost_partial_prediction(bool isCB, polylabel &ld, size_t k)
Definition: search.cc:894
bool v_array_contains(v_array< T > &A, T x)
Definition: v_array.h:237
std::string final_regressor_name
Definition: global_data.h:535
v_array< wclass > costs
size_t dat_new_feature_namespace
Definition: search.cc:245
void free_key(unsigned char *mem, scored_action)
Definition: search.cc:277
v_array< int32_t > neighbor_features
Definition: search.cc:151
void get_training_timesteps(search_private &priv, v_array< size_t > &timesteps)
Definition: search.cc:1983
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147
void decr()
Definition: v_array.h:60
size_t save_every_k_runs
Definition: search.cc:200
uint64_t mask()
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
uint32_t fid
Definition: ezexample.h:6
bool emptylines_separate_examples
Definition: parser.h:84
predictor & set_input(example &input_example)
Definition: search.cc:3173
float f
Definition: cache.cc:40
uint32_t total_number_of_policies
Definition: search.cc:231
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
v_array< action_repr > learn_condition_on_act
Definition: search.cc:175
action single_prediction_LDF(search_private &priv, example *ecs, size_t ec_cnt, int policy, float &a_cost, action override_action)
Definition: search.cc:1310
BaseTask * metaoverride
Definition: search.cc:270
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
clock_t start_clock_time
Definition: search.cc:264
v_array< char > learn_condition_on_names
Definition: search.cc:176
void set_end_examples(void(*f)(T &))
Definition: learner.h:295
label_parser lp
Definition: parser.h:102
Search::search_task task
polylabel & allowed_actions_to_ld(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
Definition: search.cc:937
T * end_array
Definition: v_array.h:38
void add_example_namespaces_from_example(example &target, example &source)
void update_dump_interval(bool progress_add, float progress_arg)
Definition: global_data.h:215
bool test_only
Definition: example.h:76
std::pair< std::string, std::string > audit_strings
Definition: feature_group.h:22
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887
uint32_t cs_get_cost_index(bool isCB, polylabel &ld, size_t k)
Definition: search.cc:889
predictor & erase_oracles()
Definition: search.cc:3296