28 my_task_data->
num_level = (size_t)ceil(log(num_actions) / log(2));
44 size_t gold_label = ec[0]->l.multi.label;
46 size_t learner_id = 0;
48 for (
size_t i = 0; i < my_task_data->
num_level; i++)
51 size_t y_allowed_size = (label + mask + 1 <= my_task_data->
max_label) ? 2 : 1;
52 action oracle = (((gold_label - 1) & mask) > 0) + 1;
53 size_t prediction = sch.
predict(*ec[0], 0, &oracle, 1,
nullptr,
nullptr, my_task_data->
y_allowed.
begin(),
54 y_allowed_size,
nullptr, learner_id);
55 learner_id = (learner_id << 1) + prediction;
60 sch.
loss(!(label == gold_label));
62 sch.
output() << label <<
' ';
void run(Search::search &sch, multi_ex &ec)
std::stringstream & output()
void push_back(const T &new_ele)
void set_options(uint32_t opts)
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
v_array< uint32_t > y_allowed
std::vector< example * > multi_ex
void set_task_data(T *data)
constexpr uint64_t UINT64_ONE
void initialize(Search::search &sch, size_t &num_actions, VW::config::options_i &)
void finish(Search::search &sch)
void set_num_learners(size_t num_learners)
void loss(float incr_loss)