Vowpal Wabbit
allreduce.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5  */
6 
7 // This implements the allreduce function of MPI.
8 #pragma once
9 
10 #include <string>
11 #include <algorithm>
12 
13 #ifdef _WIN32
14 #define NOMINMAX
15 #include <WinSock2.h>
16 #include <WS2tcpip.h>
17 typedef unsigned int uint32_t;
18 typedef unsigned short uint16_t;
19 typedef int socklen_t;
20 typedef SOCKET socket_t;
21 #define CLOSESOCK closesocket
22 namespace std
23 {
24 // forward declare promise as C++/CLI doesn't allow usage in header files
25 template <typename T>
26 class promise;
27 
28 class condition_variable;
29 
30 class mutex;
31 } // namespace std
32 #else
33 #include <sys/socket.h>
34 #include <sys/socket.h>
35 #include <netinet/in.h>
36 #include <netinet/tcp.h>
37 #include <netdb.h>
38 #include <stdlib.h>
39 #include <stdio.h>
40 #include <unistd.h>
41 #include <string.h>
42 typedef int socket_t;
43 #define CLOSESOCK close
44 #include <future>
45 #endif
46 #include "vw_exception.h"
47 #include <cassert>
48 
49 constexpr size_t ar_buf_size = 1 << 16;
50 
51 struct node_socks
52 {
53  std::string current_master;
57  {
58  if (current_master != "")
59  {
60  if (parent != -1)
61  CLOSESOCK(this->parent);
62  if (children[0] != -1)
63  CLOSESOCK(this->children[0]);
64  if (children[1] != -1)
65  CLOSESOCK(this->children[1]);
66  }
67  }
68  node_socks() { current_master = ""; }
69 };
70 
71 template <class T, void (*f)(T&, const T&)>
72 void addbufs(T* buf1, const T* buf2, const size_t n)
73 {
74  for (size_t i = 0; i < n; i++) f(buf1[i], buf2[i]);
75 }
76 
77 class AllReduce
78 {
79  public:
80  const size_t total; // total number of nodes
81  const size_t node; // node id number
82  bool quiet;
83 
84  AllReduce(size_t ptotal, const size_t pnode, bool pquiet = false) : total(ptotal), node(pnode), quiet(pquiet)
85  {
86  assert(node < total);
87  }
88 
89  virtual ~AllReduce() = default;
90 };
91 
92 struct Data
93 {
94  void* buffer;
95  size_t length;
96 };
97 
99 {
100  private:
101  std::mutex* m_mutex;
102  std::condition_variable* m_cv;
103 
104  // total number of threads we wait for
105  size_t m_total;
106 
107  // number of threads reached the barrier
108  uint32_t m_count;
109 
110  // current wait-barrier-run required to protect against spurious wakeups of m_cv->wait(...)
111  bool m_run;
112 
113  public:
114  AllReduceSync(const size_t total);
115 
116  ~AllReduceSync();
117 
118  void waitForSynchronization();
119 
120  void** buffers;
121 };
122 
124 {
125  private:
128 
129  public:
130  AllReduceThreads(AllReduceThreads* root, const size_t ptotal, const size_t pnode, bool quiet = false);
131 
132  AllReduceThreads(const size_t ptotal, const size_t pnode, bool quiet = false);
133 
134  virtual ~AllReduceThreads();
135 
136  template <class T, void (*f)(T&, const T&)>
137  void all_reduce(T* buffer, const size_t n)
138  { // register buffer
139  T** buffers = (T**)m_sync->buffers;
140  buffers[node] = buffer;
141  m_sync->waitForSynchronization();
142 
143  size_t blockSize = n / total;
144  size_t index;
145  size_t end;
146 
147  if (blockSize == 0)
148  {
149  if (node < n)
150  {
151  index = node;
152  end = node + 1;
153  }
154  else
155  { // more threads than bytes --> don't do any work
156  index = end = 0;
157  }
158  }
159  else
160  {
161  index = node * blockSize;
162  end = node == total - 1 ? n : (node + 1) * blockSize;
163  }
164 
165  for (; index < end; index++)
166  { // Perform transposed AllReduce to help data locallity
167  T& first = buffers[0][index];
168 
169  for (size_t i = 1; i < total; i++) f(first, buffers[i][index]);
170 
171  // Broadcast back
172  for (size_t i = 1; i < total; i++) buffers[i][index] = first;
173  }
174 
175  m_sync->waitForSynchronization();
176  }
177 };
178 
180 {
181  private:
183  std::string span_server;
184  int port;
185  size_t unique_id; // unique id for each node in the network, id == 0 means extra io.
186 
187  void all_reduce_init();
188 
189  template <class T>
190  void pass_up(char* buffer, size_t left_read_pos, size_t right_read_pos, size_t& parent_sent_pos)
191  {
192  size_t my_bufsize =
193  std::min(ar_buf_size, std::min(left_read_pos, right_read_pos) / sizeof(T) * sizeof(T) - parent_sent_pos);
194 
195  if (my_bufsize > 0)
196  { // going to pass up this chunk of data to the parent
197  int write_size = send(socks.parent, buffer + parent_sent_pos, (int)my_bufsize, 0);
198  if (write_size < 0)
199  THROW("Write to parent failed " << my_bufsize << " " << write_size << " " << parent_sent_pos << " "
200  << left_read_pos << " " << right_read_pos);
201 
202  parent_sent_pos += write_size;
203  }
204  }
205 
206  template <class T, void (*f)(T&, const T&)>
207  void reduce(char* buffer, const size_t n)
208  {
209  fd_set fds;
210  FD_ZERO(&fds);
211  if (socks.children[0] != -1)
212  FD_SET(socks.children[0], &fds);
213  if (socks.children[1] != -1)
214  FD_SET(socks.children[1], &fds);
215 
216  socket_t max_fd = std::max(socks.children[0], socks.children[1]) + 1;
217  size_t child_read_pos[2] = {0, 0}; // First unread float from left and right children
218  int child_unprocessed[2] = {0, 0}; // The number of bytes sent by the child but not yet added to the buffer
219  char child_read_buf[2][ar_buf_size + sizeof(T) - 1];
220  size_t parent_sent_pos = 0; // First unsent float to parent
221  // parent_sent_pos <= left_read_pos
222  // parent_sent_pos <= right_read_pos
223 
224  if (socks.children[0] == -1)
225  {
226  child_read_pos[0] = n;
227  }
228  if (socks.children[1] == -1)
229  {
230  child_read_pos[1] = n;
231  }
232 
233  while (parent_sent_pos < n || child_read_pos[0] < n || child_read_pos[1] < n)
234  {
235  if (socks.parent != -1)
236  pass_up<T>(buffer, child_read_pos[0], child_read_pos[1], parent_sent_pos);
237 
238  if (parent_sent_pos >= n && child_read_pos[0] >= n && child_read_pos[1] >= n)
239  break;
240 
241  if (child_read_pos[0] < n || child_read_pos[1] < n)
242  {
243  if (max_fd > 0 && select((int)max_fd, &fds, nullptr, nullptr, nullptr) == -1)
244  THROWERRNO("select");
245 
246  for (int i = 0; i < 2; i++)
247  {
248  if (socks.children[i] != -1 && FD_ISSET(socks.children[i], &fds))
249  { // there is data to be left from left child
250  if (child_read_pos[i] == n)
251  THROW("I think child has no data to send but he thinks he has "
252  << FD_ISSET(socks.children[0], &fds) << " " << FD_ISSET(socks.children[1], &fds));
253 
254  size_t count = std::min(ar_buf_size, n - child_read_pos[i]);
255  int read_size = recv(socks.children[i], &child_read_buf[i][child_unprocessed[i]], (int)count, 0);
256  if (read_size == -1)
257  THROWERRNO("recv from child");
258 
259  addbufs<T, f>((T*)buffer + child_read_pos[i] / sizeof(T), (T*)child_read_buf[i],
260  (child_read_pos[i] + read_size) / sizeof(T) - child_read_pos[i] / sizeof(T));
261 
262  child_read_pos[i] += read_size;
263  int old_unprocessed = child_unprocessed[i];
264  child_unprocessed[i] = child_read_pos[i] % (int)sizeof(T);
265  for (int j = 0; j < child_unprocessed[i]; j++)
266  {
267  child_read_buf[i][j] =
268  child_read_buf[i][((old_unprocessed + read_size) / (int)sizeof(T)) * sizeof(T) + j];
269  }
270 
271  if (child_read_pos[i] == n) // Done reading parent
272  FD_CLR(socks.children[i], &fds);
273  }
274  else if (socks.children[i] != -1 && child_read_pos[i] != n)
275  FD_SET(socks.children[i], &fds);
276  }
277  }
278  if (socks.parent == -1 && child_read_pos[0] == n && child_read_pos[1] == n)
279  parent_sent_pos = n;
280  }
281  }
282 
283  void pass_down(char* buffer, const size_t parent_read_pos, size_t& children_sent_pos);
284  void broadcast(char* buffer, const size_t n);
285 
286  socket_t sock_connect(const uint32_t ip, const int port);
287  socket_t getsock();
288 
289  public:
290  AllReduceSockets(std::string pspan_server, const int pport, const size_t punique_id, size_t ptotal,
291  const size_t pnode, bool pquiet)
292  : AllReduce(ptotal, pnode, pquiet), span_server(pspan_server), port(pport), unique_id(punique_id)
293  {
294  }
295 
296  virtual ~AllReduceSockets() = default;
297 
298  template <class T, void (*f)(T&, const T&)>
299  void all_reduce(T* buffer, const size_t n)
300  {
301  if (span_server != socks.current_master)
302  all_reduce_init();
303  reduce<T, f>((char*)buffer, n * sizeof(T));
304  broadcast((char*)buffer, n * sizeof(T));
305  }
306 };
constexpr size_t ar_buf_size
Definition: allreduce.h:49
~node_socks()
Definition: allreduce.h:56
AllReduce(size_t ptotal, const size_t pnode, bool pquiet=false)
Definition: allreduce.h:84
void pass_up(char *buffer, size_t left_read_pos, size_t right_read_pos, size_t &parent_sent_pos)
Definition: allreduce.h:190
socket_t children[2]
Definition: allreduce.h:55
Definition: allreduce.h:92
size_t length
Definition: allreduce.h:95
node_socks socks
Definition: allreduce.h:182
bool quiet
Definition: allreduce.h:82
void addbufs(T *buf1, const T *buf2, const size_t n)
Definition: allreduce.h:72
const size_t total
Definition: allreduce.h:80
const size_t node
Definition: allreduce.h:81
socket_t parent
Definition: allreduce.h:54
#define CLOSESOCK
Definition: allreduce.h:43
AllReduceSockets(std::string pspan_server, const int pport, const size_t punique_id, size_t ptotal, const size_t pnode, bool pquiet)
Definition: allreduce.h:290
std::string current_master
Definition: allreduce.h:53
void reduce(char *buffer, const size_t n)
Definition: allreduce.h:207
void waitForSynchronization()
uint32_t m_count
Definition: allreduce.h:108
#define THROWERRNO(args)
Definition: vw_exception.h:167
void * buffer
Definition: allreduce.h:94
std::mutex * m_mutex
Definition: allreduce.h:101
std::string span_server
Definition: allreduce.h:183
AllReduceSync * m_sync
Definition: allreduce.h:126
int socket_t
Definition: allreduce.h:42
void ** buffers
Definition: allreduce.h:120
size_t unique_id
Definition: allreduce.h:185
void all_reduce(T *buffer, const size_t n)
Definition: allreduce.h:299
size_t m_total
Definition: allreduce.h:105
std::condition_variable * m_cv
Definition: allreduce.h:102
bool children(log_multi &b, uint32_t &current, uint32_t &class_index, uint32_t label)
Definition: log_multi.cc:178
#define THROW(args)
Definition: vw_exception.h:181
void all_reduce(T *buffer, const size_t n)
Definition: allreduce.h:137
float f
Definition: cache.cc:40
node_socks()
Definition: allreduce.h:68