15 #include <unordered_set> 38 size_t read_count = 0;
41 size_t next_read_size =
sizeof(ld->
type);
42 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
45 read_count +=
sizeof(ld->
type);
47 bool is_outcome_present;
48 next_read_size =
sizeof(bool);
49 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
51 is_outcome_present = *(
bool*)read_ptr;
52 read_count +=
sizeof(is_outcome_present);
54 if (is_outcome_present)
60 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
66 next_read_size =
sizeof(size_probs);
67 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
69 size_probs = *(uint32_t*)read_ptr;
70 read_count +=
sizeof(size_probs);
72 for (uint32_t i = 0; i < size_probs; i++)
75 next_read_size =
sizeof(a_s);
76 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
79 read_count +=
sizeof(a_s);
85 uint32_t size_includes;
86 next_read_size =
sizeof(size_includes);
87 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
89 size_includes = *(uint32_t*)read_ptr;
90 read_count +=
sizeof(size_includes);
92 for (uint32_t i = 0; i < size_includes; i++)
95 next_read_size =
sizeof(include);
96 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
98 include = *(uint32_t*)read_ptr;
99 read_count +=
sizeof(include);
103 next_read_size =
sizeof(ld->
weight);
104 if (cache.
buf_read(read_ptr, next_read_size) < next_read_size)
106 ld->
weight = *(
float*)read_ptr;
120 size_t size =
sizeof(uint8_t)
131 *(uint8_t*)c = static_cast<uint8_t>(ld->
type);
132 c +=
sizeof(ld->
type);
134 *(
bool*)c = ld->
outcome !=
nullptr;
143 c +=
sizeof(uint32_t);
153 c +=
sizeof(uint32_t);
157 *(uint32_t*)c = included_action;
158 c +=
sizeof(included_action);
223 if (std::isnan(probability))
224 THROW(
"error NaN probability: " << probability_str);
226 if (probability > 1.0)
228 std::cerr <<
"invalid probability > 1 specified for an outcome, resetting to 1.\n";
231 if (probability < 0.0)
233 std::cerr <<
"invalid probability < 0 specified for an outcome, resetting to 0.\n";
237 return {action_id, probability};
245 auto split_commas = v_init<substring>();
246 tokenize(
',', outcome, split_commas);
248 auto split_colons = v_init<substring>();
249 tokenize(
':', split_commas[0], split_colons);
251 if (split_colons.size() != 3)
252 THROW(
"Malformed ccb label");
254 ccb_outcome.probabilities = v_init<ACTION_SCORE::action_score>();
255 ccb_outcome.probabilities.push_back(
convert_to_score(split_colons[0], split_colons[2]));
258 if (std::isnan(ccb_outcome.cost))
259 THROW(
"error NaN cost: " << split_colons[1]);
261 split_colons.clear();
263 for (
size_t i = 1; i < split_commas.size(); i++)
265 tokenize(
':', split_commas[i], split_colons);
266 if (split_colons.size() != 2)
267 THROW(
"Must be action probability pairs");
268 ccb_outcome.probabilities.push_back(
convert_to_score(split_colons[0], split_colons[1]));
271 split_colons.delete_v();
272 split_commas.delete_v();
279 for (
const auto& inclusion : split_inclusions)
290 if (words.
size() < 2)
291 THROW(
"ccb labels may not be empty");
294 THROW(
"ccb labels require the first word to be ccb");
297 auto type = words[1];
300 if (words.
size() > 2)
301 THROW(
"shared labels may not have a cost");
306 if (words.
size() > 2)
307 THROW(
"action labels may not have a cost");
312 if (words.
size() > 4)
313 THROW(
"ccb slot label can only have a type cost and exclude list");
317 for (
size_t i = 2; i < words.
size(); i++)
319 auto is_outcome =
std::find(words[i].begin, words[i].end,
':');
320 if (is_outcome != words[i].end)
324 THROW(
"There may be only 1 outcome associated with a slot.")
341 return result_so_far + action_pred.score;
345 if (total_pred > 1.1
f || total_pred < 0.9
f)
347 THROW(
"When providing all predicition probabilties they must add up to 1.f");
353 THROW(
"unknown label type: " << type);
void delete_label(void *v)
int int_of_substring(substring s)
void parse_label(parser *p, shared_data *, void *v, v_array< substring > &words)
void accumulate(vw &all, parameters &weights, size_t offset)
void copy_array(v_array< T > &dst, const v_array< T > &src)
CCB::conditional_contextual_bandit_outcome * parse_outcome(substring &outcome)
label_parser ccb_label_parser
v_array< substring > parse_name
ACTION_SCORE::action_scores probabilities
void parse_explicit_inclusions(CCB::label *ld, v_array< substring > &split_inclusions)
size_t read_cached_label(shared_data *, void *v, io_buf &cache)
void copy_label(void *dst, void *src)
float float_of_substring(substring s)
void push_back(const T &new_ele)
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void default_label(void *v)
void buf_write(char *&pointer, size_t n)
void cache_label(void *v, io_buf &cache)
ACTION_SCORE::action_score convert_to_score(const substring &action_id_str, const substring &probability_str)
node_pred * find(recall_tree &b, uint32_t cn, example &ec)
v_array< uint32_t > explicit_included_actions
bool substring_equal(const substring &a, const substring &b)
uint32_t convert(size_t number)
float ccb_weight(void *v)
conditional_contextual_bandit_outcome * outcome
size_t buf_read(char *&pointer, size_t n)