Vowpal Wabbit
Classes | Typedefs | Functions | Variables
SelectiveBranchingMT Namespace Reference

Classes

struct  task_data
 

Typedefs

typedef std::pair< action, float > act_score
 
typedef v_array< act_scorepath
 
typedef std::pair< float, pathbranch
 

Functions

void run (Search::search &sch, multi_ex &ec)
 
void initialize (Search::search &sch, size_t &num_actions, options_i &options)
 
void finish (Search::search &sch)
 
std::ostream & operator<< (std::ostream &os, const std::pair< unsigned int, float > &v)
 

Variables

Search::search_metatask metatask = {"selective_branching", run, initialize, finish, nullptr, nullptr}
 

Typedef Documentation

◆ act_score

typedef std::pair<action, float> SelectiveBranchingMT::act_score

Definition at line 52 of file search_meta.cc.

◆ branch

typedef std::pair<float, path> SelectiveBranchingMT::branch

Definition at line 54 of file search_meta.cc.

◆ path

Definition at line 53 of file search_meta.cc.

Function Documentation

◆ finish()

void SelectiveBranchingMT::finish ( Search::search sch)

Definition at line 108 of file search_meta.cc.

References Search::search::get_metatask_data().

108 { delete sch.get_metatask_data<task_data>(); }
T * get_metatask_data()
Definition: search.h:101

◆ initialize()

void SelectiveBranchingMT::initialize ( Search::search sch,
size_t &  num_actions,
options_i options 
)

Definition at line 90 of file search_meta.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), VW::config::make_option(), and Search::search::set_metatask_data().

91 {
92  size_t max_branches = 2;
93  size_t kbest = 0;
94  option_group_definition new_options("selective branching options");
95  new_options
96  .add(make_option("search_max_branch", max_branches)
97  .default_value(2)
98  .help("maximum number of branches to consider"))
99  .add(make_option("search_kbest", kbest)
100  .default_value(0)
101  .help("number of best items to output (0=just like non-selectional-branching, default)"));
102  options.add_and_parse(new_options);
103 
104  task_data* d = new task_data(max_branches, kbest);
105  sch.set_metatask_data(d);
106 }
virtual void add_and_parse(const option_group_definition &group)=0
void set_metatask_data(T *data)
Definition: search.h:96
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80

◆ operator<<()

std::ostream& SelectiveBranchingMT::operator<< ( std::ostream &  os,
const std::pair< unsigned int, float > &  v 
)

Definition at line 56 of file search_meta.cc.

57 {
58  os << v.first << '_' << v.second;
59  return os;
60 }

◆ run()

void SelectiveBranchingMT::run ( Search::search sch,
multi_ex ec 
)

Definition at line 110 of file search_meta.cc.

References a, Search::search::base_task(), v_array< T >::begin(), SelectiveBranchingMT::task_data::branches, cdbg, v_array< T >::clear(), copy_array(), SelectiveBranchingMT::task_data::cur_branch, v_array< T >::delete_v(), v_array< T >::end(), SelectiveBranchingMT::task_data::final, Search::BaseTask::foreach_action(), Search::search::get_metatask_data(), SelectiveBranchingMT::task_data::kbest, SelectiveBranchingMT::task_data::kbest_out, SelectiveBranchingMT::task_data::max_branches, SelectiveBranchingMT::task_data::output_string, v_array< T >::push_back(), Search::BaseTask::Run(), v_array< T >::size(), SelectiveBranchingMT::task_data::total_cost, and SelectiveBranchingMT::task_data::trajectory.

