Vowpal Wabbit
Classes | Enumerations | Functions | Variables
Search Namespace Reference

Classes

struct  action_cache
 
struct  action_repr
 
struct  auto_condition_settings
 
class  BaseTask
 
struct  final_item
 
struct  prediction
 
class  predictor
 
struct  scored_action
 
struct  search
 
struct  search_metatask
 
struct  search_private
 
struct  search_task
 

Enumerations

enum  SearchState {
  INITIALIZE, INIT_TEST, INIT_TRAIN, LEARN,
  GET_TRUTH_STRING
}
 
enum  RollMethod {
  POLICY, ORACLE, MIX_PER_STATE, MIX_PER_ROLL,
  NO_ROLLOUT
}
 

Functions

std::string neighbor_feature_space ("neighbor")
 
std::string condition_feature_space ("search_condition")
 
std::ostream & operator<< (std::ostream &os, const scored_action &x)
 
std::ostream & operator<< (std::ostream &os, const action_cache &x)
 
void free_key (unsigned char *mem, scored_action)
 
void clear_cache_hash_map (search_private &priv)
 
void clear_memo_foreach_action (search_private &priv)
 
std::string audit_feature_space ("conditional")
 
bool need_memo_foreach_action (search_private &priv)
 
int random_policy (search_private &priv, bool allow_current, bool allow_optimal, bool advance_prng=true)
 
int select_learner (search_private &priv, int policy, size_t learner_id, bool is_training, bool is_local)
 
bool should_print_update (vw &all, bool hit_new_pass=false)
 
bool might_print_update (vw &all)
 
bool must_run_test (vw &all, multi_ex &ec, bool is_test_ex)
 
float safediv (float a, float b)
 
void to_short_string (std::string in, size_t max_len, char *out)
 
std::string number_to_natural (size_t big)
 
void print_update (search_private &priv)
 
void add_new_feature (search_private &priv, float val, uint64_t idx)
 
void del_features_in_top_namespace (search_private &, example &ec, size_t ns)
 
void add_neighbor_features (search_private &priv, multi_ex &ec_seq)
 
void del_neighbor_features (search_private &priv, multi_ex &ec_seq)
 
void reset_search_structure (search_private &priv)
 
void search_declare_loss (search_private &priv, float loss)
 
template<class T >
void cdbg_print_array (std::string str, v_array< T > &A)
 
template<class T >
void cerr_print_array (std::string str, v_array< T > &A)
 
size_t random (std::shared_ptr< rand_state > &rs, size_t max)
 
template<class T >
bool array_contains (T target, const T *A, size_t n)
 
void add_example_conditioning (search_private &priv, example &ec, size_t condition_on_cnt, const char *condition_on_names, action_repr *condition_on_actions)
 
void del_example_conditioning (search_private &priv, example &ec)
 
size_t cs_get_costs_size (bool isCB, polylabel &ld)
 
uint32_t cs_get_cost_index (bool isCB, polylabel &ld, size_t k)
 
float cs_get_cost_partial_prediction (bool isCB, polylabel &ld, size_t k)
 
void cs_set_cost_loss (bool isCB, polylabel &ld, size_t k, float val)
 
void cs_costs_erase (bool isCB, polylabel &ld)
 
void cs_costs_resize (bool isCB, polylabel &ld, size_t new_size)
 
void cs_cost_push_back (bool isCB, polylabel &ld, uint32_t index, float value)
 
polylabelallowed_actions_to_ld (search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
 
void allowed_actions_to_label (search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, const action *oracle_actions, size_t oracle_actions_cnt, polylabel &lab)
 
template<class T >
void ensure_size (v_array< T > &A, size_t sz)
 
template<class T >
void push_at (v_array< T > &v, T item, size_t pos)
 
action choose_oracle_action (search_private &priv, size_t ec_cnt, const action *oracle_actions, size_t oracle_actions_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
 
action single_prediction_notLDF (search_private &priv, example &ec, int policy, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, float &a_cost, action override_action)
 
action single_prediction_LDF (search_private &priv, example *ecs, size_t ec_cnt, int policy, float &a_cost, action override_action)
 
int choose_policy (search_private &priv, bool advance_prng=true)
 
bool cached_item_equivalent (unsigned char *const &A, unsigned char *const &B)
 
bool cached_action_store_or_find (search_private &priv, ptag mytag, const ptag *condition_on, const char *condition_on_names, action_repr *condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store, float &a_cost)
 
void generate_training_example (search_private &priv, polylabel &losses, float weight, bool add_conditioning=true, float min_loss=FLT_MAX)
 
bool search_predictNeedsExample (search_private &priv)
 
void foreach_action_from_cache (search_private &priv, size_t t, action override_a=(action) -1)
 
action search_predict (search_private &priv, example *ecs, size_t ec_cnt, ptag mytag, const action *oracle_actions, size_t oracle_actions_cnt, const ptag *condition_on, const char *condition_on_names, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, size_t learner_id, float &a_cost, float)
 
bool cmp_size_t (const size_t a, const size_t b)
 
bool cmp_size_t_pair (const std::pair< size_t, size_t > &a, const std::pair< size_t, size_t > &b)
 
size_t absdiff (size_t a, size_t b)
 
void hoopla_permute (size_t *B, size_t *end)
 
void get_training_timesteps (search_private &priv, v_array< size_t > &timesteps)
 
void free_final_item (final_item *p)
 
void run_task (search &sch, multi_ex &ec)
 
void verify_active_csoaa (COST_SENSITIVE::label &losses, v_array< std::pair< CS::wclass &, bool >> &known, size_t t, float multiplier)
 
void advance_from_known_actions (search_private &priv)
 
template<bool is_learn>
void train_single_example (search &sch, bool is_test_ex, bool is_holdout_ex, multi_ex &ec_seq)
 
void adjust_auto_condition (search_private &priv)
 
template<bool is_learn>
void do_actual_learning (search &sch, base_learner &base, multi_ex &ec_seq)
 
void end_pass (search &sch)
 
void finish_multiline_example (vw &all, search &sch, multi_ex &ec_seq)
 
void end_examples (search &sch)
 
bool mc_label_is_test (polylabel &lab)
 
void search_initialize (vw *all, search &sch)
 
void ensure_param (float &v, float lo, float hi, float def, const char *str)
 
void handle_condition_options (vw &all, auto_condition_settings &acset)
 
void search_finish (search &sch)
 
v_array< CS::labelread_allowed_transitions (action A, const char *filename)
 
void parse_neighbor_features (std::string &nf_string, search &sch)
 
base_learnersetup (options_i &options, vw &all)
 
float action_hamming_loss (action a, const action *A, size_t sz)
 
float action_cost_loss (action a, const action *act, const float *costs, size_t sz)
 
bool string_equal (std::string a, std::string b)
 
bool float_equal (float a, float b)
 
bool uint32_equal (uint32_t a, uint32_t b)
 
bool size_equal (size_t a, size_t b)
 

Variables

search_taskall_tasks []
 
search_metataskall_metatasks []
 
constexpr bool PRINT_UPDATE_EVERY_EXAMPLE = false
 
constexpr bool PRINT_UPDATE_EVERY_PASS = false
 
constexpr bool PRINT_CLOCK_TIME = false
 
uint32_t AUTO_CONDITION_FEATURES = 1
 
uint32_t AUTO_HAMMING_LOSS = 2
 
uint32_t EXAMPLES_DONT_CHANGE = 4
 
uint32_t IS_LDF = 8
 
uint32_t NO_CACHING = 16
 
uint32_t ACTION_COSTS = 32
 
uint64_t conditional_constant = 8290743
 

Enumeration Type Documentation

◆ RollMethod

Enumerator
POLICY 
ORACLE 
MIX_PER_STATE 
MIX_PER_ROLL 
NO_ROLLOUT 

Definition at line 59 of file search.cc.

◆ SearchState

Enumerator
INITIALIZE 
INIT_TEST 
INIT_TRAIN 
LEARN 
GET_TRUTH_STRING 

Definition at line 51 of file search.cc.

Function Documentation

◆ absdiff()

size_t Search::absdiff ( size_t  a,
size_t  b 
)
inline

Definition at line 1946 of file search.cc.

References a.

Referenced by average_diff(), and hoopla_permute().

1946 { return (a < b) ? (b - a) : (a - b); }
constexpr uint64_t a
Definition: rand48.cc:11

◆ action_cost_loss()

float Search::action_cost_loss ( action  a,
const action act,
const float *  costs,
size_t  sz 
)

Definition at line 2954 of file search.cc.

References THROW.

Referenced by Search::search::predict().

2955 {
2956  if (act == nullptr)
2957  return costs[a - 1];
2958  for (size_t i = 0; i < sz; i++)
2959  if (act[i] == a)
2960  return costs[i];
2961  THROW("action_cost_loss got action that wasn't allowed: " << a);
2962 }
constexpr uint64_t a
Definition: rand48.cc:11
#define THROW(args)
Definition: vw_exception.h:181

◆ action_hamming_loss()

float Search::action_hamming_loss ( action  a,
const action A,
size_t  sz 
)

Definition at line 2944 of file search.cc.

Referenced by Search::search::predict(), and Search::search::predictLDF().

2945 {
2946  if (sz == 0)
2947  return 0.; // latent variables have zero loss
2948  for (size_t i = 0; i < sz; i++)
2949  if (a == A[i])
2950  return 0.;
2951  return 1.;
2952 }
constexpr uint64_t a
Definition: rand48.cc:11

◆ add_example_conditioning()

void Search::add_example_conditioning ( search_private priv,
example ec,
size_t  condition_on_cnt,
const char *  condition_on_names,
action_repr condition_on_actions 
)

Definition at line 784 of file search.cc.

References Search::action_repr::a, Search::search_private::acset, add_new_feature(), Search::search_private::all, vw::audit, cdbg, features::clear(), condition_feature_space(), conditioning_namespace, COST_SENSITIVE::label::costs, polylabel::cs, Search::search_private::dat_new_feature_audit_ss, Search::search_private::dat_new_feature_ec, Search::search_private::dat_new_feature_feature_space, Search::search_private::dat_new_feature_idx, Search::search_private::dat_new_feature_namespace, Search::search_private::dat_new_feature_value, example_predict::feature_space, Search::auto_condition_settings::feature_value, example_predict::indices, features::indicies, Search::search_private::is_ldf, example::l, Search::auto_condition_settings::max_bias_ngram_length, Search::auto_condition_settings::max_quad_ngram_length, example::num_features, v_array< T >::push_back(), quadratic_constant, Search::action_repr::repr, features::size(), parameters::stride_shift(), features::sum_feat_sq, example::total_sum_feat_sq, Search::auto_condition_settings::use_passthrough_repr, features::values, and vw::weights.

Referenced by generate_training_example(), and search_predict().

786 {
787  if (condition_on_cnt == 0)
788  return;
789 
790  uint64_t extra_offset = 0;
791  if (priv.is_ldf)
792  if (ec.l.cs.costs.size() > 0)
793  extra_offset = 3849017 * ec.l.cs.costs[0].class_index;
794 
795  size_t I = condition_on_cnt;
796  size_t N = std::max(priv.acset.max_bias_ngram_length, priv.acset.max_quad_ngram_length);
797  for (size_t i = 0; i < I; i++) // position in conditioning
798  {
799  uint64_t fid = 71933 + 8491087 * extra_offset;
800  if (priv.all->audit)
801  {
802  priv.dat_new_feature_audit_ss.str("");
803  priv.dat_new_feature_audit_ss.clear();
804  priv.dat_new_feature_feature_space = &condition_feature_space;
805  }
806 
807  for (size_t n = 0; n < N; n++) // length of ngram
808  {
809  if (i + n >= I)
810  break; // no more ngrams
811  // we're going to add features for the ngram condition_on_actions[i .. i+N]
812  uint64_t name = condition_on_names[i + n];
813  fid = fid * 328901 + 71933 * ((condition_on_actions[i + n].a + 349101) * (name + 38490137));
814 
815  priv.dat_new_feature_ec = &ec;
816  priv.dat_new_feature_idx = fid * quadratic_constant;
817  priv.dat_new_feature_namespace = conditioning_namespace;
818  priv.dat_new_feature_value = priv.acset.feature_value;
819 
820  if (priv.all->audit)
821  {
822  if (n > 0)
823  priv.dat_new_feature_audit_ss << ',';
824  if ((33 <= name) && (name <= 126))
825  priv.dat_new_feature_audit_ss << name;
826  else
827  priv.dat_new_feature_audit_ss << '#' << (int)name;
828  priv.dat_new_feature_audit_ss << '=' << condition_on_actions[i + n].a;
829  }
830 
831  // add the single bias feature
832  if (n < priv.acset.max_bias_ngram_length)
833  add_new_feature(priv, 1., (uint64_t)4398201 << priv.all->weights.stride_shift());
834  // add the quadratic features
835  if (n < priv.acset.max_quad_ngram_length)
836  GD::foreach_feature<search_private, uint64_t, add_new_feature>(*priv.all, ec, priv);
837  }
838  }
839 
840  if (priv.acset.use_passthrough_repr)
841  {
842  cdbg << "BEGIN adding passthrough features" << endl;
843  for (size_t i = 0; i < I; i++)
844  {
845  if (condition_on_actions[i].repr == nullptr)
846  continue;
847  features& fs = *(condition_on_actions[i].repr);
848  char name = condition_on_names[i];
849  for (size_t k = 0; k < fs.size(); k++)
850  if ((fs.values[k] > 1e-10) || (fs.values[k] < -1e-10))
851  {
852  uint64_t fid = 84913 + 48371803 * (extra_offset + 8392817 * name) + 840137 * (4891 + fs.indicies[k]);
853  if (priv.all->audit)
854  {
855  priv.dat_new_feature_audit_ss.str("");
856  priv.dat_new_feature_audit_ss.clear();
857  priv.dat_new_feature_audit_ss << "passthrough_repr_" << i << '_' << k;
858  }
859 
860  priv.dat_new_feature_ec = &ec;
861  priv.dat_new_feature_idx = fid;
862  priv.dat_new_feature_namespace = conditioning_namespace;
863  priv.dat_new_feature_value = fs.values[k];
864  add_new_feature(priv, 1., (uint64_t)4398201 << priv.all->weights.stride_shift());
865  }
866  }
867  cdbg << "END adding passthrough features" << endl;
868  }
869 
871  if ((con_fs.size() > 0) && (con_fs.sum_feat_sq > 0.))
872  {
874  ec.total_sum_feat_sq += con_fs.sum_feat_sq;
875  ec.num_features += con_fs.size();
876  }
877  else
878  con_fs.clear();
879 }
constexpr unsigned char conditioning_namespace
Definition: constant.h:29
#define cdbg
Definition: search.h:11
v_array< namespace_index > indices
constexpr int quadratic_constant
Definition: constant.h:7
v_array< feature_index > indicies
the core definition of a set of features.
v_array< feature_value > values
std::array< features, NUM_NAMESPACES > feature_space
size_t size() const
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
size_t num_features
Definition: example.h:67
void clear()
void add_new_feature(search_private &priv, float val, uint64_t idx)
Definition: search.cc:596
polylabel l
Definition: example.h:57
float total_sum_feat_sq
Definition: example.h:71
float sum_feat_sq
std::string condition_feature_space("search_condition")
v_array< wclass > costs
uint32_t fid
Definition: ezexample.h:6

◆ add_neighbor_features()

void Search::add_neighbor_features ( search_private priv,
multi_ex ec_seq 
)

Definition at line 631 of file search.cc.

References add_new_feature(), Search::search_private::all, vw::audit, Search::search_private::dat_new_feature_audit_ss, Search::search_private::dat_new_feature_ec, Search::search_private::dat_new_feature_feature_space, Search::search_private::dat_new_feature_idx, Search::search_private::dat_new_feature_namespace, Search::search_private::dat_new_feature_value, example_predict::feature_space, example_predict::ft_offset, example_predict::indices, neighbor_feature_space(), Search::search_private::neighbor_features, neighbor_namespace, example::num_features, v_array< T >::push_back(), v_array< T >::size(), stride_shift(), parameters::stride_shift(), example::total_sum_feat_sq, and vw::weights.

Referenced by do_actual_learning().

632 {
633  if (priv.neighbor_features.size() == 0)
634  return;
635 
636  uint32_t stride_shift = priv.all->weights.stride_shift();
637  for (size_t n = 0; n < ec_seq.size(); n++) // iterate over every example in the sequence
638  {
639  example& me = *ec_seq[n];
640  for (size_t n_id = 0; n_id < priv.neighbor_features.size(); n_id++)
641  {
642  int32_t offset = priv.neighbor_features[n_id] >> 24;
643  size_t ns = priv.neighbor_features[n_id] & 0xFF;
644 
645  priv.dat_new_feature_ec = &me;
646  priv.dat_new_feature_value = 1.;
647  priv.dat_new_feature_idx = priv.neighbor_features[n_id] * 13748127;
648  priv.dat_new_feature_namespace = neighbor_namespace;
649  if (priv.all->audit)
650  {
651  priv.dat_new_feature_feature_space = &neighbor_feature_space;
652  priv.dat_new_feature_audit_ss.str("");
653  priv.dat_new_feature_audit_ss << '@' << ((offset > 0) ? '+' : '-') << (char)(abs(offset) + '0');
654  if (ns != ' ')
655  priv.dat_new_feature_audit_ss << (char)ns;
656  }
657 
658  // std::cerr << "n=" << n << " offset=" << offset << endl;
659  if ((offset < 0) && (n < (uint64_t)(-offset))) // add <s> feature
660  add_new_feature(priv, 1., (uint64_t)925871901 << stride_shift);
661  else if (n + offset >= ec_seq.size()) // add </s> feature
662  add_new_feature(priv, 1., (uint64_t)3824917 << stride_shift);
663  else // this is actually a neighbor
664  {
665  example& other = *ec_seq[n + offset];
666  GD::foreach_feature<search_private, add_new_feature>(priv.all, other.feature_space[ns], priv, me.ft_offset);
667  }
668  }
669 
671  size_t sz = fs.size();
672  if ((sz > 0) && (fs.sum_feat_sq > 0.))
673  {
675  me.total_sum_feat_sq += fs.sum_feat_sq;
676  me.num_features += sz;
677  }
678  else
679  fs.clear();
680  }
681 }
v_array< namespace_index > indices
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
constexpr unsigned char neighbor_namespace
Definition: constant.h:25
the core definition of a set of features.
std::array< features, NUM_NAMESPACES > feature_space
void push_back(const T &new_ele)
Definition: v_array.h:107
std::string neighbor_feature_space("neighbor")
size_t num_features
Definition: example.h:67
void add_new_feature(search_private &priv, float val, uint64_t idx)
Definition: search.cc:596
float total_sum_feat_sq
Definition: example.h:71

◆ add_new_feature()

void Search::add_new_feature ( search_private priv,
float  val,
uint64_t  idx 
)

Definition at line 596 of file search.cc.

References Search::search_private::all, vw::audit, cdbg, Search::search_private::dat_new_feature_audit_ss, Search::search_private::dat_new_feature_ec, Search::search_private::dat_new_feature_feature_space, Search::search_private::dat_new_feature_idx, Search::search_private::dat_new_feature_namespace, Search::search_private::dat_new_feature_value, example_predict::feature_space, features::indicies, v_array< T >::last(), parameters::mask(), v_array< T >::push_back(), features::push_back(), features::space_names, parameters::stride_shift(), features::values, and vw::weights.

Referenced by add_example_conditioning(), and add_neighbor_features().

597 {
598  uint64_t mask = priv.all->weights.mask();
599  size_t ss = priv.all->weights.stride_shift();
600 
601  uint64_t idx2 = ((idx & mask) >> ss) & mask;
602  features& fs = priv.dat_new_feature_ec->feature_space[priv.dat_new_feature_namespace];
603  fs.push_back(val * priv.dat_new_feature_value, ((priv.dat_new_feature_idx + idx2) << ss));
604  cdbg << "adding: " << fs.indicies.last() << ':' << fs.values.last() << endl;
605  if (priv.all->audit)
606  {
607  std::stringstream temp;
608  temp << "fid=" << ((idx & mask) >> ss) << "_" << priv.dat_new_feature_audit_ss.str();
609  fs.space_names.push_back(audit_strings_ptr(new audit_strings(*priv.dat_new_feature_feature_space, temp.str())));
610  }
611 }
#define cdbg
Definition: search.h:11
void push_back(feature_value v, feature_index i)
std::shared_ptr< audit_strings > audit_strings_ptr
Definition: feature_group.h:23
v_array< feature_index > indicies
the core definition of a set of features.
v_array< feature_value > values
void push_back(const T &new_ele)
Definition: v_array.h:107
v_array< audit_strings_ptr > space_names
T last() const
Definition: v_array.h:57
std::pair< std::string, std::string > audit_strings
Definition: feature_group.h:22

◆ adjust_auto_condition()

void Search::adjust_auto_condition ( search_private priv)
inline

Definition at line 2364 of file search.cc.

References Search::search_private::acset, Search::search_private::auto_condition_features, Search::auto_condition_settings::feature_value, and Search::search_private::history_length.

Referenced by do_actual_learning().

2365 {
2366  if (priv.auto_condition_features)
2367  {
2368  // turn off auto-condition if it's irrelevant
2369  if ((priv.history_length == 0) || (priv.acset.feature_value == 0.f))
2370  {
2371  std::cerr << "warning: turning off AUTO_CONDITION_FEATURES because settings make it useless" << endl;
2372  priv.auto_condition_features = false;
2373  }
2374  }
2375 }

◆ advance_from_known_actions()

void Search::advance_from_known_actions ( search_private priv)

Definition at line 2133 of file search.cc.

References Search::search_private::active_csoaa, Search::search_private::active_csoaa_verify, Search::search_private::active_known, cdbg, COST_SENSITIVE::label::costs, polylabel::cs, Search::search_private::done_with_all_actions, Search::search_private::learn_a_idx, Search::search_private::learn_losses, and Search::search_private::learn_t.

Referenced by train_single_example().

2134 {
2135  size_t t = priv.learn_t;
2136  if (!priv.active_csoaa)
2137  return;
2138  if (priv.active_csoaa_verify > 0.)
2139  return;
2140  if (t >= priv.active_known.size())
2141  return;
2142  cdbg << "advance_from_known_actions t=" << t << " active_known.size()=" << priv.active_known.size()
2143  << " learn_a_idx=" << priv.learn_a_idx << endl;
2144  // cdbg_print_array(" active_known[t]", priv.active_known[t]);
2145  if (priv.learn_a_idx >= priv.active_known[t].size())
2146  {
2147  cdbg << "advance_from_known_actions setting done_with_all_actions=true (active_known[t].size()="
2148  << priv.active_known[t].size() << ")" << endl;
2149  priv.done_with_all_actions = true;
2150  return;
2151  }
2152  // if (priv.active_known[t][priv.learn_a_idx] >= FLT_MAX) return;
2153  if (priv.active_known[t][priv.learn_a_idx].second)
2154  return;
2155  // return;
2156  // wow, we actually found something we were confident about!
2157  /*
2158  cs_cost_push_back(priv.cb_learner,
2159  priv.learn_losses,
2160  priv.is_ldf ? (uint32_t)(priv.learn_a_idx - 1) : (uint32_t)priv.learn_a_idx,
2161  priv.active_known[t][priv.learn_a_idx],
2162  true);
2163  */
2164  priv.learn_losses.cs.costs.push_back(priv.active_known[t][priv.learn_a_idx].first);
2165  cdbg << " --> adding " << priv.learn_a_idx << ":" << priv.active_known[t][priv.learn_a_idx].first.x << endl;
2166  priv.learn_a_idx++;
2168 }
#define cdbg
Definition: search.h:11
void advance_from_known_actions(search_private &priv)
Definition: search.cc:2133

◆ allowed_actions_to_label()

void Search::allowed_actions_to_label ( search_private priv,
size_t  ec_cnt,
const action allowed_actions,
size_t  allowed_actions_cnt,
const float *  allowed_actions_cost,
const action oracle_actions,
size_t  oracle_actions_cnt,
polylabel lab 
)

Definition at line 991 of file search.cc.

References Search::search_private::A, Search::search_private::cb_learner, cs_cost_push_back(), cs_costs_erase(), cs_get_costs_size(), cs_set_cost_loss(), f, Search::search_private::is_ldf, and Search::search_private::use_action_costs.

Referenced by search_predict().

994 {
995  bool isCB = priv.cb_learner;
996  if (priv.is_ldf) // LDF version easier
997  {
998  cs_costs_erase(isCB, lab);
999  for (action k = 0; k < ec_cnt; k++)
1000  cs_cost_push_back(isCB, lab, k, array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f);
1001  // std::cerr << "lab = ["; for (size_t i=0; i<lab.cs.costs.size(); i++) cdbg << ' ' << lab.cs.costs[i].class_index
1002  // << ':'
1003  // << lab.cs.costs[i].x; cdbg << " ]" << endl;
1004  }
1005  else if (priv.use_action_costs)
1006  {
1007  // TODO: Weight
1008  if (allowed_actions == nullptr)
1009  {
1010  if (cs_get_costs_size(isCB, lab) != priv.A)
1011  {
1012  cs_costs_erase(isCB, lab);
1013  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, lab, k + 1, 0.);
1014  }
1015  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, lab, k, allowed_actions_cost[k]);
1016  }
1017  else // manually specified actions
1018  {
1019  cs_costs_erase(isCB, lab);
1020  for (action k = 0; k < allowed_actions_cnt; k++)
1021  cs_cost_push_back(isCB, lab, allowed_actions[k], allowed_actions_cost[k]);
1022  }
1023  }
1024  else // non-LDF, no action costs
1025  {
1026  if ((allowed_actions == nullptr) || (allowed_actions_cnt == 0)) // any action is allowed
1027  {
1028  bool set_to_one = false;
1029  if (cs_get_costs_size(isCB, lab) != priv.A)
1030  {
1031  cs_costs_erase(isCB, lab);
1032  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, lab, k + 1, 1.);
1033  set_to_one = true;
1034  }
1035  // std::cerr << "lab = ["; for (size_t i=0; i<lab.cs.costs.size(); i++) cdbg << ' ' << lab.cs.costs[i].class_index
1036  // <<
1037  // ':' << lab.cs.costs[i].x; cdbg << " ]" << endl;
1038  if (oracle_actions_cnt <= 1) // common case to speed up
1039  {
1040  if (!set_to_one)
1041  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, lab, k, 1.);
1042  if (oracle_actions_cnt == 1)
1043  cs_set_cost_loss(isCB, lab, oracle_actions[0] - 1, 0.);
1044  }
1045  else
1046  {
1047  for (action k = 0; k < priv.A; k++)
1048  cs_set_cost_loss(isCB, lab, k, array_contains<action>(k + 1, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f);
1049  }
1050  }
1051  else // only some actions are allowed
1052  {
1053  cs_costs_erase(isCB, lab);
1054  float w = 1.; // array_contains<action>(3, oracle_actions, oracle_actions_cnt) ? 5.f : 1.f;
1055  for (size_t i = 0; i < allowed_actions_cnt; i++)
1056  {
1057  action k = allowed_actions[i];
1059  isCB, lab, k, (array_contains<action>(k, oracle_actions, oracle_actions_cnt)) ? 0.f : w); // 1.f );
1060  }
1061  }
1062  }
1063 }
uint32_t action
Definition: search.h:19
void cs_cost_push_back(bool isCB, polylabel &ld, uint32_t index, float value)
Definition: search.cc:923
void cs_costs_erase(bool isCB, polylabel &ld)
Definition: search.cc:907
void cs_set_cost_loss(bool isCB, polylabel &ld, size_t k, float val)
Definition: search.cc:899
float f
Definition: cache.cc:40
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887

