Vowpal Wabbit
lrqfa.cc
Go to the documentation of this file.
1 #include <string>
2 #include "reductions.h"
3 #include "rand48.h"
4 #include "parse_args.h" // for spoof_hex_encoded_namespaces
5 
6 using namespace LEARNER;
7 using namespace VW::config;
8 
9 struct LRQFAstate
10 {
11  vw* all;
12  std::string field_name;
13  int k;
14  int field_id[256];
15  size_t orig_size[256];
16 };
17 
18 inline float cheesyrand(uint64_t x)
19 {
20  uint64_t seed = x;
21 
22  return merand48(seed);
23 }
24 
25 constexpr inline bool example_is_test(example& ec) { return ec.l.simple.label == FLT_MAX; }
26 
27 template <bool is_learn>
29 {
30  vw& all = *lrq.all;
31 
32  memset(lrq.orig_size, 0, sizeof(lrq.orig_size));
33  for (namespace_index i : ec.indices) lrq.orig_size[i] = ec.feature_space[i].size();
34 
35  size_t which = ec.example_counter;
36  float first_prediction = 0;
37  float first_loss = 0;
38  unsigned int maxiter = (is_learn && !example_is_test(ec)) ? 2 : 1;
39  unsigned int k = lrq.k;
40  float sqrtk = (float)std::sqrt(k);
41 
42  uint32_t stride_shift = lrq.all->weights.stride_shift();
43  uint64_t weight_mask = lrq.all->weights.mask();
44  for (unsigned int iter = 0; iter < maxiter; ++iter, ++which)
45  {
46  // Add left LRQ features, holding right LRQ features fixed
47  // and vice versa
48 
49  for (std::string::const_iterator i1 = lrq.field_name.begin(); i1 != lrq.field_name.end(); ++i1)
50  {
51  for (std::string::const_iterator i2 = i1 + 1; i2 != lrq.field_name.end(); ++i2)
52  {
53  unsigned char left = (which % 2) ? *i1 : *i2;
54  unsigned char right = ((which + 1) % 2) ? *i1 : *i2;
55  unsigned int lfd_id = lrq.field_id[left];
56  unsigned int rfd_id = lrq.field_id[right];
57  for (unsigned int lfn = 0; lfn < lrq.orig_size[left]; ++lfn)
58  {
59  features& fs = ec.feature_space[left];
60  float lfx = fs.values[lfn];
61  uint64_t lindex = fs.indicies[lfn];
62  for (unsigned int n = 1; n <= k; ++n)
63  {
64  uint64_t lwindex =
65  (lindex + ((uint64_t)(rfd_id * k + n) << stride_shift)); // a feature has k weights in each field
66  float* lw = &all.weights[lwindex & weight_mask];
67  // perturb away from saddle point at (0, 0)
68  if (is_learn && !example_is_test(ec) && *lw == 0)
69  *lw = cheesyrand(lwindex) * 0.5f / sqrtk;
70 
71  for (unsigned int rfn = 0; rfn < lrq.orig_size[right]; ++rfn)
72  {
73  features& rfs = ec.feature_space[right];
74  // feature* rf = ec.atomics[right].begin + rfn;
75  // NB: ec.ft_offset added by base learner
76  float rfx = rfs.values[rfn];
77  uint64_t rindex = rfs.indicies[rfn];
78  uint64_t rwindex = (rindex + ((uint64_t)(lfd_id * k + n) << stride_shift));
79 
80  rfs.push_back(*lw * lfx * rfx, rwindex);
81  if (all.audit || all.hash_inv)
82  {
83  std::stringstream new_feature_buffer;
84  new_feature_buffer << right << '^' << rfs.space_names[rfn].get()->second << '^' << n;
85 #ifdef _WIN32
86  char* new_space = _strdup("lrqfa");
87  char* new_feature = _strdup(new_feature_buffer.str().c_str());
88 #else
89  char* new_space = strdup("lrqfa");
90  char* new_feature = strdup(new_feature_buffer.str().c_str());
91 #endif
92  rfs.space_names.push_back(audit_strings_ptr(new audit_strings(new_space, new_feature)));
93  }
94  }
95  }
96  }
97  }
98  }
99 
100  if (is_learn)
101  base.learn(ec);
102  else
103  base.predict(ec);
104 
105  // Restore example
106  if (iter == 0)
107  {
108  first_prediction = ec.pred.scalar;
109  first_loss = ec.loss;
110  }
111  else
112  {
113  ec.pred.scalar = first_prediction;
114  ec.loss = first_loss;
115  }
116 
117  for (char i : lrq.field_name)
118  {
119  namespace_index right = i;
120  features& rfs = ec.feature_space[right];
121  rfs.values.end() = rfs.values.begin() + lrq.orig_size[right];
122 
123  if (all.audit || all.hash_inv)
124  {
125  for (size_t j = lrq.orig_size[right]; j < rfs.space_names.size(); ++j) rfs.space_names[j].~audit_strings_ptr();
126 
127  rfs.space_names.end() = rfs.space_names.begin() + lrq.orig_size[right];
128  }
129  }
130  }
131 }
132 
134 {
135  std::string lrqfa;
136  option_group_definition new_options("Low Rank Quadratics FA");
137  new_options.add(make_option("lrqfa", lrqfa).keep().help("use low rank quadratic features with field aware weights"));
138  options.add_and_parse(new_options);
139 
140  if (!options.was_supplied("lrqfa"))
141  return nullptr;
142 
143  auto lrq = scoped_calloc_or_throw<LRQFAstate>();
144  lrq->all = &all;
145 
146  std::string lrqopt = spoof_hex_encoded_namespaces(lrqfa);
147  size_t last_index = lrqopt.find_last_not_of("0123456789");
148  new (&lrq->field_name) std::string(lrqopt.substr(0, last_index + 1)); // make sure there is no duplicates
149  lrq->k = atoi(lrqopt.substr(last_index + 1).c_str());
150 
151  int fd_id = 0;
152  for (char i : lrq->field_name) lrq->field_id[(int)i] = fd_id++;
153 
154  all.wpp = all.wpp * (uint64_t)(1 + lrq->k);
155  learner<LRQFAstate, example>& l = init_learner(lrq, as_singleline(setup_base(options, all)), predict_or_learn<true>,
156  predict_or_learn<false>, 1 + lrq->field_name.size() * lrq->k);
157 
158  return make_base(l);
159 }
v_array< namespace_index > indices
size_t example_counter
Definition: example.h:64
parameters weights
Definition: global_data.h:537
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
float scalar
Definition: example.h:45
std::shared_ptr< audit_strings > audit_strings_ptr
Definition: feature_group.h:23
bool hash_inv
Definition: global_data.h:541
v_array< feature_index > indicies
vw * all
Definition: lrqfa.cc:11
void predict_or_learn(LRQFAstate &lrq, single_learner &base, example &ec)
Definition: lrqfa.cc:28
LEARNER::base_learner * lrqfa_setup(options_i &options, vw &all)
Definition: lrqfa.cc:133
the core definition of a set of features.
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
constexpr bool example_is_test(example &ec)
Definition: lrqfa.cc:25
v_array< feature_value > values
virtual void add_and_parse(const option_group_definition &group)=0
std::string spoof_hex_encoded_namespaces(const std::string &arg)
Definition: parse_args.cc:568
float label
Definition: simple_label.h:14
float merand48(uint64_t &initial)
Definition: rand48.cc:16
label_data simple
Definition: example.h:28
T *& begin()
Definition: v_array.h:42
int field_id[256]
Definition: lrqfa.cc:14
size_t size() const
Definition: v_array.h:68
size_t orig_size[256]
Definition: lrqfa.cc:15
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
virtual bool was_supplied(const std::string &key)=0
unsigned char namespace_index
uint32_t wpp
Definition: global_data.h:432
T *& end()
Definition: v_array.h:43
float loss
Definition: example.h:70
option_group_definition & add(T &&op)
Definition: options.h:90
std::string field_name
Definition: lrqfa.cc:12
v_array< audit_strings_ptr > space_names
polylabel l
Definition: example.h:57
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t stride_shift()
float cheesyrand(uint64_t x)
Definition: lrqfa.cc:18
bool audit
Definition: global_data.h:486
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
uint64_t mask()
int k
Definition: lrqfa.cc:13
std::pair< std::string, std::string > audit_strings
Definition: feature_group.h:22