Vowpal Wabbit
sender.cc
Go to the documentation of this file.
1 #include <vector>
2 #ifdef _WIN32
3 #define NOMINMAX
4 #include <WinSock2.h>
5 #ifndef SHUT_RD
6 #define SHUT_RD SD_RECEIVE
7 #endif
8 
9 #ifndef SHUT_WR
10 #define SHUT_WR SD_SEND
11 #endif
12 
13 #ifndef SHUT_RDWR
14 #define SHUT_RDWR SD_BOTH
15 #endif
16 #else
17 #include <netdb.h>
18 #endif
19 #include "io_buf.h"
20 #include "cache.h"
21 #include "network.h"
22 #include "reductions.h"
23 
24 using namespace VW::config;
25 
26 struct sender
27 {
29  int sd;
30  vw* all; // loss ring_size others
32  size_t sent_index;
34 
36  {
37  buf->files.delete_v();
38  buf->space.delete_v();
39  free(delay_ring);
40  delete buf;
41  }
42 };
43 
44 void open_sockets(sender& s, std::string host)
45 {
46  s.sd = open_socket(host.c_str());
47  s.buf = new io_buf();
48  s.buf->files.push_back(s.sd);
49 }
50 
51 void send_features(io_buf* b, example& ec, uint32_t mask)
52 {
53  // note: subtracting 1 b/c not sending constant
54  output_byte(*b, (unsigned char)(ec.indices.size() - 1));
55 
56  for (namespace_index ns : ec.indices)
57  {
58  if (ns == constant_namespace)
59  continue;
60  output_features(*b, ns, ec.feature_space[ns], mask);
61  }
62  b->flush();
63 }
64 
66 {
67  float res, weight;
68 
69  get_prediction(s.sd, res, weight);
70  example& ec = *s.delay_ring[s.received_index++ % s.all->p->ring_size];
71  ec.pred.scalar = res;
72 
73  label_data& ld = ec.l.simple;
74  ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ld.label) * ec.weight;
75 
76  return_simple_example(*(s.all), nullptr, ec);
77 }
78 
80 {
81  if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index)
82  receive_result(s);
83 
84  s.all->set_minmax(s.all->sd, ec.l.simple.label);
85  s.all->p->lp.cache_label(&ec.l, *s.buf); // send label information.
86  cache_tag(*s.buf, ec.tag);
87  send_features(s.buf, ec, (uint32_t)s.all->parse_mask);
88  s.delay_ring[s.sent_index++ % s.all->p->ring_size] = &ec;
89 }
90 
92 
94 {
95  // close our outputs to signal finishing.
96  while (s.received_index != s.sent_index) receive_result(s);
97  shutdown(s.buf->files[0], SHUT_WR);
98 }
99 
101 {
102  std::string host;
103 
104  option_group_definition sender_options("Network sending");
105  sender_options.add(make_option("sendto", host).keep().help("send examples to <host>"));
106  options.add_and_parse(sender_options);
107 
108  if (!options.was_supplied("sendto"))
109  {
110  return nullptr;
111  }
112 
113  auto s = scoped_calloc_or_throw<sender>();
114  s->sd = -1;
115  open_sockets(*s.get(), host);
116 
117  s->all = &all;
118  s->delay_ring = calloc_or_throw<example*>(all.p->ring_size);
119 
121  l.set_finish_example(finish_example);
122  l.set_end_examples(end_examples);
123  return make_base(l);
124 }
v_array< char > tag
Definition: example.h:63
void open_sockets(sender &s, std::string host)
Definition: sender.cc:44
v_array< namespace_index > indices
loss_function * loss
Definition: global_data.h:523
float scalar
Definition: example.h:45
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
void output_byte(io_buf &cache, unsigned char s)
Definition: cache.cc:144
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
size_t size() const
Definition: v_array.h:68
parser * p
Definition: global_data.h:377
std::array< features, NUM_NAMESPACES > feature_space
void receive_result(sender &s)
Definition: sender.cc:65
void(* set_minmax)(shared_data *sd, float label)
Definition: global_data.h:394
virtual float getLoss(shared_data *, float prediction, float label)=0
void output_features(io_buf &cache, unsigned char index, features &fs, uint64_t mask)
Definition: cache.cc:153
virtual void flush()
Definition: io_buf.h:194
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
void push_back(const T &new_ele)
Definition: v_array.h:107
shared_data * sd
Definition: global_data.h:375
void learn(sender &s, LEARNER::single_learner &, example &ec)
Definition: sender.cc:79
v_array< int > files
Definition: io_buf.h:64
example ** delay_ring
Definition: sender.cc:31
virtual bool was_supplied(const std::string &key)=0
int sd
Definition: sender.cc:29
int open_socket(const char *host, unsigned short port)
io_buf * buf
Definition: sender.cc:28
unsigned char namespace_index
void(* cache_label)(void *, io_buf &cache)
Definition: label_parser.h:14
float get_prediction(example *ec)
Definition: parser.cc:919
const size_t ring_size
Definition: parser.h:80
Definition: io_buf.h:54
size_t received_index
Definition: sender.cc:33
void finish_example(vw &, example &)
Definition: parser.cc:881
v_array< char > space
Definition: io_buf.h:62
float loss
Definition: example.h:70
void send_features(io_buf *b, example &ec, uint32_t mask)
Definition: sender.cc:51
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
vw * all
Definition: sender.cc:30
polylabel l
Definition: example.h:57
uint64_t parse_mask
Definition: global_data.h:453
~sender()
Definition: sender.cc:35
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * sender_setup(options_i &options, vw &all)
Definition: sender.cc:100
size_t sent_index
Definition: sender.cc:32
polyprediction pred
Definition: example.h:60
void delete_v()
Definition: v_array.h:98
void cache_tag(io_buf &cache, v_array< char > tag)
Definition: cache.cc:192
constexpr unsigned char constant_namespace
Definition: constant.h:22
float weight
Definition: example.h:62
void end_examples(sender &s)
Definition: sender.cc:93
Definition: sender.cc:26
label_parser lp
Definition: parser.h:102
void return_simple_example(vw &all, void *, example &ec)