◆ allowed_actions_to_ld()

polylabel& Search::allowed_actions_to_ld ( search_private priv,
size_t  ec_cnt,
const action allowed_actions,
size_t  allowed_actions_cnt,
const float *  allowed_actions_cost 
)

Definition at line 937 of file search.cc.

References Search::search_private::A, Search::search_private::allowed_actions_cache, Search::search_private::cb_learner, cs_cost_push_back(), cs_costs_erase(), cs_costs_resize(), cs_get_costs_size(), cs_set_cost_loss(), Search::search_private::is_ldf, and Search::search_private::use_action_costs.

Referenced by choose_oracle_action(), and single_prediction_notLDF().

939 {
940  bool isCB = priv.cb_learner;
941  polylabel& ld = *priv.allowed_actions_cache;
942  uint32_t num_costs = (uint32_t)cs_get_costs_size(isCB, ld);
943 
944  if (priv.is_ldf) // LDF version easier
945  {
946  if (num_costs > ec_cnt)
947  cs_costs_resize(isCB, ld, ec_cnt);
948  else if (num_costs < ec_cnt)
949  for (action k = num_costs; k < ec_cnt; k++) cs_cost_push_back(isCB, ld, k, FLT_MAX);
950  }
951  else if (priv.use_action_costs)
952  {
953  // TODO: Weight
954  if (allowed_actions == nullptr)
955  {
956  if (cs_get_costs_size(isCB, ld) != priv.A)
957  {
958  cs_costs_erase(isCB, ld);
959  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, ld, k + 1, 0.);
960  }
961  for (action k = 0; k < priv.A; k++) cs_set_cost_loss(isCB, ld, k, allowed_actions_cost[k]);
962  }
963  else // manually specified actions
964  {
965  cs_costs_erase(isCB, ld);
966  for (action k = 0; k < allowed_actions_cnt; k++)
967  cs_cost_push_back(isCB, ld, allowed_actions[k], allowed_actions_cost[k]);
968  }
969  }
970  else // non-LDF version, no action costs
971  {
972  if ((allowed_actions == nullptr) || (allowed_actions_cnt == 0)) // any action is allowed
973  {
974  if (num_costs != priv.A) // if there are already A-many actions, they must be the right ones, unless the user did
975  // something stupid like putting duplicate allowed_actions...
976  {
977  cs_costs_erase(isCB, ld);
978  for (action k = 0; k < priv.A; k++) cs_cost_push_back(isCB, ld, k + 1, FLT_MAX); //+1 because MC is 1-based
979  }
980  }
981  else // we need to peek at allowed_actions
982  {
983  cs_costs_erase(isCB, ld);
984  for (size_t i = 0; i < allowed_actions_cnt; i++) cs_cost_push_back(isCB, ld, allowed_actions[i], FLT_MAX);
985  }
986  }
987 
988  return ld;
989 }
uint32_t action
Definition: search.h:19
void cs_cost_push_back(bool isCB, polylabel &ld, uint32_t index, float value)
Definition: search.cc:923
void cs_costs_erase(bool isCB, polylabel &ld)
Definition: search.cc:907
void cs_set_cost_loss(bool isCB, polylabel &ld, size_t k, float val)
Definition: search.cc:899
void cs_costs_resize(bool isCB, polylabel &ld, size_t new_size)
Definition: search.cc:915
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887

◆ array_contains()

template<class T >
bool Search::array_contains ( target,
const T *  A,
size_t  n 
)

Definition at line 773 of file search.cc.

Referenced by choose_oracle_action().

774 {
775  if (A == nullptr)
776  return false;
777  for (size_t i = 0; i < n; i++)
778  if (A[i] == target)
779  return true;
780  return false;
781 }

◆ audit_feature_space()

std::string Search::audit_feature_space ( "conditional"  )

Referenced by Search::search::~search().

◆ cached_action_store_or_find()

bool Search::cached_action_store_or_find ( search_private priv,
ptag  mytag,
const ptag condition_on,
const char *  condition_on_names,
action_repr condition_on_actions,
size_t  condition_on_cnt,
int  policy,
size_t  learner_id,
action a,
bool  do_store,
float &  a_cost 
)

Definition at line 1438 of file search.cc.

References Search::scored_action::a, Search::action_repr::a, Search::search_private::cache_hash_map, Search::search_private::no_caching, Search::scored_action::s, and uniform_hash().

Referenced by search_predict().

1441 {
1442  if (priv.no_caching)
1443  return do_store;
1444  if (mytag == 0)
1445  return do_store; // don't attempt to cache when tag is zero
1446 
1447  size_t sz = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) +
1448  condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char));
1449  if (sz % 4 != 0)
1450  sz += 4 - (sz % 4); // make sure sz aligns to 4 so that uniform_hash does the right thing
1451 
1452  unsigned char* item = calloc_or_throw<unsigned char>(sz);
1453  unsigned char* here = item;
1454  *here = (unsigned char)sz;
1455  here += sizeof(size_t);
1456  *here = mytag;
1457  here += sizeof(ptag);
1458  *here = policy;
1459  here += sizeof(int);
1460  *here = (unsigned char)learner_id;
1461  here += sizeof(size_t);
1462  *here = (unsigned char)condition_on_cnt;
1463  here += sizeof(size_t);
1464  for (size_t i = 0; i < condition_on_cnt; i++)
1465  {
1466  *here = condition_on[i];
1467  here += sizeof(ptag);
1468  *here = condition_on_actions[i].a;
1469  here += sizeof(action);
1470  *here = condition_on_names[i];
1471  here += sizeof(char); // SPEEDUP: should we align this at 4?
1472  }
1473  uint64_t hash = uniform_hash(item, sz, 3419);
1474 
1475  if (do_store)
1476  {
1477  priv.cache_hash_map.put(item, hash, scored_action(a, a_cost));
1478  return true;
1479  }
1480  else // its a find
1481  {
1482  scored_action sa = priv.cache_hash_map.get(item, hash);
1483  a = sa.a;
1484  a_cost = sa.s;
1485  free(item);
1486  return a != (action)-1;
1487  }
1488 }
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
uint32_t action
Definition: search.h:19
constexpr uint64_t a
Definition: rand48.cc:11
uint32_t ptag
Definition: search.h:20

◆ cached_item_equivalent()

bool Search::cached_item_equivalent ( unsigned char *const &  A,
unsigned char *const &  B 
)

Definition at line 1429 of file search.cc.

Referenced by search_initialize().

1430 {
1431  size_t sz_A = *A;
1432  size_t sz_B = *B;
1433  if (sz_A != sz_B)
1434  return false;
1435  return memcmp(A, B, sz_A) == 0;
1436 }

◆ cdbg_print_array()

template<class T >
void Search::cdbg_print_array ( std::string  str,
v_array< T > &  A 
)

Definition at line 754 of file search.cc.

References cdbg, and v_array< T >::size().

Referenced by search_predict().

755 {
756  cdbg << str << " = [";
757  for (size_t i = 0; i < A.size(); i++) cdbg << " " << A[i];
758  cdbg << " ]" << endl;
759 }
#define cdbg
Definition: search.h:11
size_t size() const
Definition: v_array.h:68

◆ cerr_print_array()

template<class T >
void Search::cerr_print_array ( std::string  str,
v_array< T > &  A 
)

Definition at line 761 of file search.cc.

References v_array< T >::size().

762 {
763  std::cerr << str << " = [";
764  for (size_t i = 0; i < A.size(); i++) std::cerr << " " << A[i];
765  std::cerr << " ]" << endl;
766 }
size_t size() const
Definition: v_array.h:68

◆ choose_oracle_action()

action Search::choose_oracle_action ( search_private priv,
size_t  ec_cnt,
const action oracle_actions,
size_t  oracle_actions_cnt,
const action allowed_actions,
size_t  allowed_actions_cnt,
const float *  allowed_actions_cost 
)

Definition at line 1097 of file search.cc.

References Search::search_private::_random_state, a, Search::search_private::A, allowed_actions_to_ld(), array_contains(), Search::search_private::cb_learner, cdbg, cs_get_cost_index(), cs_get_costs_size(), INIT_TRAIN, Search::search_private::is_ldf, Search::search_private::memo_foreach_action, Search::search_private::meta_t, need_memo_foreach_action(), Search::search_private::perturb_oracle, v_array< T >::push_back(), random(), Search::search_private::state, Search::search_private::t, and Search::search_private::use_action_costs.

Referenced by search_predict().

