Skip to content

Commit 0e5a84b

Browse files
authored
Fix alignment for in-memory allreduce buffer (#12082)
1 parent de523b2 commit 0e5a84b

2 files changed

Lines changed: 84 additions & 30 deletions

File tree

src/collective/in_memory_handler.cc

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include <algorithm>
77
#include <functional>
8+
#include <stdexcept>
9+
810
#include "comm.h"
911

1012
namespace xgboost::collective {
@@ -18,14 +20,14 @@ class AllgatherFunctor {
1820
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
1921
: world_size_{world_size}, rank_{rank} {}
2022

21-
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
22-
if (buffer->empty()) {
23+
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
24+
if (buffer->Empty()) {
2325
// Resize the buffer if this is the first request.
24-
buffer->resize(bytes * world_size_);
26+
buffer->Resize(bytes * world_size_);
2527
}
2628

2729
// Splice the input into the common buffer.
28-
buffer->replace(rank_ * bytes, bytes, input, bytes);
30+
buffer->Replace(rank_ * bytes, bytes, input);
2931
}
3032

3133
private:
@@ -44,11 +46,11 @@ class AllgatherVFunctor {
4446
std::map<std::size_t, std::string_view>* data)
4547
: world_size_{world_size}, rank_{rank}, data_{data} {}
4648

47-
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
49+
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
4850
data_->emplace(rank_, std::string_view{input, bytes});
4951
if (data_->size() == static_cast<std::size_t>(world_size_)) {
5052
for (auto const& kv : *data_) {
51-
buffer->append(kv.second);
53+
buffer->Append(kv.second);
5254
}
5355
data_->clear();
5456
}
@@ -70,14 +72,16 @@ class AllreduceFunctor {
7072
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
7173
: data_type_{dataType}, operation_{operation} {}
7274

73-
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
74-
if (buffer->empty()) {
75+
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
76+
if (buffer->Empty()) {
7577
// Copy the input if this is the first request.
76-
buffer->assign(input, bytes);
78+
buffer->Assign(input, bytes);
7779
} else {
7880
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
81+
CHECK_EQ(bytes % n_bytes_type, 0) << "Input size is not a multiple of its element size.";
82+
CHECK_EQ(buffer->Size(), bytes) << "Input size differs across allreduce calls.";
7983
// Apply the reduce_operation to the input and the buffer.
80-
Accumulate(input, bytes / n_bytes_type, &buffer->front());
84+
Accumulate(input, bytes, buffer);
8185
}
8286
}
8387

@@ -128,39 +132,41 @@ class AllreduceFunctor {
128132
}
129133
}
130134

131-
void Accumulate(char const* input, std::size_t size, char* buffer) const {
135+
void Accumulate(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
132136
using Type = ArrayInterfaceHandler::Type;
137+
auto data = buffer->Data();
138+
auto size = bytes / DispatchDType(data_type_, [](auto t) { return sizeof(t); });
133139
switch (data_type_) {
134140
case Type::kI1:
135-
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
141+
Accumulate(reinterpret_cast<std::int8_t*>(data),
136142
reinterpret_cast<std::int8_t const*>(input), size, operation_);
137143
break;
138144
case Type::kU1:
139-
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
145+
Accumulate(reinterpret_cast<std::uint8_t*>(data),
140146
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
141147
break;
142148
case Type::kI4:
143-
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
149+
Accumulate(reinterpret_cast<std::int32_t*>(data),
144150
reinterpret_cast<std::int32_t const*>(input), size, operation_);
145151
break;
146152
case Type::kU4:
147-
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
153+
Accumulate(reinterpret_cast<std::uint32_t*>(data),
148154
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
149155
break;
150156
case Type::kI8:
151-
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
157+
Accumulate(reinterpret_cast<std::int64_t*>(data),
152158
reinterpret_cast<std::int64_t const*>(input), size, operation_);
153159
break;
154160
case Type::kU8:
155-
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
161+
Accumulate(reinterpret_cast<std::uint64_t*>(data),
156162
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
157163
break;
158164
case Type::kF4:
159-
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
165+
Accumulate(reinterpret_cast<float*>(data), reinterpret_cast<float const*>(input), size,
160166
operation_);
161167
break;
162168
case Type::kF8:
163-
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
169+
Accumulate(reinterpret_cast<double*>(data), reinterpret_cast<double const*>(input), size,
164170
operation_);
165171
break;
166172
default:
@@ -182,10 +188,10 @@ class BroadcastFunctor {
182188

183189
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
184190

185-
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
191+
void operator()(char const* input, std::size_t bytes, AlignedByteBuffer* buffer) const {
186192
if (rank_ == root_) {
187193
// Copy the input if this is the root.
188-
buffer->assign(input, bytes);
194+
buffer->Assign(input, bytes);
189195
}
190196
}
191197

@@ -246,9 +252,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
246252
HandlerFunctor const& functor) {
247253
// Pass through if there is only 1 client.
248254
if (world_size_ == 1) {
249-
if (input != output->data()) {
250-
output->assign(input, bytes);
251-
}
255+
output->assign(input, bytes);
252256
return;
253257
}
254258

@@ -263,7 +267,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
263267

264268
if (received_ == world_size_) {
265269
LOG(DEBUG) << functor.name << " rank " << rank << ": all requests received";
266-
output->assign(buffer_);
270+
output->assign(buffer_.Data(), buffer_.Size());
267271
sent_++;
268272
lock.unlock();
269273
cv_.notify_all();
@@ -274,14 +278,14 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
274278
cv_.wait(lock, [this] { return received_ == world_size_; });
275279

276280
LOG(DEBUG) << functor.name << " rank " << rank << ": sending reply";
277-
output->assign(buffer_);
281+
output->assign(buffer_.Data(), buffer_.Size());
278282
sent_++;
279283

280284
if (sent_ == world_size_) {
281285
LOG(DEBUG) << functor.name << " rank " << rank << ": all replies sent";
282286
sent_ = 0;
283287
received_ = 0;
284-
buffer_.clear();
288+
buffer_.Clear();
285289
sequence_number_++;
286290
lock.unlock();
287291
cv_.notify_all();

src/collective/in_memory_handler.h

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,63 @@
33
*/
44
#pragma once
55
#include <condition_variable>
6+
#include <cstddef>
7+
#include <cstring>
68
#include <map>
79
#include <string>
10+
#include <vector>
811

912
#include "../data/array_interface.h"
1013
#include "comm.h"
1114

1215
namespace xgboost::collective {
16+
class AlignedByteBuffer {
17+
using StorageT = std::max_align_t;
18+
19+
public:
20+
[[nodiscard]] bool Empty() const { return size_ == 0; }
21+
[[nodiscard]] std::size_t Size() const { return size_; }
22+
23+
[[nodiscard]] char* Data() { return reinterpret_cast<char*>(storage_.data()); }
24+
[[nodiscard]] char const* Data() const { return reinterpret_cast<char const*>(storage_.data()); }
25+
26+
void Clear() {
27+
storage_.clear();
28+
size_ = 0;
29+
}
30+
31+
void Resize(std::size_t n_bytes) {
32+
storage_.resize((n_bytes + sizeof(StorageT) - 1) / sizeof(StorageT));
33+
size_ = n_bytes;
34+
}
35+
36+
void Assign(char const* input, std::size_t n_bytes) {
37+
this->Resize(n_bytes);
38+
if (n_bytes != 0) {
39+
std::memcpy(this->Data(), input, n_bytes);
40+
}
41+
}
42+
43+
void Replace(std::size_t pos, std::size_t n_bytes, char const* input) {
44+
CHECK_LE(pos + n_bytes, size_);
45+
if (n_bytes != 0) {
46+
std::memcpy(this->Data() + pos, input, n_bytes);
47+
}
48+
}
49+
50+
void Append(std::string_view data) {
51+
auto old_size = size_;
52+
this->Resize(size_ + data.size());
53+
if (!data.empty()) {
54+
std::memcpy(this->Data() + old_size, data.data(), data.size());
55+
}
56+
}
57+
58+
private:
59+
std::vector<StorageT> storage_{};
60+
std::size_t size_{0};
61+
};
62+
1363
/**
1464
* @brief Handles collective communication primitives in memory.
1565
*
@@ -116,10 +166,10 @@ class InMemoryHandler {
116166
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
117167
std::int32_t rank, HandlerFunctor const& functor);
118168

119-
std::int32_t world_size_{}; /// Number of workers.
169+
std::int32_t world_size_{}; /// Number of workers.
120170
std::int64_t received_{}; /// Number of calls received with the current sequence.
121-
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
122-
std::string buffer_{}; /// A shared common buffer.
171+
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
172+
AlignedByteBuffer buffer_{}; /// A shared common buffer.
123173
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
124174
uint64_t sequence_number_{}; /// Call sequence number.
125175
mutable std::mutex mutex_; /// Lock.

0 commit comments

Comments
 (0)