Vowpal Wabbit
Functions
cs_active.h File Reference

Go to the source code of this file.

Functions

LEARNER::base_learnercs_active_setup (VW::config::options_i &options, vw &all)
 

Function Documentation

◆ cs_active_setup()

LEARNER::base_learner* cs_active_setup ( VW::config::options_i options,
vw all 
)

Definition at line 310 of file cs_active.cc.

References VW::config::option_group_definition::add(), add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), vw::cost_sensitive, COST_SENSITIVE::cs_label, f, finish_example(), loss_function::getType(), LEARNER::init_learner(), vw::loss, parser::lp, LEARNER::make_base(), VW::config::make_option(), prediction_type::multilabels, vw::p, vw::sd, LEARNER::learner< T, E >::set_finish_example(), vw::set_minmax, setup_base(), THROW, vw::trace_message, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

311 {
312  auto data = scoped_calloc_or_throw<cs_active>();
313 
314  bool simulation = false;
315  int domination;
316  option_group_definition new_options("Cost-sensitive Active Learning");
317  new_options
318  .add(make_option("cs_active", data->num_classes).keep().help("Cost-sensitive active learning with <k> costs"))
319  .add(make_option("simulation", simulation).help("cost-sensitive active learning simulation mode"))
320  .add(make_option("baseline", data->is_baseline).help("cost-sensitive active learning baseline"))
321  .add(make_option("domination", domination)
322  .default_value(1)
323  .help("cost-sensitive active learning use domination. Default 1"))
324  .add(make_option("mellowness", data->c0).default_value(0.1f).help("mellowness parameter c_0. Default 0.1."))
325  .add(make_option("range_c", data->c1)
326  .default_value(0.5f)
327  .help("parameter controlling the threshold for per-label cost uncertainty. Default 0.5."))
328  .add(make_option("max_labels", data->max_labels).default_value(-1).help("maximum number of label queries."))
329  .add(make_option("min_labels", data->min_labels).default_value(-1).help("minimum number of label queries."))
330  .add(make_option("cost_max", data->cost_max).default_value(1.f).help("cost upper bound. Default 1."))
331  .add(make_option("cost_min", data->cost_min).default_value(0.f).help("cost lower bound. Default 0."))
332  // TODO replace with trace and quiet
333  .add(make_option("csa_debug", data->print_debug_stuff).help("print debug stuff for cs_active"));
334  options.add_and_parse(new_options);
335 
336  data->use_domination = true;
337  if (options.was_supplied("domination") && !domination)
338  data->use_domination = false;
339 
340  if (!options.was_supplied("cs_active"))
341  return nullptr;
342 
343  data->all = &all;
344  data->t = 1;
345 
346  auto loss_function_type = all.loss->getType();
347  if (loss_function_type != "squared")
348  THROW("error: you can't use non-squared loss with cs_active");
349 
350  if (options.was_supplied("lda"))
351  THROW("error: you can't combine lda and active learning");
352 
353  if (options.was_supplied("active"))
354  THROW("error: you can't use --cs_active and --active at the same time");
355 
356  if (options.was_supplied("active_cover"))
357  THROW("error: you can't use --cs_active and --active_cover at the same time");
358 
359  if (options.was_supplied("csoaa"))
360  THROW("error: you can't use --cs_active and --csoaa at the same time");
361 
362  if (!options.was_supplied("adax"))
363  all.trace_message << "WARNING: --cs_active should be used with --adax" << endl;
364 
365  all.p->lp = cs_label; // assigning the label parser
366  all.set_minmax(all.sd, data->cost_max);
367  all.set_minmax(all.sd, data->cost_min);
368  for (uint32_t i = 0; i < data->num_classes + 1; i++) data->examples_by_queries.push_back(0);
369 
370  learner<cs_active, example>& l = simulation
371  ? init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true, true>,
372  predict_or_learn<false, true>, data->num_classes, prediction_type::multilabels)
373  : init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true, false>,
374  predict_or_learn<false, false>, data->num_classes, prediction_type::multilabels);
375 
377  base_learner* b = make_base(l);
378  all.cost_sensitive = b;
379  return b;
380 }
loss_function * loss
Definition: global_data.h:523
LEARNER::base_learner * cost_sensitive
Definition: global_data.h:385
label_parser cs_label
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
parser * p
Definition: global_data.h:377
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
vw_ostream trace_message
Definition: global_data.h:424
void finish_example(vw &all, cs_active &cs_a, example &ec)
Definition: cs_active.cc:308
virtual bool was_supplied(const std::string &key)=0
virtual std::string getType()=0
int add(svm_params &params, svm_example *fec)
Definition: kernel_svm.cc:546
void predict_or_learn(cs_active &cs_a, single_learner &base, example &ec)
Definition: cs_active.cc:176
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40
label_parser lp
Definition: parser.h:102