Vowpal Wabbit
Functions
nn.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnernn_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ nn_setup()

LEARNER::base_learner* nn_setup ( VW::config::options_i options,
vw all 
)

Definition at line 417 of file nn.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::end_pass(), finish_example(), vw::get_random_state(), getLossFunction(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), multipredict(), nn::multitask, vw::quiet, vw::random_seed, LEARNER::learner< T, E >::set_end_pass(), LEARNER::learner< T, E >::set_finish_example(), LEARNER::learner< T, E >::set_multipredict(), setup_base(), vw::training, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

418 {
419  auto n = scoped_calloc_or_throw<nn>();
420  bool meanfield = false;
421  option_group_definition new_options("Neural Network");
422  new_options.add(make_option("nn", n->k).keep().help("Sigmoidal feedforward network with <k> hidden units"))
423  .add(make_option("inpass", n->inpass)
424  .keep()
425  .help("Train or test sigmoidal feedforward network with input passthrough."))
426  .add(make_option("multitask", n->multitask).keep().help("Share hidden layer across all reduced tasks."))
427  .add(make_option("dropout", n->dropout).keep().help("Train or test sigmoidal feedforward network using dropout."))
428  .add(make_option("meanfield", meanfield).help("Train or test sigmoidal feedforward network using mean field."));
429  options.add_and_parse(new_options);
430 
431  if (!options.was_supplied("nn"))
432  return nullptr;
433 
434  n->all = &all;
435  n->_random_state = all.get_random_state();
436 
437  if (n->multitask && !all.quiet)
438  std::cerr << "using multitask sharing for neural network " << (all.training ? "training" : "testing") << std::endl;
439 
440  if (options.was_supplied("meanfield"))
441  {
442  n->dropout = false;
443  if (!all.quiet)
444  std::cerr << "using mean field for neural network " << (all.training ? "training" : "testing") << std::endl;
445  }
446 
447  if (n->dropout && !all.quiet)
448  std::cerr << "using dropout for neural network " << (all.training ? "training" : "testing") << std::endl;
449 
450  if (n->inpass && !all.quiet)
451  std::cerr << "using input passthrough for neural network " << (all.training ? "training" : "testing") << std::endl;
452 
453  n->finished_setup = false;
454  n->squared_loss = getLossFunction(all, "squared", 0);
455 
456  n->xsubi = all.random_seed;
457 
458  n->save_xsubi = n->xsubi;
459 
460  n->hidden_units = calloc_or_throw<float>(n->k);
461  n->dropped_out = calloc_or_throw<bool>(n->k);
462  n->hidden_units_pred = calloc_or_throw<polyprediction>(n->k);
463  n->hiddenbias_pred = calloc_or_throw<polyprediction>(n->k);
464 
465  auto base = as_singleline(setup_base(options, all));
466  n->increment = base->increment; // Indexing of output layer is odd.
467  nn& nv = *n.get();
469  init_learner(n, base, predict_or_learn_multi<true, true>, predict_or_learn_multi<false, true>, n->k + 1);
470  if (nv.multitask)
474 
475  return make_base(l);
476 }
void end_pass(nn &n)
Definition: nn.cc:146
void set_multipredict(void(*u)(T &, L &, E &, size_t, size_t, polyprediction *, bool))
Definition: learner.h:217
uint64_t random_seed
Definition: global_data.h:491
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
bool training
Definition: global_data.h:488
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
bool multitask
Definition: nn.cc:38
void multipredict(nn &n, single_learner &base, example &ec, size_t count, size_t step, polyprediction *pred, bool finalize_predictions)
Definition: nn.cc:391
virtual bool was_supplied(const std::string &key)=0
Definition: nn.cc:24
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void set_end_pass(void(*f)(T &))
Definition: learner.h:286
void finish_example(vw &all, nn &, example &ec)
Definition: nn.cc:409
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
loss_function * getLossFunction(vw &all, std::string funcName, float function_parameter)