Vowpal Wabbit
shared_feature_merger.cc
Go to the documentation of this file.
2 #include "cb.h"
3 #include "example.h"
4 #include "label_dictionary.h"
5 #include "learner.h"
6 #include "options.h"
7 #include "parse_args.h"
8 #include "vw.h"
9 
10 #include <iterator>
11 
12 namespace VW
13 {
14 namespace shared_feature_merger
15 {
16 static const std::vector<std::string> option_strings = {
17  "csoaa_ldf", "wap_ldf", "cb_adf", "explore_eval", "cbify_ldf", "cb_explore_adf", "warm_cb"};
18 
20 {
21  for (const auto& opt : option_strings)
22  {
23  if (options.was_supplied(opt))
24  return true;
25  }
26  return false;
27 }
28 
29 struct sfm_data
30 {
31 };
32 
33 template <bool is_learn>
35 {
36  if (ec_seq.size() == 0)
37  THROW("cb_adf: At least one action must be provided for an example to be valid.");
38 
39  multi_ex::value_type shared_example = nullptr;
40 
41  const bool has_example_header = CB::ec_is_example_header(*ec_seq[0]);
42  if (has_example_header)
43  {
44  shared_example = ec_seq[0];
45  ec_seq.erase(ec_seq.begin());
46  // merge sequences
47  for (auto& example : ec_seq) LabelDict::add_example_namespaces_from_example(*example, *shared_example);
48  std::swap(ec_seq[0]->pred, shared_example->pred);
49  }
50  if (ec_seq.size() == 0)
51  return;
52  if (is_learn)
53  base.learn(ec_seq);
54  else
55  base.predict(ec_seq);
56 
57  if (has_example_header)
58  {
59  for (auto& example : ec_seq) LabelDict::del_example_namespaces_from_example(*example, *shared_example);
60  std::swap(shared_example->pred, ec_seq[0]->pred);
61  ec_seq.insert(ec_seq.begin(), shared_example);
62  }
63 }
64 
66 {
67  if (!use_reduction(options))
68  return nullptr;
69 
70  auto data = scoped_calloc_or_throw<sfm_data>();
71 
72  auto* base = LEARNER::as_multiline(setup_base(options, all));
73  auto& learner = LEARNER::init_learner(data, base, predict_or_learn<true>, predict_or_learn<false>);
74 
75  // TODO: Incorrect feature numbers will be reported without merging the example namespaces from the
76  // shared example in a finish_example function. However, its too expensive to perform the full operation.
77 
78  return LEARNER::make_base(learner);
79 }
80 
81 } // namespace shared_feature_merger
82 
83 } // namespace VW
void predict(E &ec, size_t i=0)
Definition: learner.h:169
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
void del_example_namespaces_from_example(example &target, example &source)
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
bool use_reduction(config::options_i &options)
LEARNER::base_learner * shared_feature_merger_setup(config::options_i &options, vw &all)
static const std::vector< std::string > option_strings
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
std::vector< example * > multi_ex
Definition: example.h:122
void predict_or_learn(sfm_data &, LEARNER::multi_learner &base, multi_ex &ec_seq)
Definition: autolink.cc:11
iterator begin()
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
#define THROW(args)
Definition: vw_exception.h:181
multi_learner * as_multiline(learner< T, E > *l)
Definition: learner.h:468
void add_example_namespaces_from_example(example &target, example &source)