Vowpal Wabbit
Functions
classweight.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnerclassweight_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ classweight_setup()

LEARNER::base_learner* classweight_setup ( VW::config::options_i options,
vw all 
)

Definition at line 71 of file classweight.cc.

References VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::make_base(), VW::config::make_option(), prediction_type::multiclass, LEARNER::learner< T, E >::pred_type, vw::quiet, prediction_type::scalar, setup_base(), THROW, vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

72 {
73  std::vector<std::string> classweight_array;
74  auto cweights = scoped_calloc_or_throw<classweights>();
75  option_group_definition new_options("importance weight classes");
76  new_options.add(make_option("classweight", classweight_array).help("importance weight multiplier for class"));
77  options.add_and_parse(new_options);
78 
79  if (!options.was_supplied("classweight"))
80  return nullptr;
81 
82  for (auto& s : classweight_array) cweights->load_string(s);
83 
84  if (!all.quiet)
85  all.trace_message << "parsed " << cweights->weights.size() << " class weights" << std::endl;
86 
87  LEARNER::single_learner* base = as_singleline(setup_base(options, all));
88 
91  ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::scalar>,
92  predict_or_learn<false, prediction_type::scalar>);
93  else if (base->pred_type == prediction_type::multiclass)
94  ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::multiclass>,
95  predict_or_learn<false, prediction_type::multiclass>);
96  else
97  THROW("--classweight not implemented for this type of prediction");
98  return make_base(*ret);
99 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool quiet
Definition: global_data.h:487
virtual void add_and_parse(const option_group_definition &group)=0
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
vw_ostream trace_message
Definition: global_data.h:424
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
prediction_type::prediction_type_t pred_type
Definition: learner.h:149
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181