Vowpal Wabbit
search_meta.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <float.h>
7 #include <errno.h>
8 
9 #include "reductions.h"
10 #include "vw.h"
11 #include "search.h"
12 
13 using namespace VW::config;
14 
15 namespace DebugMT
16 {
17 void run(Search::search& sch, multi_ex& ec);
18 Search::search_metatask metatask = {"debug", run, nullptr, nullptr, nullptr, nullptr};
19 
20 void run(Search::search& sch, multi_ex& ec)
21 {
22  sch.base_task(ec)
24  [](Search::search& /*sch*/, size_t t, float min_cost, action a, bool taken, float a_cost) -> void {
25  std::cerr << "==DebugMT== foreach_action(t=" << t << ", min_cost=" << min_cost << ", a=" << a
26  << ", taken=" << taken << ", a_cost=" << a_cost << ")" << std::endl;
27  })
28 
29  .post_prediction([](Search::search& /*sch*/, size_t t, action a, float a_cost) -> void {
30  std::cerr << "==DebugMT== post_prediction(t=" << t << ", a=" << a << ", a_cost=" << a_cost << ")" << std::endl;
31  })
32 
33  .maybe_override_prediction([](Search::search& /*sch*/, size_t t, action& a, float& a_cost) -> bool {
34  std::cerr << "==DebugMT== maybe_override_prediction(t=" << t << ", a=" << a << ", a_cost=" << a_cost << ")"
35  << std::endl;
36  return false;
37  })
38 
39  .final_run()
40 
41  .Run();
42 }
43 } // namespace DebugMT
44 
46 {
47 void run(Search::search& sch, multi_ex& ec);
48 void initialize(Search::search& sch, size_t& num_actions, options_i& options);
49 void finish(Search::search& sch);
50 Search::search_metatask metatask = {"selective_branching", run, initialize, finish, nullptr, nullptr};
51 
52 typedef std::pair<action, float> act_score;
54 typedef std::pair<float, path> branch;
55 
56 std::ostream& operator<<(std::ostream& os, const std::pair<unsigned int, float>& v)
57 {
58  os << v.first << '_' << v.second;
59  return os;
60 }
61 
62 struct task_data
63 {
64  size_t max_branches, kbest;
67  path trajectory;
68  float total_cost;
69  size_t cur_branch;
70  std::string* output_string;
71  std::stringstream* kbest_out;
72  task_data(size_t mb, size_t kb) : max_branches(mb), kbest(kb)
73  {
74  branches = v_init<branch>();
75  final = v_init<std::pair<branch, std::string*> >();
76  trajectory = v_init<act_score>();
77  output_string = nullptr;
78  kbest_out = nullptr;
79  }
81  {
82  branches.delete_v();
83  final.delete_v();
84  trajectory.delete_v();
85  delete output_string;
86  delete kbest_out;
87  }
88 };
89 
90 void initialize(Search::search& sch, size_t& /*num_actions*/, options_i& options)
91 {
92  size_t max_branches = 2;
93  size_t kbest = 0;
94  option_group_definition new_options("selective branching options");
95  new_options
96  .add(make_option("search_max_branch", max_branches)
97  .default_value(2)
98  .help("maximum number of branches to consider"))
99  .add(make_option("search_kbest", kbest)
100  .default_value(0)
101  .help("number of best items to output (0=just like non-selectional-branching, default)"));
102  options.add_and_parse(new_options);
103 
104  task_data* d = new task_data(max_branches, kbest);
105  sch.set_metatask_data(d);
106 }
107 
108 void finish(Search::search& sch) { delete sch.get_metatask_data<task_data>(); }
109 
110 void run(Search::search& sch, multi_ex& ec)
111 {
113 
114  // generate an initial trajectory, but record possible branches
115  d.branches.clear();
116  d.final.clear();
117  d.trajectory.clear();
118  d.total_cost = 0.;
119  d.output_string = nullptr;
120 
121  cdbg << "*** INITIAL PASS ***" << std::endl;
122  sch.base_task(ec)
123  .foreach_action([](Search::search& sch, size_t t, float min_cost, action a, bool taken, float a_cost) -> void {
124  cdbg << "==DebugMT== foreach_action(t=" << t << ", min_cost=" << min_cost << ", a=" << a << ", taken=" << taken
125  << ", a_cost=" << a_cost << ")" << std::endl;
126  if (taken)
127  return; // ignore the taken action
129  float delta = a_cost - min_cost;
130  path branch = v_init<act_score>();
131  push_many<act_score>(branch, d.trajectory.begin(), d.trajectory.size());
132  branch.push_back(std::make_pair(a, a_cost));
133  d.branches.push_back(std::make_pair(delta, branch));
134  cdbg << "adding branch: " << delta << " -> " << branch << std::endl;
135  })
136  .post_prediction([](Search::search& sch, size_t /*t*/, action a, float a_cost) -> void {
138  d.trajectory.push_back(std::make_pair(a, a_cost));
139  d.total_cost += a_cost;
140  })
141  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
142  sch.get_metatask_data<task_data>()->output_string = new std::string(output.str());
143  })
144  .Run();
145 
146  // the last item the trajectory stack is complete and therefore not a branch
147  // if (! d.branches.empty())
148  // d.branches.pop().second.delete_v();
149 
150  {
151  // construct the final trajectory
152  path original_final = v_init<act_score>();
153  copy_array(original_final, d.trajectory);
154  d.final.push_back(std::make_pair(std::make_pair(d.total_cost, original_final), d.output_string));
155  }
156 
157  // sort the branches by cost
158  stable_sort(
159  d.branches.begin(), d.branches.end(), [](const branch& a, const branch& b) -> bool { return a.first < b.first; });
160 
161  // make new predictions
162  for (size_t i = 0; i < std::min(d.max_branches, d.branches.size()); i++)
163  {
164  d.cur_branch = i;
165  d.trajectory.clear();
166  d.total_cost = 0.;
167  d.output_string = nullptr;
168 
169  cdbg << "*** BRANCH " << i << " *** " << d.branches[i].first << " : " << d.branches[i].second << std::endl;
170  sch.base_task(ec)
171  .foreach_action([](Search::search& /*sch*/, size_t /*t*/, float /*min_cost*/, action /*a*/, bool /*taken*/,
172  float /*a_cost*/) -> void {})
173  .maybe_override_prediction([](Search::search& sch, size_t t, action& a, float& a_cost) -> bool {
175  path& path = d.branches[d.cur_branch].second;
176  if (t >= path.size())
177  return false;
178  a = path[t].first;
179  a_cost = path[t].second;
180  return true;
181  })
182  .post_prediction([](Search::search& sch, size_t /*t*/, action a, float a_cost) -> void {
184  d.trajectory.push_back(std::make_pair(a, a_cost));
185  d.total_cost += a_cost;
186  })
187  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
188  sch.get_metatask_data<task_data>()->output_string = new std::string(output.str());
189  })
190  .Run();
191 
192  {
193  // construct the final trajectory
194  path this_final = v_init<act_score>();
195  copy_array(this_final, d.trajectory);
196  d.final.push_back(std::make_pair(std::make_pair(d.total_cost, this_final), d.output_string));
197  }
198  }
199 
200  // sort the finals by cost
201  stable_sort(d.final.begin(), d.final.end(),
202  [](const std::pair<branch, std::string*>& a, const std::pair<branch, std::string*>& b) -> bool {
203  return a.first.first < b.first.first;
204  });
205 
206  d.kbest_out = nullptr;
207  if (d.output_string && (d.kbest > 0))
208  {
209  d.kbest_out = new std::stringstream();
210  for (size_t i = 0; i < std::min(d.final.size(), d.kbest); i++)
211  (*d.kbest_out) << *d.final[i].second << "\t" << d.final[i].first.first << std::endl;
212  }
213 
214  // run the final selected trajectory
215  cdbg << "*** FINAL ***" << std::endl;
216  d.cur_branch = 0;
217  d.output_string = nullptr;
218  sch.base_task(ec)
219  .foreach_action([](Search::search& /*sch*/, size_t /*t*/, float /*min_cost*/, action /*a*/, bool /*taken*/,
220  float /*a_cost*/) -> void {})
221  .maybe_override_prediction([](Search::search& sch, size_t t, action& a, float& a_cost) -> bool {
223  path& path = d.final[d.cur_branch].first.second;
224  if ((t >= path.size()) || (path[t].first == (action)-1))
225  return false;
226  a = path[t].first;
227  a_cost = path[t].second;
228  return true;
229  })
230  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
232  if (d.kbest_out)
233  {
234  output.str("");
235  output << d.kbest_out->str();
236  }
237  })
238  .final_run()
239  .Run();
240 
241  // clean up memory
242  for (size_t i = 0; i < d.branches.size(); i++) d.branches[i].second.delete_v();
243  d.branches.clear();
244  for (size_t i = 0; i < d.final.size(); i++)
245  {
246  d.final[i].first.second.delete_v();
247  delete d.final[i].second;
248  }
249  d.final.clear();
250  delete d.kbest_out;
251  d.kbest_out = nullptr;
252 }
253 } // namespace SelectiveBranchingMT
#define cdbg
Definition: search.h:11
T * get_metatask_data()
Definition: search.h:101
std::pair< float, path > branch
Definition: search_meta.cc:54
std::pair< action, float > act_score
Definition: search_meta.cc:52
void copy_array(v_array< T > &dst, const v_array< T > &src)
Definition: v_array.h:185
void run(Search::search &sch, multi_ex &ec)
Definition: search_meta.cc:110
uint32_t action
Definition: search.h:19
v_array< act_score > path
Definition: search_meta.cc:53
virtual void add_and_parse(const option_group_definition &group)=0
void finish(vw &all, bool delete_all)
Definition: parse_args.cc:1823
T *& begin()
Definition: v_array.h:42
void set_metatask_data(T *data)
Definition: search.h:96
size_t size() const
Definition: v_array.h:68
BaseTask & foreach_action(void(*f)(search &, size_t, float, action, bool, float))
Definition: search.h:42
void push_back(const T &new_ele)
Definition: v_array.h:107
void clear()
Definition: v_array.h:88
vw * initialize(options_i &options, io_buf *model, bool skipModelLoad, trace_message_t trace_listener, void *trace_context)
Definition: parse_args.cc:1654
T *& end()
Definition: v_array.h:43
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
constexpr uint64_t a
Definition: rand48.cc:11
v_array< std::pair< branch, std::string * > > final
Definition: search_meta.cc:66
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
BaseTask base_task(multi_ex &ec)
Definition: search.h:213
task_data(size_t mb, size_t kb)
Definition: search_meta.cc:72
Search::search_metatask metatask
Definition: search_meta.cc:50
void delete_v()
Definition: v_array.h:98
std::stringstream * kbest_out
Definition: search_meta.cc:71