Vowpal Wabbit
cb_sample.cc
Go to the documentation of this file.
1 #include "reductions.h"
2 #include "cb_sample.h"
3 #include "explore.h"
4 
5 #include "rand48.h"
6 
7 using namespace LEARNER;
8 using namespace VW;
9 using namespace VW::config;
10 
11 namespace VW
12 {
13 // cb_sample is used to automatically sample and swap from a cb explore pdf.
15 {
16  explicit cb_sample_data(std::shared_ptr<rand_state> &random_state) : _random_state(random_state) {}
17  explicit cb_sample_data(std::shared_ptr<rand_state> &&random_state) : _random_state(random_state) {}
18 
19  template <bool is_learn>
20  inline void learn_or_predict(multi_learner &base, multi_ex &examples)
21  {
22  multiline_learn_or_predict<is_learn>(base, examples, examples[0]->ft_offset);
23 
24  auto action_scores = examples[0]->pred.a_s;
25  uint32_t chosen_action = -1;
26 
27  int labelled_action = -1;
28  // Find that chosen action in the learning case, skip the shared example.
29  auto it = std::find_if(examples.begin(), examples.end(), [](example *item) { return !item->l.cb.costs.empty(); });
30  if (it != examples.end())
31  {
32  labelled_action = std::distance(examples.begin(), it);
33  }
34 
35  // If we are learning and have a label, then take that action as the chosen action. Otherwise sample the
36  // distribution.
37  if (is_learn && labelled_action != -1)
38  {
39  // Find where the labelled action is in the final prediction to determine if swapping needs to occur.
40  // This only matters if the prediction decided to explore, but the same output should happen for the learn case.
41  for (size_t i = 0; i < action_scores.size(); i++)
42  {
43  auto &a_s = action_scores[i];
44  if (a_s.action == static_cast<uint32_t>(labelled_action))
45  {
46  chosen_action = static_cast<uint32_t>(i);
47  break;
48  }
49  }
50  }
51  else
52  {
53  bool tag_provided_seed = false;
54  uint64_t seed = _random_state->get_current_state();
55  if (!examples[0]->tag.empty())
56  {
57  const std::string SEED_IDENTIFIER = "seed=";
58  if (strncmp(examples[0]->tag.begin(), SEED_IDENTIFIER.c_str(), SEED_IDENTIFIER.size()) == 0 &&
59  examples[0]->tag.size() > SEED_IDENTIFIER.size())
60  {
61  substring tag_seed{examples[0]->tag.begin() + 5, examples[0]->tag.begin() + examples[0]->tag.size()};
62  seed = uniform_hash(tag_seed.begin, substring_len(tag_seed), 0);
63  tag_provided_seed = true;
64  }
65  }
66 
67  // Sampling is done after the base learner has generated a pdf.
69  seed, ACTION_SCORE::begin_scores(action_scores), ACTION_SCORE::end_scores(action_scores), chosen_action);
70  assert(result == S_EXPLORATION_OK);
71  _UNUSED(result);
72 
73  // Update the seed state in place if it was used for this example.
74  if (!tag_provided_seed)
75  {
76  _random_state->get_and_update_random();
77  }
78  }
79 
80  auto result = exploration::swap_chosen(action_scores.begin(), action_scores.end(), chosen_action);
81  assert(result == S_EXPLORATION_OK);
82 
83  _UNUSED(result);
84  }
85 
86  private:
87  std::shared_ptr<rand_state> _random_state;
88 };
89 } // namespace VW
90 
91 template <bool is_learn>
93 {
94  data.learn_or_predict<is_learn>(base, examples);
95 }
96 
98 {
99  bool cb_sample_option = false;
100 
101  option_group_definition new_options("CB Sample");
102  new_options.add(make_option("cb_sample", cb_sample_option).keep().help("Sample from CB pdf and swap top action."));
103  options.add_and_parse(new_options);
104 
105  if (!cb_sample_option)
106  return nullptr;
107 
108  if (options.was_supplied("no_predict"))
109  {
110  THROW("cb_sample cannot be used with no_predict, as there would be no predictions to sample.");
111  }
112 
113  auto data = scoped_calloc_or_throw<cb_sample_data>(all.get_random_state());
114  return make_base(init_learner(data, as_multiline(setup_base(options, all)), learn_or_predict<true>,
115  learn_or_predict<false>, 1 /* weights */, prediction_type::action_probs));
116 }
size_t substring_len(substring &s)
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
cb_sample_data(std::shared_ptr< rand_state > &random_state)
Definition: cb_sample.cc:16
char * begin
Definition: hashstring.h:9
v_array< action_score > action_scores
Definition: action_score.h:10
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void learn_or_predict(cb_sample_data &data, multi_learner &base, multi_ex &examples)
Definition: cb_sample.cc:92
void learn_or_predict(multi_learner &base, multi_ex &examples)
Definition: cb_sample.cc:20
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
cb_sample_data(std::shared_ptr< rand_state > &&random_state)
Definition: cb_sample.cc:17
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
std::shared_ptr< rand_state > _random_state
Definition: cb_sample.cc:87
int swap_chosen(ActionIt action_first, ActionIt action_last, uint32_t chosen_index)
Swap the first value with the chosen index.
virtual bool was_supplied(const std::string &key)=0
base_learner * cb_sample_setup(options_i &options, vw &all)
Definition: cb_sample.cc:97
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: autolink.cc:11
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define S_EXPLORATION_OK
Definition: explore.h:3
#define THROW(args)
Definition: vw_exception.h:181
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
#define _UNUSED(x)
Definition: vw_exception.h:244