Vowpal Wabbit
Public Member Functions | Private Attributes | List of all members
VW::cb_sample_data Struct Reference

Public Member Functions

 cb_sample_data (std::shared_ptr< rand_state > &random_state)
 
 cb_sample_data (std::shared_ptr< rand_state > &&random_state)
 
template<bool is_learn>
void learn_or_predict (multi_learner &base, multi_ex &examples)
 

Private Attributes

std::shared_ptr< rand_state_random_state
 

Detailed Description

Definition at line 14 of file cb_sample.cc.

Constructor & Destructor Documentation

◆ cb_sample_data() [1/2]

VW::cb_sample_data::cb_sample_data ( std::shared_ptr< rand_state > &  random_state)
inlineexplicit

Definition at line 16 of file cb_sample.cc.

16 : _random_state(random_state) {}
std::shared_ptr< rand_state > _random_state
Definition: cb_sample.cc:87

◆ cb_sample_data() [2/2]

VW::cb_sample_data::cb_sample_data ( std::shared_ptr< rand_state > &&  random_state)
inlineexplicit

Definition at line 17 of file cb_sample.cc.

17 : _random_state(random_state) {}
std::shared_ptr< rand_state > _random_state
Definition: cb_sample.cc:87

Member Function Documentation

◆ learn_or_predict()

template<bool is_learn>
void VW::cb_sample_data::learn_or_predict ( multi_learner base,
multi_ex examples 
)
inline

Definition at line 20 of file cb_sample.cc.

References _UNUSED, substring::begin, ACTION_SCORE::begin_scores(), ACTION_SCORE::end_scores(), S_EXPLORATION_OK, exploration::sample_after_normalizing(), substring_len(), exploration::swap_chosen(), and uniform_hash().

Referenced by learn_or_predict().

21  {
22  multiline_learn_or_predict<is_learn>(base, examples, examples[0]->ft_offset);
23 
24  auto action_scores = examples[0]->pred.a_s;
25  uint32_t chosen_action = -1;
26 
27  int labelled_action = -1;
28  // Find that chosen action in the learning case, skip the shared example.
29  auto it = std::find_if(examples.begin(), examples.end(), [](example *item) { return !item->l.cb.costs.empty(); });
30  if (it != examples.end())
31  {
32  labelled_action = std::distance(examples.begin(), it);
33  }
34 
35  // If we are learning and have a label, then take that action as the chosen action. Otherwise sample the
36  // distribution.
37  if (is_learn && labelled_action != -1)
38  {
39  // Find where the labelled action is in the final prediction to determine if swapping needs to occur.
40  // This only matters if the prediction decided to explore, but the same output should happen for the learn case.
41  for (size_t i = 0; i < action_scores.size(); i++)
42  {
43  auto &a_s = action_scores[i];
44  if (a_s.action == static_cast<uint32_t>(labelled_action))
45  {
46  chosen_action = static_cast<uint32_t>(i);
47  break;
48  }
49  }
50  }
51  else
52  {
53  bool tag_provided_seed = false;
54  uint64_t seed = _random_state->get_current_state();
55  if (!examples[0]->tag.empty())
56  {
57  const std::string SEED_IDENTIFIER = "seed=";
58  if (strncmp(examples[0]->tag.begin(), SEED_IDENTIFIER.c_str(), SEED_IDENTIFIER.size()) == 0 &&
59  examples[0]->tag.size() > SEED_IDENTIFIER.size())
60  {
61  substring tag_seed{examples[0]->tag.begin() + 5, examples[0]->tag.begin() + examples[0]->tag.size()};
62  seed = uniform_hash(tag_seed.begin, substring_len(tag_seed), 0);
63  tag_provided_seed = true;
64  }
65  }
66 
67  // Sampling is done after the base learner has generated a pdf.
69  seed, ACTION_SCORE::begin_scores(action_scores), ACTION_SCORE::end_scores(action_scores), chosen_action);
70  assert(result == S_EXPLORATION_OK);
71  _UNUSED(result);
72 
73  // Update the seed state in place if it was used for this example.
74  if (!tag_provided_seed)
75  {
76  _random_state->get_and_update_random();
77  }
78  }
79 
80  auto result = exploration::swap_chosen(action_scores.begin(), action_scores.end(), chosen_action);
81  assert(result == S_EXPLORATION_OK);
82 
83  _UNUSED(result);
84  }
size_t substring_len(substring &s)
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
char * begin
Definition: hashstring.h:9
v_array< action_score > action_scores
Definition: action_score.h:10
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
Definition: hash.h:67
score_iterator begin_scores(action_scores &a_s)
Definition: action_score.h:43
score_iterator end_scores(action_scores &a_s)
Definition: action_score.h:45
std::shared_ptr< rand_state > _random_state
Definition: cb_sample.cc:87
int swap_chosen(ActionIt action_first, ActionIt action_last, uint32_t chosen_index)
Swap the first value with the chosen index.
#define S_EXPLORATION_OK
Definition: explore.h:3
#define _UNUSED(x)
Definition: vw_exception.h:244

Member Data Documentation

◆ _random_state

std::shared_ptr<rand_state> VW::cb_sample_data::_random_state
private

Definition at line 87 of file cb_sample.cc.


The documentation for this struct was generated from the following file: