55 template <
bool is_learn>
61 : _sd(sd), _model_file_ver(model_file_ver), _no_predict(no_predict), _rank_all(rank_all), _clip_p(clip_p)
77 for (
auto& prepped_cs_label : _prepped_cs_labels) prepped_cs_label.costs.delete_v();
78 _prepped_cs_labels.delete_v();
79 _cs_labels.
costs.delete_v();
95 template <
bool predict>
102 ld.
costs = v_init<cb_class>();
109 if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX && ec->l.cb.costs[0].probability > 0)
126 known_cost = ld.
costs[0];
127 known_cost.
action = index;
134 call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
140 call_cs_ldf<false>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
152 for (uint32_t i = 0; i < examples[0]->pred.a_s.size(); i++)
154 _a_s.push_back({examples[0]->pred.a_s[i].action, examples[0]->pred.a_s[i].score});
155 _prob_s.push_back({examples[0]->pred.a_s[i].action, 0.0});
158 float sign_offset = 1.0;
159 uint32_t chosen_action = 0;
160 float example_weight = 1.0;
162 for (uint32_t i = 0; i < examples.size(); i++)
165 if (ld.
costs.size() == 1 && ld.
costs[0].cost != FLT_MAX)
172 if (ld.
costs[0].cost < 0.0)
175 example_weight = -example_weight;
198 _backup_weights.clear();
203 _backup_weights.push_back(examples[current_action]->
weight);
204 _backup_nf.push_back((uint32_t)examples[current_action]->num_features);
206 if (current_action == chosen_action)
211 if (examples[current_action]->
weight <= 1e-15)
212 examples[current_action]->weight = 0;
216 call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
219 for (
size_t i = 0; i < _prob_s.size(); i++)
221 uint32_t current_action = _prob_s[i].
action;
222 examples[current_action]->weight = _backup_weights[i];
223 examples[current_action]->num_features = _backup_nf[i];
229 gen_cs_example_dr<true>(_gen_cs, examples, _cs_labels, _clip_p);
230 call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
236 call_cs_ldf<true>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
239 template <
bool predict>
246 call_cs_ldf<false>(base, examples, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
247 std::swap(examples[0]->pred.a_s, _a_s);
253 uint32_t nf = (uint32_t)examples[_gen_cs.mtr_example]->num_features;
254 float old_weight = examples[_gen_cs.mtr_example]->weight;
255 const float clipped_p = std::max(examples[_gen_cs.mtr_example]->l.cb.costs[0].probability, _clip_p);
256 examples[_gen_cs.mtr_example]->weight *= 1.f / clipped_p * ((float)_gen_cs.event_sum / (
float)_gen_cs.action_sum);
258 std::swap(_gen_cs.mtr_ec_seq[0]->pred.a_s, _a_s_mtr_cs);
260 GEN_CS::call_cs_ldf<true>(base, _gen_cs.mtr_ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
261 examples[_gen_cs.mtr_example]->num_features = nf;
262 examples[_gen_cs.mtr_example]->weight = old_weight;
263 std::swap(_gen_cs.mtr_ec_seq[0]->pred.a_s, _a_s_mtr_cs);
264 std::swap(examples[0]->pred.a_s, _a_s);
271 THROW(
"cb_adf: At least one action must be provided for an example to be valid.");
275 for (
auto* ec : ec_seq)
278 if (ec->l.cb.costs.size() > 1)
279 THROW(
"cb_adf: badly formatted example, only one cost can be known.");
282 if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX)
287 THROW(
"cb_adf: badly formatted example, only one line can have a cost");
294 template <
bool is_learn>
297 _offset = ec_seq[0]->ft_offset;
306 switch (_gen_cs.cb_type)
309 learn_IPS(base, ec_seq);
312 learn_DR(base, ec_seq);
315 learn_DM(base, ec_seq);
319 learn_MTR<false>(base, ec_seq);
321 learn_MTR<true>(base, ec_seq);
324 learn_SM(base, ec_seq);
327 THROW(
"Unknown cb_type specified for contextual bandit learning: " << _gen_cs.cb_type);
338 call_cs_ldf<false>(base, ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, _offset);
346 for (
auto f : final_prediction_sink)
351 std::cerr <<
"write error: " << strerror(errno) << std::endl;
359 size_t num_features = 0;
366 bool labeled_example =
true;
367 if (_gen_cs.known_cost.probability > 0)
370 labeled_example =
false;
372 bool holdout_example = labeled_example;
373 for (
auto const& i : *ec_seq) holdout_example &= i->test_only;
375 _sd->update(holdout_example, labeled_example, loss, ec.
weight, num_features);
376 return labeled_example;
391 std::string outputString;
392 std::stringstream outputStringStream(outputString);
395 for (
size_t i = 0; i < costs.
size(); i++)
398 outputStringStream <<
' ';
399 outputStringStream << costs[i].action <<
':' << costs[i].partial_prediction;
421 std::string outputString;
422 std::stringstream outputStringStream(outputString);
423 for (
size_t i = 0; i < costs.
size(); i++)
426 outputStringStream <<
' ';
427 outputStringStream << costs[i].action <<
':' << costs[i].partial_prediction;
465 std::stringstream msg;
483 bool cb_adf_option =
false;
484 std::string type_string =
"mtr";
495 .help(
"Do Contextual Bandit learning with multiline action dependent features."))
496 .
add(
make_option(
"rank_all", rank_all).keep().help(
"Return actions sorted by score order"))
497 .
add(
make_option(
"no_predict", no_predict).help(
"Do not do a prediction when training"))
501 .help(
"Clipping probability in importance weight. Default: 0.f (no clipping)."))
504 .help(
"contextual bandit method to use in {ips, dm, dr, mtr, sm}. Default: mtr"));
513 options.
insert(
"cb_type", type_string);
518 size_t problem_multiplier = 1;
519 bool check_baseline_enabled =
false;
521 if (type_string ==
"dr")
524 problem_multiplier = 2;
526 check_baseline_enabled =
true;
528 else if (type_string ==
"ips")
530 else if (type_string ==
"mtr")
532 else if (type_string ==
"dm")
534 else if (type_string ==
"sm")
538 all.
trace_message <<
"warning: cb_type must be in {'ips','dr','mtr','dm','sm'}; resetting to mtr." << std::endl;
543 all.
trace_message <<
"warning: clipping probability not yet implemented for cb_type sm; p will not be clipped." 554 options.
insert(
"csoaa_ldf",
"multiline");
559 options.
insert(
"csoaa_rank",
"");
563 if (options.
was_supplied(
"baseline") && check_baseline_enabled)
565 options.
insert(
"check_enabled",
"");
568 auto ld = scoped_calloc_or_throw<cb_adf>(all.
sd, cb_type, &all.
model_file_ver, rank_all, clip_p, no_predict);
VW::version_struct * _model_file_ver
void gen_cs_example_sm(multi_ex &, uint32_t chosen_action, float sign_offset, ACTION_SCORE::action_scores action_vals, COST_SENSITIVE::label &cs_labels)
v_array< float > _backup_weights
ACTION_SCORE::action_scores a_s
#define VERSION_FILE_WITH_CB_ADF_SAVE
void(* delete_prediction)(void *)
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
void output_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
void output_example_seq(vw &all, multi_ex &ec_seq)
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
COST_SENSITIVE::label _cs_labels
v_array< uint32_t > _backup_nf
label_type::label_type_t label_type
v_array< int > final_prediction_sink
v_array< cb_class > costs
base_learner * make_base(learner< T, E > &base)
example * test_adf_sequence(multi_ex &ec_seq)
void finish_multiline_example(vw &all, cb_adf &data, multi_ex &ec_seq)
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
int generate_softmax(float lambda, InputIt scores_first, InputIt scores_last, OutputIt pdf_first, OutputIt pdf_last)
Generates softmax style exploration distribution.
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
bool get_rank_all() const
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
void gen_cs_example_dm(multi_ex &examples, COST_SENSITIVE::label &cs_labels)
score_iterator begin_scores(action_scores &a_s)
const cb_to_cs_adf & get_gen_cs() const
CB::cb_class get_observed_cost(multi_ex &examples)
score_iterator end_scores(action_scores &a_s)
void do_actual_learning(ldf &data, single_learner &base, multi_ex &ec_seq_all)
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
void set_scorer(LEARNER::single_learner *scorer)
base_learner * cb_adf_setup(options_i &options, vw &all)
void set_finish_example(void(*f)(vw &all, T &, E &))
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
void delete_action_scores(void *v)
VW::version_struct model_file_ver
virtual bool was_supplied(const std::string &key)=0
void global_print_newline(const v_array< int > &final_prediction_sink)
void output_rank_example(vw &all, cb_adf &c, example &ec, multi_ex *ec_seq)
v_array< CB::label > _cb_labels
void(* print_text)(int, std::string, v_array< char >)
void gen_cs_example_ips(multi_ex &examples, COST_SENSITIVE::label &cs_labels, float clip_p)
void finish_example(vw &, example &)
LEARNER::single_learner * scorer
virtual void insert(const std::string &key, const std::string &value)=0
action_scores _a_s_mtr_cs
option_group_definition & add(T &&op)
std::vector< example * > multi_ex
float safe_probability(float prob)
typed_option< T > make_option(std::string name, T &location)
bool example_is_newline_not_header(example &ec, vw &all)
void gen_cs_example_mtr(cb_to_cs_adf &c, multi_ex &ec_seq, COST_SENSITIVE::label &cs_labels)
const VW::version_struct * get_model_file_ver() const
bool update_statistics(example &ec, multi_ex *ec_seq)
void save_load(cb_adf &c, io_buf &model_file, bool read, bool text)
void predict(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
LEARNER::single_learner * scorer
LEARNER::base_learner * setup_base(options_i &options, vw &all)
v_array< COST_SENSITIVE::label > _prepped_cs_labels
cb_adf(shared_data *sd, size_t cb_type, VW::version_struct *model_file_ver, bool rank_all, float clip_p, bool no_predict)
void print_action_score(int f, v_array< action_score > &a_s, v_array< char > &tag)
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
void gen_cs_test_example(multi_ex &examples, COST_SENSITIVE::label &cs_labels)
void(* print)(int, float, float, v_array< char >)
multi_learner * as_multiline(learner< T, E > *l)
COST_SENSITIVE::label pred_scores
void learn(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
void do_actual_learning(LEARNER::multi_learner &base, multi_ex &ec_seq)