Vowpal Wabbit
Classes | Functions
lrq.cc File Reference
#include <cstring>
#include <cfloat>
#include "reductions.h"
#include "rand48.h"
#include "vw_exception.h"
#include "parse_args.h"

Go to the source code of this file.

Classes

struct  LRQstate
 

Functions

bool valid_int (const char *s)
 
bool cheesyrbit (uint64_t &seed)
 
float cheesyrand (uint64_t x)
 
constexpr bool example_is_test (example &ec)
 
void reset_seed (LRQstate &lrq)
 
template<bool is_learn>
void predict_or_learn (LRQstate &lrq, single_learner &base, example &ec)
 
base_learnerlrq_setup (options_i &options, vw &all)
 

Function Documentation

◆ cheesyrand()

float cheesyrand ( uint64_t  x)
inline

Definition at line 34 of file lrq.cc.

References merand48().

Referenced by predict_or_learn().

35 {
36  uint64_t seed = x;
37 
38  return merand48(seed);
39 }
float merand48(uint64_t &initial)
Definition: rand48.cc:16

◆ cheesyrbit()

bool cheesyrbit ( uint64_t &  seed)
inline

Definition at line 32 of file lrq.cc.

References merand48().

Referenced by predict_or_learn().

32 { return merand48(seed) > 0.5; }
float merand48(uint64_t &initial)
Definition: rand48.cc:16

◆ example_is_test()

constexpr bool example_is_test ( example ec)
inline

Definition at line 41 of file lrq.cc.

References example::l, label_data::label, and polylabel::simple.

Referenced by predict_or_learn().

41 { return ec.l.simple.label == FLT_MAX; }
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
polylabel l
Definition: example.h:57

◆ lrq_setup()

base_learner* lrq_setup ( options_i options,
vw all 
)

Definition at line 159 of file lrq.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), vw::quiet, vw::random_seed, reset_seed(), LEARNER::learner< T, E >::set_end_pass(), setup_base(), spoof_hex_encoded_namespaces(), THROW, vw::trace_message, valid_int(), VW::config::options_i::was_supplied(), and vw::wpp.

Referenced by parse_reductions().

160 {
161  auto lrq = scoped_calloc_or_throw<LRQstate>();
162  std::vector<std::string> lrq_names;
163  option_group_definition new_options("Low Rank Quadratics");
164  new_options.add(make_option("lrq", lrq_names).keep().help("use low rank quadratic features"))
165  .add(make_option("lrqdropout", lrq->dropout).keep().help("use dropout training for low rank quadratic features"));
166  options.add_and_parse(new_options);
167 
168  if (!options.was_supplied("lrq"))
169  return nullptr;
170 
171  uint32_t maxk = 0;
172  lrq->all = &all;
173 
174  for (auto& lrq_name : lrq_names) lrq_name = spoof_hex_encoded_namespaces(lrq_name);
175 
176  new (&lrq->lrpairs) std::set<std::string>(lrq_names.begin(), lrq_names.end());
177 
178  lrq->initial_seed = lrq->seed = all.random_seed | 8675309;
179 
180  if (!all.quiet)
181  {
182  all.trace_message << "creating low rank quadratic features for pairs: ";
183  if (lrq->dropout)
184  all.trace_message << "(using dropout) ";
185  }
186 
187  for (std::string const& i : lrq->lrpairs)
188  {
189  if (!all.quiet)
190  {
191  if ((i.length() < 3) || !valid_int(i.c_str() + 2))
192  THROW("error, low-rank quadratic features must involve two sets and a rank.");
193 
194  all.trace_message << i << " ";
195  }
196  // TODO: colon-syntax
197 
198  unsigned int k = atoi(i.c_str() + 2);
199 
200  lrq->lrindices[(int)i[0]] = true;
201  lrq->lrindices[(int)i[1]] = true;
202 
203  maxk = std::max(k, k);
204  }
205 
206  if (!all.quiet)
207  all.trace_message << std::endl;
208 
209  all.wpp = all.wpp * (uint64_t)(1 + maxk);
211  lrq, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>, 1 + maxk);
213 
214  // TODO: leaks memory ?
215  return make_base(l);
216 }
uint64_t random_seed
Definition: global_data.h:491
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
std::string spoof_hex_encoded_namespaces(const std::string &arg)
Definition: parse_args.cc:568
bool valid_int(const char *s)
Definition: lrq.cc:22
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
void reset_seed(LRQstate &lrq)
Definition: lrq.cc:43
uint32_t wpp
Definition: global_data.h:432
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181