111 {
113 
114  // generate an initial trajectory, but record possible branches
115  d.branches.clear();
116  d.final.clear();
117  d.trajectory.clear();
118  d.total_cost = 0.;
119  d.output_string = nullptr;
120 
121  cdbg << "*** INITIAL PASS ***" << std::endl;
122  sch.base_task(ec)
123  .foreach_action([](Search::search& sch, size_t t, float min_cost, action a, bool taken, float a_cost) -> void {
124  cdbg << "==DebugMT== foreach_action(t=" << t << ", min_cost=" << min_cost << ", a=" << a << ", taken=" << taken
125  << ", a_cost=" << a_cost << ")" << std::endl;
126  if (taken)
127  return; // ignore the taken action
129  float delta = a_cost - min_cost;
130  path branch = v_init<act_score>();
131  push_many<act_score>(branch, d.trajectory.begin(), d.trajectory.size());
132  branch.push_back(std::make_pair(a, a_cost));
133  d.branches.push_back(std::make_pair(delta, branch));
134  cdbg << "adding branch: " << delta << " -> " << branch << std::endl;
135  })
136  .post_prediction([](Search::search& sch, size_t /*t*/, action a, float a_cost) -> void {
138  d.trajectory.push_back(std::make_pair(a, a_cost));
139  d.total_cost += a_cost;
140  })
141  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
142  sch.get_metatask_data<task_data>()->output_string = new std::string(output.str());
143  })
144  .Run();
145 
146  // the last item the trajectory stack is complete and therefore not a branch
147  // if (! d.branches.empty())
148  // d.branches.pop().second.delete_v();
149 
150  {
151  // construct the final trajectory
152  path original_final = v_init<act_score>();
153  copy_array(original_final, d.trajectory);
154  d.final.push_back(std::make_pair(std::make_pair(d.total_cost, original_final), d.output_string));
155  }
156 
157  // sort the branches by cost
158  stable_sort(
159  d.branches.begin(), d.branches.end(), [](const branch& a, const branch& b) -> bool { return a.first < b.first; });
160 
161  // make new predictions
162  for (size_t i = 0; i < std::min(d.max_branches, d.branches.size()); i++)
163  {
164  d.cur_branch = i;
165  d.trajectory.clear();
166  d.total_cost = 0.;
167  d.output_string = nullptr;
168 
169  cdbg << "*** BRANCH " << i << " *** " << d.branches[i].first << " : " << d.branches[i].second << std::endl;
170  sch.base_task(ec)
171  .foreach_action([](Search::search& /*sch*/, size_t /*t*/, float /*min_cost*/, action /*a*/, bool /*taken*/,
172  float /*a_cost*/) -> void {})
173  .maybe_override_prediction([](Search::search& sch, size_t t, action& a, float& a_cost) -> bool {
175  path& path = d.branches[d.cur_branch].second;
176  if (t >= path.size())
177  return false;
178  a = path[t].first;
179  a_cost = path[t].second;
180  return true;
181  })
182  .post_prediction([](Search::search& sch, size_t /*t*/, action a, float a_cost) -> void {
184  d.trajectory.push_back(std::make_pair(a, a_cost));
185  d.total_cost += a_cost;
186  })
187  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
188  sch.get_metatask_data<task_data>()->output_string = new std::string(output.str());
189  })
190  .Run();
191 
192  {
193  // construct the final trajectory
194  path this_final = v_init<act_score>();
195  copy_array(this_final, d.trajectory);
196  d.final.push_back(std::make_pair(std::make_pair(d.total_cost, this_final), d.output_string));
197  }
198  }
199 
200  // sort the finals by cost
201  stable_sort(d.final.begin(), d.final.end(),
202  [](const std::pair<branch, std::string*>& a, const std::pair<branch, std::string*>& b) -> bool {
203  return a.first.first < b.first.first;
204  });
205 
206  d.kbest_out = nullptr;
207  if (d.output_string && (d.kbest > 0))
208  {
209  d.kbest_out = new std::stringstream();
210  for (size_t i = 0; i < std::min(d.final.size(), d.kbest); i++)
211  (*d.kbest_out) << *d.final[i].second << "\t" << d.final[i].first.first << std::endl;
212  }
213 
214  // run the final selected trajectory
215  cdbg << "*** FINAL ***" << std::endl;
216  d.cur_branch = 0;
217  d.output_string = nullptr;
218  sch.base_task(ec)
219  .foreach_action([](Search::search& /*sch*/, size_t /*t*/, float /*min_cost*/, action /*a*/, bool /*taken*/,
220  float /*a_cost*/) -> void {})
221  .maybe_override_prediction([](Search::search& sch, size_t t, action& a, float& a_cost) -> bool {
223  path& path = d.final[d.cur_branch].first.second;
224  if ((t >= path.size()) || (path[t].first == (action)-1))
225  return false;
226  a = path[t].first;
227  a_cost = path[t].second;
228  return true;
229  })
230  .with_output_string([](Search::search& sch, std::stringstream& output) -> void {
232  if (d.kbest_out)
233  {
234  output.str("");
235  output << d.kbest_out->str();
236  }
237  })
238  .final_run()
239  .Run();
240 
241  // clean up memory
242  for (size_t i = 0; i < d.branches.size(); i++) d.branches[i].second.delete_v();
243  d.branches.clear();
244  for (size_t i = 0; i < d.final.size(); i++)
245  {
246  d.final[i].first.second.delete_v();
247  delete d.final[i].second;
248  }
249  d.final.clear();
250  delete d.kbest_out;
251  d.kbest_out = nullptr;
252 }
#define cdbg
Definition: search.h:11
T * get_metatask_data()
Definition: search.h:101
std::pair< float, path > branch
Definition: search_meta.cc:54
void copy_array(v_array< T > &dst, const v_array< T > &src)
Definition: v_array.h:185
uint32_t action
Definition: search.h:19
v_array< act_score > path
Definition: search_meta.cc:53
BaseTask & foreach_action(void(*f)(search &, size_t, float, action, bool, float))
Definition: search.h:42
constexpr uint64_t a
Definition: rand48.cc:11
BaseTask base_task(multi_ex &ec)
Definition: search.h:213

Variable Documentation

◆ metatask

Search::search_metatask SelectiveBranchingMT::metatask = {"selective_branching", run, initialize, finish, nullptr, nullptr}

Definition at line 50 of file search_meta.cc.