25 std::cerr <<
"==DebugMT== foreach_action(t=" << t <<
", min_cost=" << min_cost <<
", a=" << a
26 <<
", taken=" << taken <<
", a_cost=" << a_cost <<
")" << std::endl;
30 std::cerr <<
"==DebugMT== post_prediction(t=" << t <<
", a=" << a <<
", a_cost=" << a_cost <<
")" << std::endl;
34 std::cerr <<
"==DebugMT== maybe_override_prediction(t=" << t <<
", a=" << a <<
", a_cost=" << a_cost <<
")" 54 typedef std::pair<float, path>
branch;
56 std::ostream& operator<<(std::ostream& os, const std::pair<unsigned int, float>& v)
58 os << v.first <<
'_' << v.second;
72 task_data(
size_t mb,
size_t kb) : max_branches(mb), kbest(kb)
74 branches = v_init<branch>();
75 final = v_init<std::pair<branch, std::string*> >();
76 trajectory = v_init<act_score>();
77 output_string =
nullptr;
92 size_t max_branches = 2;
98 .help(
"maximum number of branches to consider"))
101 .help(
"number of best items to output (0=just like non-selectional-branching, default)"));
121 cdbg <<
"*** INITIAL PASS ***" << std::endl;
124 cdbg <<
"==DebugMT== foreach_action(t=" << t <<
", min_cost=" << min_cost <<
", a=" << a <<
", taken=" << taken
125 <<
", a_cost=" << a_cost <<
")" << std::endl;
129 float delta = a_cost - min_cost;
130 path branch = v_init<act_score>();
132 branch.push_back(std::make_pair(a, a_cost));
134 cdbg <<
"adding branch: " << delta <<
" -> " << branch << std::endl;
141 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
152 path original_final = v_init<act_score>();
169 cdbg <<
"*** BRANCH " << i <<
" *** " << d.
branches[i].first <<
" : " << d.
branches[i].second << std::endl;
173 .maybe_override_prediction([](
Search::search& sch,
size_t t,
action& a,
float& a_cost) ->
bool {
176 if (t >= path.size())
179 a_cost = path[t].second;
187 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
194 path this_final = v_init<act_score>();
202 [](
const std::pair<branch, std::string*>&
a,
const std::pair<branch, std::string*>& b) ->
bool {
203 return a.first.first < b.first.first;
215 cdbg <<
"*** FINAL ***" << std::endl;
221 .maybe_override_prediction([](
Search::search& sch,
size_t t,
action& a,
float& a_cost) ->
bool {
224 if ((t >= path.size()) || (path[t].first == (
action)-1))
227 a_cost = path[t].second;
230 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
244 for (
size_t i = 0; i < d.
final.
size(); i++)
247 delete d.
final[i].second;
std::string * output_string
std::pair< float, path > branch
std::pair< action, float > act_score
void copy_array(v_array< T > &dst, const v_array< T > &src)
void run(Search::search &sch, multi_ex &ec)
v_array< act_score > path
virtual void add_and_parse(const option_group_definition &group)=0
void finish(vw &all, bool delete_all)
void set_metatask_data(T *data)
BaseTask & foreach_action(void(*f)(search &, size_t, float, action, bool, float))
void push_back(const T &new_ele)
vw * initialize(options_i &options, io_buf *model, bool skipModelLoad, trace_message_t trace_listener, void *trace_context)
v_array< branch > branches
option_group_definition & add(T &&op)
std::vector< example * > multi_ex
v_array< std::pair< branch, std::string * > > final
typed_option< T > make_option(std::string name, T &location)
BaseTask base_task(multi_ex &ec)
task_data(size_t mb, size_t kb)
Search::search_metatask metatask
std::stringstream * kbest_out