◆ predict_or_learn()

template<bool is_learn>
void predict_or_learn ( LRQstate lrq,
single_learner base,
example ec 
)

Definition at line 50 of file lrq.cc.

References LRQstate::all, vw::audit, cheesyrand(), cheesyrbit(), example::confidence, LRQstate::dropout, example::example_counter, example_is_test(), f, example_predict::feature_space, example_predict::ft_offset, vw::hash_inv, example_predict::indices, features::indicies, LEARNER::learner< T, E >::learn(), example::loss, LRQstate::lrindices, LRQstate::lrpairs, LRQstate::orig_size, example::pred, LEARNER::learner< T, E >::predict(), v_array< T >::push_back(), features::push_back(), polyprediction::scalar, LRQstate::seed, features::space_names, stride_shift(), parameters::stride_shift(), features::values, and vw::weights.

51 {
52  vw& all = *lrq.all;
53 
54  // Remember original features
55 
56  memset(lrq.orig_size, 0, sizeof(lrq.orig_size));
57  for (namespace_index i : ec.indices)
58  {
59  if (lrq.lrindices[i])
60  lrq.orig_size[i] = ec.feature_space[i].size();
61  }
62 
63  size_t which = ec.example_counter;
64  float first_prediction = 0;
65  float first_loss = 0;
66  float first_uncertainty = 0;
67  unsigned int maxiter = (is_learn && !example_is_test(ec)) ? 2 : 1;
68 
69  bool do_dropout = lrq.dropout && is_learn && !example_is_test(ec);
70  float scale = (!lrq.dropout || do_dropout) ? 1.f : 0.5f;
71 
72  uint32_t stride_shift = lrq.all->weights.stride_shift();
73  for (unsigned int iter = 0; iter < maxiter; ++iter, ++which)
74  {
75  // Add left LRQ features, holding right LRQ features fixed
76  // and vice versa
77  // TODO: what happens with --lrq ab2 --lrq ac2
78  // i.e. namespace occurs multiple times (?)
79 
80  for (std::string const& i : lrq.lrpairs)
81  {
82  unsigned char left = i[which % 2];
83  unsigned char right = i[(which + 1) % 2];
84  unsigned int k = atoi(i.c_str() + 2);
85 
86  features& left_fs = ec.feature_space[left];
87  for (unsigned int lfn = 0; lfn < lrq.orig_size[left]; ++lfn)
88  {
89  float lfx = left_fs.values[lfn];
90  uint64_t lindex = left_fs.indicies[lfn] + ec.ft_offset;
91  for (unsigned int n = 1; n <= k; ++n)
92  {
93  if (!do_dropout || cheesyrbit(lrq.seed))
94  {
95  uint64_t lwindex = (lindex + ((uint64_t)n << stride_shift));
96  weight* lw = &lrq.all->weights[lwindex];
97 
98  // perturb away from saddle point at (0, 0)
99  if (is_learn && !example_is_test(ec) && *lw == 0)
100  *lw = cheesyrand(lwindex); // not sure if lw needs a weight mask?
101 
102  features& right_fs = ec.feature_space[right];
103  for (unsigned int rfn = 0; rfn < lrq.orig_size[right]; ++rfn)
104  {
105  // NB: ec.ft_offset added by base learner
106  float rfx = right_fs.values[rfn];
107  uint64_t rindex = right_fs.indicies[rfn];
108  uint64_t rwindex = (rindex + ((uint64_t)n << stride_shift));
109 
110  right_fs.push_back(scale * *lw * lfx * rfx, rwindex);
111 
112  if (all.audit || all.hash_inv)
113  {
114  std::stringstream new_feature_buffer;
115  new_feature_buffer << right << '^' << right_fs.space_names[rfn].get()->second << '^' << n;
116 
117 #ifdef _WIN32
118  char* new_space = _strdup("lrq");
119  char* new_feature = _strdup(new_feature_buffer.str().c_str());
120 #else
121  char* new_space = strdup("lrq");
122  char* new_feature = strdup(new_feature_buffer.str().c_str());
123 #endif
124  right_fs.space_names.push_back(audit_strings_ptr(new audit_strings(new_space, new_feature)));
125  }
126  }
127  }
128  }
129  }
130  }
131 
132  if (is_learn)
133  base.learn(ec);
134  else
135  base.predict(ec);
136 
137  // Restore example
138  if (iter == 0)
139  {
140  first_prediction = ec.pred.scalar;
141  first_loss = ec.loss;
142  first_uncertainty = ec.confidence;
143  }
144  else
145  {
146  ec.pred.scalar = first_prediction;
147  ec.loss = first_loss;
148  ec.confidence = first_uncertainty;
149  }
150 
151  for (std::string const& i : lrq.lrpairs)
152  {
153  unsigned char right = i[(which + 1) % 2];
154  ec.feature_space[right].truncate_to(lrq.orig_size[right]);
155  }
156  }
157 }
v_array< namespace_index > indices
size_t example_counter
Definition: example.h:64
parameters weights
Definition: global_data.h:537
void predict(E &ec, size_t i=0)
Definition: learner.h:169
uint64_t stride_shift(const stagewise_poly &poly, uint64_t idx)
void push_back(feature_value v, feature_index i)
float scalar
Definition: example.h:45
std::shared_ptr< audit_strings > audit_strings_ptr
Definition: feature_group.h:23
bool hash_inv
Definition: global_data.h:541
v_array< feature_index > indicies
the core definition of a set of features.
float confidence
Definition: example.h:72
v_array< feature_value > values
size_t orig_size[256]
Definition: lrq.cc:15
constexpr bool example_is_test(example &ec)
Definition: lrq.cc:41
std::set< std::string > lrpairs
Definition: lrq.cc:16
float cheesyrand(uint64_t x)
Definition: lrq.cc:34
std::array< features, NUM_NAMESPACES > feature_space
bool dropout
Definition: lrq.cc:17
void push_back(const T &new_ele)
Definition: v_array.h:107
uint64_t seed
Definition: lrq.cc:18
unsigned char namespace_index
bool cheesyrbit(uint64_t &seed)
Definition: lrq.cc:32
float loss
Definition: example.h:70
float weight
v_array< audit_strings_ptr > space_names
bool lrindices[256]
Definition: lrq.cc:14
uint32_t stride_shift()
bool audit
Definition: global_data.h:486
polyprediction pred
Definition: example.h:60
vw * all
Definition: lrq.cc:13
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float f
Definition: cache.cc:40
std::pair< std::string, std::string > audit_strings
Definition: feature_group.h:22

◆ reset_seed()

void reset_seed ( LRQstate lrq)

Definition at line 43 of file lrq.cc.

References LRQstate::all, vw::bfgs, LRQstate::initial_seed, and LRQstate::seed.

Referenced by lrq_setup().

44 {
45  if (lrq.all->bfgs)
46  lrq.seed = lrq.initial_seed;
47 }
uint64_t seed
Definition: lrq.cc:18
bool bfgs
Definition: global_data.h:412
uint64_t initial_seed
Definition: lrq.cc:19
vw * all
Definition: lrq.cc:13

◆ valid_int()

bool valid_int ( const char *  s)

Definition at line 22 of file lrq.cc.

Referenced by lrq_setup().

23 {
24  char* endptr;
25 
26  int v = strtoul(s, &endptr, 0);
27  (void)v;
28 
29  return (*s != '\0' && *endptr == '\0');
30 }