Vowpal Wabbit
Functions
kernel_svm.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerkernel_svm_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ kernel_svm_setup()

LEARNER::base_learner* kernel_svm_setup ( VW::config::options_i options,
vw all 
)

Definition at line 867 of file kernel_svm.cc.

References vw::active, VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), vw::all_reduce, f, vw::get_random_state(), getLossFunction(), LEARNER::init_learner(), vw::l2_lambda, learn(), vw::loss, LEARNER::make_base(), VW::config::make_option(), predict(), save_load(), LEARNER::learner< T, E >::set_save_load(), SVM_KER_LIN, SVM_KER_POLY, SVM_KER_RBF, AllReduce::total, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

868 {
869  auto params = scoped_calloc_or_throw<svm_params>();
870  std::string kernel_type;
871  float bandwidth = 1.f;
872  int degree = 2;
873 
874  bool ksvm = false;
875 
876  option_group_definition new_options("Kernel SVM");
877  new_options.add(make_option("ksvm", ksvm).keep().help("kernel svm"))
878  .add(make_option("reprocess", params->reprocess).default_value(1).help("number of reprocess steps for LASVM"))
879  .add(make_option("pool_greedy", params->active_pool_greedy).help("use greedy selection on mini pools"))
880  .add(make_option("para_active", params->para_active).help("do parallel active learning"))
881  .add(make_option("pool_size", params->pool_size).default_value(1).help("size of pools for active learning"))
882  .add(make_option("subsample", params->subsample)
883  .default_value(1)
884  .help("number of items to subsample from the pool"))
885  .add(make_option("kernel", kernel_type)
886  .keep()
887  .default_value("linear")
888  .help("type of kernel (rbf or linear (default))"))
889  .add(make_option("bandwidth", bandwidth).keep().default_value(1.f).help("bandwidth of rbf kernel"))
890  .add(make_option("degree", degree).keep().default_value(2).help("degree of poly kernel"));
891  options.add_and_parse(new_options);
892 
893  if (!ksvm)
894  {
895  return nullptr;
896  }
897 
898  std::string loss_function = "hinge";
899  float loss_parameter = 0.0;
900  delete all.loss;
901  all.loss = getLossFunction(all, loss_function, (float)loss_parameter);
902 
903  params->model = &calloc_or_throw<svm_model>();
904  params->model->num_support = 0;
905  params->maxcache = 1024 * 1024 * 1024;
906  params->loss_sum = 0.;
907  params->all = &all;
908  params->_random_state = all.get_random_state();
909 
910  // This param comes from the active reduction.
911  // During options refactor: this changes the semantics a bit - now this will only be true if --active was supplied and
912  // NOT --simulation
913  if (all.active)
914  params->active = true;
915  if (params->active)
916  params->active_c = 1.;
917 
918  params->pool = calloc_or_throw<svm_example*>(params->pool_size);
919  params->pool_pos = 0;
920 
921  if (!options.was_supplied("subsample") && params->para_active)
922  params->subsample = (size_t)ceil(params->pool_size / all.all_reduce->total);
923 
924  params->lambda = all.l2_lambda;
925  if (params->lambda == 0.)
926  params->lambda = 1.;
927  params->all->trace_message << "Lambda = " << params->lambda << endl;
928  params->all->trace_message << "Kernel = " << kernel_type << endl;
929 
930  if (kernel_type.compare("rbf") == 0)
931  {
932  params->kernel_type = SVM_KER_RBF;
933  params->all->trace_message << "bandwidth = " << bandwidth << endl;
934  params->kernel_params = &calloc_or_throw<double>();
935  *((float*)params->kernel_params) = bandwidth;
936  }
937  else if (kernel_type.compare("poly") == 0)
938  {
939  params->kernel_type = SVM_KER_POLY;
940  params->all->trace_message << "degree = " << degree << endl;
941  params->kernel_params = &calloc_or_throw<int>();
942  *((int*)params->kernel_params) = degree;
943  }
944  else
945  params->kernel_type = SVM_KER_LIN;
946 
947  params->all->weights.stride_shift(0);
948 
951  return make_base(l);
952 }
loss_function * loss
Definition: global_data.h:523
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
Definition: learner.h:257
void learn(svm_params &params, single_learner &, example &ec)
Definition: kernel_svm.cc:833
#define SVM_KER_LIN
Definition: kernel_svm.cc:33
const size_t total
Definition: allreduce.h:80
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
AllReduce * all_reduce
Definition: global_data.h:381
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
float l2_lambda
Definition: global_data.h:445
bool active
Definition: global_data.h:489
virtual bool was_supplied(const std::string &key)=0
void predict(svm_params &params, svm_example **ec_arr, float *scores, size_t n)
Definition: kernel_svm.cc:457
#define SVM_KER_POLY
Definition: kernel_svm.cc:35
#define SVM_KER_RBF
Definition: kernel_svm.cc:34
void save_load(svm_params &params, io_buf &model_file, bool read, bool text)
Definition: kernel_svm.cc:376
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
loss_function * getLossFunction(vw &all, std::string funcName, float function_parameter)
float f
Definition: cache.cc:40