Vowpal Wabbit
Namespaces | Functions
cb_adf.h File Reference
#include <vector>
#include "reductions_fwd.h"

Go to the source code of this file.

Namespaces

 CB_ADF
 

Functions

LEARNER::base_learnercb_adf_setup (VW::config::options_i &options, vw &all)
 
CB::cb_class CB_ADF::get_observed_cost (multi_ex &examples)
 
void CB_ADF::global_print_newline (const v_array< int > &final_prediction_sink)
 
exampleCB_ADF::test_adf_sequence (multi_ex &ec_seq)
 

Function Documentation

◆ cb_adf_setup()

LEARNER::base_learner* cb_adf_setup ( VW::config::options_i options,
vw all 
)

Definition at line 481 of file cb_adf.cc.

References prediction_type::action_scores, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_multiline(), label_type::cb, CB::cb_label, CB_TYPE_DM, CB_TYPE_DR, CB_TYPE_IPS, CB_TYPE_MTR, CB_TYPE_SM, ACTION_SCORE::delete_action_scores(), vw::delete_prediction, f, CB_ADF::finish_multiline_example(), LEARNER::init_learner(), VW::config::options_i::insert(), vw::label_type, CB_ADF::learn(), parser::lp, LEARNER::make_base(), VW::config::make_option(), vw::model_file_ver, vw::p, CB_ADF::predict(), CB_ADF::save_load(), vw::scorer, vw::sd, LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_save_load(), CB_ADF::cb_adf::set_scorer(), setup_base(), vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

482 {
483  bool cb_adf_option = false;
484  std::string type_string = "mtr";
485 
486  size_t cb_type;
487  bool rank_all;
488  float clip_p;
489  bool no_predict;
490 
491  option_group_definition new_options("Contextual Bandit with Action Dependent Features");
492  new_options
493  .add(make_option("cb_adf", cb_adf_option)
494  .keep()
495  .help("Do Contextual Bandit learning with multiline action dependent features."))
496  .add(make_option("rank_all", rank_all).keep().help("Return actions sorted by score order"))
497  .add(make_option("no_predict", no_predict).help("Do not do a prediction when training"))
498  .add(make_option("clip_p", clip_p)
499  .keep()
500  .default_value(0.f)
501  .help("Clipping probability in importance weight. Default: 0.f (no clipping)."))
502  .add(make_option("cb_type", type_string)
503  .keep()
504  .help("contextual bandit method to use in {ips, dm, dr, mtr, sm}. Default: mtr"));
505  options.add_and_parse(new_options);
506 
507  if (!cb_adf_option)
508  return nullptr;
509 
510  // Ensure serialization of this option in all cases.
511  if (!options.was_supplied("cb_type"))
512  {
513  options.insert("cb_type", type_string);
514  options.add_and_parse(new_options);
515  }
516 
517  // number of weight vectors needed
518  size_t problem_multiplier = 1; // default for IPS
519  bool check_baseline_enabled = false;
520 
521  if (type_string == "dr")
522  {
523  cb_type = CB_TYPE_DR;
524  problem_multiplier = 2;
525  // only use baseline when manually enabled for loss estimation
526  check_baseline_enabled = true;
527  }
528  else if (type_string == "ips")
529  cb_type = CB_TYPE_IPS;
530  else if (type_string == "mtr")
531  cb_type = CB_TYPE_MTR;
532  else if (type_string == "dm")
533  cb_type = CB_TYPE_DM;
534  else if (type_string == "sm")
535  cb_type = CB_TYPE_SM;
536  else
537  {
538  all.trace_message << "warning: cb_type must be in {'ips','dr','mtr','dm','sm'}; resetting to mtr." << std::endl;
539  cb_type = CB_TYPE_MTR;
540  }
541 
542  if (clip_p > 0.f && cb_type == CB_TYPE_SM)
543  all.trace_message << "warning: clipping probability not yet implemented for cb_type sm; p will not be clipped."
544  << std::endl;
545 
547 
548  // Push necessary flags.
549  if ((!options.was_supplied("csoaa_ldf") && !options.was_supplied("wap_ldf")) || rank_all ||
550  !options.was_supplied("csoaa_rank"))
551  {
552  if (!options.was_supplied("csoaa_ldf"))
553  {
554  options.insert("csoaa_ldf", "multiline");
555  }
556 
557  if (!options.was_supplied("csoaa_rank"))
558  {
559  options.insert("csoaa_rank", "");
560  }
561  }
562 
563  if (options.was_supplied("baseline") && check_baseline_enabled)
564  {
565  options.insert("check_enabled", "");
566  }
567 
568  auto ld = scoped_calloc_or_throw<cb_adf>(all.sd, cb_type, &all.model_file_ver, rank_all, clip_p, no_predict);
569 
570  auto base = as_multiline(setup_base(options, all));
571  all.p->lp = CB::cb_label;
573 
574  cb_adf* bare = ld.get();
576  init_learner(ld, base, learn, predict, problem_multiplier, prediction_type::action_scores);
578 
579  bare->set_scorer(all.scorer);
580 
582  return make_base(l);
583 }
#define CB_TYPE_IPS
Definition: cb_algs.h:15
void(* delete_prediction)(void *)
Definition: global_data.h:485
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void finish_multiline_example(vw &all, cb_adf &data, multi_ex &ec_seq)
Definition: cb_adf.cc:451
#define CB_TYPE_DM
Definition: cb_algs.h:14
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
#define CB_TYPE_DR
Definition: cb_algs.h:13
parser * p
Definition: global_data.h:377
void set_scorer(LEARNER::single_learner *scorer)
Definition: cb_adf.cc:66
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
shared_data * sd
Definition: global_data.h:375
VW::version_struct model_file_ver
Definition: global_data.h:419
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
LEARNER::single_learner * scorer
Definition: global_data.h:384
virtual void insert(const std::string &key, const std::string &value)=0
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void save_load(cb_adf &c, io_buf &model_file, bool read, bool text)
Definition: cb_adf.cc:461
void predict(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:477
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define CB_TYPE_SM
Definition: cb_algs.h:17
float f
Definition: cache.cc:40
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
void learn(cb_adf &c, multi_learner &base, multi_ex &ec_seq)
Definition: cb_adf.cc:475
label_parser lp
Definition: parser.h:102
#define CB_TYPE_MTR
Definition: cb_algs.h:16