Vowpal Wabbit
topk.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <cfloat>
7 #include <sstream>
8 #include <queue>
9 #include <utility>
10 
11 #include "topk.h"
12 #include "learner.h"
13 #include "parse_args.h"
14 #include "vw.h"
15 
16 using namespace VW::config;
17 
18 namespace VW
19 {
20 class topk
21 {
22  using container_t = std::multimap<float, v_array<char>>;
23 
24  public:
25  using const_iterator_t = container_t::const_iterator;
26  topk(uint32_t k_num);
27 
28  void predict(LEARNER::single_learner& base, multi_ex& ec_seq);
29  void learn(LEARNER::single_learner& base, multi_ex& ec_seq);
30  std::pair<const_iterator_t, const_iterator_t> get_container_view();
31  void clear_container();
32 
33  private:
34  void update_priority_queue(float pred, v_array<char>& tag);
35 
36  const uint32_t _k_num;
38 };
39 } // namespace VW
40 
41 VW::topk::topk(uint32_t k_num) : _k_num(k_num) {}
42 
44 {
45  for (auto ec : ec_seq)
46  {
47  base.predict(*ec);
48  update_priority_queue(ec->pred.scalar, ec->tag);
49  }
50 }
51 
53 {
54  for (auto ec : ec_seq)
55  {
56  base.learn(*ec);
57  update_priority_queue(ec->pred.scalar, ec->tag);
58  }
59 }
60 
62 {
63  if (_pr_queue.size() < _k_num)
64  {
65  _pr_queue.insert({pred, tag});
66  }
67  else if (_pr_queue.begin()->first < pred)
68  {
69  _pr_queue.erase(_pr_queue.begin());
70  _pr_queue.insert({pred, tag});
71  }
72 }
73 
74 std::pair<VW::topk::const_iterator_t, VW::topk::const_iterator_t> VW::topk::get_container_view()
75 {
76  return {_pr_queue.cbegin(), _pr_queue.cend()};
77 }
78 
80 
81 void print_result(int file_descriptor, std::pair<VW::topk::const_iterator_t, VW::topk::const_iterator_t> const& view)
82 {
83  if (file_descriptor >= 0)
84  {
85  std::stringstream ss;
86  for (auto it = view.first; it != view.second; it++)
87  {
88  ss << std::fixed << it->first << " ";
89  print_tag(ss, it->second);
90  ss << " \n";
91  }
92  ss << '\n';
93  ssize_t len = ss.str().size();
94 #ifdef _WIN32
95  ssize_t t = _write(file_descriptor, ss.str().c_str(), (unsigned int)len);
96 #else
97  ssize_t t = write(file_descriptor, ss.str().c_str(), (unsigned int)len);
98 #endif
99  if (t != len)
100  std::cerr << "write error: " << strerror(errno) << std::endl;
101  }
102 }
103 
104 void output_example(vw& all, example& ec)
105 {
106  label_data& ld = ec.l.simple;
107 
108  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
109  if (ld.label != FLT_MAX)
110  all.sd->weighted_labels += ((double)ld.label) * ec.weight;
111 
112  print_update(all, ec);
113 }
114 
115 template <bool is_learn>
117 {
118  if (is_learn)
119  d.learn(base, ec_seq);
120  else
121  d.predict(base, ec_seq);
122 }
123 
124 void finish_example(vw& all, VW::topk& d, multi_ex& ec_seq)
125 {
126  for (auto ec : ec_seq) output_example(all, *ec);
127  for (auto sink : all.final_prediction_sink) print_result(sink, d.get_container_view());
128  d.clear_container();
129  VW::finish_example(all, ec_seq);
130 }
131 
133 {
134  uint32_t K;
135  option_group_definition new_options("Top K");
136  new_options.add(make_option("top", K).keep().help("top k recommendation"));
137  options.add_and_parse(new_options);
138 
139  if (!options.was_supplied("top"))
140  return nullptr;
141 
142  auto data = scoped_calloc_or_throw<VW::topk>(K);
143 
145  init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
146  l.set_finish_example(finish_example);
147 
148  return make_base(l);
149 }
void predict(LEARNER::single_learner &base, multi_ex &ec_seq)
Definition: topk.cc:43
const uint32_t _k_num
Definition: topk.cc:36
void predict(E &ec, size_t i=0)
Definition: learner.h:169
void clear_container()
Definition: topk.cc:79
v_array< int > final_prediction_sink
Definition: global_data.h:518
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void learn(LEARNER::single_learner &base, multi_ex &ec_seq)
Definition: topk.cc:52
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void output_example(vw &all, example &ec)
Definition: topk.cc:104
topk(uint32_t k_num)
Definition: topk.cc:41
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
void update_priority_queue(float pred, v_array< char > &tag)
Definition: topk.cc:61
void print_result(int file_descriptor, std::pair< VW::topk::const_iterator_t, VW::topk::const_iterator_t > const &view)
Definition: topk.cc:81
std::pair< const_iterator_t, const_iterator_t > get_container_view()
Definition: topk.cc:74
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
container_t::const_iterator const_iterator_t
Definition: topk.cc:25
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
double weighted_labels
Definition: global_data.h:144
std::multimap< float, v_array< char > > container_t
Definition: topk.cc:22
void predict_or_learn(VW::topk &d, LEARNER::single_learner &base, multi_ex &ec_seq)
Definition: topk.cc:116
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
option_group_definition & add(T &&op)
Definition: options.h:90
std::vector< example * > multi_ex
Definition: example.h:122
polylabel l
Definition: example.h:57
LEARNER::base_learner * topk_setup(options_i &options, vw &all)
Definition: topk.cc:132
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
Definition: topk.cc:20
Definition: autolink.cc:11
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict(bfgs &b, base_learner &, example &ec)
Definition: bfgs.cc:956
void learn(E &ec, size_t i=0)
Definition: learner.h:160
void learn(bfgs &b, base_learner &base, example &ec)
Definition: bfgs.cc:965
float weight
Definition: example.h:62
container_t _pr_queue
Definition: topk.cc:37
bool test_only
Definition: example.h:76