Vowpal Wabbit
allreduce_threads.cc
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 /*
7 This implements the allreduce function using threads.
8 */
9 #include "allreduce.h"
10 #include <future>
11 
12 AllReduceSync::AllReduceSync(const size_t total) : m_total(total), m_count(0), m_run(true)
13 {
14  m_mutex = new std::mutex;
15  m_cv = new std::condition_variable;
16  buffers = new void*[total];
17 }
18 
20 {
21  delete m_mutex;
22  delete m_cv;
23  delete[] buffers;
24 }
25 
27 {
28  std::unique_lock<std::mutex> l(*m_mutex);
29  m_count++;
30 
31  if (m_count >= m_total)
32  {
33  assert(m_count == m_total);
34 
35  m_cv->notify_all();
36 
37  // order of m_count before or after notify_all doesn't matter
38  // since the lock is still hold at this point in time.
39  m_count = 0;
40 
41  // flip for the next run
42  m_run = !m_run;
43  }
44  else
45  {
46  bool current_run = m_run;
47  // this predicate cannot depend on m_count, as somebody can race ahead and m_count++
48  // FYI just wait can spuriously wake-up
49  m_cv->wait(l, [this, current_run] { return m_run != current_run; });
50  }
51 }
52 
53 AllReduceThreads::AllReduceThreads(AllReduceThreads* root, const size_t ptotal, const size_t pnode, bool pquiet)
54  : AllReduce(ptotal, pnode, pquiet), m_sync(root->m_sync), m_syncOwner(false)
55 {
56 }
57 
58 AllReduceThreads::AllReduceThreads(const size_t ptotal, const size_t pnode, bool pquiet)
59  : AllReduce(ptotal, pnode, pquiet), m_sync(new AllReduceSync(ptotal)), m_syncOwner(true)
60 {
61 }
62 
64 {
65  if (m_syncOwner)
66  {
67  delete m_sync;
68  }
69 }
virtual ~AllReduceThreads()
AllReduceThreads(AllReduceThreads *root, const size_t ptotal, const size_t pnode, bool quiet=false)
AllReduceSync(const size_t total)
void waitForSynchronization()
uint32_t m_count
Definition: allreduce.h:108
std::mutex * m_mutex
Definition: allreduce.h:101
AllReduceSync * m_sync
Definition: allreduce.h:126
void ** buffers
Definition: allreduce.h:120
size_t m_total
Definition: allreduce.h:105
std::condition_variable * m_cv
Definition: allreduce.h:102