Vowpal Wabbit
Classes | Namespaces | Functions
topk.cc File Reference
#include <cfloat>
#include <sstream>
#include <queue>
#include <utility>
#include "topk.h"
#include "learner.h"
#include "parse_args.h"
#include "vw.h"

Go to the source code of this file.

Classes

class  VW::topk
 

Namespaces

 VW
 

Functions

void print_result (int file_descriptor, std::pair< VW::topk::const_iterator_t, VW::topk::const_iterator_t > const &view)
 
void output_example (vw &all, example &ec)
 
template<bool is_learn>
void predict_or_learn (VW::topk &d, LEARNER::single_learner &base, multi_ex &ec_seq)
 
void finish_example (vw &all, VW::topk &d, multi_ex &ec_seq)
 
LEARNER::base_learnertopk_setup (options_i &options, vw &all)
 

Function Documentation

◆ finish_example()

void finish_example ( vw all,
VW::topk d,
multi_ex ec_seq 
)

Definition at line 124 of file topk.cc.

References VW::topk::clear_container(), vw::final_prediction_sink, VW::finish_example(), VW::topk::get_container_view(), output_example(), and print_result().

125 {
126  for (auto ec : ec_seq) output_example(all, *ec);
127  for (auto sink : all.final_prediction_sink) print_result(sink, d.get_container_view());
128  d.clear_container();
129  VW::finish_example(all, ec_seq);
130 }
void clear_container()
Definition: topk.cc:79
v_array< int > final_prediction_sink
Definition: global_data.h:518
void output_example(vw &all, example &ec)
Definition: topk.cc:104
void print_result(int file_descriptor, std::pair< VW::topk::const_iterator_t, VW::topk::const_iterator_t > const &view)
Definition: topk.cc:81
std::pair< const_iterator_t, const_iterator_t > get_container_view()
Definition: topk.cc:74
void finish_example(vw &, example &)
Definition: parser.cc:881

◆ output_example()

void output_example ( vw all,
example ec 
)

Definition at line 104 of file topk.cc.

References example::l, label_data::label, example::loss, example::num_features, CB::print_update(), vw::sd, polylabel::simple, example::test_only, shared_data::update(), example::weight, and shared_data::weighted_labels.

Referenced by finish_example().

105 {
106  label_data& ld = ec.l.simple;
107 
108  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
109  if (ld.label != FLT_MAX)
110  all.sd->weighted_labels += ((double)ld.label) * ec.weight;
111 
112  print_update(all, ec);
113 }
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
double weighted_labels
Definition: global_data.h:144
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
float weight
Definition: example.h:62
bool test_only
Definition: example.h:76

◆ predict_or_learn()

template<bool is_learn>
void predict_or_learn ( VW::topk d,
LEARNER::single_learner base,
multi_ex ec_seq 
)

Definition at line 116 of file topk.cc.

References VW::topk::learn(), and VW::topk::predict().

117 {
118  if (is_learn)
119  d.learn(base, ec_seq);
120  else
121  d.predict(base, ec_seq);
122 }
void predict(LEARNER::single_learner &base, multi_ex &ec_seq)
Definition: topk.cc:43
void learn(LEARNER::single_learner &base, multi_ex &ec_seq)
Definition: topk.cc:52

◆ print_result()

void print_result ( int  file_descriptor,
std::pair< VW::topk::const_iterator_t, VW::topk::const_iterator_t > const &  view 
)

Definition at line 81 of file topk.cc.

References print_tag().

Referenced by finish_example().

82 {
83  if (file_descriptor >= 0)
84  {
85  std::stringstream ss;
86  for (auto it = view.first; it != view.second; it++)
87  {
88  ss << std::fixed << it->first << " ";
89  print_tag(ss, it->second);
90  ss << " \n";
91  }
92  ss << '\n';
93  ssize_t len = ss.str().size();
94 #ifdef _WIN32
95  ssize_t t = _write(file_descriptor, ss.str().c_str(), (unsigned int)len);
96 #else
97  ssize_t t = write(file_descriptor, ss.str().c_str(), (unsigned int)len);
98 #endif
99  if (t != len)
100  std::cerr << "write error: " << strerror(errno) << std::endl;
101  }
102 }
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81

◆ topk_setup()

LEARNER::base_learner* topk_setup ( options_i options,
vw all 
)

Definition at line 132 of file topk.cc.

References VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), VW::finish_example(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), setup_base(), and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

133 {
134  uint32_t K;
135  option_group_definition new_options("Top K");
136  new_options.add(make_option("top", K).keep().help("top k recommendation"));
137  options.add_and_parse(new_options);
138 
139  if (!options.was_supplied("top"))
140  return nullptr;
141 
142  auto data = scoped_calloc_or_throw<VW::topk>(K);
143 
145  init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
146  l.set_finish_example(finish_example);
147 
148  return make_base(l);
149 }
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
void finish_example(vw &all, VW::topk &d, multi_ex &ec_seq)
Definition: topk.cc:124
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222