Vowpal Wabbit
Classes | Namespaces | Typedefs | Functions
marginal.cc File Reference
#include <unordered_map>
#include "reductions.h"
#include "correctedMath.h"

Go to the source code of this file.

Classes

struct  MARGINAL::expert
 
struct  MARGINAL::data
 

Namespaces

 MARGINAL
 

Typedefs

typedef std::pair< double, double > MARGINAL::marginal
 
typedef std::pair< expert, expertMARGINAL::expert_pair
 

Functions

float MARGINAL::get_adanormalhedge_weights (float R, float C)
 
template<bool is_learn>
void MARGINAL::make_marginal (data &sm, example &ec)
 
void MARGINAL::undo_marginal (data &sm, example &ec)
 
template<bool is_learn>
void MARGINAL::compute_expert_loss (data &sm, example &ec)
 
void MARGINAL::update_marginal (data &sm, example &ec)
 
template<bool is_learn>
void MARGINAL::predict_or_learn (data &sm, LEARNER::single_learner &base, example &ec)
 
void MARGINAL::save_load (data &sm, io_buf &io, bool read, bool text)
 
LEARNER::base_learnermarginal_setup (options_i &options, vw &all)
 

Function Documentation

◆ marginal_setup()

LEARNER::base_learner* marginal_setup ( options_i options,
vw all 
)

Definition at line 351 of file marginal.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), MARGINAL::save_load(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

352 {
353  free_ptr<MARGINAL::data> d = scoped_calloc_or_throw<MARGINAL::data>();
354  std::string marginal;
355 
356  option_group_definition marginal_options("VW options");
357  marginal_options.add(make_option("marginal", marginal).keep().help("substitute marginal label estimates for ids"));
358  marginal_options.add(
359  make_option("initial_denominator", d->initial_denominator).default_value(1.f).help("initial denominator"));
360  marginal_options.add(
361  make_option("initial_numerator", d->initial_numerator).default_value(0.5f).help("initial numerator"));
362  marginal_options.add(make_option("compete", d->compete).help("enable competition with marginal features"));
363  marginal_options.add(
364  make_option("update_before_learn", d->update_before_learn).help("update marginal values before learning"));
365  marginal_options.add(make_option("unweighted_marginals", d->unweighted_marginals)
366  .help("ignore importance weights when computing marginals"));
367  marginal_options.add(
368  make_option("decay", d->decay).default_value(0.f).help("decay multiplier per event (1e-3 for example)"));
369  options.add_and_parse(marginal_options);
370 
371  if (!options.was_supplied("marginal"))
372  {
373  return nullptr;
374  }
375 
376  d->all = &all;
377 
378  for (size_t u = 0; u < 256; u++)
379  if (marginal.find((char)u) != std::string::npos)
380  d->id_features[u] = true;
381 
383  init_learner(d, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
384  ret.set_save_load(save_load);
385 
386  return make_base(ret);
387 }
void save_load(data &sm, io_buf &io, bool read, bool text)
Definition: marginal.cc:244
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::pair< double, double > marginal
Definition: marginal.cc:16
std::unique_ptr< T, free_fn > free_ptr
Definition: memory.h:34
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222