Vowpal Wabbit
Functions | Variables
SequenceTaskCostToGo Namespace Reference

Functions

void initialize (Search::search &sch, size_t &num_actions, options_i &)
 
void run (Search::search &sch, multi_ex &ec)
 

Variables

Search::search_task task = {"sequence_ctg", run, initialize, nullptr, nullptr, nullptr}
 

Function Documentation

◆ initialize()

void SequenceTaskCostToGo::initialize ( Search::search sch,
size_t &  num_actions,
options_i  
)

Definition at line 274 of file search_sequencetask.cc.

References Search::ACTION_COSTS, Search::AUTO_CONDITION_FEATURES, Search::AUTO_HAMMING_LOSS, Search::EXAMPLES_DONT_CHANGE, Search::search::set_options(), and Search::search::set_task_data().

275 {
276  sch.set_options(Search::AUTO_CONDITION_FEATURES | // automatically add history features to our examples, please
277  Search::AUTO_HAMMING_LOSS | // please just use hamming loss on individual predictions -- we won't declare loss
278  Search::EXAMPLES_DONT_CHANGE | // we don't do any internal example munging
279  Search::ACTION_COSTS | // we'll provide cost-per-action (rather than oracle)
280  0);
281  sch.set_task_data<size_t>(&num_actions);
282 }
uint32_t ACTION_COSTS
Definition: search.cc:50
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
void set_options(uint32_t opts)
Definition: search.cc:3053
uint32_t AUTO_HAMMING_LOSS
Definition: search.cc:49
void set_task_data(T *data)
Definition: search.h:84
uint32_t EXAMPLES_DONT_CHANGE
Definition: search.cc:49

◆ run()

void SequenceTaskCostToGo::run ( Search::search sch,
multi_ex ec 
)

Definition at line 284 of file search_sequencetask.cc.

References Search::search::get_history_length(), Search::search::get_task_data(), Search::search::output(), predict(), and Search::search::pretty_label().

285 {
286  size_t K = *sch.get_task_data<size_t>();
287  float* costs = calloc_or_throw<float>(K);
288  Search::predictor P(sch, (ptag)0);
289  for (size_t i = 0; i < ec.size(); i++)
290  {
291  action oracle = ec[i]->l.multi.label;
292  for (size_t k = 0; k < K; k++) costs[k] = 1.;
293  costs[oracle - 1] = 0.;
294  size_t prediction = P.set_tag((ptag)i + 1)
295  .set_input(*ec[i])
296  .set_allowed(nullptr, costs, K)
297  .set_condition_range((ptag)i, sch.get_history_length(), 'p')
298  .predict();
299  if (sch.output().good())
300  sch.output() << sch.pretty_label((uint32_t)prediction) << ' ';
301  }
302  free(costs);
303 }
uint32_t get_history_length()
Definition: search.cc:3098
std::stringstream & output()
Definition: search.cc:3043
std::string pretty_label(action a)
Definition: search.cc:3100
uint32_t action
Definition: search.h:19
T * get_task_data()
Definition: search.h:89
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
uint32_t ptag
Definition: search.h:20

Variable Documentation

◆ task

Search::search_task SequenceTaskCostToGo::task = {"sequence_ctg", run, initialize, nullptr, nullptr, nullptr}

Definition at line 21 of file search_sequencetask.cc.