Vowpal Wabbit
Functions
lrqfa.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerlrqfa_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ lrqfa_setup()

LEARNER::base_learner* lrqfa_setup ( VW::config::options_i options,
vw all 
)

Definition at line 133 of file lrqfa.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), setup_base(), spoof_hex_encoded_namespaces(), VW::config::options_i::was_supplied(), and vw::wpp.

Referenced by parse_reductions().

134 {
135  std::string lrqfa;
136  option_group_definition new_options("Low Rank Quadratics FA");
137  new_options.add(make_option("lrqfa", lrqfa).keep().help("use low rank quadratic features with field aware weights"));
138  options.add_and_parse(new_options);
139 
140  if (!options.was_supplied("lrqfa"))
141  return nullptr;
142 
143  auto lrq = scoped_calloc_or_throw<LRQFAstate>();
144  lrq->all = &all;
145 
146  std::string lrqopt = spoof_hex_encoded_namespaces(lrqfa);
147  size_t last_index = lrqopt.find_last_not_of("0123456789");
148  new (&lrq->field_name) std::string(lrqopt.substr(0, last_index + 1)); // make sure there is no duplicates
149  lrq->k = atoi(lrqopt.substr(last_index + 1).c_str());
150 
151  int fd_id = 0;
152  for (char i : lrq->field_name) lrq->field_id[(int)i] = fd_id++;
153 
154  all.wpp = all.wpp * (uint64_t)(1 + lrq->k);
155  learner<LRQFAstate, example>& l = init_learner(lrq, as_singleline(setup_base(options, all)), predict_or_learn<true>,
156  predict_or_learn<false>, 1 + lrq->field_name.size() * lrq->k);
157 
158  return make_base(l);
159 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::string spoof_hex_encoded_namespaces(const std::string &arg)
Definition: parse_args.cc:568
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
uint32_t wpp
Definition: global_data.h:432
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222