Skip to content
61 changes: 44 additions & 17 deletions gloo/allreduce_shm.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "gloo/allreduce_shm.h"
#include "gloo/types.h"

#include <assert.h>
#include <errno.h>
Expand Down Expand Up @@ -444,11 +445,6 @@ AllreduceSharedMemoryData::~AllreduceSharedMemoryData() {

void shm(const detail::AllreduceOptionsImpl& opts) {
const auto& context = opts.context;
if (context->shmData == nullptr) {
context->shmData = std::make_shared<AllreduceSharedMemoryData>(
context->rank, context->size);
context->shmData->initialize();
}
const size_t data_size = opts.elements * opts.elementSize;
auto& in = opts.in;
auto& out = opts.out;
Expand Down Expand Up @@ -485,20 +481,51 @@ void shm(const detail::AllreduceOptionsImpl& opts) {
}

void* data = out[0].get()->ptr;
auto tag = opts.tag;
std::unique_ptr<transport::UnboundBuffer> tagBuffer =
context->createUnboundBuffer(&tag, sizeof(tag));
transport::UnboundBuffer* tag_ptr = tagBuffer.get();
const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);

{
// Use mutex to make context->shmData thread safe.
std::unique_lock<std::mutex> lock(context->shmDataMutex);

if (context->shmData == nullptr) {
context->shmData = std::make_shared<AllreduceSharedMemoryData>(
context->rank, context->size);
context->shmData->initialize();
}

for (int offset = 0; offset < data_size;
offset += Allreduceworkspace::MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data) + offset);
size_t chunk_size = data_size - offset > Allreduceworkspace::MAX_BUF_SIZE
? Allreduceworkspace::MAX_BUF_SIZE
: data_size - offset;
size_t chunk_el = chunk_size / (data_size / opts.elements);
if (chunk_size < Allreduceworkspace::NAIVE_ALLREDUCE_THRESHOLD) {
symmetric_naive_all_reduce(
data_ptr, opts.elementSize, chunk_size, chunk_el, opts);
// In async mode there may be many allreduce ops executing simultaneously.
// However shmData is expected to occupied exclusively. We use unique tag to
// do synchronization among different ranks.
if (context->rank == 0) {
for (int i = 1; i < context->size; i++) {
tag_ptr->send(i, slot);
tag_ptr->waitSend();
}
} else {
distributed_naive_reduce(
data_ptr, opts.elementSize, chunk_size, chunk_el, opts);
lock.unlock();
tag_ptr->recv(0, slot);
tag_ptr->waitRecv();
lock.lock();
}

for (int offset = 0; offset < data_size;
offset += Allreduceworkspace::MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data) + offset);
size_t chunk_size = data_size - offset > Allreduceworkspace::MAX_BUF_SIZE
? Allreduceworkspace::MAX_BUF_SIZE
: data_size - offset;
size_t chunk_el = chunk_size / (data_size / opts.elements);
if (chunk_size < Allreduceworkspace::NAIVE_ALLREDUCE_THRESHOLD) {
symmetric_naive_all_reduce(
data_ptr, opts.elementSize, chunk_size, chunk_el, opts);
} else {
distributed_naive_reduce(
data_ptr, opts.elementSize, chunk_size, chunk_el, opts);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions gloo/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Context {

std::shared_ptr<AllreduceSharedMemoryData> shmData;

std::mutex shmDataMutex;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if there's multiple gloo process groups? Does that cause issues at all?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can we put this under shmData?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. For multiple process groups scenario, I used gloo context's address to generate unique ID for shm name. In that case different group will use different shm buffer to do allreduce op. I've verified with a test with multiple process groups with pytorch and it passed.
  2. I think we could not put this under shmData. In the first run shmData is not initialized(nullptr), if there are multiple threads reaching this point, we need to ensure the initialization work is done only by one thread here.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could not put this under shmData. In the first run shmData is not initialized(nullptr), if there are multiple threads reaching this point, we need to ensure the initialization work is done only by one thread here.

We could potentially move this to allreduce_shm.cc and make both static? that way we keep global context clean of shm specifics. wdyt?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works only on multi-processes scenario if we make them static. However, in multi-threads scenario like gloo unit test, each thread will represent a rank and the shm_data is only initialized once, which is not as expected.
Although there is a keywork thread_local to make static variable unique among threads, it may cause performance issue. In real workload such as Pytorch the calling sequence is more like:

context init -> call allreduce -> call allreduce -> call allreduce

thread_local will make shm data initialized every time calling allreduce. The initialization is very expensive as it will allocate shm buffer.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thread_local will make shm data initialized every time calling allreduce

is that correct? I think it will be initialized only the first time but I see your point. Okay I think we can at least wrap this in the macro you created.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've wrapped shm_data's declaration in the macro I created.


std::shared_ptr<transport::Device>& getDevice();

std::unique_ptr<transport::Pair>& getPair(int i);
Expand Down