Vowpal Wabbit
Public Member Functions | Public Attributes | List of all members
Search::search Struct Reference

#include <search.h>

Public Member Functions

template<class T >
void set_task_data (T *data)
 
template<class T >
T * get_task_data ()
 
template<class T >
void set_metatask_data (T *data)
 
template<class T >
T * get_metatask_data ()
 
void set_options (uint32_t opts)
 
void set_label_parser (label_parser &lp, bool(*is_test)(polylabel &))
 
void loss (float incr_loss)
 
action predict (example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
 
action predictLDF (example *ecs, size_t ec_cnt, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, size_t learner_id=0, float weight=0.)
 
bool predictNeedsExample ()
 
uint32_t get_history_length ()
 
bool is_ldf ()
 
std::stringstream & output ()
 
void set_num_learners (size_t num_learners)
 
void get_test_action_sequence (std::vector< action > &)
 
uint64_t get_mask ()
 
size_t get_stride_shift ()
 
std::string pretty_label (action a)
 
BaseTask base_task (multi_ex &ec)
 
vwget_vw_pointer_unsafe ()
 
void set_force_oracle (bool force)
 
 search ()
 
 ~search ()
 

Public Attributes

search_privatepriv
 
void * task_data
 
void * metatask_data
 
const char * task_name
 
const char * metatask_name
 

Detailed Description

Definition at line 80 of file search.h.

Constructor & Destructor Documentation

◆ search()

Search::search::search ( )

Definition at line 295 of file search.cc.

295 { priv = &calloc_or_throw<search_private>(); }
search_private * priv
Definition: search.h:216

◆ ~search()

Search::search::~search ( )

Definition at line 297 of file search.cc.

References Search::search_private::_random_state, Search::search_private::active_known, Search::search_private::active_uncertainty, Search::search_private::allowed_actions_cache, Search::audit_feature_space(), Search::search_private::bad_string_stream, Search::search_private::cache_hash_map, polylabel::cb, Search::search_private::cb_learner, cdbg, Search::clear_cache_hash_map(), Search::clear_memo_foreach_action(), Search::search_private::condition_on_actions, CB::label::costs, COST_SENSITIVE::label::costs, polylabel::cs, COST_SENSITIVE::cs_label, Search::search_private::dat_new_feature_audit_ss, VW::dealloc_example(), label_parser::delete_label, CB::delete_label(), v_array< T >::delete_v(), features::delete_v(), Search::search_private::examples_dont_change, Search::search_private::gte_label, Search::search_private::is_ldf, Search::search_private::last_action_repr, Search::search_private::ldf_test_label, Search::search_private::learn_allowed_actions, Search::search_private::learn_condition_on, Search::search_private::learn_condition_on_act, Search::search_private::learn_condition_on_names, Search::search_private::learn_ec_copy, Search::search_private::learn_losses, MULTICLASS::mc_label, Search::search_private::memo_foreach_action, Search::search_private::neighbor_features, Search::search_private::pred_string, Search::search_private::ptag_to_action, Search::search_private::rawOutputString, Search::search_private::rawOutputStringStream, Search::action_repr::repr, Search::search_private::test_action_sequence, Search::search_private::timesteps, Search::search_private::train_trajectory, and Search::search_private::truth_string.

298 {
299  if (this->priv && this->priv->all)
300  {
301  search_private& priv = *this->priv;
302  clear_cache_hash_map(priv);
303 
304  priv._random_state.~shared_ptr<rand_state>();
305  delete priv.truth_string;
306  delete priv.pred_string;
307  delete priv.bad_string_stream;
308  priv.cache_hash_map.~v_hashmap<unsigned char*, scored_action>();
309  priv.rawOutputString.~basic_string();
310  priv.test_action_sequence.~vector<action>();
311  priv.dat_new_feature_audit_ss.~basic_stringstream();
312  priv.neighbor_features.delete_v();
313  priv.timesteps.delete_v();
314  if (priv.cb_learner)
315  priv.learn_losses.cb.costs.delete_v();
316  else
317  priv.learn_losses.cs.costs.delete_v();
318  if (priv.cb_learner)
319  priv.gte_label.cb.costs.delete_v();
320  else
321  priv.gte_label.cs.costs.delete_v();
322 
323  priv.condition_on_actions.delete_v();
324  priv.learn_allowed_actions.delete_v();
325  priv.ldf_test_label.costs.delete_v();
326  priv.last_action_repr.delete_v();
327  priv.active_uncertainty.delete_v();
328  for (size_t i = 0; i < priv.active_known.size(); i++) priv.active_known[i].delete_v();
329  priv.active_known.delete_v();
330 
331  if (priv.cb_learner)
332  priv.allowed_actions_cache->cb.costs.delete_v();
333  else
334  priv.allowed_actions_cache->cs.costs.delete_v();
335 
336  priv.train_trajectory.delete_v();
337  for (Search::action_repr& ar : priv.ptag_to_action)
338  {
339  if (ar.repr != nullptr)
340  {
341  ar.repr->delete_v();
342  delete ar.repr;
343  cdbg << "delete_v" << endl;
344  }
345  }
346  priv.ptag_to_action.delete_v();
348  priv.memo_foreach_action.delete_v();
349 
350  // destroy copied examples if we needed them
351  if (!priv.examples_dont_change)
352  {
353  void (*delete_label)(void*) = priv.is_ldf ? CS::cs_label.delete_label : MC::mc_label.delete_label;
354  for (example& ec : priv.learn_ec_copy) VW::dealloc_example(delete_label, ec);
355  priv.learn_ec_copy.delete_v();
356  }
357  priv.learn_condition_on_names.delete_v();
358  priv.learn_condition_on.delete_v();
359  priv.learn_condition_on_act.delete_v();
360 
361  free(priv.allowed_actions_cache);
362  delete priv.rawOutputStringStream;
363  }
364  free(this->priv);
365 }
#define cdbg
Definition: search.h:11
label_parser cs_label
void(* delete_label)(void *)
Definition: label_parser.h:16
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
void delete_v()
search_private * priv
Definition: search.h:216
void delete_label(void *v)
Definition: cb.cc:98
uint32_t action
Definition: search.h:19
label_parser mc_label
Definition: multiclass.cc:93
void clear_cache_hash_map(search_private &priv)
Definition: search.cc:278
void clear_memo_foreach_action(search_private &priv)
Definition: search.cc:284
features * repr
Definition: search.cc:106

Member Function Documentation

◆ base_task()

BaseTask Search::search::base_task ( multi_ex ec)
inline

Definition at line 213 of file search.h.

References Search::BaseTask::BaseTask().

Referenced by DebugMT::run(), and SelectiveBranchingMT::run().

213 { return BaseTask(this, ec); }

◆ get_history_length()

uint32_t Search::search::get_history_length ( )

◆ get_mask()

uint64_t Search::search::get_mask ( )

Definition at line 3096 of file search.cc.

References Search::search_private::all, parameters::mask(), and vw::weights.

Referenced by DepParserTask::extract_features().

3096 { return this->priv->all->weights.mask(); }
parameters weights
Definition: global_data.h:537
search_private * priv
Definition: search.h:216
uint64_t mask()

◆ get_metatask_data()

template<class T >
T* Search::search::get_metatask_data ( )
inline

Definition at line 101 of file search.h.

References a, Search::BaseTask::ec, loss(), and predict().

Referenced by SelectiveBranchingMT::finish(), and SelectiveBranchingMT::run().

102  {
103  return (T*)metatask_data;
104  }
void * metatask_data
Definition: search.h:218

◆ get_stride_shift()

size_t Search::search::get_stride_shift ( )

Definition at line 3097 of file search.cc.

References Search::search_private::all, parameters::stride_shift(), and vw::weights.

Referenced by SequenceTask_DemoLDF::my_update_example_indicies().

3097 { return this->priv->all->weights.stride_shift(); }
parameters weights
Definition: global_data.h:537
search_private * priv
Definition: search.h:216
uint32_t stride_shift()

◆ get_task_data()

template<class T >
T* Search::search::get_task_data ( )
inline

◆ get_test_action_sequence()

void Search::search::get_test_action_sequence ( std::vector< action > &  V)

Definition at line 3088 of file search.cc.

References Search::search_private::test_action_sequence.

3089 {
3090  V.clear();
3091  for (size_t i = 0; i < this->priv->test_action_sequence.size(); i++) V.push_back(this->priv->test_action_sequence[i]);
3092 }
search_private * priv
Definition: search.h:216
std::vector< action > test_action_sequence
Definition: search.cc:179

◆ get_vw_pointer_unsafe()

vw & Search::search::get_vw_pointer_unsafe ( )

◆ is_ldf()

bool Search::search::is_ldf ( )

Definition at line 2965 of file search.cc.

References Search::search_private::is_ldf.

2965 { return priv->is_ldf; }
search_private * priv
Definition: search.h:216

◆ loss()

void Search::search::loss ( float  incr_loss)

◆ output()

std::stringstream & Search::search::output ( )

◆ predict()

action Search::search::predict ( example ec,
ptag  my_tag,
const action oracle_actions,
size_t  oracle_actions_cnt = 1,
const ptag condition_on = nullptr,
const char *  condition_on_names = nullptr,
const action allowed_actions = nullptr,
size_t  allowed_actions_cnt = 0,
const float *  allowed_actions_cost = nullptr,
size_t  learner_id = 0,
float  weight = 0. 
)

Definition at line 2967 of file search.cc.

References a, Search::search_private::acset, Search::action_cost_loss(), Search::action_hamming_loss(), Search::search_private::auto_hamming_loss, cdbg, Search::INIT_TEST, Search::search_private::last_action_repr, loss(), Search::search_private::ptag_to_action, Search::push_at(), Search::search_predict(), Search::search_private::state, Search::search_private::test_action_sequence, Search::search_private::use_action_costs, and Search::auto_condition_settings::use_passthrough_repr.

Referenced by Search::predictor::predict(), MulticlassTask::run(), and ArgmaxTask::run().

2970 {
2971  float a_cost = 0.;
2972  action a = search_predict(*priv, &ec, 1, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names,
2973  allowed_actions, allowed_actions_cnt, allowed_actions_cost, learner_id, a_cost, weight);
2974  if (priv->state == INIT_TEST)
2975  priv->test_action_sequence.push_back(a);
2976  if (mytag != 0)
2977  {
2978  if (mytag < priv->ptag_to_action.size())
2979  {
2980  cdbg << "delete_v at " << mytag << endl;
2981  if (priv->ptag_to_action[mytag].repr != nullptr)
2982  {
2983  priv->ptag_to_action[mytag].repr->delete_v();
2984  delete priv->ptag_to_action[mytag].repr;
2985  }
2986  }
2988  {
2989  assert((mytag >= priv->ptag_to_action.size()) || (priv->ptag_to_action[mytag].repr == nullptr));
2990  push_at(priv->ptag_to_action, action_repr(a, &(priv->last_action_repr)), mytag);
2991  }
2992  else
2993  push_at(priv->ptag_to_action, action_repr(a, (features*)nullptr), mytag);
2994  cdbg << "push_at " << mytag << endl;
2995  }
2996  if (priv->auto_hamming_loss)
2997  loss(priv->use_action_costs ? action_cost_loss(a, allowed_actions, allowed_actions_cost, allowed_actions_cnt)
2998  : action_hamming_loss(a, oracle_actions, oracle_actions_cnt));
2999  cdbg << "predict returning " << a << endl;
3000  return a;
3001 }
#define cdbg
Definition: search.h:11
auto_condition_settings acset
Definition: search.cc:152
float action_cost_loss(action a, const action *act, const float *costs, size_t sz)
Definition: search.cc:2954
void push_at(v_array< T > &v, T item, size_t pos)
Definition: search.cc:1074
SearchState state
Definition: search.cc:158
the core definition of a set of features.
search_private * priv
Definition: search.h:216
uint32_t action
Definition: search.h:19
v_array< action_repr > ptag_to_action
Definition: search.cc:178
action search_predict(search_private &priv, example *ecs, size_t ec_cnt, ptag mytag, const action *oracle_actions, size_t oracle_actions_cnt, const ptag *condition_on, const char *condition_on_names, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, size_t learner_id, float &a_cost, float)
Definition: search.cc:1652
std::vector< action > test_action_sequence
Definition: search.cc:179
float action_hamming_loss(action a, const action *A, size_t sz)
Definition: search.cc:2944
float weight
constexpr uint64_t a
Definition: rand48.cc:11
features last_action_repr
Definition: search.cc:182
void loss(float incr_loss)
Definition: search.cc:3039

◆ predictLDF()

action Search::search::predictLDF ( example ecs,
size_t  ec_cnt,
ptag  my_tag,
const action oracle_actions,
size_t  oracle_actions_cnt = 1,
const ptag condition_on = nullptr,
const char *  condition_on_names = nullptr,
size_t  learner_id = 0,
float  weight = 0. 
)

Definition at line 3003 of file search.cc.

References a, Search::action_hamming_loss(), Search::search_private::auto_hamming_loss, cdbg, COST_SENSITIVE::ec_is_example_header(), Search::INIT_TEST, Search::search_private::last_action_repr, loss(), Search::search_private::ptag_to_action, Search::push_at(), Search::search_predict(), Search::search_private::state, and Search::search_private::test_action_sequence.

Referenced by Search::predictor::predict().

3006 {
3007  float a_cost = 0.;
3008  // TODO: action costs for ldf
3009  action a = search_predict(*priv, ecs, ec_cnt, mytag, oracle_actions, oracle_actions_cnt, condition_on,
3010  condition_on_names, nullptr, 0, nullptr, learner_id, a_cost, weight);
3011  if (priv->state == INIT_TEST)
3012  priv->test_action_sequence.push_back(a);
3013 
3014  // If there is a shared example (example header), then action "1" is at index 1, but otherwise
3015  // action "1" is at index 0. Map action to its appropriate index. In particular, this fixes an
3016  // issue where the predicted action is the last, and there is no example header, causing an index
3017  // beyond the end of the array (usually resulting in a segfault at some point.)
3018  size_t action_index = a - COST_SENSITIVE::ec_is_example_header(ecs[0]) ? 0 : 1;
3019 
3020  if ((mytag != 0) && ecs[action_index].l.cs.costs.size() > 0)
3021  {
3022  if (mytag < priv->ptag_to_action.size())
3023  {
3024  cdbg << "delete_v at " << mytag << endl;
3025  if (priv->ptag_to_action[mytag].repr != nullptr)
3026  {
3027  priv->ptag_to_action[mytag].repr->delete_v();
3028  delete priv->ptag_to_action[mytag].repr;
3029  }
3030  }
3031  push_at(priv->ptag_to_action, action_repr(ecs[a].l.cs.costs[0].class_index, &(priv->last_action_repr)), mytag);
3032  }
3033  if (priv->auto_hamming_loss)
3034  loss(action_hamming_loss(a, oracle_actions, oracle_actions_cnt)); // TODO: action costs
3035  cdbg << "predict returning " << a << endl;
3036  return a;
3037 }
#define cdbg
Definition: search.h:11
void push_at(v_array< T > &v, T item, size_t pos)
Definition: search.cc:1074
SearchState state
Definition: search.cc:158
search_private * priv
Definition: search.h:216
uint32_t action
Definition: search.h:19
bool ec_is_example_header(example const &ec)
v_array< action_repr > ptag_to_action
Definition: search.cc:178
action search_predict(search_private &priv, example *ecs, size_t ec_cnt, ptag mytag, const action *oracle_actions, size_t oracle_actions_cnt, const ptag *condition_on, const char *condition_on_names, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, size_t learner_id, float &a_cost, float)
Definition: search.cc:1652
std::vector< action > test_action_sequence
Definition: search.cc:179
float action_hamming_loss(action a, const action *A, size_t sz)
Definition: search.cc:2944
float weight
constexpr uint64_t a
Definition: rand48.cc:11
features last_action_repr
Definition: search.cc:182
void loss(float incr_loss)
Definition: search.cc:3039

◆ predictNeedsExample()

bool Search::search::predictNeedsExample ( )

Definition at line 3041 of file search.cc.

References Search::search_predictNeedsExample().

Referenced by GraphTask::run(), SequenceTask_DemoLDF::run(), and DepParserTask::run().

3041 { return search_predictNeedsExample(*this->priv); }
search_private * priv
Definition: search.h:216
bool search_predictNeedsExample(search_private &priv)
Definition: search.cc:1601

◆ pretty_label()

std::string Search::search::pretty_label ( action  a)

Definition at line 3100 of file search.cc.

References a, Search::search_private::all, substring::begin, substring::end, namedlabels::get(), shared_data::ldict, and vw::sd.

Referenced by SequenceTask::run(), and SequenceTaskCostToGo::run().

3101 {
3102  if (this->priv->all->sd->ldict)
3103  {
3104  substring ss = this->priv->all->sd->ldict->get(a);
3105  return std::string(ss.begin, ss.end - ss.begin);
3106  }
3107  else
3108  {
3109  std::ostringstream os;
3110  os << a;
3111  return os.str();
3112  }
3113 }
char * end
Definition: hashstring.h:10
char * begin
Definition: hashstring.h:9
namedlabels * ldict
Definition: global_data.h:153
search_private * priv
Definition: search.h:216
shared_data * sd
Definition: global_data.h:375
constexpr uint64_t a
Definition: rand48.cc:11
uint64_t get(substring &s)
Definition: global_data.h:108

◆ set_force_oracle()

void Search::search::set_force_oracle ( bool  force)

Definition at line 3116 of file search.cc.

References Search::search_private::force_oracle.

3116 { this->priv->force_oracle = force; }
search_private * priv
Definition: search.h:216

◆ set_label_parser()

void Search::search::set_label_parser ( label_parser lp,
bool(*)(polylabel &)  is_test 
)

Definition at line 3079 of file search.cc.

References Search::search_private::all, Search::INITIALIZE, Search::search_private::label_is_test, parser::lp, vw::p, label_parser::test_label, and vw::vw_is_main.

Referenced by DepParserTask::initialize(), and GraphTask::initialize().

3080 {
3081  if (this->priv->all->vw_is_main && (this->priv->state != INITIALIZE))
3082  std::cerr << "warning: task should not set label parser except in initialize function!" << endl;
3083  this->priv->all->p->lp = lp;
3084  this->priv->all->p->lp.test_label = (bool (*)(void*))is_test;
3085  this->priv->label_is_test = is_test;
3086 }
bool(* test_label)(void *)
Definition: label_parser.h:22
search_private * priv
Definition: search.h:216
parser * p
Definition: global_data.h:377
bool(* label_is_test)(polylabel &)
Definition: search.cc:167
bool vw_is_main
Definition: global_data.h:421
label_parser lp
Definition: parser.h:102

◆ set_metatask_data()

template<class T >
void Search::search::set_metatask_data ( T *  data)
inline

Definition at line 96 of file search.h.

Referenced by SelectiveBranchingMT::initialize().

97  {
98  metatask_data = data;
99  }
void * metatask_data
Definition: search.h:218

◆ set_num_learners()

void Search::search::set_num_learners ( size_t  num_learners)

◆ set_options()

void Search::search::set_options ( uint32_t  opts)

Definition at line 3053 of file search.cc.

References Search::ACTION_COSTS, Search::search_private::all, Search::search_private::auto_condition_features, Search::AUTO_HAMMING_LOSS, Search::search_private::auto_hamming_loss, Search::EXAMPLES_DONT_CHANGE, Search::search_private::examples_dont_change, Search::INITIALIZE, Search::IS_LDF, Search::search_private::is_ldf, Search::NO_CACHING, Search::search_private::no_caching, Search::NO_ROLLOUT, THROW, Search::search_private::use_action_costs, and vw::vw_is_main.

Referenced by MulticlassTask::initialize(), SequenceTask::initialize(), EntityRelationTask::initialize(), DepParserTask::initialize(), GraphTask::initialize(), SequenceSpanTask::initialize(), SequenceTaskCostToGo::initialize(), ArgmaxTask::initialize(), and SequenceTask_DemoLDF::initialize().

3054 {
3055  if (this->priv->all->vw_is_main && (this->priv->state != INITIALIZE))
3056  std::cerr << "warning: task should not set options except in initialize function!" << endl;
3057  if ((opts & AUTO_CONDITION_FEATURES) != 0)
3058  this->priv->auto_condition_features = true;
3059  if ((opts & AUTO_HAMMING_LOSS) != 0)
3060  this->priv->auto_hamming_loss = true;
3061  if ((opts & EXAMPLES_DONT_CHANGE) != 0)
3062  this->priv->examples_dont_change = true;
3063  if ((opts & IS_LDF) != 0)
3064  this->priv->is_ldf = true;
3065  if ((opts & NO_CACHING) != 0)
3066  this->priv->no_caching = true;
3067  if ((opts & ACTION_COSTS) != 0)
3068  this->priv->use_action_costs = true;
3069 
3070  if (this->priv->is_ldf && this->priv->use_action_costs)
3071  THROW("using LDF and actions costs is not yet implemented; turn off action costs"); // TODO fix
3072 
3073  if (this->priv->use_action_costs && (this->priv->rollout_method != NO_ROLLOUT))
3074  std::cerr
3075  << "warning: task is designed to use rollout costs, but this only works when --search_rollout none is specified"
3076  << endl;
3077 }
search_private * priv
Definition: search.h:216
bool auto_condition_features
Definition: search.cc:145
uint32_t ACTION_COSTS
Definition: search.cc:50
uint32_t NO_CACHING
Definition: search.cc:49
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
bool vw_is_main
Definition: global_data.h:421
uint32_t IS_LDF
Definition: search.cc:49
uint32_t AUTO_HAMMING_LOSS
Definition: search.cc:49
uint32_t EXAMPLES_DONT_CHANGE
Definition: search.cc:49
#define THROW(args)
Definition: vw_exception.h:181

◆ set_task_data()

template<class T >
void Search::search::set_task_data ( T *  data)
inline

Member Data Documentation

◆ metatask_data

void* Search::search::metatask_data

Definition at line 218 of file search.h.

◆ metatask_name

const char* Search::search::metatask_name

Definition at line 220 of file search.h.

◆ priv

search_private* Search::search::priv

◆ task_data

void* Search::search::task_data

Definition at line 217 of file search.h.

Referenced by Search::search_initialize().

◆ task_name

const char* Search::search::task_name

Definition at line 219 of file search.h.


The documentation for this struct was generated from the following files: