Vowpal Wabbit
search_dep_parser.cc
Go to the documentation of this file.
1 /*
2  Copyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5 */
6 #include "search_dep_parser.h"
7 #include "gd.h"
8 #include "cost_sensitive.h"
9 #include "label_dictionary.h" // for add_example_namespaces_from_example
10 #include "vw.h"
11 #include "vw_exception.h"
12 
13 using namespace VW::config;
14 
15 #define val_namespace 100 // valency and distance feature space
16 #define offset_const 344429
17 #define arc_hybrid 1
18 #define arc_eager 2
19 
20 namespace DepParserTask
21 {
22 Search::search_task task = {"dep_parser", run, initialize, finish, setup, nullptr};
23 }
24 
25 struct task_data
26 {
28  size_t root_label;
29  uint32_t num_label;
30  v_array<uint32_t> valid_actions, action_loss, gold_heads, gold_tags, stack, heads, tags, temp, valid_action_temp;
31  v_array<action> gold_actions, gold_action_temp;
33  v_array<uint32_t> children[6]; // [0]:num_left_arcs, [1]:num_right_arcs; [2]: leftmost_arc, [3]: second_leftmost_arc,
34  // [4]:rightmost_arc, [5]: second_rightmost_arc
35  example *ec_buf[13];
37  bool cost_to_go, one_learner;
39 };
40 
41 namespace DepParserTask
42 {
43 using namespace Search;
44 
45 constexpr action SHIFT = 1;
46 constexpr action REDUCE_RIGHT = 2;
47 constexpr action REDUCE_LEFT = 3;
48 constexpr action REDUCE = 4;
49 constexpr uint32_t my_null = 9999999; /*representing_default*/
50 
51 void initialize(Search::search &sch, size_t & /*num_actions*/, options_i &options)
52 {
53  vw &all = sch.get_vw_pointer_unsafe();
54  task_data *data = new task_data();
55  data->action_loss.resize(5);
56  data->ex = NULL;
57  sch.set_task_data<task_data>(data);
58 
59  option_group_definition new_options("Dependency Parser Options");
60  new_options.add(make_option("root_label", data->root_label)
61  .keep()
62  .default_value(8)
63  .help("Ensure that there is only one root in each sentence"));
64  new_options.add(make_option("num_label", data->num_label).keep().default_value(12).help("Number of arc labels"));
65  new_options.add(make_option("transition_system", data->transition_system)
66  .keep()
67  .default_value(1)
68  .help("1: arc-hybrid 2: arc-eager"));
69  new_options.add(make_option("one_learner", data->one_learner)
70  .keep()
71  .help("Using one learner instead of three learners for labeled parser"));
72  new_options.add(make_option("cost_to_go", data->cost_to_go)
73  .keep()
74  .help("Estimating cost-to-go matrix based on dynamic oracle rathan than rolling-out"));
75  new_options.add(
76  make_option("old_style_labels", data->old_style_labels).keep().help("Use old hack of label information"));
77  options.add_and_parse(new_options);
78 
79  data->ex = VW::alloc_examples(sizeof(polylabel), 1);
81  for (size_t i = 1; i < 14; i++) data->ex->indices.push_back((unsigned char)i + 'A');
84 
85  if (data->one_learner)
86  sch.set_num_learners(1);
87  else
88  sch.set_num_learners(3);
89 
90  const char *pair[] = {
91  "BC", "BE", "BB", "CC", "DD", "EE", "FF", "GG", "EF", "BH", "BJ", "EL", "dB", "dC", "dD", "dE", "dF", "dG", "dd"};
92  const char *triple[] = {"EFG", "BEF", "BCE", "BCD", "BEL", "ELM", "BHI", "BCC", "BEJ", "BEH", "BJK", "BEN"};
93  std::vector<std::string> newpairs(pair, pair + 19);
94  std::vector<std::string> newtriples(triple, triple + 12);
95  all.pairs.swap(newpairs);
96  all.triples.swap(newtriples);
97 
98  all.interactions.clear();
99  all.interactions.insert(std::end(all.interactions), std::begin(all.pairs), std::end(all.pairs));
100  all.interactions.insert(std::end(all.interactions), std::begin(all.triples), std::end(all.triples));
101  if (data->cost_to_go)
103  else
105 
106  sch.set_label_parser(COST_SENSITIVE::cs_label, [](polylabel &l) -> bool { return l.cs.costs.size() == 0; });
107 }
108 
110 {
111  task_data *data = sch.get_task_data<task_data>();
112  data->valid_actions.delete_v();
113  data->valid_action_temp.delete_v();
114  data->gold_heads.delete_v();
115  data->gold_tags.delete_v();
116  data->stack.delete_v();
117  data->heads.delete_v();
118  data->tags.delete_v();
119  data->temp.delete_v();
120  data->action_loss.delete_v();
121  data->gold_actions.delete_v();
123  data->gold_action_temp.delete_v();
125  free(data->ex);
126  for (size_t i = 0; i < 6; i++) data->children[i].delete_v();
127  delete data;
128 }
129 
130 void inline add_feature(
131  example &ex, uint64_t idx, unsigned char ns, uint64_t mask, uint64_t multiplier, bool /* audit */ = false)
132 {
133  ex.feature_space[(int)ns].push_back(1.0f, (idx * multiplier) & mask);
134 }
135 
136 void add_all_features(example &ex, example &src, unsigned char tgt_ns, uint64_t mask, uint64_t multiplier,
137  uint64_t offset, bool /* audit */ = false)
138 {
139  features &tgt_fs = ex.feature_space[tgt_ns];
140  for (namespace_index ns : src.indices)
141  if (ns != constant_namespace) // ignore constant_namespace
142  for (feature_index i : src.feature_space[ns].indicies)
143  tgt_fs.push_back(1.0f, ((i / multiplier + offset) * multiplier) & mask);
144 }
145 
146 void inline reset_ex(example *ex)
147 {
148  ex->num_features = 0;
149  ex->total_sum_feat_sq = 0;
150  for (features &fs : *ex) fs.clear();
151 }
152 
153 // arc-hybrid System.
154 size_t transition_hybrid(Search::search &sch, uint64_t a_id, uint32_t idx, uint32_t t_id, uint32_t /* n */)
155 {
156  task_data *data = sch.get_task_data<task_data>();
157  v_array<uint32_t> &heads = data->heads, &stack = data->stack, &gold_heads = data->gold_heads,
158  &gold_tags = data->gold_tags, &tags = data->tags;
160  if (a_id == SHIFT)
161  {
162  stack.push_back(idx);
163  return idx + 1;
164  }
165  else if (a_id == REDUCE_RIGHT)
166  {
167  uint32_t last = stack.last();
168  uint32_t hd = stack[stack.size() - 2];
169  heads[last] = hd;
170  children[5][hd] = children[4][hd];
171  children[4][hd] = last;
172  children[1][hd]++;
173  tags[last] = t_id;
174  sch.loss(gold_heads[last] != heads[last] ? 2 : (gold_tags[last] != t_id) ? 1.f : 0.f);
175  assert(!stack.empty());
176  stack.pop();
177  return idx;
178  }
179  else if (a_id == REDUCE_LEFT)
180  {
181  size_t last = stack.last();
182  uint32_t hd = idx;
183  heads[last] = hd;
184  children[3][hd] = children[2][hd];
185  children[2][hd] = (uint32_t)last;
186  children[0][hd]++;
187  tags[last] = t_id;
188  sch.loss(gold_heads[last] != heads[last] ? 2 : (gold_tags[last] != t_id) ? 1.f : 0.f);
189  assert(!stack.empty());
190  stack.pop();
191  return idx;
192  }
193  THROW("transition_hybrid failed");
194 }
195 
196 // arc-eager system
197 size_t transition_eager(Search::search &sch, uint64_t a_id, uint32_t idx, uint32_t t_id, uint32_t n)
198 {
199  task_data *data = sch.get_task_data<task_data>();
200  v_array<uint32_t> &heads = data->heads, &stack = data->stack, &gold_heads = data->gold_heads,
201  &gold_tags = data->gold_tags, &tags = data->tags;
203  if (a_id == SHIFT)
204  {
205  stack.push_back(idx);
206  return idx + 1;
207  }
208  else if (a_id == REDUCE_RIGHT)
209  {
210  uint32_t hd = stack.last();
211  stack.push_back(idx);
212  uint32_t last = idx;
213  heads[last] = hd;
214  children[5][hd] = children[4][hd];
215  children[4][hd] = last;
216  children[1][hd]++;
217  tags[last] = t_id;
218  sch.loss(gold_heads[last] != heads[last] ? 2 : (gold_tags[last] != t_id) ? 1.f : 0.f);
219  return idx + 1;
220  }
221  else if (a_id == REDUCE_LEFT)
222  {
223  size_t last = stack.last();
224  uint32_t hd = (idx > n) ? 0 : idx;
225  heads[last] = hd;
226  children[3][hd] = children[2][hd];
227  children[2][hd] = (uint32_t)last;
228  children[0][hd]++;
229  tags[last] = t_id;
230  sch.loss(gold_heads[last] != heads[last] ? 2 : (gold_tags[last] != t_id) ? 1.f : 0.f);
231  assert(!stack.empty());
232  stack.pop();
233  return idx;
234  }
235  else if (a_id == REDUCE)
236  {
237  assert(!stack.empty());
238  stack.pop();
239  return idx;
240  }
241  THROW("transition_eager failed");
242 }
243 
244 void extract_features(Search::search &sch, uint32_t idx, multi_ex &ec)
245 {
246  vw &all = sch.get_vw_pointer_unsafe();
247  task_data *data = sch.get_task_data<task_data>();
248  reset_ex(data->ex);
249  uint64_t mask = sch.get_mask();
250  uint64_t multiplier = (uint64_t)all.wpp << all.weights.stride_shift();
251 
252  v_array<uint32_t> &stack = data->stack, &tags = data->tags, *children = data->children, &temp = data->temp;
253  example **ec_buf = data->ec_buf;
254  example &ex = *(data->ex);
255 
256  size_t n = ec.size();
257  bool empty = stack.empty();
258  size_t last = empty ? 0 : stack.last();
259 
260  for (size_t i = 0; i < 13; i++) ec_buf[i] = nullptr;
261 
262  // feature based on the top three examples in stack ec_buf[0]: s1, ec_buf[1]: s2, ec_buf[2]: s3
263  for (size_t i = 0; i < 3; i++)
264  ec_buf[i] = (stack.size() > i && *(stack.end() - (i + 1)) != 0) ? ec[*(stack.end() - (i + 1)) - 1] : 0;
265 
266  // features based on examples in string buffer ec_buf[3]: b1, ec_buf[4]: b2, ec_buf[5]: b3
267  for (size_t i = 3; i < 6; i++) ec_buf[i] = (idx + (i - 3) - 1 < n) ? ec[idx + i - 3 - 1] : 0;
268 
269  // features based on the leftmost and the rightmost children of the top element stack ec_buf[6]: sl1, ec_buf[7]: sl2,
270  // ec_buf[8]: sr1, ec_buf[9]: sr2;
271  for (size_t i = 6; i < 10; i++)
272  if (!empty && last != 0 && children[i - 4][last] != 0)
273  ec_buf[i] = ec[children[i - 4][last] - 1];
274 
275  // features based on leftmost children of the top element in bufer ec_buf[10]: bl1, ec_buf[11]: bl2
276  for (size_t i = 10; i < 12; i++)
277  ec_buf[i] = (idx <= n && children[i - 8][idx] != 0) ? ec[children[i - 8][idx] - 1] : 0;
278  ec_buf[12] = (stack.size() > 1 && *(stack.end() - 2) != 0 && children[2][*(stack.end() - 2)] != 0)
279  ? ec[children[2][*(stack.end() - 2)] - 1]
280  : 0;
281 
282  // unigram features
283  for (size_t i = 0; i < 13; i++)
284  {
285  uint64_t additional_offset = (uint64_t)(i * offset_const);
286  if (!ec_buf[i])
287  add_feature(ex, (uint64_t)438129041 + additional_offset, (unsigned char)((i + 1) + 'A'), mask, multiplier);
288  else
289  add_all_features(ex, *ec_buf[i], 'A' + (unsigned char)(i + 1), mask, multiplier, additional_offset, false);
290  }
291 
292  // Other features
293  temp.resize(10);
294  temp[0] = empty ? 0 : (idx > n ? 1 : 2 + std::min(static_cast<uint32_t>(5), idx - (uint32_t)last));
295  temp[1] = empty ? 1 : 1 + std::min(static_cast<uint32_t>(5), children[0][last]);
296  temp[2] = empty ? 1 : 1 + std::min(static_cast<uint32_t>(5), children[1][last]);
297  temp[3] = idx > n ? 1 : 1 + std::min(static_cast<uint32_t>(5), children[0][idx]);
298  for (size_t i = 4; i < 8; i++) temp[i] = (!empty && children[i - 2][last] != 0) ? tags[children[i - 2][last]] : 15;
299  for (size_t i = 8; i < 10; i++) temp[i] = (idx <= n && children[i - 6][idx] != 0) ? tags[children[i - 6][idx]] : 15;
300 
301  uint64_t additional_offset = val_namespace * offset_const;
302  for (size_t j = 0; j < 10; j++)
303  {
304  additional_offset += j * 1023;
305  add_feature(ex, temp[j] + additional_offset, val_namespace, mask, multiplier);
306  }
307  size_t count = 0;
308  for (features fs : *data->ex)
309  {
310  fs.sum_feat_sq = (float)fs.size();
311  count += fs.size();
312  }
313 
314  size_t new_count;
315  float new_weight;
316  INTERACTIONS::eval_count_of_generated_ft(all, *data->ex, new_count, new_weight);
317 
318  data->ex->num_features = count + new_count;
319  data->ex->total_sum_feat_sq = (float)count + new_weight;
320 }
321 
322 void get_valid_actions(Search::search &sch, v_array<uint32_t> &valid_action, uint64_t idx, uint64_t n,
323  uint64_t stack_depth, uint64_t state)
324 {
325  task_data *data = sch.get_task_data<task_data>();
326  uint32_t &sys = data->transition_system;
327  v_array<uint32_t> &stack = data->stack, &heads = data->heads, &temp = data->temp;
328  valid_action.clear();
329  if (sys == arc_hybrid)
330  {
331  if (idx <= n) // SHIFT
332  valid_action.push_back(SHIFT);
333  if (stack_depth >= 2) // RIGHT
334  valid_action.push_back(REDUCE_RIGHT);
335  if (stack_depth >= 1 && state != 0 && idx <= n) // LEFT
336  valid_action.push_back(REDUCE_LEFT);
337  }
338  else if (sys == arc_eager) // assume root is in N+1
339  {
340  temp.clear();
341  for (size_t i = 0; i <= 4; i++) temp.push_back(1);
342  if (idx > n)
343  {
344  temp[SHIFT] = 0;
345  temp[REDUCE_RIGHT] = 0;
346  }
347 
348  if (stack_depth == 0)
349  temp[REDUCE] = 0;
350  else if (idx <= n + 1 && heads[stack.last()] == my_null)
351  temp[REDUCE] = 0;
352 
353  if (stack_depth == 0)
354  {
355  temp[REDUCE_LEFT] = 0;
356  temp[REDUCE_RIGHT] = 0;
357  }
358  else
359  {
360  if (heads[stack.last()] != my_null)
361  temp[REDUCE_LEFT] = 0;
362  if (idx <= n && heads[idx] != my_null)
363  temp[REDUCE_RIGHT] = 0;
364  }
365  for (uint32_t i = 1; i <= 4; i++)
366  {
367  if (temp[i])
368  valid_action.push_back(i);
369  }
370  }
371 }
372 
373 bool is_valid(uint64_t action, v_array<uint32_t> valid_actions)
374 {
375  for (size_t i = 0; i < valid_actions.size(); i++)
376  if (valid_actions[i] == action)
377  return true;
378  return false;
379 }
380 
381 void get_eager_action_cost(Search::search &sch, uint32_t idx, uint64_t n)
382 {
383  task_data *data = sch.get_task_data<task_data>();
384  v_array<uint32_t> &action_loss = data->action_loss, &stack = data->stack, &gold_heads = data->gold_heads,
385  heads = data->heads;
386  size_t size = stack.size();
387  size_t last = (size == 0) ? 0 : stack.last();
388  for (size_t i = 1; i <= 4; i++) action_loss[i] = 0;
389  if (!stack.empty())
390  for (size_t i = 0; i < size; i++)
391  {
392  if (gold_heads[stack[i]] == idx && heads[stack[i]] == my_null)
393  {
394  action_loss[SHIFT] += 1;
395  action_loss[REDUCE_RIGHT] += 1;
396  }
397  if (idx <= n && (gold_heads[idx] == stack[i]))
398  {
399  if (stack[i] != 0)
400  action_loss[SHIFT] += 1;
401  if (stack[i] != last)
402  action_loss[REDUCE_RIGHT] += 1;
403  }
404  }
405  for (size_t i = idx; i <= n + 1; i++)
406  {
407  if (i <= n && gold_heads[i] == last)
408  {
409  action_loss[REDUCE] += 1;
410  action_loss[REDUCE_LEFT] += 1;
411  }
412  if (i != idx && gold_heads[last] == i)
413  action_loss[REDUCE_LEFT] += 1;
414  }
415  // if(size>0 && idx <=n && gold_heads[last] == 0 && stack[0] ==0) //should not fire
416  // action_loss[REDUCE_LEFT] +=1;
417 
418  if (gold_heads[idx] > idx || (gold_heads[idx] == 0 && size > 0 && stack[0] != 0))
419  action_loss[REDUCE_RIGHT] += 1;
420 }
421 
422 void get_hybrid_action_cost(Search::search &sch, size_t idx, uint64_t n)
423 {
424  task_data *data = sch.get_task_data<task_data>();
425  v_array<uint32_t> &action_loss = data->action_loss, &stack = data->stack, &gold_heads = data->gold_heads;
426  size_t size = stack.size();
427  size_t last = (size == 0) ? 0 : stack.last();
428 
429  for (size_t i = 1; i <= 3; i++) action_loss[i] = 0;
430  if (!stack.empty())
431  for (size_t i = 0; i < size - 1; i++)
432  if (idx <= n && (gold_heads[stack[i]] == idx || gold_heads[idx] == stack[i]))
433  action_loss[SHIFT] += 1;
434 
435  if (size > 0 && gold_heads[last] == idx)
436  action_loss[SHIFT] += 1;
437 
438  for (size_t i = idx + 1; i <= n; i++)
439  if (gold_heads[i] == last || gold_heads[last] == i)
440  action_loss[REDUCE_LEFT] += 1;
441  if (size > 0 && idx <= n && gold_heads[idx] == last)
442  action_loss[REDUCE_LEFT] += 1;
443  if (size >= 2 && gold_heads[last] == stack[size - 2])
444  action_loss[REDUCE_LEFT] += 1;
445 
446  if (gold_heads[last] >= idx)
447  action_loss[REDUCE_RIGHT] += 1;
448 
449  for (size_t i = idx; i <= n; i++)
450  if (gold_heads[i] == (uint32_t)last)
451  action_loss[REDUCE_RIGHT] += 1;
452 }
453 
454 void get_cost_to_go_losses(Search::search &sch, v_array<std::pair<action, float>> &gold_action_losses,
455  uint32_t left_label, uint32_t right_label)
456 {
457  task_data *data = sch.get_task_data<task_data>();
458  bool &one_learner = data->one_learner;
459  uint32_t &sys = data->transition_system;
460  v_array<uint32_t> &action_loss = data->action_loss, &valid_actions = data->valid_actions;
461  uint32_t &num_label = data->num_label;
462  gold_action_losses.clear();
463 
464  if (one_learner)
465  {
466  if (is_valid(SHIFT, valid_actions))
467  gold_action_losses.push_back(std::make_pair(SHIFT, (float)action_loss[SHIFT]));
468  for (uint32_t i = 2; i <= 3; i++)
469  if (is_valid(i, valid_actions))
470  {
471  for (uint32_t j = 1; j <= num_label; j++)
472  if (sys == arc_eager || j != data->root_label)
473  gold_action_losses.push_back(std::make_pair((1 + j + (i - 2) * num_label),
474  action_loss[i] + (float)(j != (i == REDUCE_LEFT ? left_label : right_label))));
475  }
476  if (sys == arc_eager && is_valid(REDUCE, valid_actions))
477  gold_action_losses.push_back(std::make_pair(2 + num_label * 2, (float)action_loss[REDUCE]));
478  }
479  else
480  {
481  for (action i = 1; i <= 3; i++)
482  if (is_valid(i, valid_actions))
483  gold_action_losses.push_back(std::make_pair(i, (float)action_loss[i]));
484  if (sys == arc_eager && is_valid(REDUCE, valid_actions))
485  gold_action_losses.push_back(std::make_pair(REDUCE, (float)action_loss[REDUCE]));
486  }
487 }
488 
489 void get_gold_actions(Search::search &sch, uint32_t idx, uint64_t /* n */, v_array<action> &gold_actions)
490 {
491  task_data *data = sch.get_task_data<task_data>();
492  v_array<uint32_t> &action_loss = data->action_loss, &stack = data->stack, &gold_heads = data->gold_heads,
493  &valid_actions = data->valid_actions;
494  gold_actions.clear();
495  size_t size = stack.size();
496  size_t last = (size == 0) ? 0 : stack.last();
497  uint32_t &sys = data->transition_system;
498 
499  if (sys == arc_hybrid && is_valid(SHIFT, valid_actions) && (stack.empty() || gold_heads[idx] == last))
500  {
501  gold_actions.push_back(SHIFT);
502  return;
503  }
504 
505  if (sys == arc_hybrid && is_valid(REDUCE_LEFT, valid_actions) && gold_heads[last] == idx)
506  {
507  gold_actions.push_back(REDUCE_LEFT);
508  return;
509  }
510  size_t best_action = 1;
511  size_t count = 0;
512  for (uint32_t i = 1; i <= 4; i++)
513  {
514  if (i == 4 && sys == arc_hybrid)
515  continue;
516  if (action_loss[i] < action_loss[best_action] && is_valid(i, valid_actions))
517  {
518  best_action = i;
519  count = 1;
520  gold_actions.clear();
521  gold_actions.push_back(i);
522  }
523  else if (action_loss[i] == action_loss[best_action] && is_valid(i, valid_actions))
524  {
525  count++;
526  gold_actions.push_back(i);
527  }
528  }
529 }
530 
532  uint32_t left_label, uint32_t right_label)
533 {
534  task_data *data = sch.get_task_data<task_data>();
535  uint32_t &sys = data->transition_system;
536  uint32_t &num_label = data->num_label;
537  actions_onelearner.clear();
538  if (is_valid(SHIFT, actions))
539  actions_onelearner.push_back(SHIFT);
540  if (sys == arc_eager && is_valid(REDUCE, actions))
541  actions_onelearner.push_back(2 + 2 * num_label);
542  if (left_label != my_null && is_valid(REDUCE_RIGHT, actions))
543  actions_onelearner.push_back(1 + right_label);
544  if (left_label != my_null && is_valid(REDUCE_LEFT, actions))
545  actions_onelearner.push_back(1 + left_label + num_label);
546  if (left_label == my_null && is_valid(REDUCE_RIGHT, actions))
547  for (uint32_t i = 0; i < num_label; i++)
548  if (i != data->root_label - 1)
549  actions_onelearner.push_back(i + 2);
550  if (left_label == my_null && is_valid(REDUCE_LEFT, actions))
551  for (uint32_t i = 0; i < num_label; i++)
552  if (sys == arc_eager || i != data->root_label - 1)
553  actions_onelearner.push_back((uint32_t)(i + 2 + num_label));
554 }
555 
557 {
558  task_data *data = sch.get_task_data<task_data>();
559  v_array<uint32_t> &gold_heads = data->gold_heads, &heads = data->heads, &gold_tags = data->gold_tags,
560  &tags = data->tags;
561  size_t n = ec.size();
562  heads.resize(n + 1);
563  tags.resize(n + 1);
564  gold_heads.clear();
565  gold_heads.push_back(0);
566  gold_tags.clear();
567  gold_tags.push_back(0);
568  for (size_t i = 0; i < n; i++)
569  {
570  v_array<COST_SENSITIVE::wclass> &costs = ec[i]->l.cs.costs;
571  uint32_t head, tag;
572  if (data->old_style_labels)
573  {
574  uint32_t label = costs[0].class_index;
575  head = (label & 255) - 1;
576  tag = label >> 8;
577  }
578  else
579  {
580  head = (costs.size() == 0) ? 0 : costs[0].class_index;
581  tag = (costs.size() <= 1) ? (uint32_t)data->root_label : costs[1].class_index;
582  }
583  if (tag > data->num_label)
584  THROW("invalid label " << tag << " which is > num actions=" << data->num_label);
585 
586  gold_heads.push_back(head);
587  gold_tags.push_back(tag);
588  heads[i + 1] = my_null;
589  tags[i + 1] = my_null;
590  }
591  for (size_t i = 0; i < 6; i++) data->children[i].resize(n + (size_t)1);
592 }
593 
594 void run(Search::search &sch, multi_ex &ec)
595 {
596  task_data *data = sch.get_task_data<task_data>();
597  v_array<uint32_t> &stack = data->stack, &gold_heads = data->gold_heads, &valid_actions = data->valid_actions,
598  &heads = data->heads, &gold_tags = data->gold_tags, &tags = data->tags,
599  &valid_action_temp = data->valid_action_temp;
600  v_array<uint32_t> &gold_action_temp = data->gold_action_temp;
601  v_array<std::pair<action, float>> &gold_action_losses = data->gold_action_losses;
602  v_array<action> &gold_actions = data->gold_actions;
603  bool &cost_to_go = data->cost_to_go, &one_learner = data->one_learner;
604  uint32_t &num_label = data->num_label;
605  uint32_t &sys = data->transition_system;
606  uint32_t n = (uint32_t)ec.size();
607  uint32_t left_label, right_label;
608  stack.clear();
609  stack.push_back((data->root_label == 0 && sys == arc_hybrid) ? 0 : 1);
610  for (size_t i = 0; i < 6; i++)
611  for (size_t j = 0; j < n + 1; j++) data->children[i][j] = 0;
612  for (size_t i = 0; i < n; i++)
613  {
614  heads[i + 1] = my_null;
615  tags[i + 1] = my_null;
616  }
617  ptag count = 1;
618  uint32_t idx = ((data->root_label == 0 && sys == arc_hybrid) ? 1 : 2);
619  Search::predictor P(sch, (ptag)0);
620  while (true)
621  {
622  if (sys == arc_hybrid && stack.size() <= 1 && idx > n)
623  break;
624  else if (sys == arc_eager && stack.size() == 0 && idx > n)
625  break;
626  bool computedFeatures = false;
627  if (sch.predictNeedsExample())
628  {
629  extract_features(sch, idx, ec);
630  computedFeatures = true;
631  }
632  get_valid_actions(sch, valid_actions, idx, n, (uint64_t)stack.size(), stack.empty() ? 0 : stack.last());
633  if (sys == arc_hybrid)
634  get_hybrid_action_cost(sch, idx, n);
635  else if (sys == arc_eager)
636  get_eager_action_cost(sch, idx, n);
637 
638  // get gold tag labels
639  left_label = stack.empty() ? my_null : gold_tags[stack.last()];
640  if (sys == arc_hybrid)
641  right_label = stack.empty() ? my_null : gold_tags[stack.last()];
642  else if (sys == arc_eager)
643  right_label = idx <= n ? gold_tags[idx] : (uint32_t)data->root_label;
644  else
645  THROW("unknown transition system");
646 
647  uint32_t a_id = 0, t_id = 0;
648  if (one_learner)
649  {
650  if (cost_to_go)
651  {
652  get_cost_to_go_losses(sch, gold_action_losses, left_label, right_label);
653  a_id = P.set_tag((ptag)count)
654  .set_input(*(data->ex))
655  .set_allowed(gold_action_losses)
656  .set_condition_range(count - 1, sch.get_history_length(), 'p')
657  .set_learner_id(0)
658  .predict();
659  }
660  else
661  {
662  get_gold_actions(sch, idx, n, gold_actions);
663  convert_to_onelearner_actions(sch, gold_actions, gold_action_temp, left_label, right_label);
664  convert_to_onelearner_actions(sch, valid_actions, valid_action_temp, my_null, my_null);
665  a_id = P.set_tag((ptag)count)
666  .set_input(*(data->ex))
667  .set_oracle(gold_action_temp)
668  .set_allowed(valid_action_temp)
669  .set_condition_range(count - 1, sch.get_history_length(), 'p')
670  .set_learner_id(0)
671  .predict();
672  }
673  if (a_id == SHIFT)
674  t_id = 0;
675  else if (a_id == 2 * num_label + 2)
676  {
677  t_id = 0;
678  a_id = REDUCE;
679  }
680  else if (a_id > 1 && a_id - 1 <= num_label)
681  {
682  t_id = a_id - 1;
683  a_id = REDUCE_RIGHT;
684  }
685  else
686  {
687  t_id = (uint64_t)a_id - num_label - 1;
688  a_id = REDUCE_LEFT;
689  }
690  }
691  else
692  {
693  if (cost_to_go)
694  {
695  get_cost_to_go_losses(sch, gold_action_losses, left_label, right_label);
696  a_id = P.set_tag((ptag)count)
697  .set_input(*(data->ex))
698  .set_allowed(gold_action_losses)
699  .set_condition_range(count - 1, sch.get_history_length(), 'p')
700  .set_learner_id(0)
701  .predict();
702  }
703  else
704  {
705  get_gold_actions(sch, idx, n, gold_actions);
706  a_id = P.set_tag((ptag)count)
707  .set_input(*(data->ex))
708  .set_oracle(gold_actions)
709  .set_allowed(valid_actions)
710  .set_condition_range(count - 1, sch.get_history_length(), 'p')
711  .set_learner_id(0)
712  .predict();
713  }
714 
715  // Predict the next action {SHIFT, REDUCE_LEFT, REDUCE_RIGHT}
716  count++;
717 
718  if (a_id != SHIFT && a_id != REDUCE)
719  {
720  if ((!computedFeatures) && sch.predictNeedsExample())
721  extract_features(sch, idx, ec);
722 
723  if (cost_to_go)
724  {
725  gold_action_losses.clear();
726  for (size_t i = 1; i <= data->num_label; i++)
727  gold_action_losses.push_back(
728  std::make_pair((action)i, i != (a_id == REDUCE_LEFT ? left_label : right_label)));
729  t_id = P.set_tag((ptag)count)
730  .set_input(*(data->ex))
731  .set_allowed(gold_action_losses)
732  .set_condition_range(count - 1, sch.get_history_length(), 'p')
733  .set_learner_id(a_id - 1)
734  .predict();
735  }
736  else
737  {
738  t_id = P.set_tag((ptag)count)
739  .set_input(*(data->ex))
740  .set_oracle(a_id == REDUCE_LEFT ? left_label : right_label)
741  .erase_alloweds()
742  .set_condition_range(count - 1, sch.get_history_length(), 'p')
743  .set_learner_id(a_id - 1)
744  .predict();
745  }
746  }
747  }
748  count++;
749  if (sys == arc_hybrid)
750  idx = (uint32_t)transition_hybrid(sch, a_id, idx, t_id, n);
751  else if (sys == arc_eager)
752  idx = (uint32_t)transition_eager(sch, a_id, idx, t_id, n);
753  }
754  if (sys == arc_hybrid)
755  {
756  heads[stack.last()] = 0;
757  tags[stack.last()] = (uint32_t)data->root_label;
758  sch.loss((gold_heads[stack.last()] != heads[stack.last()]));
759  }
760  if (sch.output().good())
761  for (size_t i = 1; i <= n; i++) sch.output() << (heads[i]) << ":" << tags[i] << std::endl;
762 }
763 } // namespace DepParserTask
predictor & set_oracle(action a)
Definition: search.cc:3314
void resize(size_t length)
Definition: v_array.h:69
v_array< namespace_index > indices
uint32_t get_history_length()
Definition: search.cc:3098
parameters weights
Definition: global_data.h:537
Search::search_task task
v_array< uint32_t > heads
vw * setup(options_i &options)
Definition: main.cc:27
std::vector< std::string > pairs
Definition: global_data.h:459
void push_back(feature_value v, feature_index i)
void add_all_features(example &ex, example &src, unsigned char tgt_ns, uint64_t mask, uint64_t multiplier, uint64_t offset, bool=false)
void set_label_parser(label_parser &lp, bool(*is_test)(polylabel &))
Definition: search.cc:3079
label_parser cs_label
std::stringstream & output()
Definition: search.cc:3043
v_array< std::pair< action, float > > gold_action_losses
void get_gold_actions(Search::search &sch, uint32_t idx, uint64_t, v_array< action > &gold_actions)
constexpr action SHIFT
Definition: search.cc:33
void eval_count_of_generated_ft(vw &all, example &ec, size_t &new_features_cnt, float &new_features_value)
std::vector< std::string > * interactions
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
uint32_t num_label
the core definition of a set of features.
uint32_t transition_system
void delete_label(void *v)
Definition: cb.cc:98
constexpr action REDUCE_LEFT
uint32_t action
Definition: search.h:19
size_t transition_hybrid(Search::search &sch, uint64_t a_id, uint32_t idx, uint32_t t_id, uint32_t)
virtual void add_and_parse(const option_group_definition &group)=0
action predict()
Definition: search.cc:3460
void finish(vw &all, bool delete_all)
Definition: parse_args.cc:1823
uint64_t get_mask()
Definition: search.cc:3096
void get_valid_actions(Search::search &sch, v_array< uint32_t > &valid_action, uint64_t idx, uint64_t n, uint64_t stack_depth, uint64_t state)
v_array< uint32_t > action_loss
size_t size() const
Definition: v_array.h:68
T * get_task_data()
Definition: search.h:89
uint32_t ACTION_COSTS
Definition: search.cc:50
#define offset_const
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
void reset_ex(example *ex)
uint32_t NO_CACHING
Definition: search.cc:49
v_array< uint32_t > children[6]
std::array< features, NUM_NAMESPACES > feature_space
uint32_t AUTO_CONDITION_FEATURES
Definition: search.cc:49
void add_feature(example &ex, uint64_t idx, unsigned char ns, uint64_t mask, uint64_t multiplier, bool=false)
size_t size() const
constexpr action REDUCE
vw & get_vw_pointer_unsafe()
Definition: search.cc:3115
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
predictor & erase_alloweds()
Definition: search.cc:3330
v_array< uint32_t > gold_tags
#define val_namespace
void get_cost_to_go_losses(Search::search &sch, v_array< std::pair< action, float >> &gold_action_losses, uint32_t left_label, uint32_t right_label)
void clear()
Definition: v_array.h:88
v_array< uint32_t > valid_action_temp
#define arc_hybrid
size_t num_features
Definition: example.h:67
unsigned char namespace_index
void extract_features(Search::search &sch, uint32_t idx, multi_ex &ec)
uint64_t feature_index
Definition: feature_group.h:21
vw * initialize(options_i &options, io_buf *model, bool skipModelLoad, trace_message_t trace_listener, void *trace_context)
Definition: parse_args.cc:1654
v_array< uint32_t > stack
v_array< action > gold_actions
void set_options(uint32_t opts)
Definition: search.cc:3053
size_t transition_eager(Search::search &sch, uint64_t a_id, uint32_t idx, uint32_t t_id, uint32_t n)
constexpr uint32_t my_null
bool is_valid(uint64_t action, v_array< uint32_t > valid_actions)
predictor & set_allowed(action a)
Definition: search.cc:3352
uint32_t wpp
Definition: global_data.h:432
std::vector< std::string > triples
Definition: global_data.h:461
T *& end()
Definition: v_array.h:43
example * ex
std::vector< example * > multi_ex
Definition: example.h:122
void set_task_data(T *data)
Definition: search.h:84
predictor & set_condition_range(ptag hi, ptag count, char name0)
Definition: search.cc:3441
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float total_sum_feat_sq
Definition: example.h:71
float sum_feat_sq
v_array< uint32_t > temp
bool empty() const
Definition: v_array.h:59
v_array< uint32_t > valid_actions
#define arc_eager
predictor & set_tag(ptag tag)
Definition: search.cc:3454
std::vector< std::string > interactions
Definition: global_data.h:457
uint32_t stride_shift()
constexpr action REDUCE_RIGHT
v_array< uint32_t > gold_heads
void get_eager_action_cost(Search::search &sch, uint32_t idx, uint64_t n)
T last() const
Definition: v_array.h:57
v_array< uint32_t > tags
void delete_v()
Definition: v_array.h:98
uint32_t ptag
Definition: search.h:20
void get_hybrid_action_cost(Search::search &sch, size_t idx, uint64_t n)
bool predictNeedsExample()
Definition: search.cc:3041
v_array< wclass > costs
constexpr unsigned char constant_namespace
Definition: constant.h:22
bool children(log_multi &b, uint32_t &current, uint32_t &class_index, uint32_t label)
Definition: log_multi.cc:178
#define THROW(args)
Definition: vw_exception.h:181
void set_num_learners(size_t num_learners)
Definition: search.cc:3094
void convert_to_onelearner_actions(Search::search &sch, v_array< action > &actions, v_array< action > &actions_onelearner, uint32_t left_label, uint32_t right_label)
predictor & set_input(example &input_example)
Definition: search.cc:3173
float f
Definition: cache.cc:40
v_array< action > gold_action_temp
void loss(float incr_loss)
Definition: search.cc:3039
example * ec_buf[13]
void run(Search::search &sch, multi_ex &ec)