1100 {
1101  action a = (action)-1;
1102  if (priv.use_action_costs)
1103  {
1104  size_t K = (allowed_actions == nullptr) ? priv.A : allowed_actions_cnt;
1105  cdbg << "costs = [";
1106  for (size_t k = 0; k < K; k++) cdbg << ' ' << allowed_actions_cost[k];
1107  cdbg << " ]" << endl;
1108  float min_cost = FLT_MAX;
1109  for (size_t k = 0; k < K; k++) min_cost = std::min(min_cost, allowed_actions_cost[k]);
1110  cdbg << "min_cost = " << min_cost;
1111  if (min_cost < FLT_MAX)
1112  {
1113  size_t count = 0;
1114  for (size_t k = 0; k < K; k++)
1115  if (allowed_actions_cost[k] <= min_cost)
1116  {
1117  cdbg << ", hit @ " << k;
1118  count++;
1119  if ((count == 1) || (priv._random_state->get_and_update_random() < 1. / (float)count))
1120  {
1121  a = (allowed_actions == nullptr) ? (uint32_t)(k + 1) : allowed_actions[k];
1122  cdbg << "***";
1123  }
1124  }
1125  }
1126  cdbg << endl;
1127  }
1128 
1129  if (a == (action)-1)
1130  {
1131  if ((priv.perturb_oracle > 0.) && (priv.state == INIT_TRAIN) &&
1132  (priv._random_state->get_and_update_random() < priv.perturb_oracle))
1133  oracle_actions_cnt = 0;
1134  a = (oracle_actions_cnt > 0)
1135  ? oracle_actions[random(priv._random_state, oracle_actions_cnt)]
1136  : (allowed_actions_cnt > 0) ? allowed_actions[random(priv._random_state, allowed_actions_cnt)]
1137  : priv.is_ldf ? (action)random(priv._random_state, ec_cnt)
1138  : (action)(1 + random(priv._random_state, priv.A));
1139  }
1140  cdbg << "choose_oracle_action from oracle_actions = [";
1141  for (size_t i = 0; i < oracle_actions_cnt; i++) cdbg << " " << oracle_actions[i];
1142  cdbg << " ], ret=" << a << endl;
1143  if (need_memo_foreach_action(priv) && (priv.state == INIT_TRAIN))
1144  {
1145  v_array<action_cache>* this_cache = new v_array<action_cache>();
1146  *this_cache = v_init<action_cache>();
1147  // TODO we don't really need to construct this polylabel
1148  polylabel l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1149  size_t K = cs_get_costs_size(priv.cb_learner, l);
1150  for (size_t k = 0; k < K; k++)
1151  {
1152  action cl = cs_get_cost_index(priv.cb_learner, l, k);
1153  float cost = array_contains(cl, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f;
1154  this_cache->push_back(action_cache(0., cl, cl == a, cost));
1155  }
1156  assert(priv.memo_foreach_action.size() == priv.meta_t + priv.t - 1);
1157  priv.memo_foreach_action.push_back(this_cache);
1158  cdbg << "memo_foreach_action[" << priv.meta_t + priv.t - 1 << "] = " << this_cache << " from oracle" << endl;
1159  }
1160  return a;
1161 }
#define cdbg
Definition: search.h:11
uint32_t action
Definition: search.h:19
void push_back(const T &new_ele)
Definition: v_array.h:107
size_t random(std::shared_ptr< rand_state > &rs, size_t max)
Definition: search.cc:768
bool need_memo_foreach_action(search_private &priv)
Definition: search.cc:370
bool array_contains(T target, const T *A, size_t n)
Definition: search.cc:773
constexpr uint64_t a
Definition: rand48.cc:11
polylabel & allowed_actions_to_ld(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
Definition: search.cc:937
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887
uint32_t cs_get_cost_index(bool isCB, polylabel &ld, size_t k)
Definition: search.cc:889

◆ choose_policy()

int Search::choose_policy ( search_private priv,
bool  advance_prng = true 
)

Definition at line 1401 of file search.cc.

References Search::search_private::allow_current_policy, INIT_TEST, INIT_TRAIN, LEARN, MIX_PER_ROLL, Search::search_private::mix_per_roll_policy, MIX_PER_STATE, NO_ROLLOUT, ORACLE, POLICY, random_policy(), Search::search_private::rollin_method, Search::search_private::rollout_method, Search::search_private::state, and THROW.

Referenced by search_predict(), and search_predictNeedsExample().

1402 {
1403  RollMethod method = (priv.state == INIT_TEST) ? POLICY
1404  : (priv.state == LEARN)
1405  ? priv.rollout_method
1406  : (priv.state == INIT_TRAIN) ? priv.rollin_method : NO_ROLLOUT; // this should never happen
1407  switch (method)
1408  {
1409  case POLICY:
1410  return random_policy(priv, priv.allow_current_policy || (priv.state == INIT_TEST), false, advance_prng);
1411 
1412  case ORACLE:
1413  return -1;
1414 
1415  case MIX_PER_STATE:
1416  return random_policy(priv, priv.allow_current_policy, true, advance_prng);
1417 
1418  case MIX_PER_ROLL:
1419  if (priv.mix_per_roll_policy == -2) // then we have to choose one!
1420  priv.mix_per_roll_policy = random_policy(priv, priv.allow_current_policy, true, advance_prng);
1421  return priv.mix_per_roll_policy;
1422 
1423  case NO_ROLLOUT:
1424  default:
1425  THROW("internal error (bug): trying to rollin or rollout with NO_ROLLOUT");
1426  }
1427 }
int random_policy(search_private &priv, bool allow_current, bool allow_optimal, bool advance_prng=true)
Definition: search.cc:376
RollMethod
Definition: search.cc:59
#define THROW(args)
Definition: vw_exception.h:181

◆ clear_cache_hash_map()

void Search::clear_cache_hash_map ( search_private priv)

Definition at line 278 of file search.cc.

References Search::search_private::cache_hash_map, and free_key().

Referenced by train_single_example(), and Search::search::~search().

279 {
280  priv.cache_hash_map.iter(free_key);
281  priv.cache_hash_map.clear();
282 }
void free_key(unsigned char *mem, scored_action)
Definition: search.cc:277

◆ clear_memo_foreach_action()

void Search::clear_memo_foreach_action ( search_private priv)

Definition at line 284 of file search.cc.

References Search::search_private::memo_foreach_action.

Referenced by train_single_example(), and Search::search::~search().

285 {
286  for (size_t i = 0; i < priv.memo_foreach_action.size(); i++)
287  if (priv.memo_foreach_action[i])
288  {
289  priv.memo_foreach_action[i]->delete_v();
290  delete priv.memo_foreach_action[i];
291  }
292  priv.memo_foreach_action.clear();
293 }

◆ cmp_size_t()

bool Search::cmp_size_t ( const size_t  a,
const size_t  b 
)
inline

Definition at line 1940 of file search.cc.

Referenced by get_training_timesteps(), and hoopla_permute().

1940 { return a < b; }
constexpr uint64_t a
Definition: rand48.cc:11

◆ cmp_size_t_pair()

bool Search::cmp_size_t_pair ( const std::pair< size_t, size_t > &  a,
const std::pair< size_t, size_t > &  b 
)
inline

Definition at line 1941 of file search.cc.

1942 {
1943  return ((a.first == b.first) && (a.second < b.second)) || (a.first < b.first);
1944 }
constexpr uint64_t a
Definition: rand48.cc:11

◆ condition_feature_space()

std::string Search::condition_feature_space ( "search_condition"  )

◆ cs_cost_push_back()

void Search::cs_cost_push_back ( bool  isCB,
polylabel ld,
uint32_t  index,
float  value 
)
inline

Definition at line 923 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by allowed_actions_to_label(), allowed_actions_to_ld(), and train_single_example().

924 {
925  if (isCB)
926  {
927  CB::cb_class cost = {value, index, 0., 0.};
928  ld.cb.costs.push_back(cost);
929  }
930  else
931  {
932  CS::wclass cost = {value, index, 0., 0.};
933  ld.cs.costs.push_back(cost);
934  }
935 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_costs_erase()

void Search::cs_costs_erase ( bool  isCB,
polylabel ld 
)
inline

Definition at line 907 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by allowed_actions_to_label(), and allowed_actions_to_ld().

908 {
909  if (isCB)
910  ld.cb.costs.clear();
911  else
912  ld.cs.costs.clear();
913 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_costs_resize()

void Search::cs_costs_resize ( bool  isCB,
polylabel ld,
size_t  new_size 
)
inline

Definition at line 915 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by allowed_actions_to_ld().

916 {
917  if (isCB)
918  ld.cb.costs.resize(new_size);
919  else
920  ld.cs.costs.resize(new_size);
921 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_get_cost_index()

uint32_t Search::cs_get_cost_index ( bool  isCB,
polylabel ld,
size_t  k 
)
inline

Definition at line 889 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by choose_oracle_action(), and single_prediction_notLDF().

890 {
891  return isCB ? ld.cb.costs[k].action : ld.cs.costs[k].class_index;
892 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_get_cost_partial_prediction()

float Search::cs_get_cost_partial_prediction ( bool  isCB,
polylabel ld,
size_t  k 
)
inline

Definition at line 894 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by single_prediction_notLDF().

895 {
896  return isCB ? ld.cb.costs[k].partial_prediction : ld.cs.costs[k].partial_prediction;
897 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_get_costs_size()

size_t Search::cs_get_costs_size ( bool  isCB,
polylabel ld 
)
inline

Definition at line 887 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by allowed_actions_to_label(), allowed_actions_to_ld(), choose_oracle_action(), generate_training_example(), and single_prediction_notLDF().

887 { return isCB ? ld.cb.costs.size() : ld.cs.costs.size(); }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ cs_set_cost_loss()

void Search::cs_set_cost_loss ( bool  isCB,
polylabel ld,
size_t  k,
float  val 
)
inline

Definition at line 899 of file search.cc.

References polylabel::cb, CB::label::costs, COST_SENSITIVE::label::costs, and polylabel::cs.

Referenced by allowed_actions_to_label(), and allowed_actions_to_ld().

900 {
901  if (isCB)
902  ld.cb.costs[k].cost = val;
903  else
904  ld.cs.costs[k].x = val;
905 }
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
COST_SENSITIVE::label cs
Definition: example.h:30
v_array< wclass > costs

◆ del_example_conditioning()

void Search::del_example_conditioning ( search_private priv,
example ec 
)

Definition at line 881 of file search.cc.

References conditioning_namespace, del_features_in_top_namespace(), example_predict::indices, v_array< T >::last(), and v_array< T >::size().

Referenced by generate_training_example(), and search_predict().

882 {
883  if ((ec.indices.size() > 0) && (ec.indices.last() == conditioning_namespace))
885 }
constexpr unsigned char conditioning_namespace
Definition: constant.h:29
v_array< namespace_index > indices
size_t size() const
Definition: v_array.h:68
void del_features_in_top_namespace(search_private &, example &ec, size_t ns)
Definition: search.cc:613
T last() const
Definition: v_array.h:57

◆ del_features_in_top_namespace()

void Search::del_features_in_top_namespace ( search_private ,
example ec,
size_t  ns 
)

Definition at line 613 of file search.cc.

References features::clear(), v_array< T >::decr(), example_predict::feature_space, example_predict::indices, v_array< T >::last(), example::num_features, v_array< T >::size(), features::size(), features::sum_feat_sq, and example::total_sum_feat_sq.

Referenced by del_example_conditioning(), and del_neighbor_features().

614 {
615  if ((ec.indices.size() == 0) || (ec.indices.last() != ns))
616  {
617  return;
618  // if (ec.indices.size() == 0)
619  //{ THROW("internal error (bug): expecting top namespace to be '" << ns << "' but it was empty"); }
620  // else
621  //{ THROW("internal error (bug): expecting top namespace to be '" << ns << "' but it was " <<
622  //(size_t)ec.indices.last()); }
623  }
624  features& fs = ec.feature_space[ns];
625  ec.indices.decr();
626  ec.num_features -= fs.size();
628  fs.clear();
629 }
v_array< namespace_index > indices
the core definition of a set of features.
size_t size() const
Definition: v_array.h:68
std::array< features, NUM_NAMESPACES > feature_space
size_t size() const
size_t num_features
Definition: example.h:67
void clear()
float total_sum_feat_sq
Definition: example.h:71
float sum_feat_sq
T last() const
Definition: v_array.h:57
void decr()
Definition: v_array.h:60

◆ del_neighbor_features()

void Search::del_neighbor_features ( search_private priv,
multi_ex ec_seq 
)

Definition at line 683 of file search.cc.

References del_features_in_top_namespace(), Search::search_private::neighbor_features, neighbor_namespace, and v_array< T >::size().

Referenced by do_actual_learning().

684 {
685  if (priv.neighbor_features.size() == 0)
686  return;
687  for (size_t n = 0; n < ec_seq.size(); n++) del_features_in_top_namespace(priv, *ec_seq[n], neighbor_namespace);
688 }
constexpr unsigned char neighbor_namespace
Definition: constant.h:25
void del_features_in_top_namespace(search_private &, example &ec, size_t ns)
Definition: search.cc:613

◆ do_actual_learning()

template<bool is_learn>
void Search::do_actual_learning ( search sch,
base_learner base,
multi_ex ec_seq 
)

Definition at line 2378 of file search.cc.

References add_neighbor_features(), adjust_auto_condition(), Search::search_private::all, Search::search_private::base_learner, cdbg, Search::search_private::current_policy, del_neighbor_features(), GET_TRUTH_STRING, Search::search_private::hit_new_pass, Search::search_private::label_is_test, might_print_update(), Search::search_private::offset, Search::search::priv, Search::search_private::read_example_last_id, Search::search_private::read_example_last_pass, reset_search_structure(), Search::search_task::run_setup, Search::search_task::run_takedown, run_task(), Search::search_private::should_produce_string, Search::search_private::state, Search::search_private::task, and Search::search_private::truth_string.

2379 {
2380  if (ec_seq.size() == 0)
2381  return; // nothing to do :)
2382 
2383  bool is_test_ex = false;
2384  bool is_holdout_ex = false;
2385 
2386  search_private& priv = *sch.priv;
2387  priv.offset = ec_seq[0]->ft_offset;
2388  priv.base_learner = &base;
2389 
2390  adjust_auto_condition(priv);
2391  priv.read_example_last_id = ec_seq[ec_seq.size() - 1]->example_counter;
2392 
2393  // hit_new_pass true would have already triggered a printout
2394  // finish_example(multi_ex). so we can reset hit_new_pass here
2395  priv.hit_new_pass = false;
2396 
2397  for (size_t i = 0; i < ec_seq.size(); i++)
2398  {
2399  is_test_ex |= priv.label_is_test(ec_seq[i]->l);
2400  is_holdout_ex |= ec_seq[i]->test_only;
2401  if (is_test_ex && is_holdout_ex)
2402  break;
2403  }
2404 
2405  if (priv.task->run_setup)
2406  priv.task->run_setup(sch, ec_seq);
2407 
2408  // if we're going to have to print to the screen, generate the "truth" std::string
2409  cdbg << "======================================== GET TRUTH STRING (" << priv.current_policy << ","
2410  << priv.read_example_last_pass << ") ========================================" << endl;
2411  if (might_print_update(*priv.all))
2412  {
2413  if (is_test_ex)
2414  priv.truth_string->str("**test**");
2415  else
2416  {
2417  reset_search_structure(*sch.priv);
2418  priv.state = GET_TRUTH_STRING;
2419  priv.should_produce_string = true;
2420  priv.truth_string->str("");
2421  run_task(sch, ec_seq);
2422  }
2423  }
2424 
2425  add_neighbor_features(priv, ec_seq);
2426  train_single_example<is_learn>(sch, is_test_ex, is_holdout_ex, ec_seq);
2427  del_neighbor_features(priv, ec_seq);
2428 
2429  if (priv.task->run_takedown)
2430  priv.task->run_takedown(sch, ec_seq);
2431 }
#define cdbg
Definition: search.h:11
bool might_print_update(vw &all)
Definition: search.cc:461
void adjust_auto_condition(search_private &priv)
Definition: search.cc:2364
void del_neighbor_features(search_private &priv, multi_ex &ec_seq)
Definition: search.cc:683
void run_task(search &sch, multi_ex &ec)
Definition: search.cc:2098
void reset_search_structure(search_private &priv)
Definition: search.cc:690
void add_neighbor_features(search_private &priv, multi_ex &ec_seq)
Definition: search.cc:631

◆ end_examples()

void Search::end_examples ( search sch)

Definition at line 2464 of file search.cc.

References Search::search_private::all, Search::search_private::current_policy, VW::config::options_i::get_typed_option(), vw::options, Search::search_private::passes_since_new_policy, Search::search::priv, VW::config::options_i::replace(), prediction_type::to_string(), Search::search_private::total_number_of_policies, and vw::training.

Referenced by setup().

2465 {
2466  search_private& priv = *sch.priv;
2467  vw* all = priv.all;
2468 
2469  if (all->training)
2470  {
2471  // TODO work out a better system to update state that will be saved in the model.
2472  // Dig out option and change it in case we already loaded a predictor which had a value stored for
2473  // --search_trained_nb_policies
2474  auto val = (priv.passes_since_new_policy == 0) ? priv.current_policy : (priv.current_policy + 1);
2475  all->options->replace("search_trained_nb_policies", std::to_string(val));
2476  all->options->get_typed_option<uint32_t>("search_trained_nb_policies").value(val);
2477  // Dig out option and change it in case we already loaded a predictor which had a value stored for
2478  // --search_total_nb_policies
2479  all->options->replace("search_total_nb_policies", std::to_string(priv.total_number_of_policies));
2480  all->options->get_typed_option<uint32_t>("search_total_nb_policies").value(priv.total_number_of_policies);
2481  }
2482 }
VW::config::options_i * options
Definition: global_data.h:428
virtual void replace(const std::string &key, const std::string &value)=0
bool training
Definition: global_data.h:488
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12

◆ end_pass()

void Search::end_pass ( search sch)

Definition at line 2433 of file search.cc.

References Search::search_private::all, Search::search_private::current_policy, VW::config::options_i::get_typed_option(), Search::search_private::hit_new_pass, vw::options, Search::search_private::passes_per_policy, Search::search_private::passes_since_new_policy, Search::search::priv, Search::search_private::read_example_last_pass, VW::config::options_i::replace(), prediction_type::to_string(), Search::search_private::total_number_of_policies, and vw::training.

2434 {
2435  search_private& priv = *sch.priv;
2436  vw* all = priv.all;
2437  priv.hit_new_pass = true;
2438  priv.read_example_last_pass++;
2439  priv.passes_since_new_policy++;
2440 
2441  if (priv.passes_since_new_policy >= priv.passes_per_policy)
2442  {
2443  priv.passes_since_new_policy = 0;
2444  if (all->training)
2445  priv.current_policy++;
2446  if (priv.current_policy > priv.total_number_of_policies)
2447  {
2448  std::cerr << "internal error (bug): too many policies; not advancing" << endl;
2449  priv.current_policy = priv.total_number_of_policies;
2450  }
2451  // reset search_trained_nb_policies in options_from_file so it is saved to regressor file later
2452  // TODO work out a better system to update state that will be saved in the model.
2453  all->options->replace("search_trained_nb_policies", std::to_string(priv.current_policy));
2454  all->options->get_typed_option<uint32_t>("search_trained_nb_policies").value(priv.current_policy);
2455  }
2456 }
VW::config::options_i * options
Definition: global_data.h:428
virtual void replace(const std::string &key, const std::string &value)=0
bool training
Definition: global_data.h:488
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12

◆ ensure_param()

void Search::ensure_param ( float &  v,
float  lo,
float  hi,
float  def,
const char *  str 
)

Definition at line 2534 of file search.cc.

Referenced by setup().

2535 {
2536  if ((v < lo) || (v > hi))
2537  {
2538  std::cerr << str << endl;
2539  v = def;
2540  }
2541 }

◆ ensure_size()

template<class T >
void Search::ensure_size ( v_array< T > &  A,
size_t  sz 
)

Definition at line 1066 of file search.cc.

References v_array< T >::begin(), v_array< T >::end(), v_array< T >::end_array, and v_array< T >::resize().

Referenced by search_predict().

1067 {
1068  if ((size_t)(A.end_array - A.begin()) < sz)
1069  A.resize(sz * 2 + 1);
1070  A.end() = A.begin() + sz;
1071 }
void resize(size_t length)
Definition: v_array.h:69
T *& begin()
Definition: v_array.h:42
T *& end()
Definition: v_array.h:43
T * end_array
Definition: v_array.h:38

◆ finish_multiline_example()

void Search::finish_multiline_example ( vw all,
search sch,
multi_ex ec_seq 
)

Definition at line 2458 of file search.cc.

References VW::finish_example(), print_update(), and Search::search::priv.

2459 {
2460  print_update(*sch.priv);
2461  VW::finish_example(all, ec_seq);
2462 }
void print_update(search_private &priv)
Definition: search.cc:527
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ float_equal()

bool Search::float_equal ( float  a,
float  b 
)

◆ foreach_action_from_cache()

void Search::foreach_action_from_cache ( search_private priv,
size_t  t,
action  override_a = (action)-1 
)

Definition at line 1634 of file search.cc.

References Search::BaseTask::_foreach_action, cdbg, Search::action_cache::cost, id(), Search::action_cache::is_opt, Search::action_cache::k, Search::search_private::memo_foreach_action, Search::search_private::meta_t, Search::search_private::metaoverride, Search::action_cache::min_cost, and Search::BaseTask::sch.

Referenced by search_predict().

1635 {
1636  cdbg << "foreach_action_from_cache: t=" << t << ", memo_foreach_action.size()=" << priv.memo_foreach_action.size()
1637  << ", override_a=" << override_a << endl;
1638  assert(t < priv.memo_foreach_action.size());
1639  v_array<action_cache>* cached = priv.memo_foreach_action[t];
1640  if (!cached)
1641  return; // the only way this can happen is if the metatask overrode this action
1642  cdbg << "memo_foreach_action size = " << cached->size() << endl;
1643  for (size_t id = 0; id < cached->size(); id++)
1644  {
1645  action_cache& ac = (*cached)[id];
1646  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, t - priv.meta_t, ac.min_cost, ac.k,
1647  (override_a == (action)-1) ? ac.is_opt : (ac.k == override_a), ac.cost);
1648  }
1649 }
#define cdbg
Definition: search.h:11
uint32_t action
Definition: search.h:19
float id(float in)
Definition: scorer.cc:51

◆ free_final_item()

void Search::free_final_item ( final_item p)

Definition at line 2056 of file search.cc.

References Search::final_item::prefix.

2057 {
2058  p->prefix->delete_v();
2059  delete p->prefix;
2060  delete p;
2061 }

◆ free_key()

void Search::free_key ( unsigned char *  mem,
scored_action   
)

Definition at line 277 of file search.cc.

Referenced by clear_cache_hash_map().

277 { free(mem); } // sa.repr.delete_v(); }

◆ generate_training_example()

void Search::generate_training_example ( search_private priv,
polylabel losses,
float  weight,
bool  add_conditioning = true,
float  min_loss = FLT_MAX 
)

Definition at line 1490 of file search.cc.

References a, add_example_conditioning(), LEARNER::as_multiline(), LEARNER::as_singleline(), Search::search_private::base_learner, v_array< T >::begin(), polylabel::cb, Search::search_private::cb_learner, cdbg, CB::label::costs, COST_SENSITIVE::label::costs, polylabel::cs, cs_get_costs_size(), Search::search_private::current_policy, del_example_conditioning(), COST_SENSITIVE::ec_is_example_header(), example_predict::ft_offset, example::in_use, Search::search_private::is_ldf, example::l, LEARNER::learner< T, E >::learn(), Search::search_private::learn_condition_on, Search::search_private::learn_condition_on_act, Search::search_private::learn_condition_on_names, Search::search_private::learn_ec_ref, Search::search_private::learn_ec_ref_cnt, Search::search_private::learn_learner_id, Search::search_private::offset, select_learner(), v_array< T >::size(), Search::search_private::total_examples_generated, and Search::search_private::xv.

Referenced by search_predict(), and train_single_example().

1493 {
1494  // should we really subtract out min-loss?
1495  // float min_loss = FLT_MAX;
1496  if (priv.cb_learner)
1497  {
1498  if (min_loss == FLT_MAX)
1499  for (size_t i = 0; i < losses.cb.costs.size(); i++) min_loss = std::min(min_loss, losses.cb.costs[i].cost);
1500  for (size_t i = 0; i < losses.cb.costs.size(); i++) losses.cb.costs[i].cost = losses.cb.costs[i].cost - min_loss;
1501  }
1502  else
1503  {
1504  if (min_loss == FLT_MAX)
1505  for (size_t i = 0; i < losses.cs.costs.size(); i++) min_loss = std::min(min_loss, losses.cs.costs[i].x);
1506  for (size_t i = 0; i < losses.cs.costs.size(); i++)
1507  losses.cs.costs[i].x = (losses.cs.costs[i].x - min_loss) * weight;
1508  }
1509  // std::cerr << "losses = ["; for (size_t i=0; i<losses.cs.costs.size(); i++) std::cerr << ' ' <<
1510  // losses.cs.costs[i].class_index
1511  // << ':' << losses.cs.costs[i].x; std::cerr << " ]" << endl;
1512 
1513  if (!priv.is_ldf) // not LDF
1514  {
1515  // since we're not LDF, it should be the case that ec_ref_cnt == 1
1516  // and learn_ec_ref[0] is a pointer to a single example
1517  assert(priv.learn_ec_ref_cnt == 1);
1518  assert(priv.learn_ec_ref != nullptr);
1519 
1520  example& ec = priv.learn_ec_ref[0];
1521  polylabel old_label = ec.l;
1522  ec.l = losses; // labels;
1523  if (add_conditioning)
1524  add_example_conditioning(priv, ec, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin(),
1525  priv.learn_condition_on_act.begin());
1526  for (size_t is_local = 0; is_local <= (size_t)priv.xv; is_local++)
1527  {
1528  int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id, true, is_local > 0);
1529  ec.in_use = true;
1530  cdbg << "BEGIN base_learner->learn(ec, " << learner << ")" << endl;
1531  as_singleline(priv.base_learner)->learn(ec, learner);
1532  cdbg << "END base_learner->learn(ec, " << learner << ")" << endl;
1533  }
1534  if (add_conditioning)
1535  del_example_conditioning(priv, ec);
1536  ec.l = old_label;
1537  priv.total_examples_generated++;
1538  }
1539  else // is LDF
1540  {
1541  assert(cs_get_costs_size(priv.cb_learner, losses) == priv.learn_ec_ref_cnt);
1542  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(priv.learn_ec_ref[0])) ? 1 : 0;
1543 
1544  // TODO: weight
1545  if (add_conditioning)
1546  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1547  {
1548  example& ec = priv.learn_ec_ref[a];
1549  add_example_conditioning(priv, ec, priv.learn_condition_on.size(), priv.learn_condition_on_names.begin(),
1550  priv.learn_condition_on_act.begin());
1551  }
1552 
1553  for (size_t is_local = 0; is_local <= (size_t)priv.xv; is_local++)
1554  {
1555  int learner = select_learner(priv, priv.current_policy, priv.learn_learner_id, true, is_local > 0);
1556 
1557  // create an example collection for
1558 
1559  multi_ex tmp;
1560  uint64_t tmp_offset = 0;
1561  if (priv.learn_ec_ref_cnt > start_K)
1562  tmp_offset = priv.learn_ec_ref[start_K].ft_offset;
1563  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1564  {
1565  example& ec = priv.learn_ec_ref[a];
1566  CS::label& lab = ec.l.cs;
1567  if (lab.costs.size() == 0)
1568  {
1569  CS::wclass wc = {0., a - (uint32_t)start_K, 0., 0.};
1570  lab.costs.push_back(wc);
1571  }
1572  lab.costs[0].x = losses.cs.costs[a - start_K].x;
1573  ec.in_use = true;
1574  // store the offset to restore it later
1575  ec.ft_offset = priv.offset;
1576  // create the example collection used to learn
1577  tmp.push_back(&ec);
1578  cdbg << "generate_training_example called learn on action a=" << a << ", costs.size=" << lab.costs.size()
1579  << " ec=" << &ec << endl;
1580  priv.total_examples_generated++;
1581  }
1582 
1583  // learn with the multiline example
1584  as_multiline(priv.base_learner)->learn(tmp, learner);
1585 
1586  // restore the offsets in examples
1587  int i = 0;
1588  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++, i++)
1589  priv.learn_ec_ref[a].ft_offset = tmp_offset;
1590  }
1591 
1592  if (add_conditioning)
1593  for (action a = (uint32_t)start_K; a < priv.learn_ec_ref_cnt; a++)
1594  {
1595  example& ec = priv.learn_ec_ref[a];
1596  del_example_conditioning(priv, ec);
1597  }
1598  }
1599 }
void add_example_conditioning(search_private &priv, example &ec, size_t condition_on_cnt, const char *condition_on_names, action_repr *condition_on_actions)
Definition: search.cc:784
#define cdbg
Definition: search.h:11
CB::label cb
Definition: example.h:31
v_array< cb_class > costs
Definition: cb.h:27
uint32_t action
Definition: search.h:19
bool ec_is_example_header(example const &ec)
int select_learner(search_private &priv, int policy, size_t learner_id, bool is_training, bool is_local)
Definition: search.cc:432
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
COST_SENSITIVE::label cs
Definition: example.h:30
float weight
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
void del_example_conditioning(search_private &priv, example &ec)
Definition: search.cc:881
bool in_use
Definition: example.h:79
iterator begin()
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< wclass > costs
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887

◆ get_training_timesteps()

void Search::get_training_timesteps ( search_private priv,
v_array< size_t > &  timesteps 
)

Definition at line 1983 of file search.cc.

References Search::search_private::_random_state, Search::search_private::active_csoaa, Search::search_private::active_known, Search::search_private::active_uncertainty, v_array< T >::begin(), v_array< T >::clear(), cmp_size_t(), v_array< T >::end(), hoopla_permute(), Search::search_private::linear_ordering, v_array< T >::push_back(), v_array< T >::size(), Search::search_private::subsample_timesteps, Search::search_private::T, and v_array_contains().

Referenced by train_single_example().

1984 {
1985  timesteps.clear();
1986 
1987  // if there's active learning, we need to
1988  if (priv.subsample_timesteps <= -1)
1989  {
1990  for (size_t i = 0; i < priv.active_uncertainty.size(); i++)
1991  if (priv._random_state->get_and_update_random() > priv.active_uncertainty[i].first)
1992  timesteps.push_back(priv.active_uncertainty[i].second - 1);
1993  /*
1994  float k = (float)priv.total_examples_generated;
1995  priv.ec_seq[t]->revert_weight = priv.all->loss->getRevertingWeight(priv.all->sd, priv.ec_seq[t].pred.scalar,
1996  priv.all->eta / powf(k, priv.all->power_t)); float importance = query_decision(active_str, *priv.ec_seq[t], k); if
1997  (importance > 0.) timesteps.push_back(pair<size_t,size_t>(0,t));
1998  */
1999  }
2000  // if there's no subsampling to do, just return [0,T)
2001  else if (priv.subsample_timesteps <= 0)
2002  for (size_t t = 0; t < priv.T; t++)
2003  {
2004  uint32_t count = 99;
2005  if (priv.active_csoaa && (t < priv.active_known.size()))
2006  {
2007  count = 0;
2008  for (std::pair<CS::wclass&, bool> wcq : priv.active_known[t])
2009  if (wcq.second)
2010  {
2011  count++;
2012  if (count > 1)
2013  break;
2014  }
2015  }
2016  if (count > 1)
2017  timesteps.push_back(t);
2018  }
2019 
2020  // if subsample in (0,1) then pick steps with that probability, but ensuring there's at least one!
2021  else if (priv.subsample_timesteps < 1)
2022  {
2023  for (size_t t = 0; t < priv.T; t++)
2024  if (priv._random_state->get_and_update_random() <= priv.subsample_timesteps)
2025  timesteps.push_back(t);
2026 
2027  if (timesteps.size() == 0) // ensure at least one
2028  timesteps.push_back((size_t)(priv._random_state->get_and_update_random() * priv.T));
2029  }
2030 
2031  // finally, if subsample >= 1, then pick (int) that many uniformly at random without replacement; could use an LFSR
2032  // but why? :P
2033  else
2034  {
2035  while ((timesteps.size() < (size_t)priv.subsample_timesteps) && (timesteps.size() < priv.T))
2036  {
2037  size_t t = (size_t)(priv._random_state->get_and_update_random() * (float)priv.T);
2038  if (!v_array_contains(timesteps, t))
2039  timesteps.push_back(t);
2040  }
2041  std::sort(timesteps.begin(), timesteps.end(), cmp_size_t);
2042  }
2043 
2044  if (!priv.linear_ordering)
2045  hoopla_permute(timesteps.begin(), timesteps.end());
2046 }
void hoopla_permute(size_t *B, size_t *end)
Definition: search.cc:1948
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
void push_back(const T &new_ele)
Definition: v_array.h:107
bool cmp_size_t(const size_t a, const size_t b)
Definition: search.cc:1940
void clear()
Definition: v_array.h:88
T *& end()
Definition: v_array.h:43
bool v_array_contains(v_array< T > &A, T x)
Definition: v_array.h:237

◆ handle_condition_options()

void Search::handle_condition_options ( vw all,
auto_condition_settings acset 
)

Definition at line 2543 of file search.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), f, Search::auto_condition_settings::feature_value, VW::config::make_option(), Search::auto_condition_settings::max_bias_ngram_length, Search::auto_condition_settings::max_quad_ngram_length, vw::options, and Search::auto_condition_settings::use_passthrough_repr.

Referenced by setup().

2544 {
2545  option_group_definition new_options("Search Auto-conditioning Options");
2546  new_options.add(make_option("search_max_bias_ngram_length", acset.max_bias_ngram_length)
2547  .keep()
2548  .default_value(1)
2549  .help("add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 "
2550  "(default), then you get a single feature for each conditional"));
2551  new_options.add(make_option("search_max_quad_ngram_length", acset.max_quad_ngram_length)
2552  .keep()
2553  .default_value(0)
2554  .help("add bias *times* input features for each ngram up to and including this length (def: 0)"));
2555  new_options.add(make_option("search_condition_feature_value", acset.feature_value)
2556  .keep()
2557  .default_value(1.f)
2558  .help("how much weight should the conditional features get? (def: 1.)"));
2559  new_options.add(make_option("search_use_passthrough_repr", acset.use_passthrough_repr)
2560  .keep()
2561  .help("should we use lower-level reduction _internal state_ as additional features? (def: no)"));
2562  all.options->add_and_parse(new_options);
2563 }
VW::config::options_i * options
Definition: global_data.h:428
virtual void add_and_parse(const option_group_definition &group)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float f
Definition: cache.cc:40

◆ hoopla_permute()

void Search::hoopla_permute ( size_t *  B,
size_t *  end 
)

Definition at line 1948 of file search.cc.

References absdiff(), and cmp_size_t().

Referenced by get_training_timesteps().

1949 {
1950  // from Curtis IPL 2004, "Darts and hoopla board design"
1951  // first sort
1952  size_t N = end - B;
1953  std::sort(B, end, cmp_size_t);
1954  // make some temporary space
1955  size_t* A = calloc_or_throw<size_t>((N + 1) * 2);
1956  A[N] = B[0]; // arbitrarily choose the maximum in the middle
1957  A[N + 1] = B[N - 1]; // so the maximum goes next to it
1958  size_t lo = N, hi = N + 1; // which parts of A have we filled in? [lo,hi]
1959  size_t i = 0, j = N - 1; // which parts of B have we already covered? [0,i] and [j,N-1]
1960  while (i + 1 < j)
1961  {
1962  // there are four options depending on where things get placed
1963  size_t d1 = absdiff(A[lo], B[i + 1]); // put B[i+1] at the bottom
1964  size_t d2 = absdiff(A[lo], B[j - 1]); // put B[j-1] at the bottom
1965  size_t d3 = absdiff(A[hi], B[i + 1]); // put B[i+1] at the top
1966  size_t d4 = absdiff(A[hi], B[j - 1]); // put B[j-1] at the top
1967  size_t mx = std::max(std::max(d1, d2), std::max(d3, d4));
1968  if (d1 >= mx)
1969  A[--lo] = B[++i];
1970  else if (d2 >= mx)
1971  A[--lo] = B[--j];
1972  else if (d3 >= mx)
1973  A[++hi] = B[++i];
1974  else
1975  A[++hi] = B[--j];
1976  }
1977  // copy it back to B
1978  memcpy(B, A + lo, N * sizeof(size_t));
1979  // clean up
1980  free(A);
1981 }
bool cmp_size_t(const size_t a, const size_t b)
Definition: search.cc:1940
size_t absdiff(size_t a, size_t b)
Definition: search.cc:1946

◆ mc_label_is_test()

bool Search::mc_label_is_test ( polylabel lab)

Definition at line 2484 of file search.cc.

References MULTICLASS::mc_label, polylabel::multi, and label_parser::test_label.

Referenced by search_initialize().

2484 { return MC::mc_label.test_label(&lab.multi); }
bool(* test_label)(void *)
Definition: label_parser.h:22
label_parser mc_label
Definition: multiclass.cc:93
MULTICLASS::label_t multi
Definition: example.h:29

◆ might_print_update()

bool Search::might_print_update ( vw all)

Definition at line 461 of file search.cc.

References vw::bfgs, shared_data::dump_interval, vw::quiet, vw::sd, and shared_data::weighted_examples().

Referenced by do_actual_learning(), must_run_test(), and train_single_example().

462 {
463  // basically do should_print_update but check me and the next
464  // example because of off-by-ones
465 
467  return true;
469  return true; // SPEEDUP: make this better
470  return (all.sd->weighted_examples() + 1. >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
471 }
bool quiet
Definition: global_data.h:487
constexpr bool PRINT_UPDATE_EVERY_EXAMPLE
Definition: search.cc:42
shared_data * sd
Definition: global_data.h:375
bool bfgs
Definition: global_data.h:412
constexpr bool PRINT_UPDATE_EVERY_PASS
Definition: search.cc:43
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147

◆ must_run_test()

bool Search::must_run_test ( vw all,
multi_ex ec,
bool  is_test_ex 
)

Definition at line 473 of file search.cc.

References vw::current_pass, vw::final_prediction_sink, vw::holdout_set_off, might_print_update(), vw::quiet, vw::raw_prediction, v_array< T >::size(), and vw::vw_is_main.

Referenced by train_single_example().

474 {
475  return (all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this
476  might_print_update(all) || // if we have to print and update to stderr
477  (all.raw_prediction > 0) || // we need raw predictions
478  ((!all.vw_is_main) && (is_test_ex)) || // library needs predictions
479  // or:
480  // it's not quiet AND
481  // current_pass == 0
482  // OR holdout is off
483  // OR it's a test example
484  ((!all.quiet || !all.vw_is_main) && // had to disable this because of library mode!
485  (!is_test_ex) &&
486  (all.holdout_set_off || // no holdout
487  ec[0]->test_only || (all.current_pass == 0) // we need error rates for progressive cost
488  ));
489 }
int raw_prediction
Definition: global_data.h:519
bool might_print_update(vw &all)
Definition: search.cc:461
v_array< int > final_prediction_sink
Definition: global_data.h:518
bool quiet
Definition: global_data.h:487
bool holdout_set_off
Definition: global_data.h:499
size_t size() const
Definition: v_array.h:68
bool vw_is_main
Definition: global_data.h:421
uint64_t current_pass
Definition: global_data.h:396

◆ need_memo_foreach_action()

bool Search::need_memo_foreach_action ( search_private priv)
inline

Definition at line 370 of file search.cc.

References INIT_TRAIN, Search::search_private::metaoverride, Search::search_private::metatask, and Search::search_private::state.

Referenced by choose_oracle_action(), search_predict(), single_prediction_LDF(), and single_prediction_notLDF().

371 {
372  return (priv.state == INIT_TRAIN) && (priv.metatask) && (priv.metaoverride); // &&
373  // (priv.metaoverride->_foreach_action || priv.metaoverride->_post_prediction);
374 }

◆ neighbor_feature_space()

std::string Search::neighbor_feature_space ( "neighbor"  )

Referenced by add_neighbor_features().

◆ number_to_natural()

std::string Search::number_to_natural ( size_t  big)

Definition at line 512 of file search.cc.

Referenced by print_update().

513 {
514  std::stringstream ss;
515  if (big > 9999999999)
516  ss << big / 1000000000 << "g";
517  else if (big > 9999999)
518  ss << big / 1000000 << "m";
519  else if (big > 9999)
520  ss << big / 1000 << "k";
521  else
522  ss << big;
523 
524  return ss.str();
525 }

◆ operator<<() [1/2]

std::ostream& Search::operator<< ( std::ostream &  os,
const scored_action x 
)

Definition at line 97 of file search.cc.

References Search::scored_action::a, and Search::scored_action::s.

98 {
99  os << x.a << ':' << x.s;
100  return os;
101 }

◆ operator<<() [2/2]

std::ostream& Search::operator<< ( std::ostream &  os,
const action_cache x 
)

Definition at line 131 of file search.cc.

References Search::action_cache::cost, Search::action_cache::is_opt, and Search::action_cache::k.

132 {
133  os << x.k << ':' << x.cost;
134  if (x.is_opt)
135  os << '*';
136  return os;
137 }

◆ parse_neighbor_features()

void Search::parse_neighbor_features ( std::string &  nf_string,
search sch 
)

Definition at line 2627 of file search.cc.

References v_array< T >::clear(), int_of_substring(), Search::search_private::neighbor_features, Search::search::priv, v_array< T >::push_back(), and tokenize().

Referenced by setup().

2628 {
2629  search_private& priv = *sch.priv;
2630  priv.neighbor_features.clear();
2631  size_t len = nf_string.length();
2632  if (len == 0)
2633  return;
2634 
2635  char* cstr = new char[len + 1];
2636  strcpy(cstr, nf_string.c_str());
2637 
2638  char* p = strtok(cstr, ",");
2639  std::vector<substring> cmd;
2640  while (p != 0)
2641  {
2642  cmd.clear();
2643  substring me = {p, p + strlen(p)};
2644  tokenize(':', me, cmd, true);
2645 
2646  int32_t posn = 0;
2647  char ns = ' ';
2648  if (cmd.size() == 1)
2649  {
2650  posn = int_of_substring(cmd[0]);
2651  ns = ' ';
2652  }
2653  else if (cmd.size() == 2)
2654  {
2655  posn = int_of_substring(cmd[0]);
2656  ns = (cmd[1].end > cmd[1].begin) ? cmd[1].begin[0] : ' ';
2657  }
2658  else
2659  {
2660  std::cerr << "warning: ignoring malformed neighbor specification: '" << p << "'" << endl;
2661  }
2662  int32_t enc = (posn << 24) | (ns & 0xFF);
2663  priv.neighbor_features.push_back(enc);
2664 
2665  p = strtok(nullptr, ",");
2666  }
2667 
2668  delete[] cstr;
2669 }
int int_of_substring(substring s)
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)

◆ print_update()

void Search::print_update ( search_private priv)

Definition at line 527 of file search.cc.

References Search::search_private::active_csoaa, Search::search_private::all, Search::search_private::beta, vw::current_pass, Search::search_private::current_policy, shared_data::example_number, Search::search_private::hit_new_pass, vw::holdout_set_off, shared_data::holdout_sum_loss, shared_data::holdout_sum_loss_since_last_dump, Search::search_private::num_calls_to_run, number_to_natural(), shared_data::old_weighted_labeled_examples, Search::search_private::pred_string, Search::search_private::printed_output_header, vw::progress_add, vw::progress_arg, vw::quiet, Search::search_private::read_example_last_pass, safediv(), vw::sd, should_print_update(), Search::search_private::start_clock_time, shared_data::sum_loss, shared_data::sum_loss_since_last_dump, to_short_string(), Search::search_private::total_cache_hits, Search::search_private::total_examples_generated, Search::search_private::total_predictions_made, Search::search_private::truth_string, shared_data::update_dump_interval(), shared_data::weighted_holdout_examples, shared_data::weighted_holdout_examples_since_last_dump, and shared_data::weighted_labeled_examples.

Referenced by finish_multiline_example().

528 {
529  vw& all = *priv.all;
530  if (!priv.printed_output_header && !all.quiet)
531  {
532  const char* header_fmt = "%-10s %-10s %8s%24s %22s %5s %5s %7s %7s %7s %-8s\n";
533  fprintf(stderr, header_fmt, "average", "since", "instance", "current true", "current predicted", "cur", "cur",
534  "predic", "cache", "examples", "");
535  if (priv.active_csoaa)
536  fprintf(stderr, header_fmt, "loss", "last", "counter", "output prefix", "output prefix", "pass", "pol", "made",
537  "hits", "gener", "#run");
538  else
539  fprintf(stderr, header_fmt, "loss", "last", "counter", "output prefix", "output prefix", "pass", "pol", "made",
540  "hits", "gener", "beta");
541  std::cerr.precision(5);
542  priv.printed_output_header = true;
543  }
544 
545  if (!should_print_update(all, priv.hit_new_pass))
546  return;
547 
548  char true_label[21];
549  char pred_label[21];
550  to_short_string(priv.truth_string->str(), 20, true_label);
551  to_short_string(priv.pred_string->str(), 20, pred_label);
552 
553  float avg_loss = 0.;
554  float avg_loss_since = 0.;
555  bool use_heldout_loss = (!all.holdout_set_off && all.current_pass >= 1) && (all.sd->weighted_holdout_examples > 0);
556  if (use_heldout_loss)
557  {
558  avg_loss = safediv((float)all.sd->holdout_sum_loss, (float)all.sd->weighted_holdout_examples);
559  avg_loss_since = safediv(
561 
564  }
565  else
566  {
567  avg_loss = safediv((float)all.sd->sum_loss, (float)all.sd->weighted_labeled_examples);
568  avg_loss_since = safediv((float)all.sd->sum_loss_since_last_dump,
570  }
571 
572  auto const& inst_cntr = number_to_natural((size_t)all.sd->example_number);
573  auto const& total_pred = number_to_natural(priv.total_predictions_made);
574  auto const& total_cach = number_to_natural(priv.total_cache_hits);
575  auto const& total_exge = number_to_natural(priv.total_examples_generated);
576 
577  fprintf(stderr, "%-10.6f %-10.6f %8s [%s] [%s] %5d %5d %7s %7s %7s %-8f", avg_loss, avg_loss_since,
578  inst_cntr.c_str(), true_label, pred_label, (int)priv.read_example_last_pass, (int)priv.current_policy,
579  total_pred.c_str(), total_cach.c_str(), total_exge.c_str(),
580  priv.active_csoaa ? priv.num_calls_to_run : priv.beta);
581 
582  if (PRINT_CLOCK_TIME)
583  {
584  size_t num_sec = (size_t)(((float)(clock() - priv.start_clock_time)) / CLOCKS_PER_SEC);
585  std::cerr << " " << num_sec << "sec";
586  }
587 
588  if (use_heldout_loss)
589  fprintf(stderr, " h");
590 
591  fprintf(stderr, "\n");
592  fflush(stderr);
594 }
double sum_loss
Definition: global_data.h:145
void to_short_string(std::string in, size_t max_len, char *out)
Definition: search.cc:499
bool quiet
Definition: global_data.h:487
float safediv(float a, float b)
Definition: search.cc:491
bool holdout_set_off
Definition: global_data.h:499
bool progress_add
Definition: global_data.h:545
double sum_loss_since_last_dump
Definition: global_data.h:146
shared_data * sd
Definition: global_data.h:375
float progress_arg
Definition: global_data.h:546
double old_weighted_labeled_examples
Definition: global_data.h:142
double weighted_holdout_examples
Definition: global_data.h:156
double holdout_sum_loss
Definition: global_data.h:159
uint64_t current_pass
Definition: global_data.h:396
std::string number_to_natural(size_t big)
Definition: search.cc:512
bool should_print_update(vw &all, bool hit_new_pass=false)
Definition: search.cc:449
uint64_t example_number
Definition: global_data.h:137
constexpr bool PRINT_CLOCK_TIME
Definition: search.cc:44
double weighted_labeled_examples
Definition: global_data.h:141
double weighted_holdout_examples_since_last_dump
Definition: global_data.h:157
double holdout_sum_loss_since_last_dump
Definition: global_data.h:158
void update_dump_interval(bool progress_add, float progress_arg)
Definition: global_data.h:215

◆ push_at()

template<class T >
void Search::push_at ( v_array< T > &  v,
item,
size_t  pos 
)

Definition at line 1074 of file search.cc.

References v_array< T >::begin(), v_array< T >::end(), v_array< T >::end_array, v_array< T >::resize(), and v_array< T >::size().

Referenced by Search::search::predict(), Search::search::predictLDF(), and search_predict().

1075 {
1076  if (v.size() > pos)
1077  v.begin()[pos] = item;
1078  else
1079  {
1080  if (v.end_array > v.begin() + pos)
1081  {
1082  // there's enough memory, just not enough filler
1083  memset(v.end(), 0, sizeof(T) * (pos - v.size()));
1084  v.begin()[pos] = item;
1085  v.end() = v.begin() + pos + 1;
1086  }
1087  else
1088  {
1089  // there's not enough memory
1090  v.resize(2 * pos + 3);
1091  v.begin()[pos] = item;
1092  v.end() = v.begin() + pos + 1;
1093  }
1094  }
1095 }
void resize(size_t length)
Definition: v_array.h:69
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
T *& end()
Definition: v_array.h:43
T * end_array
Definition: v_array.h:38

◆ random()

size_t Search::random ( std::shared_ptr< rand_state > &  rs,
size_t  max 
)

Definition at line 768 of file search.cc.

Referenced by choose_oracle_action().

769 {
770  return (size_t)(rs->get_and_update_random() * (float)max);
771 }

◆ random_policy()

int Search::random_policy ( search_private priv,
bool  allow_current,
bool  allow_optimal,
bool  advance_prng = true 
)

Definition at line 376 of file search.cc.

References Search::search_private::_random_state, Search::search_private::beta, Search::search_private::current_policy, f, and ldamath::powf().

Referenced by choose_policy().

377 {
378  if (priv.beta >= 1)
379  {
380  if (allow_current)
381  return (int)priv.current_policy;
382  if (priv.current_policy > 0)
383  return (((int)priv.current_policy) - 1);
384  if (allow_optimal)
385  return -1;
386  std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << endl;
387  return (int)priv.current_policy;
388  }
389 
390  int num_valid_policies = (int)priv.current_policy + allow_optimal + allow_current;
391  int pid = -1;
392 
393  if (num_valid_policies == 0)
394  {
395  std::cerr << "internal error (bug): no valid policies to choose from! defaulting to current" << endl;
396  return (int)priv.current_policy;
397  }
398  else if (num_valid_policies == 1)
399  pid = 0;
400  else if (num_valid_policies == 2)
401  pid = (advance_prng ? priv._random_state->get_and_update_random() : priv._random_state->get_random()) >= priv.beta;
402  else
403  {
404  // SPEEDUP this up in the case that beta is small!
405  float r = (advance_prng ? priv._random_state->get_and_update_random() : priv._random_state->get_random());
406  pid = 0;
407 
408  if (r > priv.beta)
409  {
410  r -= priv.beta;
411  while ((r > 0) && (pid < num_valid_policies - 1))
412  {
413  pid++;
414  r -= priv.beta * powf(1.f - priv.beta, (float)pid);
415  }
416  }
417  }
418  // figure out which policy pid refers to
419  if (allow_optimal && (pid == num_valid_policies - 1))
420  return -1; // this is the optimal policy
421 
422  pid = (int)priv.current_policy - pid;
423  if (!allow_current)
424  pid--;
425 
426  return pid;
427 }
T powf(T, T)
Definition: lda_core.cc:428
float f
Definition: cache.cc:40

◆ read_allowed_transitions()

v_array<CS::label> Search::read_allowed_transitions ( action  A,
const char *  filename 
)

Definition at line 2579 of file search.cc.

References c, f, v_array< T >::push_back(), and THROW.

Referenced by setup().

2580 {
2581  FILE* f = fopen(filename, "r");
2582  if (f == nullptr)
2583  THROW("error: could not read file " << filename << " (" << strerror(errno)
2584  << "); assuming all transitions are valid");
2585 
2586  bool* bg = calloc_or_throw<bool>(((size_t)(A + 1)) * (A + 1));
2587  int rd, from, to, count = 0;
2588  while ((rd = fscanf(f, "%d:%d", &from, &to)) > 0)
2589  {
2590  if ((from < 0) || (from > (int)A))
2591  {
2592  std::cerr << "warning: ignoring transition from " << from << " because it's out of the range [0," << A << "]"
2593  << endl;
2594  }
2595  if ((to < 0) || (to > (int)A))
2596  {
2597  std::cerr << "warning: ignoring transition to " << to << " because it's out of the range [0," << A << "]" << endl;
2598  }
2599  bg[from * (A + 1) + to] = true;
2600  count++;
2601  }
2602  fclose(f);
2603 
2604  v_array<CS::label> allowed = v_init<CS::label>();
2605 
2606  for (size_t from = 0; from < A; from++)
2607  {
2608  v_array<CS::wclass> costs = v_init<CS::wclass>();
2609 
2610  for (size_t to = 0; to < A; to++)
2611  if (bg[from * (A + 1) + to])
2612  {
2613  CS::wclass c = {FLT_MAX, (action)to, 0., 0.};
2614  costs.push_back(c);
2615  }
2616 
2617  CS::label ld = {costs};
2618  allowed.push_back(ld);
2619  }
2620  free(bg);
2621 
2622  std::cerr << "read " << count << " allowed transitions from " << filename << endl;
2623 
2624  return allowed;
2625 }
uint32_t action
Definition: search.h:19
void push_back(const T &new_ele)
Definition: v_array.h:107
#define THROW(args)
Definition: vw_exception.h:181
constexpr uint64_t c
Definition: rand48.cc:12
float f
Definition: cache.cc:40

◆ reset_search_structure()

void Search::reset_search_structure ( search_private priv)

Definition at line 690 of file search.cc.

References Search::search_private::_random_state, Search::search_private::adaptive_beta, Search::search_private::alpha, Search::search_private::beta, Search::search_private::cb_learner, features::delete_v(), Search::search_private::done_with_all_actions, Search::search_private::force_setup_ec_ref, Search::search_private::learn_loss, Search::search_private::loss_declared_cnt, Search::search_private::meta_t, Search::search_private::mix_per_roll_policy, Search::search_private::num_features, Search::search_private::ptag_to_action, Search::search_private::read_example_last_id, Search::action_repr::repr, Search::search_private::should_produce_string, Search::search_private::t, Search::search_private::test_loss, Search::search_private::total_examples_generated, and Search::search_private::train_loss.

Referenced by do_actual_learning(), and train_single_example().

691 {
692  // NOTE: make sure do NOT reset priv.learn_a_idx
693  priv.t = 0;
694  priv.meta_t = 0;
695  priv.loss_declared_cnt = 0;
696  priv.done_with_all_actions = false;
697  priv.test_loss = 0.;
698  priv.learn_loss = 0.;
699  priv.train_loss = 0.;
700  priv.num_features = 0;
701  priv.should_produce_string = false;
702  priv.mix_per_roll_policy = -2;
703  priv.force_setup_ec_ref = false;
704  if (priv.adaptive_beta)
705  {
706  float x = -log1pf(-priv.alpha) * (float)priv.total_examples_generated;
707  static constexpr float log_of_2 = (float)0.6931471805599453;
708  priv.beta = (x <= log_of_2) ? -expm1f(-x) : (1 - expf(-x)); // numerical stability
709  // float priv_beta = 1.f - powf(1.f - priv.alpha, (float)priv.total_examples_generated);
710  // assert( fabs(priv_beta - priv.beta) < 1e-2 );
711  if (priv.beta > 1)
712  priv.beta = 1;
713  }
714  for (Search::action_repr& ar : priv.ptag_to_action)
715  {
716  if (ar.repr != nullptr)
717  {
718  ar.repr->delete_v();
719  delete ar.repr;
720  }
721  }
722  priv.ptag_to_action.clear();
723 
724  if (!priv.cb_learner) // was: if rollout_all_actions
725  {
726  priv._random_state->set_random_state((uint32_t)(priv.read_example_last_id * 147483 + 4831921) * 2147483647);
727  }
728 }
void delete_v()
void clear()
features * repr
Definition: search.cc:106

◆ run_task()

void Search::run_task ( search sch,
multi_ex ec 
)

Definition at line 2098 of file search.cc.

References GET_TRUTH_STRING, Search::search_private::metatask, Search::search_private::num_calls_to_run, Search::search::priv, Search::search_task::run, Search::search_metatask::run, Search::search_private::state, and Search::search_private::task.

Referenced by do_actual_learning(), and train_single_example().

2099 {
2100  search_private& priv = *sch.priv;
2101  priv.num_calls_to_run++;
2102  if (priv.metatask && (priv.state != GET_TRUTH_STRING))
2103  priv.metatask->run(sch, ec);
2104  else
2105  priv.task->run(sch, ec);
2106 }

◆ safediv()

float Search::safediv ( float  a,
float  b 
)

Definition at line 491 of file search.cc.

References f.

Referenced by print_update().

492 {
493  if (b == 0.f)
494  return 0.f;
495  else
496  return (a / b);
497 }
constexpr uint64_t a
Definition: rand48.cc:11
float f
Definition: cache.cc:40

◆ search_declare_loss()

void Search::search_declare_loss ( search_private priv,
float  loss 
)

Definition at line 730 of file search.cc.

References cdbg, INIT_TEST, INIT_TRAIN, LEARN, Search::search_private::learn_loss, loss(), Search::search_private::loss_declared_cnt, Search::search_private::rollout_num_steps, Search::search_private::state, Search::search_private::test_loss, and Search::search_private::train_loss.

Referenced by Search::search::loss().

731 {
732  priv.loss_declared_cnt++;
733  switch (priv.state)
734  {
735  case INIT_TEST:
736  priv.test_loss += loss;
737  break;
738  case INIT_TRAIN:
739  priv.train_loss += loss;
740  break;
741  case LEARN:
742  if ((priv.rollout_num_steps == 0) || (priv.loss_declared_cnt <= priv.rollout_num_steps))
743  {
744  priv.learn_loss += loss;
745  cdbg << "priv.learn_loss += " << loss << " (now = " << priv.learn_loss << ")" << endl;
746  }
747  break;
748  default:
749  break; // get rid of the warning about missing cases (danger!)
750  }
751 }
#define cdbg
Definition: search.h:11
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60

◆ search_finish()

void Search::search_finish ( search sch)

Definition at line 2565 of file search.cc.

References Search::search_private::active_csoaa, cdbg, Search::search_task::finish, Search::search_metatask::finish, Search::search_private::metatask, Search::search_private::num_calls_to_run, Search::search::priv, and Search::search_private::task.

Referenced by setup().

2566 {
2567  search_private& priv = *sch.priv;
2568  cdbg << "search_finish" << endl;
2569 
2570  if (priv.active_csoaa)
2571  std::cerr << "search calls to run = " << priv.num_calls_to_run << endl;
2572 
2573  if (priv.task->finish)
2574  priv.task->finish(sch);
2575  if (priv.metatask && priv.metatask->finish)
2576  priv.metatask->finish(sch);
2577 }
#define cdbg
Definition: search.h:11

◆ search_initialize()

void Search::search_initialize ( vw all,
search sch 
)

Definition at line 2486 of file search.cc.

References Search::search_private::_random_state, Search::search_private::acset, Search::search_private::active_csoaa, Search::search_private::active_known, Search::search_private::active_uncertainty, Search::search_private::adaptive_beta, Search::search_private::all, Search::search_private::allow_current_policy, Search::search_private::bad_string_stream, Search::search_private::cache_hash_map, cached_item_equivalent(), COST_SENSITIVE::cs_label, Search::search_private::dat_new_feature_audit_ss, label_parser::default_label, Search::search_private::empty_cs_label, Search::auto_condition_settings::feature_value, vw::get_random_state(), INITIALIZE, Search::search_private::label_is_test, Search::auto_condition_settings::max_bias_ngram_length, mc_label_is_test(), MIX_PER_ROLL, Search::search_private::mix_per_roll_policy, Search::search_private::num_learners, Search::search_private::pred_string, Search::search::priv, Search::search_private::rawOutputString, Search::search_private::rawOutputStringStream, Search::search_private::rollin_method, Search::search_private::rollout_method, Search::search_private::state, Search::search::task_data, Search::search_private::test_action_sequence, Search::search_private::total_number_of_policies, and Search::search_private::truth_string.

Referenced by setup().

2487 {
2488  search_private& priv = *sch.priv; // priv is zero initialized by default
2489  priv.all = all;
2490  priv._random_state = all->get_random_state();
2491 
2492  priv.active_csoaa = false;
2493  priv.label_is_test = mc_label_is_test;
2494 
2495  priv.num_learners = 1;
2496  priv.state = INITIALIZE;
2497  priv.mix_per_roll_policy = -2;
2498 
2499  priv.pred_string = new std::stringstream();
2500  priv.truth_string = new std::stringstream();
2501  priv.bad_string_stream = new std::stringstream();
2502  priv.bad_string_stream->clear(priv.bad_string_stream->badbit);
2503 
2504  priv.rollout_method = MIX_PER_ROLL;
2505  priv.rollin_method = MIX_PER_ROLL;
2506 
2507  priv.allow_current_policy = true;
2508  priv.adaptive_beta = true;
2509 
2510  priv.total_number_of_policies = 1;
2511 
2512  priv.acset.max_bias_ngram_length = 1;
2513 
2514  priv.acset.feature_value = 1.;
2515 
2516  scored_action sa((action)-1, 0.);
2517  new (&priv.cache_hash_map) v_hashmap<unsigned char*, scored_action>();
2518  priv.cache_hash_map.set_default_value(sa);
2519  priv.cache_hash_map.set_equivalent(cached_item_equivalent);
2520 
2521  sch.task_data = nullptr;
2522 
2523  priv.active_uncertainty = v_init<std::pair<float, size_t>>();
2524  priv.active_known = v_init<v_array<std::pair<CS::wclass&, bool>>>();
2525 
2526  CS::cs_label.default_label(&priv.empty_cs_label);
2527 
2528  new (&priv.rawOutputString) std::string();
2529  priv.rawOutputStringStream = new std::stringstream(priv.rawOutputString);
2530  new (&priv.test_action_sequence) std::vector<action>();
2531  new (&priv.dat_new_feature_audit_ss) std::stringstream();
2532 }
bool mc_label_is_test(polylabel &lab)
Definition: search.cc:2484
label_parser cs_label
void(* default_label)(void *)
Definition: label_parser.h:12
uint32_t action
Definition: search.h:19
void set_default_value(const V &def)
Definition: v_hashmap.h:37
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
bool cached_item_equivalent(unsigned char *const &A, unsigned char *const &B)
Definition: search.cc:1429

◆ search_predict()

action Search::search_predict ( search_private priv,
example ecs,
size_t  ec_cnt,
ptag  mytag,
const action oracle_actions,
size_t  oracle_actions_cnt,
const ptag condition_on,
const char *  condition_on_names,
const action allowed_actions,
size_t  allowed_actions_cnt,
const float *  allowed_actions_cost,
size_t  learner_id,
float &  a_cost,
float   
)

Definition at line 1652 of file search.cc.

References Search::BaseTask::_foreach_action, Search::BaseTask::_maybe_override_prediction, Search::BaseTask::_post_prediction, a, Search::search_private::A, Search::search_private::acset, add_example_conditioning(), Search::search_private::all, allowed_actions_to_label(), vw::audit, Search::search_private::auto_condition_features, v_array< T >::begin(), cached_action_store_or_find(), cdbg, cdbg_print_array(), choose_oracle_action(), choose_policy(), features::clear(), Search::search_private::condition_on_actions, VW::copy_example_data(), label_parser::copy_label, COST_SENSITIVE::label::costs, polylabel::cs, COST_SENSITIVE::cs_label, del_example_conditioning(), Search::search_private::done_with_all_actions, COST_SENSITIVE::ec_is_example_header(), v_array< T >::end(), ensure_size(), Search::search_private::examples_dont_change, Search::search_private::force_oracle, foreach_action_from_cache(), generate_training_example(), GET_TRUTH_STRING, Search::search_private::gte_label, features::indicies, INIT_TEST, INIT_TRAIN, Search::search_private::is_ldf, Search::search_private::last_action_repr, LEARN, Search::search_private::learn_a_idx, Search::search_private::learn_allowed_actions, Search::search_private::learn_condition_on, Search::search_private::learn_condition_on_act, Search::search_private::learn_condition_on_names, Search::search_private::learn_ec_copy, Search::search_private::learn_ec_ref, Search::search_private::learn_ec_ref_cnt, Search::search_private::learn_learner_id, Search::search_private::learn_oracle_action, Search::search_private::learn_t, Search::search_private::loss_declared_cnt, Search::search_private::memo_foreach_action, Search::search_private::meta_t, Search::search_private::metaoverride, need_memo_foreach_action(), NO_ROLLOUT, example::passthrough, Search::search_private::ptag_to_action, push_at(), Search::search_private::rollout_method, Search::search_private::rollout_num_steps, Search::BaseTask::sch, select_learner(), single_prediction_LDF(), single_prediction_notLDF(), features::size(), Search::search_private::state, Search::search_private::t, example::test_only, THROW, Search::search_private::total_cache_hits, Search::search_private::train_trajectory, vw::training, Search::search_private::use_action_costs, Search::auto_condition_settings::use_passthrough_repr, and features::values.

Referenced by Search::search::predict(), and Search::search::predictLDF().

1655 {
1656  size_t condition_on_cnt = condition_on_names ? strlen(condition_on_names) : 0;
1657  size_t t = priv.t + priv.meta_t;
1658  priv.t++;
1659 
1660  // make sure parameters come in pairs correctly
1661  assert((oracle_actions == nullptr) == (oracle_actions_cnt == 0));
1662  assert((condition_on == nullptr) == (condition_on_names == nullptr));
1663  assert(((allowed_actions == nullptr) && (allowed_actions_cost == nullptr)) == (allowed_actions_cnt == 0));
1664  assert(priv.use_action_costs == (allowed_actions_cost != nullptr));
1665  if (allowed_actions_cost != nullptr)
1666  assert(oracle_actions == nullptr);
1667 
1668  // if we're just after the string, choose an oracle action
1669  if ((priv.state == GET_TRUTH_STRING) || priv.force_oracle)
1670  {
1672  priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1673  // if (priv.metaoverride && priv.metaoverride->_post_prediction)
1674  // priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t-priv.meta_t, a, 0.);
1675  a_cost = 0.;
1676  return a;
1677  }
1678 
1679  // if we're in LEARN mode and before learn_t, return the train action
1680  if ((priv.state == LEARN) && (t < priv.learn_t))
1681  {
1682  assert(t < priv.train_trajectory.size());
1683  action a = priv.train_trajectory[t].a;
1684  a_cost = priv.train_trajectory[t].s;
1685  cdbg << "LEARN " << t << " < priv.learn_t ==> a=" << a << ", a_cost=" << a_cost << endl;
1686  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1687  foreach_action_from_cache(priv, t);
1688  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1689  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1690  return a;
1691  }
1692 
1693  // for LDF, # of valid actions is ec_cnt; otherwise it's either allowed_actions_cnt or A
1694  size_t valid_action_cnt = priv.is_ldf ? ec_cnt : (allowed_actions_cnt > 0) ? allowed_actions_cnt : priv.A;
1695 
1696  // if we're in LEARN mode and _at_ learn_t, then:
1697  // - choose the next action
1698  // - decide if we're done
1699  // - if we are, then copy/mark the example ref
1700  if ((priv.state == LEARN) && (t == priv.learn_t))
1701  {
1702  action a = (action)priv.learn_a_idx;
1703  priv.loss_declared_cnt = 0;
1704 
1705  cdbg << "LEARN " << t << " = priv.learn_t ==> a=" << a << ", learn_a_idx=" << priv.learn_a_idx
1706  << " valid_action_cnt=" << valid_action_cnt << endl;
1707  priv.learn_a_idx++;
1708 
1709  // check to see if we're done with available actions
1710  if (priv.learn_a_idx >= valid_action_cnt)
1711  {
1712  priv.done_with_all_actions = true;
1713  priv.learn_learner_id = learner_id;
1714 
1715  // set reference or copy example(s)
1716  if (oracle_actions_cnt > 0)
1717  priv.learn_oracle_action = oracle_actions[0];
1718  priv.learn_ec_ref_cnt = ec_cnt;
1719  if (priv.examples_dont_change)
1720  priv.learn_ec_ref = ecs;
1721  else
1722  {
1723  size_t label_size = priv.is_ldf ? sizeof(CS::label) : sizeof(MC::label_t);
1724  void (*label_copy_fn)(void*, void*) = priv.is_ldf ? CS::cs_label.copy_label : nullptr;
1725 
1726  ensure_size(priv.learn_ec_copy, ec_cnt);
1727  for (size_t i = 0; i < ec_cnt; i++)
1728  VW::copy_example_data(priv.all->audit, priv.learn_ec_copy.begin() + i, ecs + i, label_size, label_copy_fn);
1729 
1730  priv.learn_ec_ref = priv.learn_ec_copy.begin();
1731  }
1732 
1733  // copy conditioning stuff and allowed actions
1734  if (priv.auto_condition_features)
1735  {
1736  ensure_size(priv.learn_condition_on, condition_on_cnt);
1737  ensure_size(priv.learn_condition_on_act, condition_on_cnt);
1738 
1739  priv.learn_condition_on.end() =
1740  priv.learn_condition_on.begin() + condition_on_cnt; // allow .size() to be used in lieu of _cnt
1741 
1742  memcpy(priv.learn_condition_on.begin(), condition_on, condition_on_cnt * sizeof(ptag));
1743 
1744  for (size_t i = 0; i < condition_on_cnt; i++)
1745  push_at(priv.learn_condition_on_act,
1746  action_repr(((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size()))
1747  ? priv.ptag_to_action[condition_on[i]]
1748  : 0),
1749  i);
1750 
1751  if (condition_on_names == nullptr)
1752  {
1753  ensure_size(priv.learn_condition_on_names, 1);
1754  priv.learn_condition_on_names[0] = 0;
1755  }
1756  else
1757  {
1758  ensure_size(priv.learn_condition_on_names, strlen(condition_on_names) + 1);
1759  strcpy(priv.learn_condition_on_names.begin(), condition_on_names);
1760  }
1761  }
1762 
1763  if (allowed_actions && (allowed_actions_cnt > 0))
1764  {
1765  ensure_size(priv.learn_allowed_actions, allowed_actions_cnt);
1766  memcpy(priv.learn_allowed_actions.begin(), allowed_actions, allowed_actions_cnt * sizeof(action));
1767  cdbg_print_array("in LEARN, learn_allowed_actions", priv.learn_allowed_actions);
1768  }
1769  }
1770 
1771  assert((allowed_actions_cnt == 0) || (a < allowed_actions_cnt));
1772 
1773  a_cost = 0.;
1774  action a_name = (allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[a] : priv.is_ldf ? a : (a + 1);
1775  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1776  {
1777  foreach_action_from_cache(priv, t, a_name);
1778  if (priv.memo_foreach_action[t])
1779  {
1780  cdbg << "@ memo_foreach_action: t=" << t << ", a=" << a << ", cost=" << (*priv.memo_foreach_action[t])[a].cost
1781  << endl;
1782  a_cost = (*priv.memo_foreach_action[t])[a].cost;
1783  }
1784  }
1785 
1786  a = a_name;
1787 
1788  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1789  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1790  return a;
1791  }
1792 
1793  if ((priv.state == LEARN) && (t > priv.learn_t) && (priv.rollout_num_steps > 0) &&
1794  (priv.loss_declared_cnt >= priv.rollout_num_steps))
1795  {
1796  cdbg << "... skipping" << endl;
1797  action a = priv.is_ldf ? 0 : ((allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[0] : 1);
1798  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1799  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, 0.);
1800  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1801  foreach_action_from_cache(priv, t);
1802  a_cost = 0.;
1803  return a;
1804  }
1805 
1806  if ((priv.state == INIT_TRAIN) || (priv.state == INIT_TEST) || ((priv.state == LEARN) && (t > priv.learn_t)))
1807  {
1808  // we actually need to run the policy
1809 
1810  int policy = choose_policy(priv);
1811  action a = 0;
1812 
1813  cdbg << "executing policy " << policy << endl;
1814 
1815  bool gte_here = (priv.state == INIT_TRAIN) && (priv.rollout_method == NO_ROLLOUT) &&
1816  ((oracle_actions_cnt > 0) || (priv.use_action_costs));
1817  a_cost = 0.;
1818  bool skip = false;
1819 
1820  if (priv.metaoverride && priv.metaoverride->_maybe_override_prediction &&
1821  (priv.state != LEARN)) // if LEARN and t>learn_t,then we cannot allow overrides!
1822  {
1823  skip = priv.metaoverride->_maybe_override_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1824  cdbg << "maybe_override_prediction --> " << skip << ", a=" << a << ", a_cost=" << a_cost << endl;
1825  if (skip && need_memo_foreach_action(priv))
1826  priv.memo_foreach_action.push_back(nullptr);
1827  }
1828 
1829  if ((!skip) && (policy == -1))
1830  a = choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt,
1831  allowed_actions_cost); // TODO: we probably want to actually get costs for oracle actions???
1832 
1833  bool need_fea = (policy == -1) && priv.metaoverride && priv.metaoverride->_foreach_action;
1834 
1835  if ((policy >= 0) || gte_here || need_fea) // the last case is we need to do foreach action
1836  {
1837  int learner = select_learner(priv, policy, learner_id, false, priv.state != INIT_TEST);
1838 
1839  ensure_size(priv.condition_on_actions, condition_on_cnt);
1840  for (size_t i = 0; i < condition_on_cnt; i++)
1841  priv.condition_on_actions[i] = ((1 <= condition_on[i]) && (condition_on[i] < priv.ptag_to_action.size()))
1842  ? priv.ptag_to_action[condition_on[i]]
1843  : 0;
1844 
1845  bool not_test = priv.all->training && !ecs[0].test_only;
1846 
1847  if ((!skip) && (!need_fea) && not_test &&
1848  cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin(),
1849  condition_on_cnt, policy, learner_id, a, false, a_cost))
1850  // if this succeeded, 'a' has the right action
1851  priv.total_cache_hits++;
1852  else // we need to predict, and then cache, and maybe run foreach_action
1853  {
1854  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(ecs[0])) ? 1 : 0;
1855  priv.last_action_repr.clear();
1856  if (priv.auto_condition_features)
1857  for (size_t n = start_K; n < ec_cnt; n++)
1859  priv, ecs[n], condition_on_cnt, condition_on_names, priv.condition_on_actions.begin());
1860 
1861  if (((!skip) && (policy >= 0)) || need_fea) // only make a prediction if we're going to use the output
1862  {
1863  if (priv.auto_condition_features && priv.acset.use_passthrough_repr)
1864  {
1865  if (priv.is_ldf)
1866  {
1867  THROW("search cannot use state representations in ldf mode");
1868  }
1869  if (ecs[0].passthrough)
1870  {
1871  THROW("search cannot passthrough");
1872  }
1873  ecs[0].passthrough = &priv.last_action_repr;
1874  }
1875  a = priv.is_ldf ? single_prediction_LDF(priv, ecs, ec_cnt, learner, a_cost, need_fea ? a : (action)-1)
1876  : single_prediction_notLDF(priv, *ecs, learner, allowed_actions, allowed_actions_cnt,
1877  allowed_actions_cost, a_cost, need_fea ? a : (action)-1);
1878 
1879  cdbg << "passthrough = [";
1880  for (size_t kk = 0; kk < priv.last_action_repr.size(); kk++)
1881  cdbg << ' ' << priv.last_action_repr.indicies[kk] << ':' << priv.last_action_repr.values[kk];
1882  cdbg << " ]" << endl;
1883 
1884  ecs[0].passthrough = nullptr;
1885  }
1886 
1887  if (need_fea)
1888  {
1889  // TODO this
1890  }
1891 
1892  if (gte_here)
1893  {
1894  cdbg << "INIT_TRAIN, NO_ROLLOUT, at least one oracle_actions, a=" << a << endl;
1895  // we can generate a training example _NOW_ because we're not doing rollouts
1896  // allowed_actions_to_losses(priv, ec_cnt, allowed_actions, allowed_actions_cnt, oracle_actions,
1897  // oracle_actions_cnt, losses);
1898  allowed_actions_to_label(priv, ec_cnt, allowed_actions, allowed_actions_cnt, allowed_actions_cost,
1899  oracle_actions, oracle_actions_cnt, priv.gte_label);
1900  cdbg << "priv.gte_label = [";
1901  for (size_t i = 0; i < priv.gte_label.cs.costs.size(); i++)
1902  cdbg << ' ' << priv.gte_label.cs.costs[i].class_index << ':' << priv.gte_label.cs.costs[i].x;
1903  cdbg << " ]" << endl;
1904 
1905  priv.learn_ec_ref = ecs;
1906  priv.learn_ec_ref_cnt = ec_cnt;
1907  if (allowed_actions)
1908  {
1909  ensure_size(priv.learn_allowed_actions, allowed_actions_cnt); // TODO: do we really need this?
1910  memcpy(priv.learn_allowed_actions.begin(), allowed_actions, allowed_actions_cnt * sizeof(action));
1911  }
1912  size_t old_learner_id = priv.learn_learner_id;
1913  priv.learn_learner_id = learner_id;
1915  priv, priv.gte_label, 1., false); // this is false because the conditioning has already been added!
1916  priv.learn_learner_id = old_learner_id;
1917  }
1918 
1919  if (priv.auto_condition_features)
1920  for (size_t n = start_K; n < ec_cnt; n++) del_example_conditioning(priv, ecs[n]);
1921 
1922  if (not_test && (!skip))
1923  cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.begin(),
1924  condition_on_cnt, policy, learner_id, a, true, a_cost);
1925  }
1926  }
1927 
1928  if (priv.state == INIT_TRAIN)
1929  priv.train_trajectory.push_back(scored_action(a, a_cost)); // note the action for future reference
1930 
1931  if (priv.metaoverride && priv.metaoverride->_post_prediction)
1932  priv.metaoverride->_post_prediction(*priv.metaoverride->sch, t - priv.meta_t, a, a_cost);
1933 
1934  return a;
1935  }
1936 
1937  THROW("error: predict called in unknown state");
1938 }
void add_example_conditioning(search_private &priv, example &ec, size_t condition_on_cnt, const char *condition_on_names, action_repr *condition_on_actions)
Definition: search.cc:784
#define cdbg
Definition: search.h:11
void(* copy_label)(void *, void *)
Definition: label_parser.h:18
label_parser cs_label
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
void push_at(v_array< T > &v, T item, size_t pos)
Definition: search.cc:1074
bool cached_action_store_or_find(search_private &priv, ptag mytag, const ptag *condition_on, const char *condition_on_names, action_repr *condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store, float &a_cost)
Definition: search.cc:1438
void generate_training_example(search_private &priv, polylabel &losses, float weight, bool add_conditioning=true, float min_loss=FLT_MAX)
Definition: search.cc:1490
uint32_t action
Definition: search.h:19
action single_prediction_notLDF(search_private &priv, example &ec, int policy, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, float &a_cost, action override_action)
Definition: search.cc:1163
bool ec_is_example_header(example const &ec)
int select_learner(search_private &priv, int policy, size_t learner_id, bool is_training, bool is_local)
Definition: search.cc:432
void foreach_action_from_cache(search_private &priv, size_t t, action override_a=(action) -1)
Definition: search.cc:1634
bool need_memo_foreach_action(search_private &priv)
Definition: search.cc:370
void cdbg_print_array(std::string str, v_array< T > &A)
Definition: search.cc:754
int choose_policy(search_private &priv, bool advance_prng=true)
Definition: search.cc:1401
constexpr uint64_t a
Definition: rand48.cc:11
void ensure_size(v_array< T > &A, size_t sz)
Definition: search.cc:1066
void del_example_conditioning(search_private &priv, example &ec)
Definition: search.cc:881
features * passthrough
Definition: example.h:74
void allowed_actions_to_label(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, const action *oracle_actions, size_t oracle_actions_cnt, polylabel &lab)
Definition: search.cc:991
action choose_oracle_action(search_private &priv, size_t ec_cnt, const action *oracle_actions, size_t oracle_actions_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
Definition: search.cc:1097
iterator begin()
uint32_t ptag
Definition: search.h:20
#define THROW(args)
Definition: vw_exception.h:181
action single_prediction_LDF(search_private &priv, example *ecs, size_t ec_cnt, int policy, float &a_cost, action override_action)
Definition: search.cc:1310
bool test_only
Definition: example.h:76

◆ search_predictNeedsExample()

bool Search::search_predictNeedsExample ( search_private priv)

Definition at line 1601 of file search.cc.

References choose_policy(), GET_TRUTH_STRING, INIT_TEST, INIT_TRAIN, INITIALIZE, LEARN, Search::search_private::learn_t, Search::search_private::loss_declared_cnt, Search::search_private::meta_t, NO_ROLLOUT, Search::search_private::rollout_method, Search::search_private::rollout_num_steps, Search::search_private::state, and Search::search_private::t.

Referenced by Search::search::predictNeedsExample().

1602 {
1603  // this is basically copied from the logic of search_predict()
1604  switch (priv.state)
1605  {
1606  case INITIALIZE:
1607  return false;
1608  case GET_TRUTH_STRING:
1609  return false;
1610  case INIT_TEST:
1611  return true;
1612  case INIT_TRAIN:
1613  // TODO: do we need to do something here for metatasks?
1614  // if (priv.beam && (priv.t < priv.beam_actions.size()))
1615  // return false;
1616  if (priv.rollout_method == NO_ROLLOUT)
1617  return true;
1618  break;
1619  case LEARN:
1620  if (priv.t + priv.meta_t < priv.learn_t)
1621  return false; // TODO: in meta search mode with foreach feature we'll need it even here
1622  if (priv.t + priv.meta_t == priv.learn_t)
1623  return true; // SPEEDUP: we really only need it on the last learn_a, but this is hard to know...
1624  // t > priv.learn_t
1625  if ((priv.rollout_num_steps > 0) && (priv.loss_declared_cnt >= priv.rollout_num_steps))
1626  return false; // skipping
1627  break;
1628  }
1629 
1630  int pol = choose_policy(priv, false); // choose a policy but don't advance prng
1631  return (pol != -1);
1632 }
int choose_policy(search_private &priv, bool advance_prng=true)
Definition: search.cc:1401

◆ select_learner()

int Search::select_learner ( search_private priv,
int  policy,
size_t  learner_id,
bool  is_training,
bool  is_local 
)

Definition at line 432 of file search.cc.

References Search::search_private::all, shared_data::example_number, Search::search_private::num_learners, vw::sd, and Search::search_private::xv.

Referenced by generate_training_example(), and search_predict().

433 {
434  if (policy < 0)
435  return policy; // optimal policy
436  else
437  {
438  if (priv.xv)
439  {
440  learner_id *= 3;
441  if (!is_local)
442  learner_id += 1 + (size_t)(is_training ^ (priv.all->sd->example_number % 2 == 1));
443  }
444  int p = (int)(policy * priv.num_learners + learner_id);
445  return p;
446  }
447 }

◆ setup()

LEARNER::base_learner * Search::setup ( options_i options,
vw all 
)

Definition at line 2671 of file search.cc.

References Search::search_private::A, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), CB::cb_label, cdbg, vw::check_holdout_every_n_passes, COST_SENSITIVE::cs_label, label_parser::default_label, parser::emptylines_separate_examples, end_examples(), LEARNER::end_pass(), ensure_param(), f, finish_multiline_example(), VW::config::options_i::get_typed_option(), handle_condition_options(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, parser::lp, LEARNER::make_base(), VW::config::make_option(), label_type::mc, MULTICLASS::mc_label, MIX_PER_ROLL, MIX_PER_STATE, NO_ROLLOUT, vw::numpasses, vw::options, ORACLE, vw::p, parse_neighbor_features(), POLICY, read_allowed_transitions(), VW::config::options_i::replace(), search_finish(), search_initialize(), vw::searchstr, LEARNER::learner< T, E >::set_end_examples(), LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_finish(), LEARNER::learner< T, E >::set_finish_example(), setup_base(), THROW, prediction_type::to_string(), vw::training, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

2672 {
2673  free_ptr<search> sch = scoped_calloc_or_throw<search>();
2674  search_private& priv = *sch->priv;
2675  std::string task_string;
2676  std::string metatask_string;
2677  std::string interpolation_string = "data";
2678  std::string neighbor_features_string;
2679  std::string rollout_string = "mix_per_state";
2680  std::string rollin_string = "mix_per_state";
2681 
2682  uint32_t search_trained_nb_policies;
2683  std::string search_allowed_transitions;
2684 
2685  priv.A = 1;
2686  option_group_definition new_options("Search options");
2687  new_options.add(
2688  make_option("search", priv.A).keep().help("Use learning to search, argument=maximum action id or 0 for LDF"));
2689  new_options.add(make_option("search_task", task_string)
2690  .keep()
2691  .help("the search task (use \"--search_task list\" to get a list of available tasks)"));
2692  new_options.add(
2693  make_option("search_metatask", metatask_string)
2694  .keep()
2695  .help("the search metatask (use \"--search_metatask list\" to get a list of available metatasks)"));
2696  new_options.add(make_option("search_interpolation", interpolation_string)
2697  .keep()
2698  .help("at what level should interpolation happen? [*data|policy]"));
2699  new_options.add(
2700  make_option("search_rollout", rollout_string)
2701  .help("how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]"));
2702  new_options.add(make_option("search_rollin", rollin_string)
2703  .help("how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]"));
2704  new_options.add(make_option("search_passes_per_policy", priv.passes_per_policy)
2705  .default_value(1)
2706  .help("number of passes per policy (only valid for search_interpolation=policy)"));
2707  new_options.add(make_option("search_beta", priv.beta)
2708  .default_value(0.5f)
2709  .help("interpolation rate for policies (only valid for search_interpolation=policy)"));
2710  new_options.add(make_option("search_alpha", priv.alpha)
2711  .default_value(1e-10f)
2712  .help("annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data)"));
2713  new_options.add(make_option("search_total_nb_policies", priv.total_number_of_policies)
2714  .help("if we are going to train the policies through multiple separate calls to vw, we need to "
2715  "specify this parameter and tell vw how many policies are eventually going to be trained"));
2716  new_options.add(make_option("search_trained_nb_policies", search_trained_nb_policies)
2717  .help("the number of trained policies in a file"));
2718  new_options.add(make_option("search_allowed_transitions", search_allowed_transitions)
2719  .help("read file of allowed transitions [def: all transitions are allowed]"));
2720  new_options.add(make_option("search_subsample_time", priv.subsample_timesteps)
2721  .help("instead of training at all timesteps, use a subset. if value in (0,1), train on a random "
2722  "v%. if v>=1, train on precisely v steps per example, if v<=-1, use active learning"));
2723  new_options.add(
2724  make_option("search_neighbor_features", neighbor_features_string)
2725  .keep()
2726  .help("copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line "
2727  "namespace a and next next line from namespace _unnamed_, where ',' separates them"));
2728  new_options.add(make_option("search_rollout_num_steps", priv.rollout_num_steps)
2729  .help("how many calls of \"loss\" before we stop really predicting on rollouts and switch to "
2730  "oracle (default means \"infinite\")"));
2731  new_options.add(make_option("search_history_length", priv.history_length)
2732  .keep()
2733  .default_value(1)
2734  .help("some tasks allow you to specify how much history their depend on; specify that here"));
2735  new_options.add(make_option("search_no_caching", priv.no_caching)
2736  .help("turn off the built-in caching ability (makes things slower, but technically more safe)"));
2737  new_options.add(
2738  make_option("search_xv", priv.xv).help("train two separate policies, alternating prediction/learning"));
2739  new_options.add(make_option("search_perturb_oracle", priv.perturb_oracle)
2740  .default_value(0.f)
2741  .help("perturb the oracle on rollin with this probability"));
2742  new_options.add(make_option("search_linear_ordering", priv.linear_ordering)
2743  .help("insist on generating examples in linear order (def: hoopla permutation)"));
2744  new_options.add(make_option("search_active_verify", priv.active_csoaa_verify)
2745  .help("verify that active learning is doing the right thing (arg = multiplier, should be = "
2746  "cost_range * range_c)"));
2747  new_options.add(make_option("search_save_every_k_runs", priv.save_every_k_runs).help("save model every k runs"));
2748  options.add_and_parse(new_options);
2749 
2750  if (!options.was_supplied("search_task"))
2751  return nullptr;
2752 
2753  search_initialize(&all, *sch.get());
2754 
2755  parse_neighbor_features(neighbor_features_string, *sch.get());
2756 
2757  if (interpolation_string.compare("data") == 0) // run as dagger
2758  {
2759  priv.adaptive_beta = true;
2760  priv.allow_current_policy = true;
2761  priv.passes_per_policy = all.numpasses;
2762  if (priv.current_policy > 1)
2763  priv.current_policy = 1;
2764  }
2765  else if (interpolation_string.compare("policy") == 0)
2766  ;
2767  else
2768  THROW("error: --search_interpolation must be 'data' or 'policy'");
2769 
2770  if ((rollout_string.compare("policy") == 0) || (rollout_string.compare("learn") == 0))
2771  priv.rollout_method = POLICY;
2772  else if ((rollout_string.compare("oracle") == 0) || (rollout_string.compare("ref") == 0))
2773  priv.rollout_method = ORACLE;
2774  else if ((rollout_string.compare("mix_per_state") == 0))
2775  priv.rollout_method = MIX_PER_STATE;
2776  else if ((rollout_string.compare("mix_per_roll") == 0) || (rollout_string.compare("mix") == 0))
2777  priv.rollout_method = MIX_PER_ROLL;
2778  else if ((rollout_string.compare("none") == 0))
2779  {
2780  priv.rollout_method = NO_ROLLOUT;
2781  priv.no_caching = true;
2782  }
2783  else
2784  THROW("error: --search_rollout must be 'learn', 'ref', 'mix', 'mix_per_state' or 'none'");
2785 
2786  if ((rollin_string.compare("policy") == 0) || (rollin_string.compare("learn") == 0))
2787  priv.rollin_method = POLICY;
2788  else if ((rollin_string.compare("oracle") == 0) || (rollin_string.compare("ref") == 0))
2789  priv.rollin_method = ORACLE;
2790  else if ((rollin_string.compare("mix_per_state") == 0))
2791  priv.rollin_method = MIX_PER_STATE;
2792  else if ((rollin_string.compare("mix_per_roll") == 0) || (rollin_string.compare("mix") == 0))
2793  priv.rollin_method = MIX_PER_ROLL;
2794  else
2795  THROW("error: --search_rollin must be 'learn', 'ref', 'mix' or 'mix_per_state'");
2796 
2797  // check if the base learner is contextual bandit, in which case, we dont rollout all actions.
2798  priv.allowed_actions_cache = &calloc_or_throw<polylabel>();
2799  if (options.was_supplied("cb"))
2800  {
2801  priv.cb_learner = true;
2802  CB::cb_label.default_label(priv.allowed_actions_cache);
2803  priv.learn_losses.cb.costs = v_init<CB::cb_class>();
2804  priv.gte_label.cb.costs = v_init<CB::cb_class>();
2805  }
2806  else
2807  {
2808  priv.cb_learner = false;
2809  CS::cs_label.default_label(priv.allowed_actions_cache);
2810  priv.learn_losses.cs.costs = v_init<CS::wclass>();
2811  priv.gte_label.cs.costs = v_init<CS::wclass>();
2812  }
2813 
2814  ensure_param(priv.beta, 0.0, 1.0, 0.5, "warning: search_beta must be in (0,1); resetting to 0.5");
2815  ensure_param(priv.alpha, 0.0, 1.0, 1e-10f, "warning: search_alpha must be in (0,1); resetting to 1e-10");
2816 
2817  priv.num_calls_to_run = 0;
2818 
2819  // compute total number of policies we will have at end of training
2820  // we add current_policy for cases where we start from an initial set of policies loaded through -i option
2821  uint32_t tmp_number_of_policies = priv.current_policy;
2822  if (all.training)
2823  tmp_number_of_policies += (int)ceil(((float)all.numpasses) / ((float)priv.passes_per_policy));
2824 
2825  // the user might have specified the number of policies that will eventually be trained through multiple vw calls,
2826  // so only set total_number_of_policies to computed value if it is larger
2827  cdbg << "current_policy=" << priv.current_policy << " tmp_number_of_policies=" << tmp_number_of_policies
2828  << " total_number_of_policies=" << priv.total_number_of_policies << endl;
2829  if (tmp_number_of_policies > priv.total_number_of_policies)
2830  {
2831  priv.total_number_of_policies = tmp_number_of_policies;
2832  if (priv.current_policy >
2833  0) // we loaded a file but total number of policies didn't match what is needed for training
2834  std::cerr << "warning: you're attempting to train more classifiers than was allocated initially. Likely to cause "
2835  "bad performance."
2836  << endl;
2837  }
2838 
2839  // current policy currently points to a new policy we would train
2840  // if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy
2841  // so that we only use those loaded when testing (as run_prediction is called with allow_current to true)
2842  if (!all.training && priv.current_policy > 0)
2843  priv.current_policy--;
2844 
2845  all.options->replace("search_trained_nb_policies", std::to_string(priv.current_policy));
2846  all.options->get_typed_option<uint32_t>("search_trained_nb_policies").value(priv.current_policy);
2847 
2848  all.options->replace("search_total_nb_policies", std::to_string(priv.total_number_of_policies));
2849  all.options->get_typed_option<uint32_t>("search_total_nb_policies").value(priv.total_number_of_policies);
2850 
2851  cdbg << "search current_policy = " << priv.current_policy
2852  << " total_number_of_policies = " << priv.total_number_of_policies << endl;
2853 
2854  if (task_string.compare("list") == 0)
2855  {
2856  std::cerr << endl << "available search tasks:" << endl;
2857  for (search_task** mytask = all_tasks; *mytask != nullptr; mytask++)
2858  std::cerr << " " << (*mytask)->task_name << endl;
2859  std::cerr << endl;
2860  exit(0);
2861  }
2862  if (metatask_string.compare("list") == 0)
2863  {
2864  std::cerr << endl << "available search metatasks:" << endl;
2865  for (search_metatask** mytask = all_metatasks; *mytask != nullptr; mytask++)
2866  std::cerr << " " << (*mytask)->metatask_name << endl;
2867  std::cerr << endl;
2868  exit(0);
2869  }
2870  for (search_task** mytask = all_tasks; *mytask != nullptr; mytask++)
2871  if (task_string.compare((*mytask)->task_name) == 0)
2872  {
2873  priv.task = *mytask;
2874  sch->task_name = (*mytask)->task_name;
2875  break;
2876  }
2877  if (priv.task == nullptr)
2878  {
2879  if (!options.was_supplied("help"))
2880  THROW("fail: unknown task for --search_task '" << task_string << "'; use --search_task list to get a list");
2881  }
2882  priv.metatask = nullptr;
2883  for (search_metatask** mytask = all_metatasks; *mytask != nullptr; mytask++)
2884  if (metatask_string.compare((*mytask)->metatask_name) == 0)
2885  {
2886  priv.metatask = *mytask;
2887  sch->metatask_name = (*mytask)->metatask_name;
2888  break;
2889  }
2890  all.p->emptylines_separate_examples = true;
2891 
2892  if (!options.was_supplied("csoaa") && !options.was_supplied("cs_active") && !options.was_supplied("csoaa_ldf") &&
2893  !options.was_supplied("wap_ldf") && !options.was_supplied("cb"))
2894  {
2895  options.insert("csoaa", std::to_string(priv.A));
2896  }
2897 
2898  priv.active_csoaa = options.was_supplied("cs_active");
2899  priv.active_csoaa_verify = -1.;
2900  if (options.was_supplied("search_active_verify"))
2901  if (!priv.active_csoaa)
2902  THROW("cannot use --search_active_verify without using --cs_active");
2903 
2904  cdbg << "active_csoaa = " << priv.active_csoaa << ", active_csoaa_verify = " << priv.active_csoaa_verify << endl;
2905 
2906  base_learner* base = setup_base(*all.options, all);
2907 
2908  // default to OAA labels unless the task wants to override this (which they can do in initialize)
2909  all.p->lp = MC::mc_label;
2910  all.label_type = label_type::mc;
2911  if (priv.task && priv.task->initialize)
2912  priv.task->initialize(*sch.get(), priv.A, options);
2913  if (priv.metatask && priv.metatask->initialize)
2914  priv.metatask->initialize(*sch.get(), priv.A, options);
2915  priv.meta_t = 0;
2916 
2917  if (options.was_supplied("search_allowed_transitions"))
2918  read_allowed_transitions((action)priv.A, search_allowed_transitions.c_str());
2919 
2920  // set up auto-history (used to only do this if AUTO_CONDITION_FEATURES was on, but that doesn't work for hooktask)
2921  handle_condition_options(all, priv.acset);
2922 
2923  if (!priv.allow_current_policy) // if we're not dagger
2924  all.check_holdout_every_n_passes = priv.passes_per_policy;
2925 
2926  all.searchstr = sch.get();
2927 
2928  priv.start_clock_time = clock();
2929 
2930  if (priv.xv)
2931  priv.num_learners *= 3;
2932 
2933  cdbg << "num_learners = " << priv.num_learners << endl;
2934 
2935  learner<search, multi_ex>& l = init_learner(sch, make_base(*base), do_actual_learning<true>,
2936  do_actual_learning<false>, priv.total_number_of_policies * priv.num_learners);
2941  return make_base(l);
2942 }
v_array< CS::label > read_allowed_transitions(action A, const char *filename)
Definition: search.cc:2579
#define cdbg
Definition: search.h:11
void * searchstr
Definition: global_data.h:430
VW::config::options_i * options
Definition: global_data.h:428
label_parser cs_label
void search_initialize(vw *all, search &sch)
Definition: search.cc:2486
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
search_metatask * all_metatasks[]
Definition: search.cc:39
void(* default_label)(void *)
Definition: label_parser.h:12
virtual void replace(const std::string &key, const std::string &value)=0
label_type::label_type_t label_type
Definition: global_data.h:550
void parse_neighbor_features(std::string &nf_string, search &sch)
Definition: search.cc:2627
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
uint32_t action
Definition: search.h:19
search_task * all_tasks[]
Definition: search.cc:35
virtual void add_and_parse(const option_group_definition &group)=0
void handle_condition_options(vw &all, auto_condition_settings &acset)
Definition: search.cc:2543
size_t check_holdout_every_n_passes
Definition: global_data.h:503
bool training
Definition: global_data.h:488
label_parser mc_label
Definition: multiclass.cc:93
parser * p
Definition: global_data.h:377
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
size_t numpasses
Definition: global_data.h:451
void search_finish(search &sch)
Definition: search.cc:2565
void end_examples(search &sch)
Definition: search.cc:2464
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
void set_finish(void(*f)(T &))
Definition: learner.h:265
void end_pass(search &sch)
Definition: search.cc:2433
void ensure_param(float &v, float lo, float hi, float def, const char *str)
Definition: search.cc:2534
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
bool emptylines_separate_examples
Definition: parser.h:84
float f
Definition: cache.cc:40
const char * to_string(prediction_type_t prediction_type)
Definition: learner.cc:12
void set_end_examples(void(*f)(T &))
Definition: learner.h:295
label_parser lp
Definition: parser.h:102

◆ should_print_update()

bool Search::should_print_update ( vw all,
bool  hit_new_pass = false 
)

Definition at line 449 of file search.cc.

References vw::bfgs, shared_data::dump_interval, vw::quiet, vw::sd, and shared_data::weighted_examples().

Referenced by print_update().

450 {
451  // uncomment to print out final loss after all examples processed
452  // commented for now so that outputs matches make test
453 
455  return true;
456  if (PRINT_UPDATE_EVERY_PASS && hit_new_pass)
457  return true;
458  return (all.sd->weighted_examples() >= all.sd->dump_interval) && !all.quiet && !all.bfgs;
459 }
bool quiet
Definition: global_data.h:487
constexpr bool PRINT_UPDATE_EVERY_EXAMPLE
Definition: search.cc:42
shared_data * sd
Definition: global_data.h:375
bool bfgs
Definition: global_data.h:412
constexpr bool PRINT_UPDATE_EVERY_PASS
Definition: search.cc:43
double weighted_examples()
Definition: global_data.h:188
float dump_interval
Definition: global_data.h:147

◆ single_prediction_LDF()

action Search::single_prediction_LDF ( search_private priv,
example ecs,
size_t  ec_cnt,
int  policy,
float &  a_cost,
action  override_action 
)

Definition at line 1310 of file search.cc.

References Search::BaseTask::_foreach_action, a, LabelDict::add_example_namespaces_from_example(), LEARNER::as_multiline(), Search::search_private::base_learner, cdbg, Search::action_cache::cost, COST_SENSITIVE::label::costs, polylabel::cs, COST_SENSITIVE::cs_label, label_parser::default_label, LabelDict::del_example_namespaces_from_example(), v_array< T >::delete_v(), COST_SENSITIVE::ec_is_example_header(), example_predict::ft_offset, Search::search_private::is_ldf, Search::action_cache::is_opt, Search::action_cache::k, example::l, Search::search_private::ldf_test_label, Search::search_private::memo_foreach_action, Search::search_private::metaoverride, Search::action_cache::min_cost, need_memo_foreach_action(), example::num_features, Search::search_private::num_features, Search::search_private::offset, example::partial_prediction, LEARNER::learner< T, E >::predict(), v_array< T >::push_back(), Search::BaseTask::sch, v_array< T >::size(), Search::search_private::t, and Search::search_private::total_predictions_made.

Referenced by search_predict().

1313 {
1314  bool need_partial_predictions = need_memo_foreach_action(priv) ||
1315  (priv.metaoverride && priv.metaoverride->_foreach_action) || (override_action != (action)-1);
1316 
1317  CS::cs_label.default_label(&priv.ldf_test_label);
1318  CS::wclass wc = {0., 1, 0., 0.};
1319  priv.ldf_test_label.costs.push_back(wc);
1320 
1321  // keep track of best (aka chosen) action
1322  float best_prediction = 0.;
1323  action best_action = 0;
1324 
1325  size_t start_K = (priv.is_ldf && COST_SENSITIVE::ec_is_example_header(ecs[0])) ? 1 : 0;
1326 
1327  v_array<action_cache>* this_cache = nullptr;
1328  if (need_partial_predictions)
1329  {
1330  this_cache = new v_array<action_cache>();
1331  *this_cache = v_init<action_cache>();
1332  }
1333 
1334  for (action a = (uint32_t)start_K; a < ec_cnt; a++)
1335  {
1336  cdbg << "== single_prediction_LDF a=" << a << "==" << endl;
1337  if (start_K > 0)
1339 
1340  polylabel old_label = ecs[a].l;
1341  ecs[a].l.cs = priv.ldf_test_label;
1342 
1343  multi_ex tmp;
1344  uint64_t old_offset = ecs[a].ft_offset;
1345  ecs[a].ft_offset = priv.offset;
1346  tmp.push_back(&ecs[a]);
1347  as_multiline(priv.base_learner)->predict(tmp, policy);
1348 
1349  ecs[a].ft_offset = old_offset;
1350  cdbg << "partial_prediction[" << a << "] = " << ecs[a].partial_prediction << endl;
1351 
1352  if (override_action != (action)-1)
1353  {
1354  if (a == override_action)
1355  a_cost = ecs[a].partial_prediction;
1356  }
1357  else if ((a == start_K) || (ecs[a].partial_prediction < best_prediction))
1358  {
1359  best_prediction = ecs[a].partial_prediction;
1360  best_action = a;
1361  a_cost = best_prediction;
1362  }
1363  if (this_cache)
1364  this_cache->push_back(action_cache(0., a, false, ecs[a].partial_prediction));
1365 
1366  priv.num_features += ecs[a].num_features;
1367  ecs[a].l = old_label;
1368  if (start_K > 0)
1370  }
1371  if (override_action != (action)-1)
1372  best_action = override_action;
1373  else
1374  a_cost = best_prediction;
1375 
1376  if (this_cache)
1377  {
1378  for (size_t i = 0; i < this_cache->size(); i++)
1379  {
1380  action_cache& ac = (*this_cache)[i];
1381  ac.min_cost = a_cost;
1382  ac.is_opt = (ac.k == best_action);
1383  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1384  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, priv.t - 1, ac.min_cost, ac.k, ac.is_opt, ac.cost);
1385  }
1386  if (need_memo_foreach_action(priv) && (override_action == (action)-1))
1387  priv.memo_foreach_action.push_back(this_cache);
1388  else
1389  {
1390  this_cache->delete_v();
1391  delete this_cache;
1392  }
1393  }
1394 
1395  // TODO: generate raw predictions if necessary
1396 
1397  priv.total_predictions_made++;
1398  return best_action;
1399 }
#define cdbg
Definition: search.h:11
void predict(E &ec, size_t i=0)
Definition: learner.h:169
label_parser cs_label
void(* default_label)(void *)
Definition: label_parser.h:12
void del_example_namespaces_from_example(example &target, example &source)
uint32_t action
Definition: search.h:19
float partial_prediction
Definition: example.h:68
bool ec_is_example_header(example const &ec)
size_t size() const
Definition: v_array.h:68
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
size_t num_features
Definition: example.h:67
bool need_memo_foreach_action(search_private &priv)
Definition: search.cc:370
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
void delete_v()
Definition: v_array.h:98
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
void add_example_namespaces_from_example(example &target, example &source)

