99 os << x.
a <<
':' << x.
s;
109 if (_repr !=
nullptr)
127 : min_cost(_min_cost), k(_k), is_opt(_is_opt), cost(_cost)
133 os << x.
k <<
':' << x.
cost;
295 search::search() { priv = &calloc_or_throw<search_private>(); }
299 if (this->priv && this->priv->all)
339 if (ar.
repr !=
nullptr)
343 cdbg <<
"delete_v" << endl;
386 std::cerr <<
"internal error (bug): no valid policies to choose from! defaulting to current" << endl;
390 int num_valid_policies = (int)priv.
current_policy + allow_optimal + allow_current;
393 if (num_valid_policies == 0)
395 std::cerr <<
"internal error (bug): no valid policies to choose from! defaulting to current" << endl;
398 else if (num_valid_policies == 1)
400 else if (num_valid_policies == 2)
411 while ((r > 0) && (pid < num_valid_policies - 1))
419 if (allow_optimal && (pid == num_valid_policies - 1))
422 pid = (int)priv.current_policy - pid;
454 if (PRINT_UPDATE_EVERY_EXAMPLE)
456 if (PRINT_UPDATE_EVERY_PASS && hit_new_pass)
466 if (PRINT_UPDATE_EVERY_EXAMPLE)
468 if (PRINT_UPDATE_EVERY_PASS)
501 for (
size_t i = 0; i < max_len; i++)
502 out[i] = ((i >= in.length()) || (in[i] ==
'\n') || (in[i] ==
'\t')) ?
' ' : in[i];
504 if (in.length() > max_len)
506 out[max_len - 2] =
'.';
507 out[max_len - 1] =
'.';
514 std::stringstream ss;
515 if (big > 9999999999)
516 ss << big / 1000000000 <<
"g";
517 else if (big > 9999999)
518 ss << big / 1000000 <<
"m";
520 ss << big / 1000 <<
"k";
532 const char* header_fmt =
"%-10s %-10s %8s%24s %22s %5s %5s %7s %7s %7s %-8s\n";
533 fprintf(stderr, header_fmt,
"average",
"since",
"instance",
"current true",
"current predicted",
"cur",
"cur",
534 "predic",
"cache",
"examples",
"");
536 fprintf(stderr, header_fmt,
"loss",
"last",
"counter",
"output prefix",
"output prefix",
"pass",
"pol",
"made",
537 "hits",
"gener",
"#run");
539 fprintf(stderr, header_fmt,
"loss",
"last",
"counter",
"output prefix",
"output prefix",
"pass",
"pol",
"made",
540 "hits",
"gener",
"beta");
541 std::cerr.precision(5);
554 float avg_loss_since = 0.;
556 if (use_heldout_loss)
577 fprintf(stderr,
"%-10.6f %-10.6f %8s [%s] [%s] %5d %5d %7s %7s %7s %-8f", avg_loss, avg_loss_since,
579 total_pred.c_str(), total_cach.c_str(), total_exge.c_str(),
582 if (PRINT_CLOCK_TIME)
584 size_t num_sec = (size_t)(((
float)(clock() - priv.
start_clock_time)) / CLOCKS_PER_SEC);
585 std::cerr <<
" " << num_sec <<
"sec";
588 if (use_heldout_loss)
589 fprintf(stderr,
" h");
591 fprintf(stderr,
"\n");
601 uint64_t idx2 = ((idx & mask) >> ss) & mask;
607 std::stringstream temp;
637 for (
size_t n = 0; n < ec_seq.size(); n++)
659 if ((offset < 0) && (n < (uint64_t)(-offset)))
661 else if (n + offset >= ec_seq.size())
665 example& other = *ec_seq[n + offset];
671 size_t sz = fs.size();
672 if ((sz > 0) && (fs.sum_feat_sq > 0.))
707 static constexpr
float log_of_2 = (
float)0.6931471805599453;
708 priv.
beta = (x <= log_of_2) ? -expm1f(-x) : (1 - expf(-x));
716 if (ar.
repr !=
nullptr)
745 cdbg <<
"priv.learn_loss += " << loss <<
" (now = " << priv.
learn_loss <<
")" << endl;
756 cdbg << str <<
" = [";
757 for (
size_t i = 0; i < A.
size(); i++)
cdbg <<
" " << A[i];
758 cdbg <<
" ]" << endl;
763 std::cerr << str <<
" = [";
764 for (
size_t i = 0; i < A.
size(); i++) std::cerr <<
" " << A[i];
765 std::cerr <<
" ]" << endl;
768 size_t random(std::shared_ptr<rand_state>& rs,
size_t max)
770 return (
size_t)(rs->get_and_update_random() * (float)max);
777 for (
size_t i = 0; i < n; i++)
785 const char* condition_on_names,
action_repr* condition_on_actions)
787 if (condition_on_cnt == 0)
790 uint64_t extra_offset = 0;
793 extra_offset = 3849017 * ec.
l.
cs.
costs[0].class_index;
795 size_t I = condition_on_cnt;
797 for (
size_t i = 0; i < I; i++)
799 uint64_t
fid = 71933 + 8491087 * extra_offset;
807 for (
size_t n = 0; n < N; n++)
812 uint64_t name = condition_on_names[i + n];
813 fid = fid * 328901 + 71933 * ((condition_on_actions[i + n].
a + 349101) * (name + 38490137));
824 if ((33 <= name) && (name <= 126))
836 GD::foreach_feature<search_private, uint64_t, add_new_feature>(*priv.
all, ec, priv);
842 cdbg <<
"BEGIN adding passthrough features" << endl;
843 for (
size_t i = 0; i < I; i++)
845 if (condition_on_actions[i].repr ==
nullptr)
848 char name = condition_on_names[i];
849 for (
size_t k = 0; k < fs.
size(); k++)
852 uint64_t
fid = 84913 + 48371803 * (extra_offset + 8392817 * name) + 840137 * (4891 + fs.
indicies[k]);
867 cdbg <<
"END adding passthrough features" << endl;
896 return isCB ? ld.
cb.
costs[k].partial_prediction : ld.
cs.
costs[k].partial_prediction;
938 size_t allowed_actions_cnt,
const float* allowed_actions_cost)
946 if (num_costs > ec_cnt)
948 else if (num_costs < ec_cnt)
954 if (allowed_actions ==
nullptr)
966 for (
action k = 0; k < allowed_actions_cnt; k++)
972 if ((allowed_actions ==
nullptr) || (allowed_actions_cnt == 0))
974 if (num_costs != priv.
A)
984 for (
size_t i = 0; i < allowed_actions_cnt; i++)
cs_cost_push_back(isCB, ld, allowed_actions[i], FLT_MAX);
992 size_t allowed_actions_cnt,
const float* allowed_actions_cost,
const action* oracle_actions,
993 size_t oracle_actions_cnt,
polylabel& lab)
999 for (
action k = 0; k < ec_cnt; k++)
1000 cs_cost_push_back(isCB, lab, k, array_contains<action>(k, oracle_actions, oracle_actions_cnt) ? 0.
f : 1.
f);
1008 if (allowed_actions ==
nullptr)
1020 for (
action k = 0; k < allowed_actions_cnt; k++)
1026 if ((allowed_actions ==
nullptr) || (allowed_actions_cnt == 0))
1028 bool set_to_one =
false;
1038 if (oracle_actions_cnt <= 1)
1042 if (oracle_actions_cnt == 1)
1047 for (
action k = 0; k < priv.
A; k++)
1048 cs_set_cost_loss(isCB, lab, k, array_contains<action>(k + 1, oracle_actions, oracle_actions_cnt) ? 0.
f : 1.
f);
1055 for (
size_t i = 0; i < allowed_actions_cnt; i++)
1057 action k = allowed_actions[i];
1059 isCB, lab, k, (array_contains<action>(k, oracle_actions, oracle_actions_cnt)) ? 0.
f : w);
1077 v.
begin()[pos] = item;
1083 memset(v.
end(), 0,
sizeof(T) * (pos - v.
size()));
1084 v.
begin()[pos] = item;
1091 v.
begin()[pos] = item;
1098 size_t oracle_actions_cnt,
const action* allowed_actions,
size_t allowed_actions_cnt,
1099 const float* allowed_actions_cost)
1104 size_t K = (allowed_actions ==
nullptr) ? priv.
A : allowed_actions_cnt;
1105 cdbg <<
"costs = [";
1106 for (
size_t k = 0; k < K; k++)
cdbg <<
' ' << allowed_actions_cost[k];
1107 cdbg <<
" ]" << endl;
1108 float min_cost = FLT_MAX;
1109 for (
size_t k = 0; k < K; k++) min_cost = std::min(min_cost, allowed_actions_cost[k]);
1110 cdbg <<
"min_cost = " << min_cost;
1111 if (min_cost < FLT_MAX)
1114 for (
size_t k = 0; k < K; k++)
1115 if (allowed_actions_cost[k] <= min_cost)
1117 cdbg <<
", hit @ " << k;
1119 if ((count == 1) || (priv.
_random_state->get_and_update_random() < 1. / (float)count))
1121 a = (allowed_actions ==
nullptr) ? (uint32_t)(k + 1) : allowed_actions[k];
1133 oracle_actions_cnt = 0;
1134 a = (oracle_actions_cnt > 0)
1136 : (allowed_actions_cnt > 0) ? allowed_actions[
random(priv.
_random_state, allowed_actions_cnt)]
1140 cdbg <<
"choose_oracle_action from oracle_actions = [";
1141 for (
size_t i = 0; i < oracle_actions_cnt; i++)
cdbg <<
" " << oracle_actions[i];
1142 cdbg <<
" ], ret=" << a << endl;
1146 *this_cache = v_init<action_cache>();
1150 for (
size_t k = 0; k < K; k++)
1153 float cost =
array_contains(cl, oracle_actions, oracle_actions_cnt) ? 0.f : 1.f;
1158 cdbg <<
"memo_foreach_action[" << priv.
meta_t + priv.
t - 1 <<
"] = " << this_cache <<
" from oracle" << endl;
1164 size_t allowed_actions_cnt,
const float* allowed_actions_cost,
float& a_cost,
1168 vw& all = *priv.
all;
1172 if ((allowed_actions_cnt > 0) || need_partial_predictions)
1177 cdbg <<
"allowed_actions_cnt=" << allowed_actions_cnt <<
", ec.l = [";
1178 for (
size_t i = 0; i < ec.
l.
cs.
costs.size(); i++)
1180 cdbg <<
" ]" << endl;
1185 cdbg <<
"a=" << act <<
" from";
1186 if (allowed_actions)
1188 for (
size_t ii = 0; ii < allowed_actions_cnt; ii++)
cdbg <<
' ' << allowed_actions[ii];
1192 cdbg <<
"a_cost = " << a_cost << endl;
1194 if (override_action != (
action)-1)
1195 act = override_action;
1197 if (need_partial_predictions)
1200 float min_cost = FLT_MAX;
1201 for (
size_t k = 0; k < K; k++)
1204 if (cost < min_cost)
1211 *this_cache = v_init<action_cache>();
1213 for (
size_t k = 0; k < K; k++)
1219 if (override_action == cl)
1228 cdbg <<
"memo_foreach_action[" << priv.
meta_t + priv.
t - 1 <<
"] = " << this_cache << endl;
1235 float min_cost = FLT_MAX, min_cost2 = FLT_MAX;
1236 for (
size_t k = 0; k < K; k++)
1239 if (cost < min_cost)
1241 min_cost2 = min_cost;
1244 else if (cost < min_cost2)
1249 if (min_cost2 < FLT_MAX)
1255 THROW(
"cannot use active_csoaa with cb learning");
1256 size_t cur_t = priv.
t + priv.
meta_t - 1;
1265 for (
size_t k = 0; k < ec.
l.
cs.
costs.size(); k++)
1276 std::pair<CS::wclass&, bool> p = {wc, query_needed};
1322 float best_prediction = 0.;
1328 if (need_partial_predictions)
1331 *this_cache = v_init<action_cache>();
1334 for (
action a = (uint32_t)start_K;
a < ec_cnt;
a++)
1336 cdbg <<
"== single_prediction_LDF a=" <<
a <<
"==" << endl;
1346 tmp.push_back(&ecs[a]);
1352 if (override_action != (
action)-1)
1354 if (a == override_action)
1357 else if ((a == start_K) || (ecs[a].partial_prediction < best_prediction))
1361 a_cost = best_prediction;
1367 ecs[
a].
l = old_label;
1371 if (override_action != (
action)-1)
1372 best_action = override_action;
1374 a_cost = best_prediction;
1378 for (
size_t i = 0; i < this_cache->
size(); i++)
1382 ac.
is_opt = (ac.
k == best_action);
1425 THROW(
"internal error (bug): trying to rollin or rollout with NO_ROLLOUT");
1435 return memcmp(A, B, sz_A) == 0;
1439 const char* condition_on_names,
action_repr* condition_on_actions,
size_t condition_on_cnt,
int policy,
1440 size_t learner_id,
action&
a,
bool do_store,
float& a_cost)
1447 size_t sz =
sizeof(size_t) +
sizeof(
ptag) +
sizeof(int) +
sizeof(
size_t) +
sizeof(size_t) +
1448 condition_on_cnt * (
sizeof(
ptag) +
sizeof(
action) +
sizeof(
char));
1452 unsigned char* item = calloc_or_throw<unsigned char>(sz);
1453 unsigned char* here = item;
1454 *here = (
unsigned char)sz;
1455 here +=
sizeof(size_t);
1457 here +=
sizeof(
ptag);
1459 here +=
sizeof(int);
1460 *here = (
unsigned char)learner_id;
1461 here +=
sizeof(size_t);
1462 *here = (
unsigned char)condition_on_cnt;
1463 here +=
sizeof(size_t);
1464 for (
size_t i = 0; i < condition_on_cnt; i++)
1466 *here = condition_on[i];
1467 here +=
sizeof(
ptag);
1468 *here = condition_on_actions[i].
a;
1470 *here = condition_on_names[i];
1471 here +=
sizeof(char);
1491 float min_loss = FLT_MAX)
1498 if (min_loss == FLT_MAX)
1499 for (
size_t i = 0; i < losses.
cb.
costs.size(); i++) min_loss = std::min(min_loss, losses.
cb.
costs[i].cost);
1500 for (
size_t i = 0; i < losses.
cb.
costs.size(); i++) losses.
cb.
costs[i].cost = losses.
cb.
costs[i].cost - min_loss;
1504 if (min_loss == FLT_MAX)
1505 for (
size_t i = 0; i < losses.
cs.
costs.size(); i++) min_loss = std::min(min_loss, losses.
cs.
costs[i].x);
1506 for (
size_t i = 0; i < losses.
cs.
costs.size(); i++)
1523 if (add_conditioning)
1526 for (
size_t is_local = 0; is_local <= (size_t)priv.
xv; is_local++)
1530 cdbg <<
"BEGIN base_learner->learn(ec, " << learner <<
")" << endl;
1532 cdbg <<
"END base_learner->learn(ec, " << learner <<
")" << endl;
1534 if (add_conditioning)
1545 if (add_conditioning)
1553 for (
size_t is_local = 0; is_local <= (size_t)priv.
xv; is_local++)
1560 uint64_t tmp_offset = 0;
1567 if (lab.
costs.size() == 0)
1569 CS::wclass wc = {0.,
a - (uint32_t)start_K, 0., 0.};
1570 lab.
costs.push_back(wc);
1578 cdbg <<
"generate_training_example called learn on action a=" <<
a <<
", costs.size=" << lab.
costs.size()
1579 <<
" ec=" << &ec << endl;
1592 if (add_conditioning)
1636 cdbg <<
"foreach_action_from_cache: t=" << t <<
", memo_foreach_action.size()=" << priv.
memo_foreach_action.size()
1637 <<
", override_a=" << override_a << endl;
1642 cdbg <<
"memo_foreach_action size = " << cached->size() << endl;
1643 for (
size_t id = 0;
id < cached->size();
id++)
1653 size_t oracle_actions_cnt,
const ptag* condition_on,
const char* condition_on_names,
const action* allowed_actions,
1654 size_t allowed_actions_cnt,
const float* allowed_actions_cost,
size_t learner_id,
float& a_cost,
float )
1656 size_t condition_on_cnt = condition_on_names ? strlen(condition_on_names) : 0;
1657 size_t t = priv.
t + priv.
meta_t;
1661 assert((oracle_actions ==
nullptr) == (oracle_actions_cnt == 0));
1662 assert((condition_on ==
nullptr) == (condition_on_names ==
nullptr));
1663 assert(((allowed_actions ==
nullptr) && (allowed_actions_cost ==
nullptr)) == (allowed_actions_cnt == 0));
1665 if (allowed_actions_cost !=
nullptr)
1666 assert(oracle_actions ==
nullptr);
1672 priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt, allowed_actions_cost);
1685 cdbg <<
"LEARN " << t <<
" < priv.learn_t ==> a=" << a <<
", a_cost=" << a_cost << endl;
1694 size_t valid_action_cnt = priv.
is_ldf ? ec_cnt : (allowed_actions_cnt > 0) ? allowed_actions_cnt : priv.
A;
1705 cdbg <<
"LEARN " << t <<
" = priv.learn_t ==> a=" << a <<
", learn_a_idx=" << priv.
learn_a_idx 1706 <<
" valid_action_cnt=" << valid_action_cnt << endl;
1716 if (oracle_actions_cnt > 0)
1727 for (
size_t i = 0; i < ec_cnt; i++)
1744 for (
size_t i = 0; i < condition_on_cnt; i++)
1751 if (condition_on_names ==
nullptr)
1763 if (allowed_actions && (allowed_actions_cnt > 0))
1771 assert((allowed_actions_cnt == 0) || (a < allowed_actions_cnt));
1774 action a_name = (allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[
a] : priv.
is_ldf ?
a : (a + 1);
1780 cdbg <<
"@ memo_foreach_action: t=" << t <<
", a=" << a <<
", cost=" << (*priv.
memo_foreach_action[t])[a].cost
1796 cdbg <<
"... skipping" << endl;
1797 action a = priv.
is_ldf ? 0 : ((allowed_actions && (allowed_actions_cnt > 0)) ? allowed_actions[0] : 1);
1813 cdbg <<
"executing policy " << policy << endl;
1824 cdbg <<
"maybe_override_prediction --> " << skip <<
", a=" << a <<
", a_cost=" << a_cost << endl;
1829 if ((!skip) && (policy == -1))
1830 a =
choose_oracle_action(priv, ec_cnt, oracle_actions, oracle_actions_cnt, allowed_actions, allowed_actions_cnt,
1831 allowed_actions_cost);
1835 if ((policy >= 0) || gte_here || need_fea)
1840 for (
size_t i = 0; i < condition_on_cnt; i++)
1847 if ((!skip) && (!need_fea) && not_test &&
1849 condition_on_cnt, policy, learner_id,
a,
false, a_cost))
1857 for (
size_t n = start_K; n < ec_cnt; n++)
1861 if (((!skip) && (policy >= 0)) || need_fea)
1867 THROW(
"search cannot use state representations in ldf mode");
1869 if (ecs[0].passthrough)
1871 THROW(
"search cannot passthrough");
1877 allowed_actions_cost, a_cost, need_fea ? a : (
action)-1);
1879 cdbg <<
"passthrough = [";
1882 cdbg <<
" ]" << endl;
1894 cdbg <<
"INIT_TRAIN, NO_ROLLOUT, at least one oracle_actions, a=" << a << endl;
1899 oracle_actions, oracle_actions_cnt, priv.
gte_label);
1900 cdbg <<
"priv.gte_label = [";
1903 cdbg <<
" ]" << endl;
1907 if (allowed_actions)
1922 if (not_test && (!skip))
1924 condition_on_cnt, policy, learner_id,
a,
true, a_cost);
1937 THROW(
"error: predict called in unknown state");
1940 inline bool cmp_size_t(
const size_t a,
const size_t b) {
return a < b; }
1941 inline bool cmp_size_t_pair(
const std::pair<size_t, size_t>&
a,
const std::pair<size_t, size_t>& b)
1943 return ((a.first == b.first) && (a.second < b.second)) || (a.first < b.first);
1946 inline size_t absdiff(
size_t a,
size_t b) {
return (a < b) ? (b -
a) : (a - b); }
1955 size_t* A = calloc_or_throw<size_t>((N + 1) * 2);
1957 A[N + 1] = B[N - 1];
1958 size_t lo = N, hi = N + 1;
1959 size_t i = 0, j = N - 1;
1963 size_t d1 =
absdiff(A[lo], B[i + 1]);
1964 size_t d2 =
absdiff(A[lo], B[j - 1]);
1965 size_t d3 =
absdiff(A[hi], B[i + 1]);
1966 size_t d4 =
absdiff(A[hi], B[j - 1]);
1967 size_t mx = std::max(std::max(d1, d2), std::max(d3, d4));
1978 memcpy(B, A + lo, N *
sizeof(
size_t));
2002 for (
size_t t = 0; t < priv.
T; t++)
2004 uint32_t count = 99;
2008 for (std::pair<CS::wclass&, bool> wcq : priv.
active_known[t])
2023 for (
size_t t = 0; t < priv.
T; t++)
2027 if (timesteps.
size() == 0)
2037 size_t t = (size_t)(priv.
_random_state->get_and_update_random() * (float)priv.
T);
2063 void BaseTask::Run()
2068 if (!_final_run && !_with_output_string)
2086 if (_with_output_string && old_should_produce_string)
2111 float threshold = multiplier / std::sqrt((
float)t);
2112 cdbg <<
"verify_active_csoaa, losses = [";
2114 cdbg <<
" ]" << endl;
2119 if (!known[i].second)
2121 float err = pow(known[i].first.partial_prediction - wc.
x, 2);
2122 if (err > threshold)
2124 std::cerr <<
"verify_active_csoaa failed: truth " << wc.
class_index <<
":" << wc.
x <<
", known[" << i
2125 <<
"]=" << known[i].first.partial_prediction <<
", error=" << err <<
" vs threshold " << threshold
2142 cdbg <<
"advance_from_known_actions t=" << t <<
" active_known.size()=" << priv.
active_known.size()
2147 cdbg <<
"advance_from_known_actions setting done_with_all_actions=true (active_known[t].size()=" 2170 template <
bool is_learn>
2174 vw& all = *priv.
all;
2175 bool ran_test =
false;
2180 cdbg <<
"is_test_ex=" << is_test_ex <<
" vw_is_main=" << all.
vw_is_main << endl;
2185 cdbg <<
"======================================== INIT TEST (" << priv.
current_policy <<
"," 2211 if ((!is_learn) || is_test_ex || is_holdout_ex || ec_seq[0]->test_only || (!priv.
all->
training))
2217 cdbg <<
"======================================== INIT TRAIN (" << priv.
current_policy <<
"," 2240 cdbg <<
"======================================== LEARN (" << priv.
current_policy <<
"," 2249 cdbg <<
"memo_foreach_action[" << i <<
"] = ";
2268 cdbg <<
"skipping because it looks like this was overridden by metatask" << endl;
2276 bool skipped_all_actions =
true;
2287 skipped_all_actions =
false;
2292 cdbg <<
"-------------------------------------------------------------------------------------" << endl;
2309 if (skipped_all_actions)
2315 cdbg <<
"<<<<<" << endl;
2316 cdbg <<
"skipped all actions; learn_t = " << priv.
learn_t <<
", learn_a_idx = " << priv.
learn_a_idx << endl;
2318 cdbg <<
">>>>>" << endl;
2321 cdbg <<
"didn't skip all actions" << endl;
2335 cdbg <<
"priv.learn_losses = [";
2337 cdbg <<
" ]" << endl;
2338 cdbg <<
"gte" << endl;
2358 if (this_num > prev_num)
2371 std::cerr <<
"warning: turning off AUTO_CONDITION_FEATURES because settings make it useless" << endl;
2377 template <
bool is_learn>
2380 if (ec_seq.size() == 0)
2383 bool is_test_ex =
false;
2384 bool is_holdout_ex =
false;
2387 priv.
offset = ec_seq[0]->ft_offset;
2397 for (
size_t i = 0; i < ec_seq.size(); i++)
2400 is_holdout_ex |= ec_seq[i]->test_only;
2401 if (is_test_ex && is_holdout_ex)
2409 cdbg <<
"======================================== GET TRUTH STRING (" << priv.
current_policy <<
"," 2426 train_single_example<is_learn>(sch, is_test_ex, is_holdout_ex, ec_seq);
2448 std::cerr <<
"internal error (bug): too many policies; not advancing" << endl;
2524 priv.
active_known = v_init<v_array<std::pair<CS::wclass&, bool>>>();
2534 void ensure_param(
float& v,
float lo,
float hi,
float def,
const char* str)
2536 if ((v < lo) || (v > hi))
2538 std::cerr << str << endl;
2549 .help(
"add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 " 2550 "(default), then you get a single feature for each conditional"));
2554 .help(
"add bias *times* input features for each ngram up to and including this length (def: 0)"));
2558 .help(
"how much weight should the conditional features get? (def: 1.)"));
2561 .help(
"should we use lower-level reduction _internal state_ as additional features? (def: no)"));
2568 cdbg <<
"search_finish" << endl;
2581 FILE*
f = fopen(filename,
"r");
2583 THROW(
"error: could not read file " << filename <<
" (" << strerror(errno)
2584 <<
"); assuming all transitions are valid");
2586 bool* bg = calloc_or_throw<bool>(((size_t)(A + 1)) * (A + 1));
2587 int rd, from, to, count = 0;
2588 while ((rd = fscanf(f,
"%d:%d", &from, &to)) > 0)
2590 if ((from < 0) || (from > (int)A))
2592 std::cerr <<
"warning: ignoring transition from " << from <<
" because it's out of the range [0," << A <<
"]" 2595 if ((to < 0) || (to > (int)A))
2597 std::cerr <<
"warning: ignoring transition to " << to <<
" because it's out of the range [0," << A <<
"]" << endl;
2599 bg[from * (A + 1) + to] =
true;
2606 for (
size_t from = 0; from < A; from++)
2610 for (
size_t to = 0; to < A; to++)
2611 if (bg[from * (A + 1) + to])
2622 std::cerr <<
"read " << count <<
" allowed transitions from " << filename << endl;
2631 size_t len = nf_string.length();
2635 char* cstr =
new char[len + 1];
2636 strcpy(cstr, nf_string.c_str());
2638 char* p = strtok(cstr,
",");
2639 std::vector<substring> cmd;
2648 if (cmd.size() == 1)
2653 else if (cmd.size() == 2)
2656 ns = (cmd[1].end > cmd[1].begin) ? cmd[1].begin[0] :
' ';
2660 std::cerr <<
"warning: ignoring malformed neighbor specification: '" << p <<
"'" << endl;
2662 int32_t enc = (posn << 24) | (ns & 0xFF);
2665 p = strtok(
nullptr,
",");
2675 std::string task_string;
2676 std::string metatask_string;
2677 std::string interpolation_string =
"data";
2678 std::string neighbor_features_string;
2679 std::string rollout_string =
"mix_per_state";
2680 std::string rollin_string =
"mix_per_state";
2682 uint32_t search_trained_nb_policies;
2683 std::string search_allowed_transitions;
2688 make_option(
"search", priv.A).keep().help(
"Use learning to search, argument=maximum action id or 0 for LDF"));
2691 .help(
"the search task (use \"--search_task list\" to get a list of available tasks)"));
2695 .help(
"the search metatask (use \"--search_metatask list\" to get a list of available metatasks)"));
2696 new_options.
add(
make_option(
"search_interpolation", interpolation_string)
2698 .help(
"at what level should interpolation happen? [*data|policy]"));
2701 .help(
"how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]"));
2703 .help(
"how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]"));
2704 new_options.
add(
make_option(
"search_passes_per_policy", priv.passes_per_policy)
2706 .help(
"number of passes per policy (only valid for search_interpolation=policy)"));
2708 .default_value(0.5f)
2709 .help(
"interpolation rate for policies (only valid for search_interpolation=policy)"));
2711 .default_value(1e-10
f)
2712 .help(
"annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data)"));
2713 new_options.
add(
make_option(
"search_total_nb_policies", priv.total_number_of_policies)
2714 .help(
"if we are going to train the policies through multiple separate calls to vw, we need to " 2715 "specify this parameter and tell vw how many policies are eventually going to be trained"));
2716 new_options.
add(
make_option(
"search_trained_nb_policies", search_trained_nb_policies)
2717 .help(
"the number of trained policies in a file"));
2718 new_options.
add(
make_option(
"search_allowed_transitions", search_allowed_transitions)
2719 .help(
"read file of allowed transitions [def: all transitions are allowed]"));
2720 new_options.
add(
make_option(
"search_subsample_time", priv.subsample_timesteps)
2721 .help(
"instead of training at all timesteps, use a subset. if value in (0,1), train on a random " 2722 "v%. if v>=1, train on precisely v steps per example, if v<=-1, use active learning"));
2724 make_option(
"search_neighbor_features", neighbor_features_string)
2726 .help(
"copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line " 2727 "namespace a and next next line from namespace _unnamed_, where ',' separates them"));
2728 new_options.
add(
make_option(
"search_rollout_num_steps", priv.rollout_num_steps)
2729 .help(
"how many calls of \"loss\" before we stop really predicting on rollouts and switch to " 2730 "oracle (default means \"infinite\")"));
2731 new_options.
add(
make_option(
"search_history_length", priv.history_length)
2734 .help(
"some tasks allow you to specify how much history their depend on; specify that here"));
2735 new_options.
add(
make_option(
"search_no_caching", priv.no_caching)
2736 .help(
"turn off the built-in caching ability (makes things slower, but technically more safe)"));
2738 make_option(
"search_xv", priv.xv).help(
"train two separate policies, alternating prediction/learning"));
2739 new_options.
add(
make_option(
"search_perturb_oracle", priv.perturb_oracle)
2741 .help(
"perturb the oracle on rollin with this probability"));
2742 new_options.
add(
make_option(
"search_linear_ordering", priv.linear_ordering)
2743 .help(
"insist on generating examples in linear order (def: hoopla permutation)"));
2744 new_options.
add(
make_option(
"search_active_verify", priv.active_csoaa_verify)
2745 .help(
"verify that active learning is doing the right thing (arg = multiplier, should be = " 2746 "cost_range * range_c)"));
2747 new_options.
add(
make_option(
"search_save_every_k_runs", priv.save_every_k_runs).help(
"save model every k runs"));
2757 if (interpolation_string.compare(
"data") == 0)
2759 priv.adaptive_beta =
true;
2760 priv.allow_current_policy =
true;
2762 if (priv.current_policy > 1)
2763 priv.current_policy = 1;
2765 else if (interpolation_string.compare(
"policy") == 0)
2768 THROW(
"error: --search_interpolation must be 'data' or 'policy'");
2770 if ((rollout_string.compare(
"policy") == 0) || (rollout_string.compare(
"learn") == 0))
2771 priv.rollout_method =
POLICY;
2772 else if ((rollout_string.compare(
"oracle") == 0) || (rollout_string.compare(
"ref") == 0))
2773 priv.rollout_method =
ORACLE;
2774 else if ((rollout_string.compare(
"mix_per_state") == 0))
2776 else if ((rollout_string.compare(
"mix_per_roll") == 0) || (rollout_string.compare(
"mix") == 0))
2778 else if ((rollout_string.compare(
"none") == 0))
2781 priv.no_caching =
true;
2784 THROW(
"error: --search_rollout must be 'learn', 'ref', 'mix', 'mix_per_state' or 'none'");
2786 if ((rollin_string.compare(
"policy") == 0) || (rollin_string.compare(
"learn") == 0))
2787 priv.rollin_method =
POLICY;
2788 else if ((rollin_string.compare(
"oracle") == 0) || (rollin_string.compare(
"ref") == 0))
2789 priv.rollin_method =
ORACLE;
2790 else if ((rollin_string.compare(
"mix_per_state") == 0))
2792 else if ((rollin_string.compare(
"mix_per_roll") == 0) || (rollin_string.compare(
"mix") == 0))
2795 THROW(
"error: --search_rollin must be 'learn', 'ref', 'mix' or 'mix_per_state'");
2798 priv.allowed_actions_cache = &calloc_or_throw<polylabel>();
2801 priv.cb_learner =
true;
2803 priv.learn_losses.cb.costs = v_init<CB::cb_class>();
2804 priv.gte_label.cb.costs = v_init<CB::cb_class>();
2808 priv.cb_learner =
false;
2810 priv.learn_losses.cs.costs = v_init<CS::wclass>();
2811 priv.gte_label.cs.costs = v_init<CS::wclass>();
2814 ensure_param(priv.beta, 0.0, 1.0, 0.5,
"warning: search_beta must be in (0,1); resetting to 0.5");
2815 ensure_param(priv.alpha, 0.0, 1.0, 1e-10
f,
"warning: search_alpha must be in (0,1); resetting to 1e-10");
2817 priv.num_calls_to_run = 0;
2821 uint32_t tmp_number_of_policies = priv.current_policy;
2823 tmp_number_of_policies += (int)ceil(((
float)all.
numpasses) / ((float)priv.passes_per_policy));
2827 cdbg <<
"current_policy=" << priv.current_policy <<
" tmp_number_of_policies=" << tmp_number_of_policies
2828 <<
" total_number_of_policies=" << priv.total_number_of_policies << endl;
2829 if (tmp_number_of_policies > priv.total_number_of_policies)
2831 priv.total_number_of_policies = tmp_number_of_policies;
2832 if (priv.current_policy >
2834 std::cerr <<
"warning: you're attempting to train more classifiers than was allocated initially. Likely to cause " 2842 if (!all.
training && priv.current_policy > 0)
2843 priv.current_policy--;
2851 cdbg <<
"search current_policy = " << priv.current_policy
2852 <<
" total_number_of_policies = " << priv.total_number_of_policies << endl;
2854 if (task_string.compare(
"list") == 0)
2856 std::cerr << endl <<
"available search tasks:" << endl;
2857 for (
search_task** mytask = all_tasks; *mytask !=
nullptr; mytask++)
2858 std::cerr <<
" " << (*mytask)->task_name << endl;
2862 if (metatask_string.compare(
"list") == 0)
2864 std::cerr << endl <<
"available search metatasks:" << endl;
2865 for (
search_metatask** mytask = all_metatasks; *mytask !=
nullptr; mytask++)
2866 std::cerr <<
" " << (*mytask)->metatask_name << endl;
2870 for (
search_task** mytask = all_tasks; *mytask !=
nullptr; mytask++)
2871 if (task_string.compare((*mytask)->task_name) == 0)
2873 priv.task = *mytask;
2874 sch->task_name = (*mytask)->task_name;
2877 if (priv.task ==
nullptr)
2880 THROW(
"fail: unknown task for --search_task '" << task_string <<
"'; use --search_task list to get a list");
2882 priv.metatask =
nullptr;
2883 for (
search_metatask** mytask = all_metatasks; *mytask !=
nullptr; mytask++)
2884 if (metatask_string.compare((*mytask)->metatask_name) == 0)
2886 priv.metatask = *mytask;
2887 sch->metatask_name = (*mytask)->metatask_name;
2899 priv.active_csoaa_verify = -1.;
2901 if (!priv.active_csoaa)
2902 THROW(
"cannot use --search_active_verify without using --cs_active");
2904 cdbg <<
"active_csoaa = " << priv.active_csoaa <<
", active_csoaa_verify = " << priv.active_csoaa_verify << endl;
2911 if (priv.task && priv.task->initialize)
2912 priv.task->initialize(*sch.get(), priv.A, options);
2913 if (priv.metatask && priv.metatask->initialize)
2914 priv.metatask->initialize(*sch.get(), priv.A, options);
2917 if (options.
was_supplied(
"search_allowed_transitions"))
2923 if (!priv.allow_current_policy)
2928 priv.start_clock_time = clock();
2931 priv.num_learners *= 3;
2933 cdbg <<
"num_learners = " << priv.num_learners << endl;
2936 do_actual_learning<false>, priv.total_number_of_policies * priv.num_learners);
2948 for (
size_t i = 0; i < sz; i++)
2957 return costs[a - 1];
2958 for (
size_t i = 0; i < sz; i++)
2961 THROW(
"action_cost_loss got action that wasn't allowed: " << a);
2965 bool search::is_ldf() {
return priv->
is_ldf; }
2968 const ptag* condition_on,
const char* condition_on_names,
const action* allowed_actions,
size_t allowed_actions_cnt,
2969 const float* allowed_actions_cost,
size_t learner_id,
float weight)
2972 action a =
search_predict(*priv, &ec, 1, mytag, oracle_actions, oracle_actions_cnt, condition_on, condition_on_names,
2973 allowed_actions, allowed_actions_cnt, allowed_actions_cost, learner_id, a_cost, weight);
2978 if (mytag < priv->ptag_to_action.size())
2980 cdbg <<
"delete_v at " << mytag << endl;
2994 cdbg <<
"push_at " << mytag << endl;
2999 cdbg <<
"predict returning " << a << endl;
3004 size_t oracle_actions_cnt,
const ptag* condition_on,
const char* condition_on_names,
size_t learner_id,
3009 action a =
search_predict(*priv, ecs, ec_cnt, mytag, oracle_actions, oracle_actions_cnt, condition_on,
3010 condition_on_names,
nullptr, 0,
nullptr, learner_id, a_cost, weight);
3020 if ((mytag != 0) && ecs[action_index].l.cs.costs.size() > 0)
3022 if (mytag < priv->ptag_to_action.size())
3024 cdbg <<
"delete_v at " << mytag << endl;
3035 cdbg <<
"predict returning " << a << endl;
3043 std::stringstream& search::output()
3053 void search::set_options(uint32_t opts)
3056 std::cerr <<
"warning: task should not set options except in initialize function!" << endl;
3057 if ((opts & AUTO_CONDITION_FEATURES) != 0)
3063 if ((opts &
IS_LDF) != 0)
3064 this->priv->
is_ldf =
true;
3070 if (this->priv->
is_ldf && this->priv->use_action_costs)
3071 THROW(
"using LDF and actions costs is not yet implemented; turn off action costs");
3075 <<
"warning: task is designed to use rollout costs, but this only works when --search_rollout none is specified" 3082 std::cerr <<
"warning: task should not set label parser except in initialize function!" << endl;
3083 this->priv->
all->
p->
lp = lp;
3088 void search::get_test_action_sequence(std::vector<action>& V)
3091 for (
size_t i = 0; i < this->priv->
test_action_sequence.size(); i++) V.push_back(this->priv->test_action_sequence[i]);
3094 void search::set_num_learners(
size_t num_learners) { this->priv->
num_learners = num_learners; }
3098 uint32_t search::get_history_length() {
return (uint32_t)this->priv->
history_length; }
3109 std::ostringstream os;
3115 vw& search::get_vw_pointer_unsafe() {
return *this->priv->
all; }
3116 void search::set_force_oracle(
bool force) { this->priv->
force_oracle = force; }
3126 , oracle_is_pointer(false)
3127 , allowed_is_pointer(false)
3128 , allowed_cost_is_pointer(false)
3177 ec = &input_example;
3199 if (temp !=
nullptr)
3202 THROW(
"realloc failed in search.cc");
3205 ec = calloc_or_throw<example>(input_length);
3212 THROW(
"call to set_input_at without previous call to set_input_length");
3215 THROW(
"call to set_input_at with too large a position: posn (" << posn <<
") >= ec_cnt(" <<
ec_cnt <<
")");
3224 size_t old_size = A.
size();
3225 T* old_pointer = A.
begin();
3226 A.
begin() = calloc_or_throw<T>(new_size);
3229 memcpy(A.
begin(), old_pointer, old_size *
sizeof(T));
3239 size_t new_size = clear_first ? 1 : (A.
size() + 1);
3240 make_new_pointer<T>(A, new_size);
3242 A[new_size - 1] =
a;
3256 size_t old_size = A.
size();
3266 size_t new_size = old_size + count;
3267 make_new_pointer<T>(A, new_size);
3270 memcpy(A.begin() + old_size,
a, count *
sizeof(T));
3277 push_many<T>(A,
a, count);
3287 A.
end() = a + count;
3375 for (
size_t i = 0; i <
a.size(); i++)
3384 for (
size_t i = 0; i <
a.size(); i++)
3431 for (
ptag i = 0; i < count; i++)
3435 char name = name0 + i;
3464 const char* cNa =
nullptr;
3475 :
sch.
predict(*
ec,
my_tag, orA,
oracle_actions.
size(), cOn, cNa, alA, numAlA, alAcosts,
learner_id,
weight);
void add_example_conditioning(search_private &priv, example &ec, size_t condition_on_cnt, const char *condition_on_names, action_repr *condition_on_actions)
void hoopla_permute(size_t *B, size_t *end)
void free_final_item(final_item *p)
int int_of_substring(substring s)
predictor & set_oracle(action a)
void resize(size_t length)
void copy_label(void *dst, void *src)
constexpr unsigned char conditioning_namespace
void to_short_string(std::string in, size_t max_len, char *out)
v_array< CS::label > read_allowed_transitions(action A, const char *filename)
bool cmp_size_t_pair(const std::pair< size_t, size_t > &a, const std::pair< size_t, size_t > &b)
v_array< namespace_index > indices
void(* copy_label)(void *, void *)
auto_condition_settings acset
void predict(E &ec, size_t i=0)
final_item(v_array< scored_action > *p, std::string s, float ic)
std::stringstream dat_new_feature_audit_ss
bool must_run_test(vw &all, multi_ex &ec, bool is_test_ex)
vw * setup(options_i &options)
void deep_copy_from(const features &src)
v_array< char > condition_on_names
size_t read_example_last_pass
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
std::ostream & operator<<(std::ostream &os, const action_cache &x)
std::string audit_feature_space("conditional")
void train_single_example(search &sch, bool is_test_ex, bool is_holdout_ex, multi_ex &ec_seq)
void push_back(feature_value v, feature_index i)
bool mc_label_is_test(polylabel &lab)
VW::config::options_i * options
std::stringstream * pred_string
search_metatask * metatask
example * dat_new_feature_ec
v_array< float > allowed_actions_cost
void(* delete_label)(void *)
void copy_example_data(bool audit, example *dst, example *src)
std::shared_ptr< audit_strings > audit_strings_ptr
float action_cost_loss(action a, const action *act, const float *costs, size_t sz)
void search_initialize(vw *all, search &sch)
void push_at(v_array< T > &v, T item, size_t pos)
bool might_print_update(vw &all)
void finish_multiline_example(vw &all, cbify &, multi_ex &ec_seq)
search_metatask * all_metatasks[]
size_t max_bias_ngram_length
constexpr int quadratic_constant
bool cached_action_store_or_find(search_private &priv, ptag mytag, const ptag *condition_on, const char *condition_on_names, action_repr *condition_on_actions, size_t condition_on_cnt, int policy, size_t learner_id, action &a, bool do_store, float &a_cost)
v_array< feature_index > indicies
predictor & add_condition(ptag tag, char name)
void adjust_auto_condition(search_private &priv)
v_array< v_array< std::pair< CS::wclass &, bool > > > active_known
void(* default_label)(void *)
virtual void replace(const std::string &key, const std::string &value)=0
label_type::label_type_t label_type
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
void set_input_at(size_t posn, example &input_example)
bool(* test_label)(void *)
void set_input_length(size_t input_length)
constexpr unsigned char neighbor_namespace
void del_neighbor_features(search_private &priv, multi_ex &ec_seq)
void del_example_namespaces_from_example(example &target, example &source)
v_array< int > final_prediction_sink
v_array< scored_action > * prefix
void(* run_setup)(search &, multi_ex &)
void run_task(search &sch, multi_ex &ec)
the core definition of a set of features.
v_array< cb_class > costs
v_hashmap< unsigned char *, scored_action > cache_hash_map
void generate_training_example(search_private &priv, polylabel &losses, float weight, bool add_conditioning=true, float min_loss=FLT_MAX)
void parse_neighbor_features(std::string &nf_string, search &sch)
v_array< ptag > learn_condition_on
VW_STD14_CONSTEXPR uint64_t uniform_hash(const void *key, size_t len, uint64_t seed)
base_learner * make_base(learner< T, E > &base)
v_array< ptag > condition_on_tags
void delete_label(void *v)
polylabel * allowed_actions_cache
v_array< feature_value > values
int random_policy(search_private &priv, bool allow_current, bool allow_optimal, bool advance_prng=true)
search_task * all_tasks[]
float safediv(float a, float b)
virtual void add_and_parse(const option_group_definition &group)=0
action single_prediction_notLDF(search_private &priv, example &ec, int policy, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, float &a_cost, action override_action)
predictor & set_weight(float w)
bool ec_is_example_header(example const &ec)
size_t max_quad_ngram_length
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
void do_actual_learning(search &sch, base_learner &base, multi_ex &ec_seq)
void handle_condition_options(vw &all, auto_condition_settings &acset)
constexpr bool PRINT_UPDATE_EVERY_EXAMPLE
size_t check_holdout_every_n_passes
void(* run)(search &, multi_ex &)
void advance_from_known_actions(search_private &priv)
bool auto_condition_features
void(* run_takedown)(search &, multi_ex &)
uint64_t dat_new_feature_idx
int select_learner(search_private &priv, int policy, size_t learner_id, bool is_training, bool is_local)
v_array< size_t > timesteps
std::string rawOutputString
void(* _foreach_action)(search &, size_t, float, action, bool, float)
LEARNER::base_learner * base_learner
void cs_cost_push_back(bool isCB, polylabel &ld, uint32_t index, float value)
RollMethod rollout_method
void save_predictor(vw &all, std::string reg_name, size_t current_pass)
double sum_loss_since_last_dump
bool allowed_cost_is_pointer
void(* _post_prediction)(search &, size_t, action, float)
v_array< action_repr > ptag_to_action
std::shared_ptr< rand_state > get_random_state()
std::array< features, NUM_NAMESPACES > feature_space
std::unique_ptr< T, free_fn > free_ptr
single_learner * as_singleline(learner< T, E > *l)
uint32_t AUTO_CONDITION_FEATURES
size_t total_examples_generated
void clear_cache_hash_map(search_private &priv)
predictor & add_oracle(action a)
MULTICLASS::label_t multi
predictor & add_condition_range(ptag hi, ptag count, char name0)
bool use_passthrough_repr
void set_finish_example(void(*f)(vw &all, T &, E &))
void clear_memo_foreach_action(search_private &priv)
float dat_new_feature_value
void reset_search_structure(search_private &priv)
bool(* label_is_test)(polylabel &)
size_t passes_since_new_policy
Search::search_metatask metatask
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
size_t num_calls_to_run_previous
void push_back(const T &new_ele)
predictor & erase_alloweds()
typed_option< T > & get_typed_option(const std::string &key)
void end_pass(example &ec, vw &all)
size_t random(std::shared_ptr< rand_state > &rs, size_t max)
action search_predict(search_private &priv, example *ecs, size_t ec_cnt, ptag mytag, const action *oracle_actions, size_t oracle_actions_cnt, const ptag *condition_on, const char *condition_on_names, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, size_t learner_id, float &a_cost, float)
void print_update(search_private &priv)
std::string neighbor_feature_space("neighbor")
v_array< v_array< action_cache > * > memo_foreach_action
bool cmp_size_t(const size_t a, const size_t b)
size_t total_predictions_made
float subsample_timesteps
void tokenize(char delim, substring s, ContainerT &ret, bool allow_empty=false)
void cs_costs_erase(bool isCB, polylabel &ld)
void add_neighbor_features(search_private &priv, multi_ex &ec_seq)
double old_weighted_labeled_examples
void foreach_action_from_cache(search_private &priv, size_t t, action override_a=(action) -1)
virtual bool was_supplied(const std::string &key)=0
double weighted_holdout_examples
std::string * dat_new_feature_feature_space
void search_declare_loss(search_private &priv, float loss)
size_t absdiff(size_t a, size_t b)
bool need_memo_foreach_action(search_private &priv)
action predictLDF(example *ecs, size_t ec_cnt, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, size_t learner_id=0, float weight=0.)
std::vector< action > test_action_sequence
size_t read_example_last_id
std::stringstream * truth_string
v_array< action_repr > condition_on_actions
float action_hamming_loss(action a, const action *A, size_t sz)
bool should_produce_string
predictor & add_allowed(action a)
void(* print_text)(int, std::string, v_array< char >)
bool allow_current_policy
predictor & add_to(v_array< T > &A, bool &A_is_ptr, T a, bool clear_first)
predictor & set_learner_id(size_t id)
bool search_predictNeedsExample(search_private &priv)
bool cached_item_equivalent(unsigned char *const &A, unsigned char *const &B)
action predict(example &ec, ptag my_tag, const action *oracle_actions, size_t oracle_actions_cnt=1, const ptag *condition_on=nullptr, const char *condition_on_names=nullptr, const action *allowed_actions=nullptr, size_t allowed_actions_cnt=0, const float *allowed_actions_cost=nullptr, size_t learner_id=0, float weight=0.)
v_array< action > oracle_actions
std::string number_to_natural(size_t big)
predictor & set_allowed(action a)
std::shared_ptr< rand_state > _random_state
v_array< action > learn_allowed_actions
void finish_example(vw &, example &)
bool should_print_update(vw &all, bool hit_new_pass=false)
void del_features_in_top_namespace(search_private &, example &ec, size_t ns)
virtual void insert(const std::string &key, const std::string &value)=0
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
scored_action(action _a=(action) -1, float _s=0)
void cdbg_print_array(std::string str, v_array< T > &A)
option_group_definition & add(T &&op)
void add_new_feature(search_private &priv, float val, uint64_t idx)
void search_finish(search &sch)
std::vector< example * > multi_ex
void end_examples(search &sch)
v_array< audit_strings_ptr > space_names
v_array< example > learn_ec_copy
action learn_oracle_action
void cs_set_cost_loss(bool isCB, polylabel &ld, size_t k, float val)
int choose_policy(search_private &priv, bool advance_prng=true)
uint32_t AUTO_HAMMING_LOSS
bool array_contains(T target, const T *A, size_t n)
v_array< std::pair< float, size_t > > active_uncertainty
void make_new_pointer(v_array< T > &A, size_t new_size)
MULTILABEL::labels multilabels
void ensure_size(v_array< T > &A, size_t sz)
predictor & set_condition_range(ptag hi, ptag count, char name0)
void del_example_conditioning(search_private &priv, example &ec)
typed_option< T > make_option(std::string name, T &location)
action_repr(action _a, features *_repr)
v_array< uint32_t > label_v
void set_end_pass(void(*f)(T &))
constexpr bool PRINT_CLOCK_TIME
void set_finish(void(*f)(T &))
void verify_active_csoaa(COST_SENSITIVE::label &losses, v_array< std::pair< CS::wclass &, bool >> &known, size_t t, float multiplier)
std::stringstream * bad_string_stream
void allowed_actions_to_label(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost, const action *oracle_actions, size_t oracle_actions_cnt, polylabel &lab)
predictor & set_tag(ptag tag)
std::string condition_feature_space("search_condition")
void ensure_param(float &v, float lo, float hi, float def, const char *str)
action choose_oracle_action(search_private &priv, size_t ec_cnt, const action *oracle_actions, size_t oracle_actions_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
uint64_t get(substring &s)
features last_action_repr
std::stringstream * rawOutputStringStream
v_array< scored_action > train_trajectory
void cs_costs_resize(bool isCB, polylabel &ld, size_t new_size)
double weighted_labeled_examples
double weighted_holdout_examples_since_last_dump
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Search::search_metatask metatask
action_cache(float _min_cost, action _k, bool _is_opt, float _cost)
void predict(bfgs &b, base_learner &, example &ec)
predictor & set_condition(ptag tag, char name)
void cerr_print_array(std::string str, v_array< T > &A)
double holdout_sum_loss_since_last_dump
uint64_t conditional_constant
constexpr bool PRINT_UPDATE_EVERY_PASS
void learn(E &ec, size_t i=0)
uint32_t EXAMPLES_DONT_CHANGE
bool(* _maybe_override_prediction)(search &, size_t, action &, float &)
v_array< action > allowed_actions
float cs_get_cost_partial_prediction(bool isCB, polylabel &ld, size_t k)
bool v_array_contains(v_array< T > &A, T x)
std::string final_regressor_name
bool done_with_all_actions
size_t dat_new_feature_namespace
void free_key(unsigned char *mem, scored_action)
float active_csoaa_verify
v_array< int32_t > neighbor_features
void get_training_timesteps(search_private &priv, v_array< size_t > ×teps)
double weighted_examples()
bool emptylines_separate_examples
predictor & set_input(example &input_example)
uint32_t total_number_of_policies
multi_learner * as_multiline(learner< T, E > *l)
v_array< action_repr > learn_condition_on_act
action single_prediction_LDF(search_private &priv, example *ecs, size_t ec_cnt, int policy, float &a_cost, action override_action)
const char * to_string(prediction_type_t prediction_type)
v_array< char > learn_condition_on_names
bool printed_output_header
void set_end_examples(void(*f)(T &))
polylabel & allowed_actions_to_ld(search_private &priv, size_t ec_cnt, const action *allowed_actions, size_t allowed_actions_cnt, const float *allowed_actions_cost)
bool examples_dont_change
void add_example_namespaces_from_example(example &target, example &source)
void update_dump_interval(bool progress_add, float progress_arg)
std::pair< std::string, std::string > audit_strings
size_t cs_get_costs_size(bool isCB, polylabel &ld)
uint32_t cs_get_cost_index(bool isCB, polylabel &ld, size_t k)
predictor & erase_oracles()