Vowpal Wabbit
Classes | Functions | Variables
ArgmaxTask Namespace Reference

Classes

struct  task_data
 

Functions

void initialize (Search::search &sch, size_t &, options_i &options)
 
void finish (Search::search &sch)
 
void run (Search::search &sch, multi_ex &ec)
 

Variables

Search::search_task task = {"argmax", run, initialize, finish, nullptr, nullptr}
 

Function Documentation

◆ finish()

void ArgmaxTask::finish ( Search::search sch)

Definition at line 336 of file search_sequencetask.cc.

References Search::search::get_task_data().

337 {
338  task_data* D = sch.get_task_data<task_data>();
339  delete D;
340 }
T * get_task_data()
Definition: search.h:89

◆ initialize()

void ArgmaxTask::initialize ( Search::search sch,
size_t &  ,
options_i options 
)

Definition at line 315 of file search_sequencetask.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), Search::AUTO_CONDITION_FEATURES, Search::EXAMPLES_DONT_CHANGE, f, ArgmaxTask::task_data::false_negative_cost, VW::config::make_option(), ArgmaxTask::task_data::negative_weight, ArgmaxTask::task_data::predict_max, Search::search::set_options(), and Search::search::set_task_data().

316 {
317  task_data* D = new task_data();
318 
319  option_group_definition new_options("argmax options");
320  new_options.add(make_option("cost", D->false_negative_cost).default_value(10.0f).help("False Negative Cost"))
321  .add(make_option("negative_weight", D->negative_weight)
322  .default_value(1.f)
323  .help("Relative weight of negative examples"))
324  .add(make_option("max", D->predict_max).help("Disable structure: just predict the max"));
325  options.add_and_parse(new_options);
326 
327  sch.set_task_data(D);
328 
329  if (D->predict_max)
330  sch.set_options(Search::EXAMPLES_DONT_CHANGE); // we don't do any internal example munging
331  else
332  sch.set_options(Search::AUTO_CONDITION_FEATURES | // automatically add history features to our examples, please
333  Search::EXAMPLES_DONT_CHANGE); // we don't do any internal example munging
334 }
virtual void add_and_parse(const option_group_definition &group)=0
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
void set_options(uint32_t opts)
Definition: search.cc:3053
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void set_task_data(T *data)
Definition: search.h:84
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t EXAMPLES_DONT_CHANGE
Definition: search.cc:49
float f
Definition: cache.cc:40

◆ run()

void ArgmaxTask::run ( Search::search sch,
multi_ex ec 
)

Definition at line 342 of file search_sequencetask.cc.

References ArgmaxTask::task_data::false_negative_cost, Search::search::get_task_data(), loss(), Search::search::loss(), ArgmaxTask::task_data::negative_weight, Search::search::output(), Search::search::predict(), and ArgmaxTask::task_data::predict_max.

343 {
344  task_data& D = *sch.get_task_data<task_data>();
345  uint32_t max_prediction = 1;
346  uint32_t max_label = 1;
347 
348  for (size_t i = 0; i < ec.size(); i++) max_label = std::max(ec[i]->l.multi.label, max_label);
349 
350  for (ptag i = 0; i < ec.size(); i++)
351  {
352  // labels should be 1 or 2, and our output is MAX of all predicted values
353  uint32_t oracle = D.predict_max ? max_label : ec[i]->l.multi.label;
354  uint32_t prediction = sch.predict(*ec[i], i + 1, &oracle, 1, &i, "p");
355 
356  max_prediction = std::max(prediction, max_prediction);
357  }
358  float loss = 0.;
359  if (max_label > max_prediction)
360  loss = D.false_negative_cost / D.negative_weight;
361  else if (max_prediction > max_label)
362  loss = 1.;
363  sch.loss(loss);
364 
365  if (sch.output().good())
366  sch.output() << max_prediction;
367 }
std::stringstream & output()
Definition: search.cc:3043
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
T * get_task_data()
Definition: search.h:89
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
Definition: search.cc:2967
uint32_t ptag
Definition: search.h:20
void loss(float incr_loss)
Definition: search.cc:3039

Variable Documentation

◆ task

Search::search_task ArgmaxTask::task = {"argmax", run, initialize, finish, nullptr, nullptr}

Definition at line 25 of file search_sequencetask.cc.