◆ single_prediction_notLDF()

action Search::single_prediction_notLDF ( search_private priv,
example ec,
int  policy,
const action allowed_actions,
size_t  allowed_actions_cnt,
const float *  allowed_actions_cost,
float &  a_cost,
action  override_action 
)

Definition at line 1163 of file search.cc.

References Search::BaseTask::_foreach_action, Search::search_private::active_csoaa, Search::search_private::active_known, Search::search_private::active_uncertainty, Search::search_private::all, allowed_actions_to_ld(), LEARNER::as_singleline(), Search::search_private::base_learner, Search::search_private::cb_learner, cdbg, COST_SENSITIVE::wclass::class_index, COST_SENSITIVE::label::costs, polylabel::cs, cs_get_cost_index(), cs_get_cost_partial_prediction(), cs_get_costs_size(), Search::search_private::empty_cs_label, INIT_TEST, INIT_TRAIN, example::l, MULTILABEL::labels::label_v, Search::search_private::memo_foreach_action, Search::search_private::meta_t, Search::search_private::metaoverride, polyprediction::multiclass, polyprediction::multilabels, need_memo_foreach_action(), example::num_features, Search::search_private::num_features, example::partial_prediction, example::pred, LEARNER::learner< T, E >::predict(), vw::print_text, v_array< T >::push_back(), vw::raw_prediction, Search::search_private::rawOutputStringStream, Search::BaseTask::sch, Search::search_private::state, Search::search_private::subsample_timesteps, Search::search_private::t, example::tag, THROW, Search::search_private::total_predictions_made, and v_array_contains().

