Vowpal Wabbit
allreduce_sockets.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 /*
7 This implements the allreduce function of MPI. Code primarily by
8 Alekh Agarwal and John Langford, with help Olivier Chapelle.
9  */
10 #include <iostream>
11 #include <sstream>
12 #include <cstdio>
13 #include <cmath>
14 #include <ctime>
15 #include <errno.h>
16 #include <string.h>
17 #include <stdlib.h>
18 #ifdef _WIN32
19 #define NOMINMAX
20 #include <WinSock2.h>
21 #include <Windows.h>
22 #include <WS2tcpip.h>
23 #include <io.h>
24 #else
25 #include <unistd.h>
26 #include <arpa/inet.h>
27 #endif
28 #include <sys/timeb.h>
29 #include "allreduce.h"
30 #include "vw_exception.h"
31 
32 using std::cerr;
33 using std::endl;
34 
35 // port is already in network order
36 socket_t AllReduceSockets::sock_connect(const uint32_t ip, const int port)
37 {
38  socket_t sock = socket(PF_INET, SOCK_STREAM, 0);
39  if (sock == -1)
40  THROWERRNO("socket");
41 
42  sockaddr_in far_end;
43  far_end.sin_family = AF_INET;
44  far_end.sin_port = port;
45 
46  far_end.sin_addr = *(in_addr*)&ip;
47  memset(&far_end.sin_zero, '\0', 8);
48 
49  {
50  char dotted_quad[INET_ADDRSTRLEN];
51  if (nullptr == inet_ntop(AF_INET, &(far_end.sin_addr), dotted_quad, INET_ADDRSTRLEN))
52  THROWERRNO("inet_ntop");
53 
54  char hostname[NI_MAXHOST];
55  char servInfo[NI_MAXSERV];
56  if (getnameinfo((sockaddr*)&far_end, sizeof(sockaddr), hostname, NI_MAXHOST, servInfo, NI_MAXSERV, NI_NUMERICSERV))
57  THROWERRNO("getnameinfo(" << dotted_quad << ")");
58 
59  if (!quiet)
60  cerr << "connecting to " << dotted_quad << " = " << hostname << ':' << ntohs(port) << endl;
61  }
62 
63  size_t count = 0;
64  int ret;
65  while ((ret = connect(sock, (sockaddr*)&far_end, sizeof(far_end))) == -1 && count < 100)
66  {
67  count++;
68  std::stringstream msg;
69  if (!quiet)
70  {
71  msg << "connect attempt " << count << " failed: " << strerror(errno);
72  cerr << msg.str() << endl;
73  }
74 #ifdef _WIN32
75  Sleep(1);
76 #else
77  sleep(1);
78 #endif
79  }
80  if (ret == -1)
81  THROW("cannot connect");
82  return sock;
83 }
84 
86 {
87  socket_t sock = socket(PF_INET, SOCK_STREAM, 0);
88  if (sock < 0)
89  THROWERRNO("socket");
90 
91  // SO_REUSEADDR will allow port rebinding on Windows, causing multiple instances
92  // of VW on the same machine to potentially contact the wrong tree node.
93 #ifndef _WIN32
94  int on = 1;
95  if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on)) < 0)
96  {
97  if (!quiet)
98  cerr << "setsockopt SO_REUSEADDR: " << strerror(errno) << endl;
99  }
100 #endif
101 
102  // Enable TCP Keep Alive to prevent socket leaks
103  int enableTKA = 1;
104  if (setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char*)&enableTKA, sizeof(enableTKA)) < 0)
105  {
106  if (!quiet)
107  cerr << "setsockopt SO_KEEPALIVE: " << strerror(errno) << endl;
108  }
109 
110  return sock;
111 }
112 
114 {
115 #ifdef _WIN32
116  WSAData wsaData;
117  int lastError = WSAStartup(MAKEWORD(2, 2), &wsaData);
118  if (lastError != 0)
119  THROWERRNO("WSAStartup() returned error:" << lastError);
120 #endif
121 
122  struct hostent* master = gethostbyname(span_server.c_str());
123 
124  if (master == nullptr)
125  THROWERRNO("gethostbyname(" << span_server << ")");
126 
128 
129  uint32_t master_ip = *((uint32_t*)master->h_addr);
130 
131  socket_t master_sock = sock_connect(master_ip, htons(port));
132  if (send(master_sock, (const char*)&unique_id, sizeof(unique_id), 0) < (int)sizeof(unique_id))
133  {
134  THROW("write unique_id=" << unique_id << " to span server failed");
135  }
136  else
137  {
138  if (!quiet)
139  cerr << "wrote unique_id=" << unique_id << endl;
140  }
141  if (send(master_sock, (const char*)&total, sizeof(total), 0) < (int)sizeof(total))
142  {
143  THROW("write total=" << total << " to span server failed");
144  }
145  else
146  {
147  if (!quiet)
148  cerr << "wrote total=" << total << endl;
149  }
150  if (send(master_sock, (char*)&node, sizeof(node), 0) < (int)sizeof(node))
151  {
152  THROW("write node=" << node << " to span server failed");
153  }
154  else
155  {
156  if (!quiet)
157  cerr << "wrote node=" << node << endl;
158  }
159  int ok;
160  if (recv(master_sock, (char*)&ok, sizeof(ok), 0) < (int)sizeof(ok))
161  {
162  THROW("read ok from span server failed");
163  }
164  else
165  {
166  if (!quiet)
167  cerr << "read ok=" << ok << endl;
168  }
169  if (!ok)
170  THROW("mapper already connected");
171 
172  uint16_t kid_count;
173  uint16_t parent_port;
174  uint32_t parent_ip;
175 
176  if (recv(master_sock, (char*)&kid_count, sizeof(kid_count), 0) < (int)sizeof(kid_count))
177  {
178  THROW("read kid_count from span server failed");
179  }
180  else
181  {
182  if (!quiet)
183  cerr << "read kid_count=" << kid_count << endl;
184  }
185 
186  socket_t sock = -1;
187  short unsigned int netport = htons(26544);
188  if (kid_count > 0)
189  {
190  sock = getsock();
191  sockaddr_in address;
192  address.sin_family = AF_INET;
193  address.sin_addr.s_addr = htonl(INADDR_ANY);
194  address.sin_port = netport;
195 
196  bool listening = false;
197  while (!listening)
198  {
199  if (::bind(sock, (sockaddr*)&address, sizeof(address)) < 0)
200  {
201 #ifdef _WIN32
202  if (WSAGetLastError() == WSAEADDRINUSE)
203 #else
204  if (errno == EADDRINUSE)
205 #endif
206  {
207  netport = htons(ntohs(netport) + 1);
208  address.sin_port = netport;
209  }
210  else
211  THROWERRNO("bind");
212  }
213  else
214  {
215  if (listen(sock, kid_count) < 0)
216  {
217  if (!quiet)
218  cerr << "listen: " << strerror(errno) << endl;
219  CLOSESOCK(sock);
220  sock = getsock();
221  }
222  else
223  {
224  listening = true;
225  }
226  }
227  }
228  }
229 
230  if (send(master_sock, (const char*)&netport, sizeof(netport), 0) < (int)sizeof(netport))
231  THROW("write netport failed!");
232 
233  if (recv(master_sock, (char*)&parent_ip, sizeof(parent_ip), 0) < (int)sizeof(parent_ip))
234  {
235  THROW("read parent_ip failed!");
236  }
237  else
238  {
239  char dotted_quad[INET_ADDRSTRLEN];
240  if (nullptr == inet_ntop(AF_INET, (char*)&parent_ip, dotted_quad, INET_ADDRSTRLEN))
241  {
242  if (!quiet)
243  cerr << "read parent_ip=" << parent_ip << "(inet_ntop: " << strerror(errno) << ")" << endl;
244  }
245  else
246  {
247  if (!quiet)
248  cerr << "read parent_ip=" << dotted_quad << endl;
249  }
250  }
251  if (recv(master_sock, (char*)&parent_port, sizeof(parent_port), 0) < (int)sizeof(parent_port))
252  {
253  THROW("read parent_port failed!");
254  }
255  else
256  {
257  if (!quiet)
258  cerr << "read parent_port=" << parent_port << endl;
259  }
260 
261  CLOSESOCK(master_sock);
262 
263  if (parent_ip != (uint32_t)-1)
264  {
265  socks.parent = sock_connect(parent_ip, parent_port);
266  }
267  else
268  socks.parent = -1;
269 
270  socks.children[0] = -1;
271  socks.children[1] = -1;
272  for (int i = 0; i < kid_count; i++)
273  {
274  sockaddr_in child_address;
275  socklen_t size = sizeof(child_address);
276  socket_t f = accept(sock, (sockaddr*)&child_address, &size);
277  if (f < 0)
278  THROWERRNO("accept");
279 
280  // char hostname[NI_MAXHOST];
281  // char servInfo[NI_MAXSERV];
282  // getnameinfo((sockaddr *) &child_address, sizeof(sockaddr), hostname, NI_MAXHOST, servInfo, NI_MAXSERV,
283  // NI_NUMERICSERV); cerr << "connected to " << hostname << ':' << ntohs(port) << endl;
284  socks.children[i] = f;
285  }
286 
287  if (kid_count > 0)
288  CLOSESOCK(sock);
289 }
290 
291 void AllReduceSockets::pass_down(char* buffer, const size_t parent_read_pos, size_t& children_sent_pos)
292 {
293  size_t my_bufsize = std::min(ar_buf_size, (parent_read_pos - children_sent_pos));
294 
295  if (my_bufsize > 0)
296  {
297  // going to pass up this chunk of data to the children
298  if (socks.children[0] != -1 &&
299  send(socks.children[0], buffer + children_sent_pos, (int)my_bufsize, 0) < (int)my_bufsize)
300  {
301  THROW("Write to left child failed");
302  }
303  if (socks.children[1] != -1 &&
304  send(socks.children[1], buffer + children_sent_pos, (int)my_bufsize, 0) < (int)my_bufsize)
305  {
306  THROW("Write to right child failed");
307  }
308 
309  children_sent_pos += my_bufsize;
310  }
311 }
312 
313 void AllReduceSockets::broadcast(char* buffer, const size_t n)
314 {
315  size_t parent_read_pos = 0; // First unread float from parent
316  size_t children_sent_pos = 0; // First unsent float to children
317  // parent_sent_pos <= left_read_pos
318  // parent_sent_pos <= right_read_pos
319 
320  if (socks.parent == -1)
321  {
322  parent_read_pos = n;
323  }
324  if (socks.children[0] == -1 && socks.children[1] == -1)
325  children_sent_pos = n;
326 
327  while (parent_read_pos < n || children_sent_pos < n)
328  {
329  pass_down(buffer, parent_read_pos, children_sent_pos);
330  if (parent_read_pos >= n && children_sent_pos >= n)
331  break;
332 
333  if (socks.parent != -1)
334  {
335  // there is data to be read from the parent
336  if (parent_read_pos == n)
337  THROW("I think parent has no data to send but he thinks he has");
338 
339  size_t count = std::min(ar_buf_size, n - parent_read_pos);
340  int read_size = recv(socks.parent, buffer + parent_read_pos, (int)count, 0);
341  if (read_size == -1)
342  {
343  THROW(" recv from parent: " << strerror(errno));
344  }
345  parent_read_pos += read_size;
346  }
347  }
348 }
constexpr size_t ar_buf_size
Definition: allreduce.h:49
socket_t children[2]
Definition: allreduce.h:55
node_socks socks
Definition: allreduce.h:182
bool quiet
Definition: allreduce.h:82
const size_t total
Definition: allreduce.h:80
socket_t parent
Definition: allreduce.h:54
#define CLOSESOCK
Definition: allreduce.h:43
std::string current_master
Definition: allreduce.h:53
void pass_down(char *buffer, const size_t parent_read_pos, size_t &children_sent_pos)
#define THROWERRNO(args)
Definition: vw_exception.h:167
std::string span_server
Definition: allreduce.h:183
int socket_t
Definition: allreduce.h:42
socket_t sock_connect(const uint32_t ip, const int port)
size_t unique_id
Definition: allreduce.h:185
void broadcast(char *buffer, const size_t n)
#define THROW(args)
Definition: vw_exception.h:181
float f
Definition: cache.cc:40