Vowpal Wabbit
search_entityrelationtask.cc
Go to the documentation of this file.
1 /*
2  CoPyright (c) by respective owners including Yahoo!, Microsoft, and
3  individual contributors. All rights reserved. Released under a BSD (revised)
4  license as described in the file LICENSE.
5 */
7 #include "vw.h"
8 
9 using namespace VW::config;
10 
11 #define R_NONE 10 // label for NONE relation
12 #define LABEL_SKIP 11 // label for SKIP
13 
15 {
16 Search::search_task task = {"entity_relation", run, initialize, finish, nullptr, nullptr};
17 }
18 
19 namespace EntityRelationTask
20 {
21 using namespace Search;
22 namespace CS = COST_SENSITIVE;
23 
24 void update_example_indicies(bool audit, example* ec, uint64_t mult_amount, uint64_t plus_amount);
25 
26 struct task_data
27 {
29  float entity_cost;
31  float skip_cost;
33  bool allow_skip;
36  size_t search_order;
39 };
40 
41 void initialize(Search::search& sch, size_t& /*num_actions*/, options_i& options)
42 {
43  task_data* my_task_data = new task_data();
44  sch.set_task_data<task_data>(my_task_data);
45 
46  option_group_definition new_options("Entity Relation Options");
47  new_options
48  .add(make_option("relation_cost", my_task_data->relation_cost).keep().default_value(1.f).help("Relation Cost"))
49  .add(make_option("entity_cost", my_task_data->entity_cost).keep().default_value(1.f).help("Entity Cost"))
50  .add(make_option("constraints", my_task_data->constraints).keep().help("Use Constraints"))
51  .add(make_option("relation_none_cost", my_task_data->relation_none_cost)
52  .keep()
53  .default_value(0.5f)
54  .help("None Relation Cost"))
55  .add(make_option("skip_cost", my_task_data->skip_cost)
56  .keep()
57  .default_value(0.01f)
58  .help("Skip Cost (only used when search_order = skip"))
59  .add(make_option("search_order", my_task_data->search_order)
60  .keep()
61  .default_value(0)
62  .help("Search Order 0: EntityFirst 1: Mix 2: Skip 3: EntityFirst(LDF)"));
63  options.add_and_parse(new_options);
64 
65  // setup entity and relation labels
66  // Entity label 1:E_Other 2:E_Peop 3:E_Org 4:E_Loc
67  // Relation label 5:R_Live_in 6:R_OrgBased_in 7:R_Located_in 8:R_Work_For 9:R_Kill 10:R_None
68  for (int i = 1; i < 5; i++) my_task_data->y_allowed_entity.push_back(i);
69 
70  for (int i = 5; i < 11; i++) my_task_data->y_allowed_relation.push_back(i);
71 
72  my_task_data->allow_skip = false;
73 
74  if (my_task_data->search_order != 3 && my_task_data->search_order != 4)
75  {
76  sch.set_options(0);
77  }
78  else
79  {
80  example* ldf_examples = VW::alloc_examples(sizeof(CS::label), 10);
81  CS::wclass default_wclass = {0., 0, 0., 0.};
82  for (size_t a = 0; a < 10; a++)
83  {
84  ldf_examples[a].l.cs.costs.push_back(default_wclass);
85  ldf_examples[a].interactions = &sch.get_vw_pointer_unsafe().interactions;
86  }
87  my_task_data->ldf_entity = ldf_examples;
88  my_task_data->ldf_relation = ldf_examples + 4;
90  }
91 
92  sch.set_num_learners(2);
93  if (my_task_data->search_order == 4)
94  sch.set_num_learners(3);
95 }
96 
98 {
99  task_data* my_task_data = sch.get_task_data<task_data>();
100  my_task_data->y_allowed_entity.delete_v();
101  my_task_data->y_allowed_relation.delete_v();
102  if (my_task_data->search_order == 3)
103  {
104  for (size_t a = 0; a < 10; a++) VW::dealloc_example(CS::cs_label.delete_label, my_task_data->ldf_entity[a]);
105  free(my_task_data->ldf_entity);
106  }
107  delete my_task_data;
108 } // if we had task data, we'd want to free it here
109 
110 bool check_constraints(size_t ent1_id, size_t ent2_id, size_t rel_id)
111 {
112  size_t valid_ent1_id[] = {2, 3, 4, 2, 2}; // encode the valid entity-relation combinations
113  size_t valid_ent2_id[] = {4, 4, 4, 3, 2};
114  if (rel_id - 5 == 5)
115  return true;
116  if (valid_ent1_id[rel_id - 5] == ent1_id && valid_ent2_id[rel_id - 5] == ent2_id)
117  return true;
118  return false;
119 }
120 
121 void decode_tag(v_array<char> tag, char& type, int& id1, int& id2)
122 {
123  std::string s1;
124  std::string s2;
125  type = tag[0];
126  size_t idx = 2;
127  while (idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0')
128  {
129  s1.push_back(tag[idx]);
130  idx++;
131  }
132  id1 = atoi(s1.c_str());
133  idx++;
134  assert(type == 'R');
135  while (idx < tag.size() && tag[idx] != '_' && tag[idx] != '\0')
136  {
137  s2.push_back(tag[idx]);
138  idx++;
139  }
140  id2 = atoi(s2.c_str());
141 }
142 
144  Search::search& sch, example* ex, v_array<size_t>& /*predictions*/, ptag my_tag, bool isLdf = false)
145 {
146  task_data* my_task_data = sch.get_task_data<task_data>();
147  size_t prediction;
148  if (my_task_data->allow_skip)
149  {
150  v_array<uint32_t> star_labels = v_init<uint32_t>();
151  star_labels.push_back(ex->l.multi.label);
152  star_labels.push_back(LABEL_SKIP);
153  my_task_data->y_allowed_entity.push_back(LABEL_SKIP);
154  prediction = Search::predictor(sch, my_tag)
155  .set_input(*ex)
156  .set_oracle(star_labels)
157  .set_allowed(my_task_data->y_allowed_entity)
158  .set_learner_id(1)
159  .predict();
160  my_task_data->y_allowed_entity.pop();
161  }
162  else
163  {
164  if (isLdf)
165  {
166  for (uint32_t a = 0; a < 4; a++)
167  {
168  VW::copy_example_data(false, &my_task_data->ldf_entity[a], ex);
169  update_example_indicies(true, &my_task_data->ldf_entity[a], 28904713, 4832917 * (uint64_t)(a + 1));
170  CS::label& lab = my_task_data->ldf_entity[a].l.cs;
171  lab.costs[0].x = 0.f;
172  lab.costs[0].class_index = a;
173  lab.costs[0].partial_prediction = 0.f;
174  lab.costs[0].wap_value = 0.f;
175  }
176  prediction = Search::predictor(sch, my_tag)
177  .set_input(my_task_data->ldf_entity, 4)
178  .set_oracle(ex->l.multi.label - 1)
179  .set_learner_id(1)
180  .predict() +
181  1;
182  }
183  else
184  {
185  prediction = Search::predictor(sch, my_tag)
186  .set_input(*ex)
187  .set_oracle(ex->l.multi.label)
188  .set_allowed(my_task_data->y_allowed_entity)
189  .set_learner_id(0)
190  .predict();
191  }
192  }
193 
194  // record loss
195  float loss = 0.0;
196  if (prediction == LABEL_SKIP)
197  {
198  loss = my_task_data->skip_cost;
199  }
200  else if (prediction != ex->l.multi.label)
201  loss = my_task_data->entity_cost;
202  sch.loss(loss);
203  return prediction;
204 }
205 size_t predict_relation(Search::search& sch, example* ex, v_array<size_t>& predictions, ptag my_tag, bool isLdf = false)
206 {
207  char type;
208  int id1, id2;
209  task_data* my_task_data = sch.get_task_data<task_data>();
210  size_t hist[2];
211  decode_tag(ex->tag, type, id1, id2);
212  v_array<uint32_t> constrained_relation_labels = v_init<uint32_t>();
213  if (my_task_data->constraints && predictions[id1] != 0 && predictions[id2] != 0)
214  {
215  hist[0] = predictions[id1];
216  hist[1] = predictions[id2];
217  }
218  else
219  {
220  hist[0] = 0;
221  hist[1] = 0;
222  }
223  for (size_t j = 0; j < my_task_data->y_allowed_relation.size(); j++)
224  {
225  if (!my_task_data->constraints || hist[0] == (size_t)0 ||
226  check_constraints(hist[0], hist[1], my_task_data->y_allowed_relation[j]))
227  constrained_relation_labels.push_back(my_task_data->y_allowed_relation[j]);
228  }
229 
230  size_t prediction;
231  if (my_task_data->allow_skip)
232  {
233  v_array<uint32_t> star_labels = v_init<uint32_t>();
234  star_labels.push_back(ex->l.multi.label);
235  star_labels.push_back(LABEL_SKIP);
236  constrained_relation_labels.push_back(LABEL_SKIP);
237  prediction = Search::predictor(sch, my_tag)
238  .set_input(*ex)
239  .set_oracle(star_labels)
240  .set_allowed(constrained_relation_labels)
241  .set_learner_id(2)
242  .add_condition(id1, 'a')
243  .add_condition(id2, 'b')
244  .predict();
245  constrained_relation_labels.pop();
246  }
247  else
248  {
249  if (isLdf)
250  {
251  int correct_label = 0; // if correct label is not in the set, use the first one
252  for (size_t a = 0; a < constrained_relation_labels.size(); a++)
253  {
254  VW::copy_example_data(false, &my_task_data->ldf_relation[a], ex);
256  true, &my_task_data->ldf_relation[a], 28904713, 4832917 * (uint64_t)(constrained_relation_labels[a]));
257  CS::label& lab = my_task_data->ldf_relation[a].l.cs;
258  lab.costs[0].x = 0.f;
259  lab.costs[0].class_index = constrained_relation_labels[a];
260  lab.costs[0].partial_prediction = 0.f;
261  lab.costs[0].wap_value = 0.f;
262  if (constrained_relation_labels[a] == ex->l.multi.label)
263  {
264  correct_label = (int)a;
265  }
266  }
267  size_t pred_pos = Search::predictor(sch, my_tag)
268  .set_input(my_task_data->ldf_relation, constrained_relation_labels.size())
269  .set_oracle(correct_label)
270  .set_learner_id(2)
271  .predict();
272  prediction = constrained_relation_labels[pred_pos];
273  }
274  else
275  {
276  prediction = Search::predictor(sch, my_tag)
277  .set_input(*ex)
278  .set_oracle(ex->l.multi.label)
279  .set_allowed(constrained_relation_labels)
280  .set_learner_id(1)
281  .predict();
282  }
283  }
284 
285  float loss = 0.0;
286  if (prediction == LABEL_SKIP)
287  {
288  loss = my_task_data->skip_cost;
289  }
290  else if (prediction != ex->l.multi.label)
291  {
292  if (ex->l.multi.label == R_NONE)
293  {
294  loss = my_task_data->relation_none_cost;
295  }
296  else
297  {
298  loss = my_task_data->relation_cost;
299  }
300  }
301  sch.loss(loss);
302  constrained_relation_labels.delete_v();
303  return prediction;
304 }
305 
306 void entity_first_decoding(Search::search& sch, multi_ex& ec, v_array<size_t>& predictions, bool isLdf = false)
307 {
308  // ec.size = #entity + #entity*(#entity-1)/2
309  size_t n_ent = (size_t)(std::sqrt(ec.size() * 8 + 1) - 1) / 2;
310  // Do entity recognition first
311  for (size_t i = 0; i < ec.size(); i++)
312  {
313  if (i < n_ent)
314  predictions[i] = predict_entity(sch, ec[i], predictions, (ptag)i, isLdf);
315  else
316  predictions[i] = predict_relation(sch, ec[i], predictions, (ptag)i, isLdf);
317  }
318 }
319 
321 {
322  // ec.size = #entity + #entity*(#entity-1)/2
323  uint32_t n_ent = (uint32_t)((std::sqrt(ec.size() * 8 + 1) - 1) / 2);
324  for (uint32_t t = 0; t < ec.size(); t++)
325  {
326  // Do entity recognition first
327  uint32_t count = 0;
328  for (ptag i = 0; i < n_ent; i++)
329  {
330  if (count == t)
331  {
332  predictions[i] = predict_entity(sch, ec[i], predictions, i);
333  break;
334  }
335  count++;
336  for (uint32_t j = 0; j < i; j++)
337  {
338  if (count == t)
339  {
340  ptag rel_index = (ptag)(n_ent + (2 * n_ent - j - 1) * j / 2 + i - j - 1);
341  predictions[rel_index] = predict_relation(sch, ec[rel_index], predictions, rel_index);
342  break;
343  }
344  count++;
345  }
346  }
347  }
348 }
349 
351 {
352  task_data* my_task_data = sch.get_task_data<task_data>();
353  // ec.size = #entity + #entity*(#entity-1)/2
354  size_t n_ent = (size_t)(std::sqrt(ec.size() * 8 + 1) - 1) / 2;
355 
356  bool must_predict = false;
357  size_t n_predicts = 0;
358  size_t p_n_predicts = 0;
359  my_task_data->allow_skip = true;
360 
361  // loop until all the entity and relation types are predicted
362  for (ptag t = 0;; t++)
363  {
364  ptag i = t % (uint32_t)ec.size();
365  if (n_predicts == ec.size())
366  break;
367 
368  if (predictions[i] == 0)
369  {
370  if (must_predict)
371  {
372  my_task_data->allow_skip = false;
373  }
374  size_t prediction = 0;
375  if (i < n_ent) // do entity recognition
376  {
377  prediction = predict_entity(sch, ec[i], predictions, i);
378  }
379  else // do relation recognition
380  {
381  prediction = predict_relation(sch, ec[i], predictions, i);
382  }
383 
384  if (prediction != LABEL_SKIP)
385  {
386  predictions[i] = prediction;
387  n_predicts++;
388  }
389 
390  if (must_predict)
391  {
392  my_task_data->allow_skip = true;
393  must_predict = false;
394  }
395  }
396 
397  if (i == ec.size() - 1)
398  {
399  if (n_predicts == p_n_predicts)
400  {
401  must_predict = true;
402  }
403  p_n_predicts = n_predicts;
404  }
405  }
406 }
407 
408 void run(Search::search& sch, multi_ex& ec)
409 {
410  task_data* my_task_data = sch.get_task_data<task_data>();
411 
412  v_array<size_t> predictions = v_init<size_t>();
413  for (size_t i = 0; i < ec.size(); i++)
414  {
415  predictions.push_back(0);
416  }
417 
418  switch (my_task_data->search_order)
419  {
420  case 0:
421  entity_first_decoding(sch, ec, predictions, false);
422  break;
423  case 1:
424  er_mixed_decoding(sch, ec, predictions);
425  break;
426  case 2:
427  er_allow_skip_decoding(sch, ec, predictions);
428  break;
429  case 3:
430  entity_first_decoding(sch, ec, predictions, true); // LDF = true
431  break;
432  default:
433  std::cerr << "search order " << my_task_data->search_order << "is undefined." << std::endl;
434  }
435 
436  for (size_t i = 0; i < ec.size(); i++)
437  {
438  if (sch.output().good())
439  sch.output() << predictions[i] << ' ';
440  }
441  predictions.delete_v();
442 }
443 // this is totally bogus for the example -- you'd never actually do this!
444 void update_example_indicies(bool /* audit */, example* ec, uint64_t mult_amount, uint64_t plus_amount)
445 {
446  for (features& fs : *ec)
447  for (feature_index& idx : fs.indicies) idx = ((idx * mult_amount) + plus_amount);
448 }
449 } // namespace EntityRelationTask
predictor & set_oracle(action a)
Definition: search.cc:3314
v_array< char > tag
Definition: example.h:63
#define R_NONE
T pop()
Definition: v_array.h:58
void entity_first_decoding(Search::search &sch, multi_ex &ec, v_array< size_t > &predictions, bool isLdf=false)
label_parser cs_label
std::stringstream & output()
Definition: search.cc:3043
void copy_example_data(bool audit, example *dst, example *src)
Definition: example.cc:72
void update_example_indicies(bool audit, example *ec, uint64_t mult_amount, uint64_t plus_amount)
Definition: search.cc:33
predictor & add_condition(ptag tag, char name)
Definition: search.cc:3414
std::vector< std::string > * interactions
void dealloc_example(void(*delete_label)(void *), example &ec, void(*delete_prediction)(void *))
Definition: example.cc:219
the core definition of a set of features.
void delete_label(void *v)
Definition: cb.cc:98
virtual void add_and_parse(const option_group_definition &group)=0
action predict()
Definition: search.cc:3460
void finish(vw &all, bool delete_all)
Definition: parse_args.cc:1823
float loss(cbify &data, uint32_t label, uint32_t final_prediction)
Definition: cbify.cc:60
size_t size() const
Definition: v_array.h:68
T * get_task_data()
Definition: search.h:89
example * alloc_examples(size_t, size_t count=1)
Definition: example.cc:204
MULTICLASS::label_t multi
Definition: example.h:29
void run(Search::search &sch, multi_ex &ec)
vw & get_vw_pointer_unsafe()
Definition: search.cc:3115
void push_back(const T &new_ele)
Definition: v_array.h:107
COST_SENSITIVE::label cs
Definition: example.h:30
#define LABEL_SKIP
void er_mixed_decoding(Search::search &sch, multi_ex &ec, v_array< size_t > &predictions)
size_t predict_entity(Search::search &sch, example *ex, v_array< size_t > &, ptag my_tag, bool isLdf=false)
uint32_t IS_LDF
Definition: search.cc:49
void decode_tag(v_array< char > tag, char &type, int &id1, int &id2)
uint64_t feature_index
Definition: feature_group.h:21
vw * initialize(options_i &options, io_buf *model, bool skipModelLoad, trace_message_t trace_listener, void *trace_context)
Definition: parse_args.cc:1654
void set_options(uint32_t opts)
Definition: search.cc:3053
predictor & set_learner_id(size_t id)
Definition: search.cc:3448
predictor & set_allowed(action a)
Definition: search.cc:3352
size_t predict_relation(Search::search &sch, example *ex, v_array< size_t > &predictions, ptag my_tag, bool isLdf=false)
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
void set_task_data(T *data)
Definition: search.h:84
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
std::vector< std::string > interactions
Definition: global_data.h:457
void delete_v()
Definition: v_array.h:98
uint32_t ptag
Definition: search.h:20
v_array< wclass > costs
bool check_constraints(size_t ent1_id, size_t ent2_id, size_t rel_id)
void set_num_learners(size_t num_learners)
Definition: search.cc:3094
predictor & set_input(example &input_example)
Definition: search.cc:3173
float f
Definition: cache.cc:40
void er_allow_skip_decoding(Search::search &sch, multi_ex &ec, v_array< size_t > &predictions)
void loss(float incr_loss)
Definition: search.cc:3039