1 #include <unordered_map> 10 std::unordered_map<uint32_t, float>
weights;
14 std::stringstream ss(source);
16 while (std::getline(ss, item,
','))
18 std::stringstream inner_ss(item);
21 std::getline(inner_ss, klass,
':');
22 std::getline(inner_ss, weight,
':');
24 if (!klass.size() || !weight.size())
26 THROW(
"error: while parsing --classweight " << item);
29 int klass_int = std::stoi(klass);
30 float weight_double = std::stof(weight);
32 weights[klass_int] = weight_double;
38 auto got = weights.find(klass);
39 if (got == weights.end())
46 template <
bool is_learn,
int pred_type>
73 std::vector<std::string> classweight_array;
74 auto cweights = scoped_calloc_or_throw<classweights>();
76 new_options.add(
make_option(
"classweight", classweight_array).help(
"importance weight multiplier for class"));
82 for (
auto& s : classweight_array) cweights->load_string(s);
85 all.
trace_message <<
"parsed " << cweights->weights.size() <<
" class weights" << std::endl;
91 ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::scalar>,
92 predict_or_learn<false, prediction_type::scalar>);
94 ret = &LEARNER::init_learner<classweights>(cweights, base, predict_or_learn<true, prediction_type::multiclass>,
95 predict_or_learn<false, prediction_type::multiclass>);
97 THROW(
"--classweight not implemented for this type of prediction");
void predict(E &ec, size_t i=0)
LEARNER::base_learner * classweight_setup(options_i &options, vw &all)
base_learner * make_base(learner< T, E > &base)
void load_string(std::string const &source)
virtual void add_and_parse(const option_group_definition &group)=0
std::unordered_map< uint32_t, float > weights
single_learner * as_singleline(learner< T, E > *l)
MULTICLASS::label_t multi
float get_class_weight(uint32_t klass)
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
prediction_type::prediction_type_t pred_type
LEARNER::base_learner * setup_base(options_i &options, vw &all)
void learn(E &ec, size_t i=0)
static void predict_or_learn(classweights &cweights, LEARNER::single_learner &base, example &ec)