Vowpal Wabbit
Classes | Functions | Variables
VW::shared_feature_merger Namespace Reference

Classes

struct  sfm_data
 

Functions

bool use_reduction (config::options_i &options)
 
template<bool is_learn>
void predict_or_learn (sfm_data &, LEARNER::multi_learner &base, multi_ex &ec_seq)
 
LEARNER::base_learnershared_feature_merger_setup (config::options_i &options, vw &all)
 

Variables

static const std::vector< std::string > option_strings
 

Function Documentation

◆ predict_or_learn()

template<bool is_learn>
void VW::shared_feature_merger::predict_or_learn ( sfm_data ,
LEARNER::multi_learner base,
multi_ex ec_seq 
)

Definition at line 34 of file shared_feature_merger.cc.

References LabelDict::add_example_namespaces_from_example(), example_predict::begin(), LabelDict::del_example_namespaces_from_example(), CB::ec_is_example_header(), LEARNER::learner< T, E >::learn(), example::pred, LEARNER::learner< T, E >::predict(), and THROW.

35 {
36  if (ec_seq.size() == 0)
37  THROW("cb_adf: At least one action must be provided for an example to be valid.");
38 
39  multi_ex::value_type shared_example = nullptr;
40 
41  const bool has_example_header = CB::ec_is_example_header(*ec_seq[0]);
42  if (has_example_header)
43  {
44  shared_example = ec_seq[0];
45  ec_seq.erase(ec_seq.begin());
46  // merge sequences
47  for (auto& example : ec_seq) LabelDict::add_example_namespaces_from_example(*example, *shared_example);
48  std::swap(ec_seq[0]->pred, shared_example->pred);
49  }
50  if (ec_seq.size() == 0)
51  return;
52  if (is_learn)
53  base.learn(ec_seq);
54  else
55  base.predict(ec_seq);
56 
57  if (has_example_header)
58  {
59  for (auto& example : ec_seq) LabelDict::del_example_namespaces_from_example(*example, *shared_example);
60  std::swap(shared_example->pred, ec_seq[0]->pred);
61  ec_seq.insert(ec_seq.begin(), shared_example);
62  }
63 }
void predict(E &ec, size_t i=0)
Definition: learner.h:169
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
void del_example_namespaces_from_example(example &target, example &source)
iterator begin()
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
#define THROW(args)
Definition: vw_exception.h:181
void add_example_namespaces_from_example(example &target, example &source)

◆ shared_feature_merger_setup()

LEARNER::base_learner * VW::shared_feature_merger::shared_feature_merger_setup ( config::options_i options,
vw all 
)

Definition at line 65 of file shared_feature_merger.cc.

References LEARNER::as_multiline(), LEARNER::init_learner(), LEARNER::make_base(), setup_base(), and use_reduction().

Referenced by parse_reductions().

66 {
67  if (!use_reduction(options))
68  return nullptr;
69 
70  auto data = scoped_calloc_or_throw<sfm_data>();
71 
72  auto* base = LEARNER::as_multiline(setup_base(options, all));
73  auto& learner = LEARNER::init_learner(data, base, predict_or_learn<true>, predict_or_learn<false>);
74 
75  // TODO: Incorrect feature numbers will be reported without merging the example namespaces from the
76  // shared example in a finish_example function. However, its too expensive to perform the full operation.
77 
78  return LEARNER::make_base(learner);
79 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool use_reduction(config::options_i &options)
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468

◆ use_reduction()

bool VW::shared_feature_merger::use_reduction ( config::options_i options)

Definition at line 19 of file shared_feature_merger.cc.

References VW::config::options_i::was_supplied().

Referenced by shared_feature_merger_setup().

20 {
21  for (const auto& opt : option_strings)
22  {
23  if (options.was_supplied(opt))
24  return true;
25  }
26  return false;
27 }
static const std::vector< std::string > option_strings

Variable Documentation

◆ option_strings

const std::vector<std::string> VW::shared_feature_merger::option_strings
static
Initial value:
= {
"csoaa_ldf", "wap_ldf", "cb_adf", "explore_eval", "cbify_ldf", "cb_explore_adf", "warm_cb"}

Definition at line 16 of file shared_feature_merger.cc.