Vowpal Wabbit
Functions
lrq.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerlrq_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ lrq_setup()

LEARNER::base_learner* lrq_setup ( VW::config::options_i options,
vw all 
)

Definition at line 159 of file lrq.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), vw::quiet, vw::random_seed, reset_seed(), LEARNER::learner< T, E >::set_end_pass(), setup_base(), spoof_hex_encoded_namespaces(), THROW, vw::trace_message, valid_int(), VW::config::options_i::was_supplied(), and vw::wpp.

Referenced by parse_reductions().

160 {
161  auto lrq = scoped_calloc_or_throw<LRQstate>();
162  std::vector<std::string> lrq_names;
163  option_group_definition new_options("Low Rank Quadratics");
164  new_options.add(make_option("lrq", lrq_names).keep().help("use low rank quadratic features"))
165  .add(make_option("lrqdropout", lrq->dropout).keep().help("use dropout training for low rank quadratic features"));
166  options.add_and_parse(new_options);
167 
168  if (!options.was_supplied("lrq"))
169  return nullptr;
170 
171  uint32_t maxk = 0;
172  lrq->all = &all;
173 
174  for (auto& lrq_name : lrq_names) lrq_name = spoof_hex_encoded_namespaces(lrq_name);
175 
176  new (&lrq->lrpairs) std::set<std::string>(lrq_names.begin(), lrq_names.end());
177 
178  lrq->initial_seed = lrq->seed = all.random_seed | 8675309;
179 
180  if (!all.quiet)
181  {
182  all.trace_message << "creating low rank quadratic features for pairs: ";
183  if (lrq->dropout)
184  all.trace_message << "(using dropout) ";
185  }
186 
187  for (std::string const& i : lrq->lrpairs)
188  {
189  if (!all.quiet)
190  {
191  if ((i.length() < 3) || !valid_int(i.c_str() + 2))
192  THROW("error, low-rank quadratic features must involve two sets and a rank.");
193 
194  all.trace_message << i << " ";
195  }
196  // TODO: colon-syntax
197 
198  unsigned int k = atoi(i.c_str() + 2);
199 
200  lrq->lrindices[(int)i[0]] = true;
201  lrq->lrindices[(int)i[1]] = true;
202 
203  maxk = std::max(k, k);
204  }
205 
206  if (!all.quiet)
207  all.trace_message << std::endl;
208 
209  all.wpp = all.wpp * (uint64_t)(1 + maxk);
211  lrq, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>, 1 + maxk);
213 
214  // TODO: leaks memory ?
215  return make_base(l);
216 }
uint64_t random_seed
Definition: global_data.h:491
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
std::string spoof_hex_encoded_namespaces(const std::string &arg)
Definition: parse_args.cc:568
bool valid_int(const char *s)
Definition: lrq.cc:22
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
void reset_seed(LRQstate &lrq)
Definition: lrq.cc:43
uint32_t wpp
Definition: global_data.h:432
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181