Referenced by search_predict().

1167 {
1168  vw& all = *priv.all;
1169  polylabel old_label = ec.l;
1170  bool need_partial_predictions = need_memo_foreach_action(priv) ||
1171  (priv.metaoverride && priv.metaoverride->_foreach_action) || (override_action != (action)-1) || priv.active_csoaa;
1172  if ((allowed_actions_cnt > 0) || need_partial_predictions)
1173  ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1174  else
1175  ec.l.cs = priv.empty_cs_label;
1176 
1177  cdbg << "allowed_actions_cnt=" << allowed_actions_cnt << ", ec.l = [";
1178  for (size_t i = 0; i < ec.l.cs.costs.size(); i++)
1179  cdbg << ' ' << ec.l.cs.costs[i].class_index << ':' << ec.l.cs.costs[i].x;
1180  cdbg << " ]" << endl;
1181 
1182  as_singleline(priv.base_learner)->predict(ec, policy);
1183 
1184  uint32_t act = ec.pred.multiclass;
1185  cdbg << "a=" << act << " from";
1186  if (allowed_actions)
1187  {
1188  for (size_t ii = 0; ii < allowed_actions_cnt; ii++) cdbg << ' ' << allowed_actions[ii];
1189  }
1190  cdbg << endl;
1191  a_cost = ec.partial_prediction;
1192  cdbg << "a_cost = " << a_cost << endl;
1193 
1194  if (override_action != (action)-1)
1195  act = override_action;
1196 
1197  if (need_partial_predictions)
1198  {
1199  size_t K = cs_get_costs_size(priv.cb_learner, ec.l);
1200  float min_cost = FLT_MAX;
1201  for (size_t k = 0; k < K; k++)
1202  {
1203  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1204  if (cost < min_cost)
1205  min_cost = cost;
1206  }
1207  v_array<action_cache>* this_cache = nullptr;
1208  if (need_memo_foreach_action(priv) && (override_action == (action)-1))
1209  {
1210  this_cache = new v_array<action_cache>();
1211  *this_cache = v_init<action_cache>();
1212  }
1213  for (size_t k = 0; k < K; k++)
1214  {
1215  action cl = cs_get_cost_index(priv.cb_learner, ec.l, k);
1216  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1217  if (priv.metaoverride && priv.metaoverride->_foreach_action)
1218  priv.metaoverride->_foreach_action(*priv.metaoverride->sch, priv.t - 1, min_cost, cl, cl == act, cost);
1219  if (override_action == cl)
1220  a_cost = cost;
1221  if (this_cache)
1222  this_cache->push_back(action_cache(min_cost, cl, cl == act, cost));
1223  }
1224  if (this_cache)
1225  {
1226  assert(priv.memo_foreach_action.size() == priv.meta_t + priv.t - 1);
1227  priv.memo_foreach_action.push_back(this_cache);
1228  cdbg << "memo_foreach_action[" << priv.meta_t + priv.t - 1 << "] = " << this_cache << endl;
1229  }
1230  }
1231 
1232  if ((priv.state == INIT_TRAIN) && (priv.subsample_timesteps <= -1)) // active learning
1233  {
1234  size_t K = cs_get_costs_size(priv.cb_learner, ec.l);
1235  float min_cost = FLT_MAX, min_cost2 = FLT_MAX;
1236  for (size_t k = 0; k < K; k++)
1237  {
1238  float cost = cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1239  if (cost < min_cost)
1240  {
1241  min_cost2 = min_cost;
1242  min_cost = cost;
1243  }
1244  else if (cost < min_cost2)
1245  {
1246  min_cost2 = cost;
1247  }
1248  }
1249  if (min_cost2 < FLT_MAX)
1250  priv.active_uncertainty.push_back(std::make_pair(min_cost2 - min_cost, priv.t + priv.meta_t));
1251  }
1252  if ((priv.state == INIT_TRAIN) && priv.active_csoaa)
1253  {
1254  if (priv.cb_learner)
1255  THROW("cannot use active_csoaa with cb learning");
1256  size_t cur_t = priv.t + priv.meta_t - 1;
1257  while (priv.active_known.size() <= cur_t)
1258  {
1259  priv.active_known.push_back(v_array<std::pair<CS::wclass&, bool>>());
1260  priv.active_known[priv.active_known.size() - 1] = v_init<std::pair<CS::wclass&, bool>>();
1261  cdbg << "active_known length now " << priv.active_known.size() << endl;
1262  }
1263  priv.active_known[cur_t].clear();
1264  assert(ec.l.cs.costs.size() > 0);
1265  for (size_t k = 0; k < ec.l.cs.costs.size(); k++)
1266  {
1267  /* priv.active_known[cur_t].push_back( ec.l.cs.costs[k].pred_is_certain
1268  ? ec.l.cs.costs[k].partial_prediction
1269  : FLT_MAX );
1270  cdbg << "active_known[" << cur_t << "][" << (priv.active_known[cur_t].size() -
1271  1) << "] = certain=" << ec.l.cs.costs[k].pred_is_certain << ", cost=" << ec.l.cs.costs[k].partial_prediction <<
1272  "}" << endl; */
1273  CS::wclass& wc = ec.l.cs.costs[k];
1274  // Get query_needed from pred
1275  bool query_needed = v_array_contains(ec.pred.multilabels.label_v, wc.class_index);
1276  std::pair<CS::wclass&, bool> p = {wc, query_needed};
1277  // Push into active_known[cur_t] with wc
1278  priv.active_known[cur_t].push_back(p);
1279  // cdbg << "active_known[" << cur_t << "][" << (priv.active_known[cur_t].size() - 1) << "] = " << wc.class_index
1280  // << ':' << wc.x << " pp=" << wc.partial_prediction << " query_needed=" << wc.query_needed << " max_pred=" <<
1281  // wc.max_pred << " min_pred=" << wc.min_pred << " is_range_overlapped=" << wc.is_range_overlapped << "
1282  // is_range_large=" << wc.is_range_large << endl;
1283  // query_needed=" << ec.l.cs.costs[k].query_needed << ", cost=" << ec.l.cs.costs[k].partial_prediction << "}" <<
1284  // endl;
1285  }
1286  }
1287 
1288  // generate raw predictions if necessary
1289  if ((priv.state == INIT_TEST) && (all.raw_prediction > 0))
1290  {
1291  priv.rawOutputStringStream->str("");
1292  for (size_t k = 0; k < cs_get_costs_size(priv.cb_learner, ec.l); k++)
1293  {
1294  if (k > 0)
1295  (*priv.rawOutputStringStream) << ' ';
1296  (*priv.rawOutputStringStream) << cs_get_cost_index(priv.cb_learner, ec.l, k) << ':'
1297  << cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k);
1298  }
1299  all.print_text(all.raw_prediction, priv.rawOutputStringStream->str(), ec.tag);
1300  }
1301 
1302  ec.l = old_label;
1303 
1304  priv.total_predictions_made++;
1305  priv.num_features += ec.num_features;
1306 
1307  return act;
1308 }
v_array< char > tag
Definition: example.h:63
#define cdbg
Definition: search.h:11
int raw_prediction
Definition: global_data.h:519
uint32_t multiclass
Definition: example.h:49
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint32_t action
Definition: search.h:19
float partial_prediction
Definition: example.h:68
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
size_t num_features
Definition: example.h:67
bool need_memo_foreach_action(search_private &priv)
Definition: search.cc:370
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
polylabel l
Definition: example.h:57
MULTILABEL::labels multilabels
Definition: example.h:50
v_array< uint32_t > label_v
Definition: multilabel.h:16
polyprediction pred
Definition: example.h:60
float cs_get_cost_partial_prediction(bool isCB, polylabel &ld, size_t k)
Definition: search.cc:894
bool v_array_contains(v_array< T > &A, T x)
Definition: v_array.h:237
v_array< wclass > costs
#define THROW(args)
Definition: vw_exception.h:181
polylabel & allowed_actions_to_ld(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
Definition: search.cc:937
size_t cs_get_costs_size(bool isCB, polylabel &ld)
Definition: search.cc:887
uint32_t cs_get_cost_index(bool isCB, polylabel &ld, size_t k)
Definition: search.cc:889

◆ size_equal()

bool Search::size_equal ( size_t  a,
size_t  b 
)

◆ string_equal()

bool Search::string_equal ( std::string  a,
std::string  b 
)

◆ to_short_string()

void Search::to_short_string ( std::string  in,
size_t  max_len,
char *  out 
)

Definition at line 499 of file search.cc.

Referenced by print_update().

500 {
501  for (size_t i = 0; i < max_len; i++)
502  out[i] = ((i >= in.length()) || (in[i] == '\n') || (in[i] == '\t')) ? ' ' : in[i];
503 
504  if (in.length() > max_len)
505  {
506  out[max_len - 2] = '.';
507  out[max_len - 1] = '.';
508  }
509  out[max_len] = 0;
510 }

◆ train_single_example()

template<bool is_learn>
void Search::train_single_example ( search sch,
bool  is_test_ex,
bool  is_holdout_ex,
multi_ex ec_seq 
)

Definition at line 2171 of file search.cc.

References Search::search_private::active_csoaa, Search::search_private::active_csoaa_verify, Search::search_private::active_known, Search::search_private::active_uncertainty, advance_from_known_actions(), Search::search_private::all, polylabel::cb, Search::search_private::cb_learner, cdbg, v_array< T >::clear(), clear_cache_hash_map(), clear_memo_foreach_action(), CB::label::costs, COST_SENSITIVE::label::costs, polylabel::cs, cs_cost_push_back(), COST_SENSITIVE::cs_label, Search::search_private::current_policy, label_parser::delete_label, Search::search_private::done_with_all_actions, Search::search_private::examples_dont_change, vw::final_prediction_sink, vw::final_regressor_name, Search::search_private::force_setup_ec_ref, generate_training_example(), get_training_timesteps(), INIT_TEST, INIT_TRAIN, Search::search_private::is_ldf, LEARN, Search::search_private::learn_a_idx, Search::search_private::learn_allowed_actions, Search::search_private::learn_ec_copy, Search::search_private::learn_ec_ref, Search::search_private::learn_ec_ref_cnt, Search::search_private::learn_loss, Search::search_private::learn_losses, Search::search_private::learn_t, Search::search_private::loss_declared_cnt, MULTICLASS::mc_label, Search::search_private::memo_foreach_action, Search::search_private::meta_t, Search::search_private::metatask, might_print_update(), must_run_test(), NO_ROLLOUT, Search::search_private::num_calls_to_run, Search::search_private::num_calls_to_run_previous, Search::search_private::num_features, Search::search_private::pred_string, vw::print_text, Search::search::priv, vw::raw_prediction, Search::search_private::read_example_last_pass, reset_search_structure(), Search::search_private::rollout_method, run_task(), Search::search_private::save_every_k_runs, save_predictor(), vw::sd, Search::search_private::should_produce_string, v_array< T >::size(), Search::search_private::state, Search::search_private::t, Search::search_private::T, Search::search_private::test_action_sequence, Search::search_private::test_loss, Search::search_private::timesteps, Search::search_private::train_trajectory, vw::training, shared_data::update(), verify_active_csoaa(), and vw::vw_is_main.

2172 {
2173  search_private& priv = *sch.priv;
2174  vw& all = *priv.all;
2175  bool ran_test = false; // we must keep track so that even if we skip test, we still update # of examples seen
2176 
2177  // if (! priv.no_caching)
2178  clear_cache_hash_map(priv);
2179 
2180  cdbg << "is_test_ex=" << is_test_ex << " vw_is_main=" << all.vw_is_main << endl;
2181  cdbg << "must_run_test = " << must_run_test(all, ec_seq, is_test_ex) << endl;
2182  // do an initial test pass to compute output (and loss)
2183  if (must_run_test(all, ec_seq, is_test_ex))
2184  {
2185  cdbg << "======================================== INIT TEST (" << priv.current_policy << ","
2186  << priv.read_example_last_pass << ") ========================================" << endl;
2187 
2188  ran_test = true;
2189 
2190  // do the prediction
2191  reset_search_structure(priv);
2192  priv.state = INIT_TEST;
2193  priv.should_produce_string =
2194  might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0);
2195  priv.pred_string->str("");
2196  priv.test_action_sequence.clear();
2197  run_task(sch, ec_seq);
2198 
2199  // accumulate loss
2200  if (!is_test_ex)
2201  all.sd->update(ec_seq[0]->test_only, !is_test_ex, priv.test_loss, 1.f, priv.num_features);
2202 
2203  // generate output
2204  for (int sink : all.final_prediction_sink) all.print_text((int)sink, priv.pred_string->str(), ec_seq[0]->tag);
2205 
2206  if (all.raw_prediction > 0)
2207  all.print_text(all.raw_prediction, "", ec_seq[0]->tag);
2208  }
2209 
2210  // if we're not training, then we're done!
2211  if ((!is_learn) || is_test_ex || is_holdout_ex || ec_seq[0]->test_only || (!priv.all->training))
2212  return;
2213 
2214  // SPEEDUP: if the oracle was never called, we can skip this!
2215 
2216  // do a pass over the data allowing oracle
2217  cdbg << "======================================== INIT TRAIN (" << priv.current_policy << ","
2218  << priv.read_example_last_pass << ") ========================================" << endl;
2219  // std::cerr << "training" << endl;
2220 
2221  clear_cache_hash_map(priv);
2222  reset_search_structure(priv);
2224  priv.state = INIT_TRAIN;
2225  priv.active_uncertainty.clear();
2226  priv.train_trajectory.clear(); // this is where we'll store the training sequence
2227  run_task(sch, ec_seq);
2228 
2229  if (!ran_test) // was && !priv.ec_seq[0]->test_only) { but we know it's not test_only
2230  all.sd->update(ec_seq[0]->test_only, true, priv.test_loss, 1.f, priv.num_features);
2231 
2232  // if there's nothing to train on, we're done!
2233  if ((priv.loss_declared_cnt == 0) || (priv.t + priv.meta_t == 0) ||
2234  (priv.rollout_method == NO_ROLLOUT)) // TODO: make sure NO_ROLLOUT works with beam!
2235  {
2236  return;
2237  }
2238 
2239  // otherwise, we have some learn'in to do!
2240  cdbg << "======================================== LEARN (" << priv.current_policy << ","
2241  << priv.read_example_last_pass << ") ========================================" << endl;
2242  priv.T = priv.metatask ? priv.meta_t : priv.t;
2243  get_training_timesteps(priv, priv.timesteps);
2244  cdbg << "train_trajectory.size() = " << priv.train_trajectory.size() << ":\t";
2245  cdbg_print_array<scored_action>("", priv.train_trajectory);
2246  // cdbg << "memo_foreach_action = " << priv.memo_foreach_action << endl;
2247  for (size_t i = 0; i < priv.memo_foreach_action.size(); i++)
2248  {
2249  cdbg << "memo_foreach_action[" << i << "] = ";
2250  if (priv.memo_foreach_action[i])
2251  cdbg << *priv.memo_foreach_action[i];
2252  else
2253  cdbg << "null";
2254  cdbg << endl;
2255  }
2256 
2257  if (priv.cb_learner)
2258  priv.learn_losses.cb.costs.clear();
2259  else
2260  priv.learn_losses.cs.costs.clear();
2261 
2262  for (size_t tid = 0; tid < priv.timesteps.size(); tid++)
2263  {
2264  cdbg << "timestep = " << priv.timesteps[tid] << " [" << tid << "/" << priv.timesteps.size() << "]" << endl;
2265 
2266  if (priv.metatask && !priv.memo_foreach_action[tid])
2267  {
2268  cdbg << "skipping because it looks like this was overridden by metatask" << endl;
2269  continue;
2270  }
2271 
2272  priv.learn_ec_ref = nullptr;
2273  priv.learn_ec_ref_cnt = 0;
2274 
2275  reset_search_structure(priv); // TODO remove this?
2276  bool skipped_all_actions = true;
2277  priv.learn_a_idx = 0;
2278  priv.done_with_all_actions = false;
2279  // for each action, roll out to get a loss
2280  while (!priv.done_with_all_actions)
2281  {
2282  priv.learn_t = priv.timesteps[tid];
2284  if (priv.done_with_all_actions)
2285  break;
2286 
2287  skipped_all_actions = false;
2288  reset_search_structure(priv);
2289 
2290  priv.state = LEARN;
2291  priv.learn_t = priv.timesteps[tid];
2292  cdbg << "-------------------------------------------------------------------------------------" << endl;
2293  cdbg << "learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl;
2294  // cdbg_print_array("priv.active_known[learn_t]", priv.active_known[priv.learn_t]);
2295  run_task(sch, ec_seq);
2296  // cerr_print_array("in GENER, learn_allowed_actions", priv.learn_allowed_actions);
2297  float this_loss = priv.learn_loss;
2298  cs_cost_push_back(priv.cb_learner, priv.learn_losses,
2299  priv.is_ldf ? (uint32_t)(priv.learn_a_idx - 1) : (uint32_t)priv.learn_a_idx, this_loss);
2300  // (priv.learn_allowed_actions.size() > 0) ?
2301  // priv.learn_allowed_actions[priv.learn_a_idx-1] : priv.is_ldf ? (priv.learn_a_idx-1) :
2302  // (priv.learn_a_idx),
2303  // priv.learn_loss);
2304  }
2305  if (priv.active_csoaa_verify > 0.)
2307  priv.learn_losses.cs, priv.active_known[priv.learn_t], ec_seq[0]->example_counter, priv.active_csoaa_verify);
2308 
2309  if (skipped_all_actions)
2310  {
2311  reset_search_structure(priv);
2312  priv.state = LEARN;
2313  priv.learn_t = priv.timesteps[tid];
2314  priv.force_setup_ec_ref = true;
2315  cdbg << "<<<<<" << endl;
2316  cdbg << "skipped all actions; learn_t = " << priv.learn_t << ", learn_a_idx = " << priv.learn_a_idx << endl;
2317  run_task(sch, ec_seq); // TODO: i guess we can break out of this early
2318  cdbg << ">>>>>" << endl;
2319  }
2320  else
2321  cdbg << "didn't skip all actions" << endl;
2322 
2323  // now we can make a training example
2324  if (priv.learn_allowed_actions.size() > 0)
2325  {
2326  for (size_t i = 0; i < priv.learn_allowed_actions.size(); i++)
2327  {
2328  priv.learn_losses.cs.costs[i].class_index = priv.learn_allowed_actions[i];
2329  }
2330  }
2331  // float min_loss = 0.;
2332  // if (priv.metatask)
2333  // for (size_t aid=0; aid<priv.memo_foreach_action[tid]->size(); aid++)
2334  // min_loss = std::min(min_loss, priv.memo_foreach_action[tid]->get(aid).cost);
2335  cdbg << "priv.learn_losses = [";
2336  for (auto& wc : priv.learn_losses.cs.costs) cdbg << " " << wc.class_index << ":" << wc.x;
2337  cdbg << " ]" << endl;
2338  cdbg << "gte" << endl;
2339  generate_training_example(priv, priv.learn_losses, 1., true); // , min_loss); // TODO: weight
2340  if (!priv.examples_dont_change)
2341  for (size_t n = 0; n < priv.learn_ec_copy.size(); n++)
2342  {
2343  if (sch.priv->is_ldf)
2344  CS::cs_label.delete_label(&priv.learn_ec_copy[n].l.cs);
2345  else
2346  MC::mc_label.delete_label(&priv.learn_ec_copy[n].l.multi);
2347  }
2348  if (priv.cb_learner)
2349  priv.learn_losses.cb.costs.clear();
2350  else
2351  priv.learn_losses.cs.costs.clear();
2352  }
2353 
2354  if (priv.active_csoaa && (priv.save_every_k_runs > 1))
2355  {
2356  size_t prev_num = priv.num_calls_to_run_previous / priv.save_every_k_runs;
2357  size_t this_num = priv.num_calls_to_run / priv.save_every_k_runs;
2358  if (this_num > prev_num)
2359  save_predictor(all, all.final_regressor_name, this_num);
2360  priv.num_calls_to_run_previous = priv.num_calls_to_run;
2361  }
2362 }
#define cdbg
Definition: search.h:11
int raw_prediction
Definition: global_data.h:519
bool must_run_test(vw &all, multi_ex &ec, bool is_test_ex)
Definition: search.cc:473
label_parser cs_label
void(* delete_label)(void *)
Definition: label_parser.h:16
bool might_print_update(vw &all)
Definition: search.cc:461
v_array< int > final_prediction_sink
Definition: global_data.h:518
void run_task(search &sch, multi_ex &ec)
Definition: search.cc:2098
void generate_training_example(search_private &priv, polylabel &losses, float weight, bool add_conditioning=true, float min_loss=FLT_MAX)
Definition: search.cc:1490
void advance_from_known_actions(search_private &priv)
Definition: search.cc:2133
size_t size() const
Definition: v_array.h:68
void cs_cost_push_back(bool isCB, polylabel &ld, uint32_t index, float value)
Definition: search.cc:923
void save_predictor(vw &all, std::string reg_name, size_t current_pass)
label_parser mc_label
Definition: multiclass.cc:93
void clear_cache_hash_map(search_private &priv)
Definition: search.cc:278
void clear_memo_foreach_action(search_private &priv)
Definition: search.cc:284
void reset_search_structure(search_private &priv)
Definition: search.cc:690
shared_data * sd
Definition: global_data.h:375
bool vw_is_main
Definition: global_data.h:421
void(* print_text)(int, std::string, v_array< char >)
Definition: global_data.h:522
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
void verify_active_csoaa(COST_SENSITIVE::label &losses, v_array< std::pair< CS::wclass &, bool >> &known, size_t t, float multiplier)
Definition: search.cc:2108
std::string final_regressor_name
Definition: global_data.h:535
void get_training_timesteps(search_private &priv, v_array< size_t > &timesteps)
Definition: search.cc:1983

