Vowpal Wabbit
Classes | Functions
lrqfa.cc File Reference
#include <string>
#include "reductions.h"
#include "rand48.h"
#include "parse_args.h"

Go to the source code of this file.

Classes

struct  LRQFAstate
 

Functions

float cheesyrand (uint64_t x)
 
constexpr bool example_is_test (example &ec)
 
template<bool is_learn>
void predict_or_learn (LRQFAstate &lrq, single_learner &base, example &ec)
 
LEARNER::base_learnerlrqfa_setup (options_i &options, vw &all)
 

Function Documentation

◆ cheesyrand()

float cheesyrand ( uint64_t  x)
inline

Definition at line 18 of file lrqfa.cc.

References merand48().

Referenced by predict_or_learn().

19 {
20  uint64_t seed = x;
21 
22  return merand48(seed);
23 }
float merand48(uint64_t &initial)
Definition: rand48.cc:16

◆ example_is_test()

constexpr bool example_is_test ( example ec)
inline

Definition at line 25 of file lrqfa.cc.

References example::l, label_data::label, and polylabel::simple.

Referenced by predict_or_learn().

25 { return ec.l.simple.label == FLT_MAX; }
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
polylabel l
Definition: example.h:57

◆ lrqfa_setup()

LEARNER::base_learner* lrqfa_setup ( options_i options,
vw all 
)

Definition at line 133 of file lrqfa.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), setup_base(), spoof_hex_encoded_namespaces(), VW::config::options_i::was_supplied(), and vw::wpp.

Referenced by parse_reductions().

