Vowpal Wabbit
search_multiclasstask.cc
Go to the documentation of this file.
1 /*
2  CoPyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5 */
7 
8 namespace MulticlassTask
9 {
10 Search::search_task task = {"multiclasstask", run, initialize, finish, nullptr, nullptr};
11 }
12 
13 namespace MulticlassTask
14 {
15 struct task_data
16 {
17  size_t max_label;
18  size_t num_level;
20 };
21 
22 void initialize(Search::search& sch, size_t& num_actions, VW::config::options_i& /*vm*/)
23 {
24  task_data* my_task_data = new task_data();
25  sch.set_options(0);
26  sch.set_num_learners(num_actions);
27  my_task_data->max_label = num_actions;
28  my_task_data->num_level = (size_t)ceil(log(num_actions) / log(2));
29  my_task_data->y_allowed.push_back(1);
30  my_task_data->y_allowed.push_back(2);
31  sch.set_task_data(my_task_data);
32 }
33 
35 {
36  task_data* my_task_data = sch.get_task_data<task_data>();
37  my_task_data->y_allowed.delete_v();
38  delete my_task_data;
39 }
40 
41 void run(Search::search& sch, multi_ex& ec)
42 {
43  task_data* my_task_data = sch.get_task_data<task_data>();
44  size_t gold_label = ec[0]->l.multi.label;
45  size_t label = 0;
46  size_t learner_id = 0;
47 
48  for (size_t i = 0; i < my_task_data->num_level; i++)
49  {
50  size_t mask = UINT64_ONE << (my_task_data->num_level - i - 1);
51  size_t y_allowed_size = (label + mask + 1 <= my_task_data->max_label) ? 2 : 1;
52  action oracle = (((gold_label - 1) & mask) > 0) + 1;
53  size_t prediction = sch.predict(*ec[0], 0, &oracle, 1, nullptr, nullptr, my_task_data->y_allowed.begin(),
54  y_allowed_size, nullptr, learner_id); // TODO: do we really need y_allowed?
55  learner_id = (learner_id << 1) + prediction;
56  if (prediction == 2)
57  label += mask;
58  }
59  label += 1;
60  sch.loss(!(label == gold_label));
61  if (sch.output().good())
62  sch.output() << label << ' ';
63 }
64 } // namespace MulticlassTask
void run(Search::search &sch, multi_ex &ec)
std::stringstream & output()
Definition: search.cc:3043
uint32_t action
Definition: search.h:19
T *& begin()
Definition: v_array.h:42
T * get_task_data()
Definition: search.h:89
void push_back(const T &new_ele)
Definition: v_array.h:107
void set_options(uint32_t opts)
Definition: search.cc:3053
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
Definition: search.cc:2967
std::vector< example * > multi_ex
Definition: example.h:122
void set_task_data(T *data)
Definition: search.h:84
constexpr uint64_t UINT64_ONE
Search::search_task task
void delete_v()
Definition: v_array.h:98
void initialize(Search::search &sch, size_t &num_actions, VW::config::options_i &)
void finish(Search::search &sch)
void set_num_learners(size_t num_learners)
Definition: search.cc:3094
void loss(float incr_loss)
Definition: search.cc:3039