1 #include "gtest/gtest.h" 2 #include "gmock/gmock.h" 8 using namespace ::testing;
10 TEST(ExploreTestSuite, EpsilonGreedy)
12 std::vector<float> pdf(4);
14 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-6
f), std::vector<float>{0.1f, 0.1f, 0.7f, 0.1f}));
17 TEST(ExploreTestSuite, EpsilonGreedyTopActionOutOfBounds)
19 std::vector<float> pdf(4);
21 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-6
f), std::vector<float>{0.1f, 0.1f, 0.1f, 0.7f}));
24 TEST(ExploreTestSuite, EpsilonGreedy_bad_range)
26 std::vector<float> pdf;
30 EXPECT_THAT(pdf.size(), 0);
33 TEST(ExploreTestSuite, Softmax)
35 std::vector<float> scores = {1, 2, 3, 8};
36 std::vector<float> pdf(4);
38 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{0.128f, 0.157f, 0.192f, 0.522f}));
41 TEST(ExploreTestSuite, SoftmaxInBalanced)
43 std::vector<float> scores = {1, 2, 3};
44 std::vector<float> pdf(4);
46 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{0.269f, 0.328f, 0.401f, 0}));
49 TEST(ExploreTestSuite, SoftmaxInBalanced2)
51 std::vector<float> scores = {1, 2, 3, 8, 4};
52 std::vector<float> pdf(4);
54 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{0.128f, 0.157f, 0.192f, 0.522f}));
57 TEST(ExploreTestSuite, Softmax_bad_range)
59 std::vector<float> scores;
60 std::vector<float> pdf;
67 TEST(ExploreTestSuite, Bag)
69 std::vector<uint16_t> top_actions = {0, 0, 1, 1};
70 std::vector<float> pdf(4);
72 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{0, 0, 0.5, 0.5f}));
75 TEST(ExploreTestSuite, Bag10)
77 std::vector<uint16_t> top_actions = {10};
78 std::vector<float> pdf(4);
80 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{1.f, 0, 0, 0}));
83 TEST(ExploreTestSuite, BagEmpty)
85 std::vector<uint16_t> top_actions;
86 std::vector<float> pdf(4);
88 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{1.f, 0, 0, 0}));
91 TEST(ExploreTestSuite, Bag_bad_range)
93 std::vector<uint16_t> top_actions;
94 std::vector<float> pdf;
104 std::vector<float> pdf = {1.f, 0, 0};
106 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{.8f, .1f, .1f}));
109 TEST(ExploreTestSuite, enforce_minimum_probability_no_zeros)
111 std::vector<float> pdf = {0.9f, 0.1f, 0};
113 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{.8f, .2f, .0f}));
116 TEST(ExploreTestSuite, enforce_minimum_probability_uniform)
118 std::vector<float> pdf = {0.9f, 0.1f, 0, 0};
120 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{.25f, .25f, .25f, .25f}));
123 TEST(ExploreTestSuite, enforce_minimum_probability_uniform_no_zeros)
125 std::vector<float> pdf = {0.9f, 0.1f, 0};
127 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-3
f), std::vector<float>{.5f, .5f, .0f}));
130 TEST(ExploreTestSuite, enforce_minimum_probability_bad_range)
132 std::vector<float> pdf;
138 TEST(ExploreTestSuite, sampling)
140 std::vector<float> pdf = {0.8f, 0.1f, 0.1f};
141 std::vector<float> histogram(3);
144 uint32_t chosen_index;
145 for (
size_t i = 0; i < rep; i++)
151 histogram[chosen_index]++;
153 for (
auto& d : histogram) d /= rep;
155 EXPECT_THAT(pdf, Pointwise(FloatNearPointwise(1e-2
f), histogram));
158 TEST(PairIteratorTestSuite, simple_test)
160 using ActionType = size_t;
162 const int num_actions = 3;
163 ActionType actions[num_actions];
164 float pdf[num_actions];
166 std::generate(std::begin(pdf), std::end(pdf), [&n]() {
return n++; });
167 std::iota(std::begin(actions), std::end(actions), 0);
168 float scores[] = {.4f, .1f, .2f};
171 using FirstVal = ActionType;
172 using SecondVal = float;
173 using FirstIt = FirstVal*;
174 using SecondIt = SecondVal*;
179 const iter begin_coll(std::begin(actions), std::begin(pdf));
180 const iter end_coll(std::end(actions), std::end(pdf));
181 size_t diff = end_coll - begin_coll;
182 std::sort(begin_coll, end_coll, [&scores](
const loc& l,
const loc& r) {
return scores[l._val1] < scores[r._val1]; });
184 EXPECT_THAT(actions, ElementsAre(1, 2, 0));
185 EXPECT_THAT(pdf, ElementsAre(1.0
f, 2.0
f, 0.0
f));
188 TEST(ExploreTestSuite, sampling_rank)
190 std::vector<float> scores = {0.f, 3.f, 1.f};
191 std::vector<float> histogram(scores.size() * scores.size());
192 std::vector<int> ranking(3);
197 for (
size_t i = 0; i < rep; i++)
199 std::vector<float> pdf = {0.8f, 0.1f, 0.1f};
206 std::begin(ranking), std::end(ranking)));
209 uint32_t chosen_action_idx;
214 if (chosen_action_idx != 0)
216 std::iter_swap(std::begin(ranking), std::begin(ranking) + chosen_action_idx);
217 std::iter_swap(std::begin(pdf), std::begin(pdf) + chosen_action_idx);
220 for (
size_t i = 0; i < ranking.size(); i++) histogram[i * ranking.size() + ranking[i]]++;
223 for (
auto& d : histogram) d /= rep;
228 std::vector<float> ranking_pdf = {
247 EXPECT_THAT(histogram, Pointwise(FloatNearPointwise(1e-2
f), ranking_pdf));
250 TEST(ExploreTestSuite, sampling_rank_bad_range)
252 std::vector<float> pdf;
253 std::vector<float> scores;
254 std::vector<int> ranking(3);
256 uint32_t chosen_index;
263 TEST(ExploreTestSuite, sampling_rank_zero_pdf)
265 std::vector<float> pdf = {0.f, 0.f, 0.f};
266 std::vector<float> expected_pdf = {1.f, 0.f, 0.f};
268 uint32_t chosen_index;
273 EXPECT_THAT(expected_pdf, Pointwise(FloatNearPointwise(1e-2
f), pdf));
276 TEST(ExploreTestSuite, sampling_rank_negative_pdf)
278 std::vector<float> pdf = {1.0f, -2.f, -3.f};
279 std::vector<float> expected_pdf = {1.f, 0.f, 0.f};
280 std::vector<int> ranking(3);
281 uint32_t chosen_index;
285 EXPECT_THAT(expected_pdf, Pointwise(FloatNearPointwise(1e-2
f), pdf));
286 EXPECT_THAT(0, chosen_index);
int generate_bag(InputIt top_actions_first, InputIt top_actions_last, OutputIt pdf_first, OutputIt pdf_last)
Generates an exploration distribution according to votes on actions.
int sample_after_normalizing(uint64_t seed, It pdf_first, It pdf_last, uint32_t &chosen_index)
Sample an index from the provided pdf. If the pdf is not normalized it will be updated in-place...
Vowpal Wabbit slim predictor. Supports: regression, multi-class classification and contextual bandits...
int generate_softmax(float lambda, InputIt scores_first, InputIt scores_last, OutputIt pdf_first, OutputIt pdf_last)
Generates softmax style exploration distribution.
TEST(ExploreTestSuite, EpsilonGreedy)
int generate_epsilon_greedy(float epsilon, uint32_t top_action, It pdf_first, It pdf_last)
Generates epsilon-greedy style exploration distribution.
#define E_EXPLORATION_BAD_RANGE
int enforce_minimum_probability(float minimum_uniform, bool update_zero_elements, It pdf_first, It pdf_last)
Updates the pdf to ensure each action is explored with at least minimum_uniform/num_actions.