Vowpal Wabbit
cb_explore_adf_first.cc
Go to the documentation of this file.
1 #include "cb_explore_adf_first.h"
2 
3 #include "reductions.h"
4 #include "cb_adf.h"
5 #include "rand48.h"
6 #include "bs.h"
7 #include "gen_cs_example.h"
8 #include "cb_explore.h"
9 #include "explore.h"
10 #include "cb_explore_adf_common.h"
11 #include <vector>
12 #include <algorithm>
13 #include <cmath>
14 
15 namespace VW
16 {
17 namespace cb_explore_adf
18 {
19 namespace first
20 {
22 {
23  private:
24  size_t _tau;
25  float _epsilon;
26 
27  public:
28  cb_explore_adf_first(size_t tau, float epsilon);
29  ~cb_explore_adf_first() = default;
30 
31  // Should be called through cb_explore_adf_base for pre/post-processing
32  void predict(LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<false>(base, examples); }
33  void learn(LEARNER::multi_learner& base, multi_ex& examples) { predict_or_learn_impl<true>(base, examples); }
34 
35  private:
36  template <bool is_learn>
38 };
39 
40 cb_explore_adf_first::cb_explore_adf_first(size_t tau, float epsilon) : _tau(tau), _epsilon(epsilon) {}
41 
42 template <bool is_learn>
44 {
45  // Explore tau times, then act according to optimal.
46  if (is_learn)
47  LEARNER::multiline_learn_or_predict<true>(base, examples, examples[0]->ft_offset);
48  else
49  LEARNER::multiline_learn_or_predict<false>(base, examples, examples[0]->ft_offset);
50 
51  v_array<ACTION_SCORE::action_score>& preds = examples[0]->pred.a_s;
52  uint32_t num_actions = (uint32_t)preds.size();
53 
54  if (_tau)
55  {
56  float prob = 1.f / (float)num_actions;
57  for (size_t i = 0; i < num_actions; i++) preds[i].score = prob;
58  if (is_learn)
59  _tau--;
60  }
61  else
62  {
63  for (size_t i = 1; i < num_actions; i++) preds[i].score = 0.;
64  preds[0].score = 1.0;
65  }
66 
68 }
69 
71 {
72  using config::make_option;
73  bool cb_explore_adf_option = false;
74  size_t tau = 0;
75  float epsilon = 0.;
76  config::option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features");
77  new_options
78  .add(make_option("cb_explore_adf", cb_explore_adf_option)
79  .keep()
80  .help("Online explore-exploit for a contextual bandit problem with multiline action dependent features"))
81  .add(make_option("first", tau).keep().help("tau-first exploration"))
82  .add(make_option("epsilon", epsilon).keep().help("epsilon-greedy exploration"));
83  options.add_and_parse(new_options);
84 
85  if (!cb_explore_adf_option || !options.was_supplied("first"))
86  return nullptr;
87 
88  // Ensure serialization of cb_adf in all cases.
89  if (!options.was_supplied("cb_adf"))
90  {
91  options.insert("cb_adf", "");
92  }
93 
95 
96  size_t problem_multiplier = 1;
97 
99  all.p->lp = CB::cb_label;
101 
102  using explore_type = cb_explore_adf_base<cb_explore_adf_first>;
103  auto data = scoped_calloc_or_throw<explore_type>(tau, epsilon);
104 
107 
108  l.set_finish_example(explore_type::finish_multiline_example);
109  return make_base(l);
110 }
111 } // namespace first
112 } // namespace cb_explore_adf
113 } // namespace VW
void(* delete_prediction)(void *)
Definition: global_data.h:485
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
Definition: cbify.cc:373
void predict(LEARNER::multi_learner &base, multi_ex &examples)
label_type::label_type_t label_type
Definition: global_data.h:550
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void predict_or_learn_impl(LEARNER::multi_learner &base, multi_ex &examples)
size_t size() const
Definition: v_array.h:68
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
parser * p
Definition: global_data.h:377
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void delete_action_scores(void *v)
Definition: action_score.cc:29
virtual bool was_supplied(const std::string &key)=0
LEARNER::base_learner * setup(config::options_i &options, vw &all)
void learn(LEARNER::multi_learner &base, multi_ex &examples)
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last)
Updates the pdf to ensure each action is explored with at least minimum_uniform/num_actions.
virtual void insert(const std::string &key, const std::string &value)=0
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
label_parser cb_label
Definition: cb.cc:167
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: autolink.cc:11
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
label_parser lp
Definition: parser.h:102