Vowpal Wabbit
beam.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5 */
6 #pragma once
7 
8 #include <cstdio>
9 #include <cfloat>
10 #include <cstdlib>
11 #include "v_array.h"
12 
13 // TODO: special case the version where beam_size == 1
14 // TODO: *maybe* special case the version where beam_size <= 10
15 
16 #define BEAM_CONSTANT_SIZE 0
17 
18 namespace Beam
19 {
20 template <class T>
22 {
23  uint32_t hash; // a cached hash value -- if a ~= b then h(a) must== h(b)
24  float cost; // cost of this element
25  T *data; // pointer to element data -- rarely accessed!
26  bool active; // is this currently active
27  // bool recombined; // if we're not the BEST then we've been recombined
28  // v_array<T*> * recomb_friends; // if we're the BEST (among ~= elements), then recomb_friends is everything that's
29  // equivalent to us but worse... NOT USED if we're not doing k-best predictions
30 };
31 
32 inline int compare_on_cost(const void *void_a, const void *void_b)
33 {
34  if (void_a == void_b)
35  return 0;
36  const beam_element<void> *a = (const beam_element<void> *)void_a;
37  const beam_element<void> *b = (const beam_element<void> *)void_b;
38  if (a->active && !b->active)
39  return -1; // active things come before inactive things
40  else if (!a->active && b->active)
41  return 1;
42  else if (!a->active && !b->active)
43  return 0;
44  else if (a->cost < b->cost)
45  return -1; // otherwise sort by cost
46  else if (a->cost > b->cost)
47  return 1;
48  else
49  return 0;
50 }
51 
52 inline int compare_on_hash_then_cost(const void *void_a, const void *void_b)
53 {
54  if (void_a == void_b)
55  return 0;
56  const beam_element<void> *a = (const beam_element<void> *)void_a;
57  const beam_element<void> *b = (const beam_element<void> *)void_b;
58  if (a->active && !b->active)
59  return -1; // active things come before inactive things
60  else if (!a->active && b->active)
61  return 1;
62  else if (!a->active && !b->active)
63  return 0;
64  else if (a->hash < b->hash)
65  return -1; // if the hashes are different, sort by hash
66  else if (a->hash > b->hash)
67  return 1;
68  else if (a->cost < b->cost)
69  return -1; // otherwise sort by cost
70  else if (a->cost > b->cost)
71  return 1;
72  else
73  return 0;
74 }
75 
76 template <class T>
77 class beam
78 {
79  private:
80  size_t beam_size; // the beam size -- how many active elements can we have
81  size_t count; // how many elements do we have currently -- should be == to A.size()
82  float pruning_coefficient; // prune anything with cost >= pruning_coefficient * best, set to FLT_MAX to not do
83  // coefficient-based pruning
84  float worst_cost; // what is the cost of the worst (highest cost) item in the beam
85  float best_cost; // what is the cost of the best (lowest cost) item in the beam
86  float prune_if_gt; // prune any element with cost greater than this
87  T *best_cost_data; // easy access to best-cost item
88  bool do_kbest;
89  v_array<beam_element<T>> A; // the actual data
90  // v_array<v_array<beam_element<T>*>> recomb_buckets;
91 
92  // static size_t NUM_RECOMB_BUCKETS = 10231;
93 
94  bool (*is_equivalent)(T *, T *); // test if two items are equivalent; nullptr means don't do hypothesis recombination
95 
96  public:
97  beam(size_t beam_size, float prune_coeff = FLT_MAX, bool (*test_equiv)(T *, T *) = nullptr, bool kbest = false)
98  : beam_size(beam_size), pruning_coefficient(prune_coeff), do_kbest(kbest), is_equivalent(test_equiv)
99  {
100  count = 0;
101  worst_cost = -FLT_MAX;
102  best_cost = FLT_MAX;
103  prune_if_gt = FLT_MAX;
104  best_cost_data = nullptr;
105  A = v_init<beam_element<T>>();
106  if (beam_size <= BEAM_CONSTANT_SIZE)
107  A.resize(beam_size, true);
108  else
109  A.resize((beam_size + 1) * 4, true);
110  if (beam_size == 1)
111  do_kbest = false; // automatically turn of kbest
112  }
113 
114  inline bool might_insert(float cost) { return (cost <= prune_if_gt) && ((count < beam_size) || (cost < worst_cost)); }
115 
116  bool insert(T *data, float cost, uint32_t hash) // returns TRUE iff element was actually added
117  {
118  if (!might_insert(cost))
119  return false;
120 
121  // bool we_were_worse = false;
122  // if (is_equivalent) {
123  // size_t mod = recomb_buckets.size();
124  // size_t id = hash % mod;
125  // size_t equiv_pos = bucket_contains_equiv(recomb_buckets[i], data, hash);
126  // if (equiv_pos != (size_t) -1) { // we can recombing at equiv_pos
127  // if (cost >= recomb_buckets[i][equiv_pos].cost) {
128  // // we are more expensive, so ignore
129  // we_were_worse = true;
130  // beam_element<T> * be = new beam_element<T>;
131  // be->hash = hash; be->cost = cost; be->data = data; be->active = true; be->recombined = false;
132  // be->recomb_friends = nullptr; add_recomb_friend(recomb_buckets[i][equiv_pos], be);
133  // }
134  // }
135 
136  if (beam_size < BEAM_CONSTANT_SIZE)
137  { // find the worst item and directly replace it
138  size_t worst_idx = 0;
139  float worst_idx_cost = A[0].cost;
140  for (size_t i = 1; i < beam_size; i++)
141  if (A[i].cost > worst_idx_cost)
142  {
143  worst_idx = i;
144  worst_idx_cost = A[i].cost;
145  if (worst_idx_cost <= worst_cost)
146  break;
147  }
148  if (cost >= worst_idx_cost)
149  return false;
150 
151  A[worst_idx].hash = hash;
152  A[worst_idx].cost = cost;
153  A[worst_idx].data = data;
154  A[worst_idx].active = true;
155  // A[worst_idx].recombined = false;
156  // A[worst_idx].recomb_friends = nullptr; // TODO: free it if it isn't nullptr
157  worst_cost = cost;
158  }
159  else
160  {
161  beam_element<T> be;
162  be.hash = hash;
163  be.cost = cost;
164  be.data = data;
165  be.active = true;
166  // be.recombined = false;
167  // be.recomb_friends = nullptr;
168 
169  A.push_back(be);
170  count++;
171  }
172 
173  if (cost < best_cost)
174  {
175  best_cost = cost;
176  best_cost_data = data;
177  }
178  if (cost > worst_cost)
179  {
180  worst_cost = cost;
181  prune_if_gt = std::max(1.f, best_cost) * pruning_coefficient;
182  }
183  return true;
184  }
185 
187  {
188  if (count == 0)
189  return nullptr;
190  beam_element<T> *ret = A.begin;
191  while ((ret != A.end) && (!ret->active)) ++ret;
192  return (ret == A.end) ? nullptr : ret;
193  }
194 
196  {
197  if (count == 0)
198  return nullptr;
199 
200  beam_element<T> *ret = nullptr;
201  float next_best_cost = FLT_MAX;
202  for (beam_element<T> *el = A.begin; el != A.end; el++)
203  if ((ret == nullptr) && el->active && (el->cost <= best_cost))
204  ret = el;
205  else if (el->active && (el->cost < next_best_cost))
206  {
207  next_best_cost = el->cost;
208  best_cost_data = el->data;
209  }
210 
211  if (ret != nullptr)
212  {
213  best_cost = next_best_cost;
214  prune_if_gt = std::max(1.f, best_cost) * pruning_coefficient;
215  ret->active = false;
216  count--;
217  }
218  else
219  {
220  best_cost = FLT_MAX;
221  prune_if_gt = FLT_MAX;
222  best_cost_data = nullptr;
223  }
224 
225  return ret;
226  }
227 
229  {
230  qsort(A.begin, A.size(), sizeof(beam_element<T>), compare_on_hash_then_cost);
231  size_t start = 0;
232  while (start < A.size() - 1)
233  {
234  size_t end = start + 1;
235  for (; (end < A.size()) && (A[start].hash == A[end].hash); end++)
236  ;
237  assert(start < A.size());
238  assert(end <= A.size());
239  // std::cerr << "start=" << start << " end=" << end << std::endl;
240  // go over all pairs
241  for (size_t i = start; i < end; i++)
242  {
243  if (!A[i].active)
244  continue;
245  assert(i < A.size());
246  for (size_t j = i + 1; j < end; j++)
247  {
248  if (!A[j].active)
249  continue;
250  assert(j < A.size());
251  // std::cerr << "te " << i << "," << j << std::endl;
252  if (is_equivalent(A[i].data, A[j].data))
253  {
254  A[j].active = false; // TODO: if kbest is on, do recomb_friends
255  // std::cerr << "equivalent " << i << "," << j << ": " << ((size_t)A[i].data) << " and " <<
256  // ((size_t)A[j].data)
257  // << std::endl;
258  }
259  }
260  }
261  start = end;
262  }
263  }
264 
265  void compact(void (*free_data)(T *) = nullptr)
266  {
267  if (is_equivalent)
268  do_recombination();
269  qsort(A.begin, A.size(), sizeof(beam_element<T>), compare_on_cost); // TODO: quick select
270 
271  if (count <= beam_size)
272  return;
273 
274  count = beam_size;
275  if (is_equivalent) // we might be able to get rid of even more
276  while ((count > 1) && !A[count - 1].active) count--;
277 
278  if (free_data)
279  for (beam_element<T> *be = A.begin + count; be != A.end; ++be) free_data(be->data);
280 
281  A.end = A.begin + count;
282 
283  best_cost = A[0].cost;
284  worst_cost = A[count - 1].cost;
285  prune_if_gt = std::max(1.f, best_cost) * pruning_coefficient;
286  best_cost_data = A[0].data;
287  }
288 
289  void maybe_compact(void (*free_data)(T *) = nullptr)
290  {
291  if (count >= beam_size * 10)
292  compact(free_data);
293  }
294 
295  void erase(void (*free_data)(T *) = nullptr)
296  {
297  if (free_data)
298  for (beam_element<T> *be = A.begin; be != A.end; ++be) free_data(be->data);
299  A.erase();
300  count = 0;
301  worst_cost = -FLT_MAX;
302  best_cost = FLT_MAX;
303  prune_if_gt = FLT_MAX;
304  best_cost_data = nullptr;
305  }
306 
308  {
309  assert(A.size() == 0);
310  A.delete_v();
311  }
312 
313  beam_element<T> *begin() { return A.begin; }
314  beam_element<T> *end() { return A.end; }
315  size_t size() { return count; }
316  bool empty() { return A.empty(); }
317  size_t get_beam_size() { return beam_size; }
318 
319  private:
320  // void add_recomb_friend(beam_element<T> *better, beam_element<T> *worse) {
321  // assert( better->cost <= worse->cost );
322  // if (better->recomb_friends == nullptr) {
323  // if (worse->recomb_friends != nullptr) {
324  // better->recomb_friends = worse->recomb_friends;
325  // worse->recomb_friends = nullptr;
326  // } else
327  // better->recomb_friends = new std::vector<beam_element<T>*>;
328  // } else {
329  // assert(worse->recomb_friends == nullptr);
330  // }
331  // }
332 };
333 
334 } // namespace Beam
void compact(void(*free_data)(T *)=nullptr)
Definition: beam.h:265
void resize(size_t length)
Definition: v_array.h:69
~beam()
Definition: beam.h:307
uint32_t hash
Definition: beam.h:23
Definition: beam.h:18
void maybe_compact(void(*free_data)(T *)=nullptr)
Definition: beam.h:289
int compare_on_cost(const void *void_a, const void *void_b)
Definition: beam.h:32
#define BEAM_CONSTANT_SIZE
Definition: beam.h:16
v_array< beam_element< T > > A
Definition: beam.h:89
size_t beam_size
Definition: beam.h:80
float pruning_coefficient
Definition: beam.h:82
Definition: active.h:6
beam_element< T > * begin()
Definition: beam.h:313
size_t count
Definition: beam.h:81
bool do_kbest
Definition: beam.h:88
beam(size_t beam_size, float prune_coeff=FLT_MAX, bool(*test_equiv)(T *, T *)=nullptr, bool kbest=false)
Definition: beam.h:97
beam_element< T > * pop_best_item()
Definition: beam.h:195
void do_recombination()
Definition: beam.h:228
float prune_if_gt
Definition: beam.h:86
bool might_insert(float cost)
Definition: beam.h:114
float best_cost
Definition: beam.h:85
T *& begin()
Definition: v_array.h:42
size_t size() const
Definition: v_array.h:68
size_t size()
Definition: beam.h:315
void erase(void(*free_data)(T *)=nullptr)
Definition: beam.h:295
T * best_cost_data
Definition: beam.h:87
int compare_on_hash_then_cost(const void *void_a, const void *void_b)
Definition: beam.h:52
void push_back(const T &new_ele)
Definition: v_array.h:107
bool empty()
Definition: beam.h:316
T *& end()
Definition: v_array.h:43
constexpr uint64_t a
Definition: rand48.cc:11
bool empty() const
Definition: v_array.h:59
size_t get_beam_size()
Definition: beam.h:317
void delete_v()
Definition: v_array.h:98
beam_element< T > * end()
Definition: beam.h:314
float cost
Definition: beam.h:24
bool insert(T *data, float cost, uint32_t hash)
Definition: beam.h:116
beam_element< T > * get_best_item()
Definition: beam.h:186
float worst_cost
Definition: beam.h:84
float f
Definition: cache.cc:40