554 uint32_t num_actions = 0;
555 auto data = scoped_calloc_or_throw<warm_cb>();
563 .help(
"Convert multiclass on <k> classes into a contextual bandit problem"))
565 .help(
"consume cost-sensitive classification examples instead of multiclass"))
566 .add(
make_option(
"loss0", data->loss0).default_value(0.f).help(
"loss for correct label"))
567 .
add(
make_option(
"loss1", data->loss1).default_value(1.
f).help(
"loss for incorrect label"))
570 .help(
"number of training examples for warm start phase"))
571 .
add(
make_option(
"epsilon", data->epsilon).keep().help(
"epsilon-greedy exploration"))
573 .default_value(UINT32_MAX)
574 .help(
"number of examples for the interactive contextual bandit learning phase"))
575 .
add(
make_option(
"warm_start_update", data->upd_ws).help(
"indicator of warm start updates"))
576 .
add(
make_option(
"interaction_update", data->upd_inter).help(
"indicator of interaction updates"))
579 .help(
"type of label corruption in the warm start phase (1: uniformly at random, 2: circular, 3: " 580 "replacing with overwriting label)"))
583 .help(
"probability of label corruption in the warm start phase"))
586 .help(
"the number of candidate lambdas to aggregate (lambda is the importance weight parameter between " 590 .help(
"The scheme for generating candidate lambda set (1: center lambda=0.5, 2: center lambda=0.5, min " 591 "lambda=0, max lambda=1, 3: center lambda=epsilon/(1+epsilon), 4: center " 592 "lambda=epsilon/(1+epsilon), min lambda=0, max lambda=1); the rest of candidate lambda values are " 593 "generated using a doubling scheme"))
596 .help(
"the label used by type 3 corruptions (overwriting)"))
598 .help(
"simulate contextual bandit updates on warm start examples"));
604 THROW(
"label corruption on cost-sensitive examples not currently supported");
613 data->a_s = v_init<action_score>();
616 data->use_cs = use_cs;
625 std::stringstream ss;
626 ss << std::max(std::abs(data->loss0), std::abs(data->loss1)) / (data->loss1 - data->loss0);
627 options.
insert(
"lr_multiplier", ss.str());
639 std::cerr <<
"Warning: no epsilon (greedy parameter) specified; resetting to 0.05" << std::endl;
640 data->epsilon = 0.05f;
645 data, base, predict_or_learn_adf<true, true>, predict_or_learn_adf<false, true>, all.
p, data->choices_lambda);
648 data, base, predict_or_learn_adf<true, false>, predict_or_learn_adf<false, false>, all.
p, data->choices_lambda);
void(* delete_prediction)(void *)
learner< T, E > & init_cost_sensitive_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
base_learner * make_base(learner< T, E > &base)
void finish(warm_cb &data)
virtual void add_and_parse(const option_group_definition &group)=0
std::shared_ptr< rand_state > get_random_state()
virtual bool was_supplied(const std::string &key)=0
virtual void insert(const std::string &key, const std::string &value)=0
int add(svm_params ¶ms, svm_example *fec)
typed_option< T > make_option(std::string name, T &location)
void set_finish(void(*f)(T &))
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
LEARNER::base_learner * setup_base(options_i &options, vw &all)
void init_adf_data(warm_cb &data, const uint32_t num_actions)
multi_learner * as_multiline(learner< T, E > *l)
const char * to_string(prediction_type_t prediction_type)