134 {
135  std::string lrqfa;
136  option_group_definition new_options("Low Rank Quadratics FA");
137  new_options.add(make_option("lrqfa", lrqfa).keep().help("use low rank quadratic features with field aware weights"));
138  options.add_and_parse(new_options);
139 
140  if (!options.was_supplied("lrqfa"))
141  return nullptr;
142 
143  auto lrq = scoped_calloc_or_throw<LRQFAstate>();
144  lrq->all = &all;
145 
146  std::string lrqopt = spoof_hex_encoded_namespaces(lrqfa);
147  size_t last_index = lrqopt.find_last_not_of("0123456789");
148  new (&lrq->field_name) std::string(lrqopt.substr(0, last_index + 1)); // make sure there is no duplicates
149  lrq->k = atoi(lrqopt.substr(last_index + 1).c_str());
150 
151  int fd_id = 0;
152  for (char i : lrq->field_name) lrq->field_id[(int)i] = fd_id++;
153 
154  all.wpp = all.wpp * (uint64_t)(1 + lrq->k);
155  learner<LRQFAstate, example>& l = init_learner(lrq, as_singleline(setup_base(options, all)), predict_or_learn<true>,
156  predict_or_learn<false>, 1 + lrq->field_name.size() * lrq->k);
157 
158  return make_base(l);
159 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::string spoof_hex_encoded_namespaces(const std::string &arg)
Definition: parse_args.cc:568
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
uint32_t wpp
Definition: global_data.h:432
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222

◆ predict_or_learn()

template<bool is_learn>
void predict_or_learn ( LRQFAstate lrq,
single_learner base,
example ec 
)

Definition at line 28 of file lrqfa.cc.

References LRQFAstate::all, vw::audit, v_array< T >::begin(), cheesyrand(), v_array< T >::end(), example::example_counter, example_is_test(), example_predict::feature_space, LRQFAstate::field_id, LRQFAstate::field_name, vw::hash_inv, example_predict::indices, features::indicies, LRQFAstate::k, LEARNER::learner< T, E >::learn(), example::loss, parameters::mask(), LRQFAstate::orig_size, example::pred, LEARNER::learner< T, E >::predict(), v_array< T >::push_back(), features::push_back(), polyprediction::scalar, v_array< T >::size(), features::space_names, stride_shift(), parameters::stride_shift(), features::values, and vw::weights.

29 {
30  vw& all = *lrq.all;
31 
32  memset(lrq.orig_size, 0, sizeof(lrq.orig_size));
33  for (namespace_index i : ec.indices) lrq.orig_size[i] = ec.feature_space[i].size();
34 
35  size_t which = ec.example_counter;
36  float first_prediction = 0;
37  float first_loss = 0;
38  unsigned int maxiter = (is_learn && !example_is_test(ec)) ? 2 : 1;
39  unsigned int k = lrq.k;
40  float sqrtk = (float)std::sqrt(k);
41 
42  uint32_t stride_shift = lrq.all->weights.stride_shift();
43  uint64_t weight_mask = lrq.all->weights.mask();
44  for (unsigned int iter = 0; iter < maxiter; ++iter, ++which)
45  {
46  // Add left LRQ features, holding right LRQ features fixed
47  // and vice versa
48 
49  for (std::string::const_iterator i1 = lrq.field_name.begin(); i1 != lrq.field_name.end(); ++i1)
50  {
51  for (std::string::const_iterator i2 = i1 + 1; i2 != lrq.field_name.end(); ++i2)
52  {
53  unsigned char left = (which % 2) ? *i1 : *i2;
54  unsigned char right = ((which + 1) % 2) ? *i1 : *i2;
55  unsigned int lfd_id = lrq.field_id[left];
56  unsigned int rfd_id = lrq.field_id[right];
57  for (unsigned int lfn = 0; lfn < lrq.orig_size[left]; ++lfn)
58  {
59  features& fs = ec.feature_space[left];
60  float lfx = fs.values[lfn];
61  uint64_t lindex = fs.indicies[lfn];
62  for (unsigned int n = 1; n <= k; ++n)
63  {
64  uint64_t lwindex =
65  (lindex + ((uint64_t)(rfd_id * k + n) << stride_shift)); // a feature has k weights in each field
66  float* lw = &all.weights[lwindex & weight_mask];
67  // perturb away from saddle point at (0, 0)
68  if (is_learn && !example_is_test(ec) && *lw == 0)
69  *lw = cheesyrand(lwindex) * 0.5f / sqrtk;
70 
71  for (unsigned int rfn = 0; rfn < lrq.orig_size[right]; ++rfn)
72  {
73  features& rfs = ec.feature_space[right];
74  // feature* rf = ec.atomics[right].begin + rfn;
75  // NB: ec.ft_offset added by base learner
76  float rfx = rfs.values[rfn];
77  uint64_t rindex = rfs.indicies[rfn];
78  uint64_t rwindex = (rindex + ((uint64_t)(lfd_id * k + n) << stride_shift));
79 
80  rfs.push_back(*lw * lfx * rfx, rwindex);
81  if (all.audit || all.hash_inv)
82  {
83  std::stringstream new_feature_buffer;
84  new_feature_buffer << right << '^' << rfs.space_names[rfn].get()->second << '^' << n;
85 #ifdef _WIN32
86  char* new_space = _strdup("lrqfa");
87  char* new_feature = _strdup(new_feature_buffer.str().c_str());
88 #else
89  char* new_space = strdup("lrqfa");
90  char* new_feature = strdup(new_feature_buffer.str().c_str());
91 #endif
92  rfs.space_names.push_back(audit_strings_ptr(new audit_strings(new_space, new_feature)));
93  }
94  }
95  }
96  }
97  }
98  }
99 
100  if (is_learn)
101  base.learn(ec);
102  else
103  base.predict(ec);
104 
105  // Restore example
106  if (iter == 0)
107  {
108  first_prediction = ec.pred.scalar;
109  first_loss = ec.loss;
110  }
111  else
112  {
113  ec.pred.scalar = first_prediction;
114  ec.loss = first_loss;
115  }
116 
117  for (char i : lrq.field_name)
118  {
119  namespace_index right = i;
120  features& rfs = ec.feature_space[right];
121  rfs.values.end() = rfs.values.begin() + lrq.orig_size[right];
122 
123  if (all.audit || all.hash_inv)
124  {
125  for (size_t j = lrq.orig_size[right]; j < rfs.space_names.size(); ++j) rfs.space_names[j].~audit_strings_ptr();
126 
127  rfs.space_names.end() = rfs.space_names.begin() + lrq.orig_size[right];
128  }
129  }
130  }
131 }
v_array< namespace_index > indices
size_t example_counter
Definition: example.h:64
parameters weights
Definition: global_data.h:537
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
float scalar
Definition: example.h:45
std::shared_ptr< audit_strings > audit_strings_ptr
Definition: feature_group.h:23
bool hash_inv
Definition: global_data.h:541
v_array< feature_index > indicies
vw * all
Definition: lrqfa.cc:11
the core definition of a set of features.
constexpr bool example_is_test(example &ec)
Definition: lrqfa.cc:25
v_array< feature_value > values
T *& begin()
Definition: v_array.h:42
int field_id[256]
Definition: lrqfa.cc:14
size_t size() const
Definition: v_array.h:68
size_t orig_size[256]
Definition: lrqfa.cc:15
std::array< features, NUM_NAMESPACES > feature_space
void push_back(const T &new_ele)
Definition: v_array.h:107
unsigned char namespace_index
T *& end()
Definition: v_array.h:43
float loss
Definition: example.h:70
std::string field_name
Definition: lrqfa.cc:12
v_array< audit_strings_ptr > space_names
uint32_t stride_shift()
float cheesyrand(uint64_t x)
Definition: lrqfa.cc:18
bool audit
Definition: global_data.h:486
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
uint64_t mask()
int k
Definition: lrqfa.cc:13
std::pair< std::string, std::string > audit_strings
Definition: feature_group.h:22