Vowpal Wabbit
Functions
bfgs.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerbfgs_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ bfgs_setup()

LEARNER::base_learner* bfgs_setup ( VW::config::options_i options,
vw all 
)

Definition at line 1093 of file bfgs.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), vw::audit, vw::bfgs, LEARNER::end_pass(), VW::config::options_i::get_typed_option(), vw::hash_inv, vw::hessian_on, shared_data::holdout_best_loss, vw::holdout_set_off, init_driver(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), vw::numpasses, vw::quiet, save_load(), vw::sd, parameters::stride(), parameters::stride_shift(), THROW, vw::training, and vw::weights.

Referenced by parse_reductions().

1094 {
1095  auto b = scoped_calloc_or_throw<bfgs>();
1096  bool conjugate_gradient = false;
1097  bool bfgs_option = false;
1098  option_group_definition bfgs_outer_options("LBFGS and Conjugate Gradient options");
1099  bfgs_outer_options.add(
1100  make_option("conjugate_gradient", conjugate_gradient).keep().help("use conjugate gradient based optimization"));
1101 
1102  option_group_definition bfgs_inner_options("LBFGS and Conjugate Gradient options");
1103  bfgs_inner_options.add(make_option("bfgs", bfgs_option).keep().help("use conjugate gradient based optimization"));
1104  bfgs_inner_options.add(make_option("hessian_on", all.hessian_on).help("use second derivative in line search"));
1105  bfgs_inner_options.add(make_option("mem", b->m).default_value(15).help("memory in bfgs"));
1106  bfgs_inner_options.add(
1107  make_option("termination", b->rel_threshold).default_value(0.001f).help("Termination threshold"));
1108 
1109  options.add_and_parse(bfgs_outer_options);
1110  if (!conjugate_gradient)
1111  {
1112  options.add_and_parse(bfgs_inner_options);
1113  if (!bfgs_option)
1114  {
1115  return nullptr;
1116  }
1117  }
1118 
1119  b->all = &all;
1120  b->wolfe1_bound = 0.01;
1121  b->first_hessian_on = true;
1122  b->first_pass = true;
1123  b->gradient_pass = true;
1124  b->preconditioner_pass = true;
1125  b->backstep_on = false;
1126  b->final_pass = all.numpasses;
1127  b->no_win_counter = 0;
1128 
1129  if (!all.holdout_set_off)
1130  {
1131  all.sd->holdout_best_loss = FLT_MAX;
1132  b->early_stop_thres = options.get_typed_option<size_t>("early_terminate").value();
1133  }
1134 
1135  if (b->m == 0)
1136  all.hessian_on = true;
1137 
1138  if (!all.quiet)
1139  {
1140  if (b->m > 0)
1141  b->all->trace_message << "enabling BFGS based optimization ";
1142  else
1143  b->all->trace_message << "enabling conjugate gradient optimization via BFGS ";
1144  if (all.hessian_on)
1145  b->all->trace_message << "with curvature calculation" << std::endl;
1146  else
1147  b->all->trace_message << "**without** curvature calculation" << std::endl;
1148  }
1149 
1150  if (all.numpasses < 2 && all.training)
1151  THROW("you must make at least 2 passes to use BFGS");
1152 
1153  all.bfgs = true;
1154  all.weights.stride_shift(2);
1155 
1156  void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr;
1157  if (all.audit)
1158  learn_ptr = learn<true>;
1159  else
1160  learn_ptr = learn<false>;
1161 
1163  if (all.audit || all.hash_inv)
1164  l = &init_learner(b, learn_ptr, predict<true>, all.weights.stride());
1165  else
1166  l = &init_learner(b, learn_ptr, predict<false>, all.weights.stride());
1167 
1168  l->set_save_load(save_load);
1169  l->set_init_driver(init_driver);
1170  l->set_end_pass(end_pass);
1171 
1172  return make_base(*l);
1173 }
parameters weights
Definition: global_data.h:537
bool hash_inv
Definition: global_data.h:541
uint32_t stride()
double holdout_best_loss
Definition: global_data.h:161
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
Definition: bfgs.cc:62
virtual void add_and_parse(const option_group_definition &group)=0
bool holdout_set_off
Definition: global_data.h:499
bool training
Definition: global_data.h:488
bool hessian_on
Definition: global_data.h:413
void end_pass(bfgs &b)
Definition: bfgs.cc:897
void save_load(bfgs &b, io_buf &model_file, bool read, bool text)
Definition: bfgs.cc:1026
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
typed_option< T > & get_typed_option(const std::string &key)
Definition: options.h:120
bool bfgs
Definition: global_data.h:412
size_t numpasses
Definition: global_data.h:451
void init_driver(bfgs &b)
Definition: bfgs.cc:1091
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
uint32_t stride_shift()
bool audit
Definition: global_data.h:486
#define THROW(args)
Definition: vw_exception.h:181