117 d.trajectory.clear();
119 d.output_string =
nullptr;
121 cdbg <<
"*** INITIAL PASS ***" << std::endl;
124 cdbg <<
"==DebugMT== foreach_action(t=" << t <<
", min_cost=" << min_cost <<
", a=" << a <<
", taken=" << taken
125 <<
", a_cost=" << a_cost <<
")" << std::endl;
129 float delta = a_cost - min_cost;
131 push_many<act_score>(
branch, d.trajectory.begin(), d.trajectory.size());
132 branch.push_back(std::make_pair(a, a_cost));
133 d.branches.push_back(std::make_pair(delta, branch));
134 cdbg <<
"adding branch: " << delta <<
" -> " << branch << std::endl;
138 d.trajectory.push_back(std::make_pair(a, a_cost));
139 d.total_cost += a_cost;
141 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
152 path original_final = v_init<act_score>();
154 d.final.push_back(std::make_pair(std::make_pair(d.total_cost, original_final), d.output_string));
159 d.branches.begin(), d.branches.end(), [](
const branch&
a,
const branch& b) ->
bool {
return a.first < b.first; });
162 for (
size_t i = 0; i < std::min(d.max_branches, d.branches.size()); i++)
165 d.trajectory.clear();
167 d.output_string =
nullptr;
169 cdbg <<
"*** BRANCH " << i <<
" *** " << d.branches[i].first <<
" : " << d.branches[i].second << std::endl;
173 .maybe_override_prediction([](
Search::search& sch,
size_t t,
action& a,
float& a_cost) ->
bool {
175 path&
path = d.branches[d.cur_branch].second;
176 if (t >= path.size())
179 a_cost = path[t].second;
184 d.trajectory.push_back(std::make_pair(a, a_cost));
185 d.total_cost += a_cost;
187 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
194 path this_final = v_init<act_score>();
196 d.final.push_back(std::make_pair(std::make_pair(d.total_cost, this_final), d.output_string));
201 stable_sort(d.final.begin(), d.final.end(),
202 [](
const std::pair<branch, std::string*>&
a,
const std::pair<branch, std::string*>& b) ->
bool {
203 return a.first.first < b.first.first;
206 d.kbest_out =
nullptr;
207 if (d.output_string && (d.kbest > 0))
209 d.kbest_out =
new std::stringstream();
210 for (
size_t i = 0; i < std::min(d.final.size(), d.kbest); i++)
211 (*d.kbest_out) << *d.final[i].second <<
"\t" << d.final[i].first.first << std::endl;
215 cdbg <<
"*** FINAL ***" << std::endl;
217 d.output_string =
nullptr;
221 .maybe_override_prediction([](
Search::search& sch,
size_t t,
action& a,
float& a_cost) ->
bool {
223 path& path = d.final[d.cur_branch].first.second;
224 if ((t >= path.size()) || (path[t].first == (
action)-1))
227 a_cost = path[t].second;
230 .with_output_string([](
Search::search& sch, std::stringstream& output) ->
void {
235 output << d.kbest_out->str();
242 for (
size_t i = 0; i < d.branches.size(); i++) d.branches[i].second.delete_v();
244 for (
size_t i = 0; i < d.final.size(); i++)
246 d.final[i].first.second.delete_v();
247 delete d.final[i].second;
251 d.kbest_out =
nullptr;
std::pair< float, path > branch
void copy_array(v_array< T > &dst, const v_array< T > &src)
v_array< act_score > path
BaseTask & foreach_action(void(*f)(search &, size_t, float, action, bool, float))
BaseTask base_task(multi_ex &ec)