Vowpal Wabbit
Classes | Functions | Variables
MulticlassTask Namespace Reference

Classes

struct  task_data
 

Functions

void initialize (Search::search &sch, size_t &num_actions, VW::config::options_i &)
 
void finish (Search::search &sch)
 
void run (Search::search &sch, multi_ex &ec)
 

Variables

Search::search_task task = {"multiclasstask", run, initialize, finish, nullptr, nullptr}
 

Function Documentation

◆ finish()

void MulticlassTask::finish ( Search::search sch)

Definition at line 34 of file search_multiclasstask.cc.

References v_array< T >::delete_v(), Search::search::get_task_data(), and MulticlassTask::task_data::y_allowed.

35 {
36  task_data* my_task_data = sch.get_task_data<task_data>();
37  my_task_data->y_allowed.delete_v();
38  delete my_task_data;
39 }
T * get_task_data()
Definition: search.h:89

◆ initialize()

void MulticlassTask::initialize ( Search::search sch,
size_t &  num_actions,
VW::config::options_i  
)

Definition at line 22 of file search_multiclasstask.cc.

References MulticlassTask::task_data::max_label, MulticlassTask::task_data::num_level, v_array< T >::push_back(), Search::search::set_num_learners(), Search::search::set_options(), Search::search::set_task_data(), and MulticlassTask::task_data::y_allowed.

23 {
24  task_data* my_task_data = new task_data();
25  sch.set_options(0);
26  sch.set_num_learners(num_actions);
27  my_task_data->max_label = num_actions;
28  my_task_data->num_level = (size_t)ceil(log(num_actions) / log(2));
29  my_task_data->y_allowed.push_back(1);
30  my_task_data->y_allowed.push_back(2);
31  sch.set_task_data(my_task_data);
32 }
void set_options(uint32_t opts)
Definition: search.cc:3053
void set_task_data(T *data)
Definition: search.h:84
void set_num_learners(size_t num_learners)
Definition: search.cc:3094

◆ run()

void MulticlassTask::run ( Search::search sch,
multi_ex ec 
)

Definition at line 41 of file search_multiclasstask.cc.

References v_array< T >::begin(), Search::search::get_task_data(), Search::search::loss(), MulticlassTask::task_data::max_label, MulticlassTask::task_data::num_level, Search::search::output(), Search::search::predict(), UINT64_ONE, and MulticlassTask::task_data::y_allowed.

42 {
43  task_data* my_task_data = sch.get_task_data<task_data>();
44  size_t gold_label = ec[0]->l.multi.label;
45  size_t label = 0;
46  size_t learner_id = 0;
47 
48  for (size_t i = 0; i < my_task_data->num_level; i++)
49  {
50  size_t mask = UINT64_ONE << (my_task_data->num_level - i - 1);
51  size_t y_allowed_size = (label + mask + 1 <= my_task_data->max_label) ? 2 : 1;
52  action oracle = (((gold_label - 1) & mask) > 0) + 1;
53  size_t prediction = sch.predict(*ec[0], 0, &oracle, 1, nullptr, nullptr, my_task_data->y_allowed.begin(),
54  y_allowed_size, nullptr, learner_id); // TODO: do we really need y_allowed?
55  learner_id = (learner_id << 1) + prediction;
56  if (prediction == 2)
57  label += mask;
58  }
59  label += 1;
60  sch.loss(!(label == gold_label));
61  if (sch.output().good())
62  sch.output() << label << ' ';
63 }
std::stringstream & output()
Definition: search.cc:3043
uint32_t action
Definition: search.h:19
T * get_task_data()
Definition: search.h:89
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
Definition: search.cc:2967
constexpr uint64_t UINT64_ONE
void loss(float incr_loss)
Definition: search.cc:3039

Variable Documentation

◆ task

Search::search_task MulticlassTask::task = {"multiclasstask", run, initialize, finish, nullptr, nullptr}

Definition at line 10 of file search_multiclasstask.cc.