◆ uint32_equal()

bool Search::uint32_equal ( uint32_t  a,
uint32_t  b 
)

◆ verify_active_csoaa()

void Search::verify_active_csoaa ( COST_SENSITIVE::label losses,
v_array< std::pair< CS::wclass &, bool >> &  known,
size_t  t,
float  multiplier 
)

Definition at line 2108 of file search.cc.

References cdbg, COST_SENSITIVE::wclass::class_index, COST_SENSITIVE::label::costs, and COST_SENSITIVE::wclass::x.

Referenced by train_single_example().

2110 {
2111  float threshold = multiplier / std::sqrt((float)t);
2112  cdbg << "verify_active_csoaa, losses = [";
2113  for (COST_SENSITIVE::wclass& wc : losses.costs) cdbg << " " << wc.class_index << ":" << wc.x;
2114  cdbg << " ]" << endl;
2115  // cdbg_print_array("verify_active_csoaa, known", known);
2116  size_t i = 0;
2117  for (COST_SENSITIVE::wclass& wc : losses.costs)
2118  {
2119  if (!known[i].second)
2120  {
2121  float err = pow(known[i].first.partial_prediction - wc.x, 2);
2122  if (err > threshold)
2123  {
2124  std::cerr << "verify_active_csoaa failed: truth " << wc.class_index << ":" << wc.x << ", known[" << i
2125  << "]=" << known[i].first.partial_prediction << ", error=" << err << " vs threshold " << threshold
2126  << endl;
2127  }
2128  }
2129  i++;
2130  }
2131 }
#define cdbg
Definition: search.h:11
v_array< wclass > costs

