Vowpal Wabbit
spanning_tree.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 
7 #include "spanning_tree.h"
8 #include "vw_exception.h"
9 
10 #include <string.h>
11 #include <errno.h>
12 #include <stdlib.h>
13 #include <stdio.h>
14 #include <string>
15 #include <iostream>
16 #include <fstream>
17 #include <cmath>
18 #include <map>
19 #include <future>
20 
21 struct client
22 {
23  uint32_t client_ip;
25 };
26 
27 struct partial
28 {
30  size_t filled;
31 };
32 
33 static int socket_sort(const void* s1, const void* s2)
34 {
35  client* socket1 = (client*)s1;
36  client* socket2 = (client*)s2;
37  if (socket1->client_ip != socket2->client_ip)
38  return socket1->client_ip - socket2->client_ip;
39  else
40  return (int)(socket1->socket - socket2->socket);
41 }
42 
43 int build_tree(int* parent, uint16_t* kid_count, size_t source_count, int offset)
44 {
45  if (source_count == 1)
46  {
47  kid_count[offset] = 0;
48  return offset;
49  }
50 
51  int height = (int)floor(log((double)source_count) / log(2.0));
52  int root = (1 << height) - 1;
53  int left_count = root;
54  int left_offset = offset;
55  int left_child = build_tree(parent, kid_count, left_count, left_offset);
56  int oroot = root + offset;
57  parent[left_child] = oroot;
58 
59  size_t right_count = source_count - left_count - 1;
60  if (right_count > 0)
61  {
62  int right_offset = oroot + 1;
63 
64  int right_child = build_tree(parent, kid_count, right_count, right_offset);
65  parent[right_child] = oroot;
66  kid_count[oroot] = 2;
67  }
68  else
69  kid_count[oroot] = 1;
70 
71  return oroot;
72 }
73 
74 void fail_send(const socket_t fd, const void* buf, const int count)
75 {
76  if (send(fd, (char*)buf, count, 0) == -1)
77  THROWERRNO("send: ");
78 }
79 
80 namespace VW
81 {
82 SpanningTree::SpanningTree(uint16_t port, bool quiet) : m_stop(false), m_port(port), m_future(nullptr), m_quiet(quiet)
83 {
84 #ifdef _WIN32
85  WSAData wsaData;
86  int lastError = WSAStartup(MAKEWORD(2, 2), &wsaData);
87  if (lastError != 0)
88  THROWERRNO("WSAStartup() returned error:" << lastError);
89 #endif
90 
91  sock = socket(PF_INET, SOCK_STREAM, 0);
92  if (sock < 0)
93  THROWERRNO("socket: ");
94 
95  int on = 1;
96  if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on)) < 0)
97  THROWERRNO("setsockopt SO_REUSEADDR: ");
98 
99  sockaddr_in address;
100  address.sin_family = AF_INET;
101  address.sin_addr.s_addr = htonl(INADDR_ANY);
102 
103  address.sin_port = htons(port);
104  if (::bind(sock, (sockaddr*)&address, sizeof(address)) < 0)
105  THROWERRNO("bind failed for " << inet_ntoa(address.sin_addr));
106 
107  sockaddr_in bound_addr;
108  memset(&bound_addr, 0, sizeof(bound_addr));
109  socklen_t len = sizeof(bound_addr);
110  if (::getsockname(sock, (sockaddr*)&bound_addr, &len) < 0)
111  THROWERRNO("getsockname: " << inet_ntoa(bound_addr.sin_addr));
112 
113  // which port did we bind too (if m_port is 0 this will give us the actual port)
114  m_port = ntohs(bound_addr.sin_port);
115 }
116 
118 {
119  Stop();
120  delete m_future;
121 }
122 
123 short unsigned int SpanningTree::BoundPort() { return m_port; }
124 
126 {
127  // launch async
128  if (m_future == nullptr)
129  {
130  m_future = new std::future<void>;
131  }
132 
133  *m_future = std::async(std::launch::async, &SpanningTree::Run, this);
134 }
135 
137 {
138  m_stop = true;
139 #ifndef _WIN32
140  // just close won't unblock the accept
141  shutdown(sock, SHUT_RD);
142 #endif
143  CLOSESOCK(sock);
144 
145  // wait for run to stop
146  if (m_future != nullptr)
147  {
148  m_future->get();
149  }
150 }
151 
153 {
154  std::map<size_t, partial> partial_nodesets;
155  while (!m_stop)
156  {
157  if (listen(sock, 1024) < 0)
158  THROWERRNO("listen: ");
159 
160  sockaddr_in client_address;
161  socklen_t size = sizeof(client_address);
162  socket_t f = accept(sock, (sockaddr*)&client_address, &size);
163 #ifdef _WIN32
164  if (f == INVALID_SOCKET)
165  {
166 #else
167  if (f < 0)
168  {
169 #endif
170  break;
171  }
172 
173  char dotted_quad[INET_ADDRSTRLEN];
174  if (NULL == inet_ntop(AF_INET, &(client_address.sin_addr), dotted_quad, INET_ADDRSTRLEN))
175  THROWERRNO("inet_ntop: ");
176 
177  char hostname[NI_MAXHOST];
178  char servInfo[NI_MAXSERV];
179  if (getnameinfo((sockaddr*)&client_address, sizeof(sockaddr), hostname, NI_MAXHOST, servInfo, NI_MAXSERV, 0))
180  THROWERRNO("getnameinfo: ");
181 
182  if (!m_quiet)
183  std::cerr << "inbound connection from " << dotted_quad << "(" << hostname << ':' << ntohs(m_port)
184  << ") serv=" << servInfo << std::endl;
185 
186  size_t nonce = 0;
187  if (recv(f, (char*)&nonce, sizeof(nonce), 0) != sizeof(nonce))
188  {
189  THROW(dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): nonce read failed, exiting");
190  }
191  else
192  {
193  if (!m_quiet)
194  std::cerr << dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): nonce=" << nonce << std::endl;
195  }
196  size_t total = 0;
197  if (recv(f, (char*)&total, sizeof(total), 0) != sizeof(total))
198  {
199  THROW(dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): total node count read failed, exiting");
200  }
201  else
202  {
203  if (!m_quiet)
204  std::cerr << dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): total=" << total << std::endl;
205  }
206  size_t id = 0;
207  if (recv(f, (char*)&id, sizeof(id), 0) != sizeof(id))
208  {
209  THROW(dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): node id read failed, exiting");
210  }
211  else
212  {
213  if (!m_quiet)
214  std::cerr << dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): node id=" << id << std::endl;
215  }
216 
217  int ok = true;
218  if (id >= total)
219  {
220  if (!m_quiet)
221  std::cout << dotted_quad << "(" << hostname << ':' << ntohs(m_port) << "): invalid id=" << id
222  << " >= " << total << " !" << std::endl;
223  ok = false;
224  }
225  partial partial_nodeset;
226 
227  if (partial_nodesets.find(nonce) == partial_nodesets.end())
228  {
229  partial_nodeset.nodes = (client*)calloc(total, sizeof(client));
230  for (size_t i = 0; i < total; i++) partial_nodeset.nodes[i].client_ip = (uint32_t)-1;
231  partial_nodeset.filled = 0;
232  }
233  else
234  {
235  partial_nodeset = partial_nodesets[nonce];
236  partial_nodesets.erase(nonce);
237  }
238 
239  if (ok && partial_nodeset.nodes[id].client_ip != (uint32_t)-1)
240  ok = false;
241  fail_send(f, &ok, sizeof(ok));
242 
243  if (ok)
244  {
245  partial_nodeset.nodes[id].client_ip = client_address.sin_addr.s_addr;
246  partial_nodeset.nodes[id].socket = f;
247  partial_nodeset.filled++;
248  }
249  if (partial_nodeset.filled != total) // Need to wait for more connections
250  {
251  partial_nodesets[nonce] = partial_nodeset;
252  for (size_t i = 0; i < total; i++)
253  {
254  if (partial_nodeset.nodes[i].client_ip == (uint32_t)-1)
255  {
256  if (!m_quiet)
257  std::cout << "nonce " << nonce << " still waiting for " << (total - partial_nodeset.filled)
258  << " nodes out of " << total << " for example node " << i << std::endl;
259  break;
260  }
261  }
262  }
263  else
264  {
265  // Time to make the spanning tree
266  qsort(partial_nodeset.nodes, total, sizeof(client), socket_sort);
267 
268  int* parent = (int*)calloc(total, sizeof(int));
269  uint16_t* kid_count = (uint16_t*)calloc(total, sizeof(uint16_t));
270 
271  int root = build_tree(parent, kid_count, total, 0);
272  parent[root] = -1;
273 
274  for (size_t i = 0; i < total; i++)
275  {
276  fail_send(partial_nodeset.nodes[i].socket, &kid_count[i], sizeof(kid_count[i]));
277  }
278 
279  uint16_t* client_ports = (uint16_t*)calloc(total, sizeof(uint16_t));
280 
281  for (size_t i = 0; i < total; i++)
282  {
283  int done = 0;
284  if (recv(partial_nodeset.nodes[i].socket, (char*)&(client_ports[i]), sizeof(client_ports[i]), 0) <
285  (int)sizeof(client_ports[i]))
286 
287  if (!m_quiet)
288  std::cerr << " Port read failed for node " << i << " read " << done << std::endl;
289  } // all clients have bound to their ports.
290 
291  for (size_t i = 0; i < total; i++)
292  {
293  if (parent[i] >= 0)
294  {
295  fail_send(partial_nodeset.nodes[i].socket, &partial_nodeset.nodes[parent[i]].client_ip,
296  sizeof(partial_nodeset.nodes[parent[i]].client_ip));
297  fail_send(partial_nodeset.nodes[i].socket, &client_ports[parent[i]], sizeof(client_ports[parent[i]]));
298  }
299  else
300  {
301  int bogus = -1;
302  uint32_t bogus2 = -1;
303  fail_send(partial_nodeset.nodes[i].socket, &bogus2, sizeof(bogus2));
304  fail_send(partial_nodeset.nodes[i].socket, &bogus, sizeof(bogus));
305  }
306  CLOSESOCK(partial_nodeset.nodes[i].socket);
307  }
308  free(client_ports);
309  free(partial_nodeset.nodes);
310  free(parent);
311  free(kid_count);
312  }
313  }
314 
315 #ifdef _WIN32
316  WSACleanup();
317 #endif
318 }
319 } // namespace VW
int build_tree(int *parent, uint16_t *kid_count, size_t source_count, int offset)
size_t filled
void fail_send(const socket_t fd, const void *buf, const int count)
static int socket_sort(const void *s1, const void *s2)
#define CLOSESOCK
Definition: allreduce.h:43
socket_t socket
float id(float in)
Definition: scorer.cc:51
short unsigned int BoundPort()
#define THROWERRNO(args)
Definition: vw_exception.h:167
std::future< void > * m_future
Definition: spanning_tree.h:51
Definition: autolink.cc:11
int socket_t
Definition: allreduce.h:42
client * nodes
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40
uint32_t client_ip