17 typedef unsigned int uint32_t;
18 typedef unsigned short uint16_t;
19 typedef int socklen_t;
21 #define CLOSESOCK closesocket 28 class condition_variable;
33 #include <sys/socket.h> 34 #include <sys/socket.h> 35 #include <netinet/in.h> 36 #include <netinet/tcp.h> 43 #define CLOSESOCK close 58 if (current_master !=
"")
62 if (children[0] != -1)
64 if (children[1] != -1)
71 template <
class T,
void (*f)(T&, const T&)>
72 void addbufs(T* buf1,
const T* buf2,
const size_t n)
74 for (
size_t i = 0; i < n; i++)
f(buf1[i], buf2[i]);
84 AllReduce(
size_t ptotal,
const size_t pnode,
bool pquiet =
false) : total(ptotal), node(pnode), quiet(pquiet)
118 void waitForSynchronization();
132 AllReduceThreads(
const size_t ptotal,
const size_t pnode,
bool quiet =
false);
136 template <
class T,
void (*f)(T&, const T&)>
139 T** buffers = (T**)m_sync->
buffers;
140 buffers[
node] = buffer;
143 size_t blockSize = n / total;
161 index =
node * blockSize;
162 end =
node == total - 1 ? n : (
node + 1) * blockSize;
165 for (; index < end; index++)
167 T& first = buffers[0][index];
169 for (
size_t i = 1; i < total; i++)
f(first, buffers[i][index]);
172 for (
size_t i = 1; i < total; i++) buffers[i][index] = first;
187 void all_reduce_init();
190 void pass_up(
char* buffer,
size_t left_read_pos,
size_t right_read_pos,
size_t& parent_sent_pos)
193 std::min(
ar_buf_size, std::min(left_read_pos, right_read_pos) /
sizeof(T) *
sizeof(T) - parent_sent_pos);
197 int write_size = send(socks.
parent, buffer + parent_sent_pos, (
int)my_bufsize, 0);
199 THROW(
"Write to parent failed " << my_bufsize <<
" " << write_size <<
" " << parent_sent_pos <<
" " 200 << left_read_pos <<
" " << right_read_pos);
202 parent_sent_pos += write_size;
206 template <
class T,
void (*f)(T&, const T&)>
207 void reduce(
char* buffer,
const size_t n)
217 size_t child_read_pos[2] = {0, 0};
218 int child_unprocessed[2] = {0, 0};
219 char child_read_buf[2][
ar_buf_size +
sizeof(T) - 1];
220 size_t parent_sent_pos = 0;
226 child_read_pos[0] = n;
230 child_read_pos[1] = n;
233 while (parent_sent_pos < n || child_read_pos[0] < n || child_read_pos[1] < n)
236 pass_up<T>(buffer, child_read_pos[0], child_read_pos[1], parent_sent_pos);
238 if (parent_sent_pos >= n && child_read_pos[0] >= n && child_read_pos[1] >= n)
241 if (child_read_pos[0] < n || child_read_pos[1] < n)
243 if (max_fd > 0 && select((
int)max_fd, &fds,
nullptr,
nullptr,
nullptr) == -1)
246 for (
int i = 0; i < 2; i++)
250 if (child_read_pos[i] == n)
251 THROW(
"I think child has no data to send but he thinks he has " 252 << FD_ISSET(socks.
children[0], &fds) <<
" " << FD_ISSET(socks.
children[1], &fds));
254 size_t count = std::min(
ar_buf_size, n - child_read_pos[i]);
255 int read_size = recv(socks.
children[i], &child_read_buf[i][child_unprocessed[i]], (
int)count, 0);
259 addbufs<T, f>((T*)buffer + child_read_pos[i] /
sizeof(T), (T*)child_read_buf[i],
260 (child_read_pos[i] + read_size) /
sizeof(T) - child_read_pos[i] /
sizeof(T));
262 child_read_pos[i] += read_size;
263 int old_unprocessed = child_unprocessed[i];
264 child_unprocessed[i] = child_read_pos[i] % (int)
sizeof(T);
265 for (
int j = 0; j < child_unprocessed[i]; j++)
267 child_read_buf[i][j] =
268 child_read_buf[i][((old_unprocessed + read_size) / (
int)
sizeof(T)) *
sizeof(T) + j];
271 if (child_read_pos[i] == n)
274 else if (socks.
children[i] != -1 && child_read_pos[i] != n)
278 if (socks.
parent == -1 && child_read_pos[0] == n && child_read_pos[1] == n)
283 void pass_down(
char* buffer,
const size_t parent_read_pos,
size_t& children_sent_pos);
284 void broadcast(
char* buffer,
const size_t n);
286 socket_t sock_connect(
const uint32_t ip,
const int port);
290 AllReduceSockets(std::string pspan_server,
const int pport,
const size_t punique_id,
size_t ptotal,
291 const size_t pnode,
bool pquiet)
292 :
AllReduce(ptotal, pnode, pquiet), span_server(pspan_server), port(pport), unique_id(punique_id)
298 template <
class T,
void (*f)(T&, const T&)>
303 reduce<T, f>((
char*)buffer, n *
sizeof(T));
304 broadcast((
char*)buffer, n *
sizeof(T));
constexpr size_t ar_buf_size
AllReduce(size_t ptotal, const size_t pnode, bool pquiet=false)
void pass_up(char *buffer, size_t left_read_pos, size_t right_read_pos, size_t &parent_sent_pos)
void addbufs(T *buf1, const T *buf2, const size_t n)
AllReduceSockets(std::string pspan_server, const int pport, const size_t punique_id, size_t ptotal, const size_t pnode, bool pquiet)
std::string current_master
void reduce(char *buffer, const size_t n)
void waitForSynchronization()
void all_reduce(T *buffer, const size_t n)
std::condition_variable * m_cv
bool children(log_multi &b, uint32_t ¤t, uint32_t &class_index, uint32_t label)
void all_reduce(T *buffer, const size_t n)