Variable Documentation

◆ ACTION_COSTS

uint32_t Search::ACTION_COSTS = 32

◆ all_metatasks

search_metatask* Search::all_metatasks[]
Initial value:
= {
Search::search_metatask metatask
Definition: search_meta.cc:18
Search::search_metatask metatask
Definition: search_meta.cc:50

Definition at line 39 of file search.cc.

◆ all_tasks

search_task* Search::all_tasks[]
Initial value:

Definition at line 35 of file search.cc.

◆ AUTO_CONDITION_FEATURES

uint32_t Search::AUTO_CONDITION_FEATURES = 1

◆ AUTO_HAMMING_LOSS

uint32_t Search::AUTO_HAMMING_LOSS = 2

◆ conditional_constant

uint64_t Search::conditional_constant = 8290743

Definition at line 368 of file search.cc.

◆ EXAMPLES_DONT_CHANGE

uint32_t Search::EXAMPLES_DONT_CHANGE = 4

◆ IS_LDF

uint32_t Search::IS_LDF = 8

◆ NO_CACHING

uint32_t Search::NO_CACHING = 16

Definition at line 49 of file search.cc.

Referenced by DepParserTask::initialize(), and Search::search::set_options().

◆ PRINT_CLOCK_TIME

constexpr bool Search::PRINT_CLOCK_TIME = false

Definition at line 44 of file search.cc.

◆ PRINT_UPDATE_EVERY_EXAMPLE

constexpr bool Search::PRINT_UPDATE_EVERY_EXAMPLE = false

Definition at line 42 of file search.cc.

◆ PRINT_UPDATE_EVERY_PASS

constexpr bool Search::PRINT_UPDATE_EVERY_PASS = false

Definition at line 43 of file search.cc.