55
66#include < algorithm>
77#include < functional>
8+ #include < stdexcept>
9+
810#include " comm.h"
911
1012namespace xgboost ::collective {
@@ -18,14 +20,14 @@ class AllgatherFunctor {
1820 AllgatherFunctor (std::int32_t world_size, std::int32_t rank)
1921 : world_size_{world_size}, rank_{rank} {}
2022
21- void operator ()(char const * input, std::size_t bytes, std::string * buffer) const {
22- if (buffer->empty ()) {
23+ void operator ()(char const * input, std::size_t bytes, AlignedByteBuffer * buffer) const {
24+ if (buffer->Empty ()) {
2325 // Resize the buffer if this is the first request.
24- buffer->resize (bytes * world_size_);
26+ buffer->Resize (bytes * world_size_);
2527 }
2628
2729 // Splice the input into the common buffer.
28- buffer->replace (rank_ * bytes, bytes, input, bytes );
30+ buffer->Replace (rank_ * bytes, bytes, input);
2931 }
3032
3133 private:
@@ -44,11 +46,11 @@ class AllgatherVFunctor {
4446 std::map<std::size_t , std::string_view>* data)
4547 : world_size_{world_size}, rank_{rank}, data_{data} {}
4648
47- void operator ()(char const * input, std::size_t bytes, std::string * buffer) const {
49+ void operator ()(char const * input, std::size_t bytes, AlignedByteBuffer * buffer) const {
4850 data_->emplace (rank_, std::string_view{input, bytes});
4951 if (data_->size () == static_cast <std::size_t >(world_size_)) {
5052 for (auto const & kv : *data_) {
51- buffer->append (kv.second );
53+ buffer->Append (kv.second );
5254 }
5355 data_->clear ();
5456 }
@@ -70,14 +72,16 @@ class AllreduceFunctor {
7072 AllreduceFunctor (ArrayInterfaceHandler::Type dataType, Op operation)
7173 : data_type_{dataType}, operation_{operation} {}
7274
73- void operator ()(char const * input, std::size_t bytes, std::string * buffer) const {
74- if (buffer->empty ()) {
75+ void operator ()(char const * input, std::size_t bytes, AlignedByteBuffer * buffer) const {
76+ if (buffer->Empty ()) {
7577 // Copy the input if this is the first request.
76- buffer->assign (input, bytes);
78+ buffer->Assign (input, bytes);
7779 } else {
7880 auto n_bytes_type = DispatchDType (data_type_, [](auto t) { return sizeof (t); });
81+ CHECK_EQ (bytes % n_bytes_type, 0 ) << " Input size is not a multiple of its element size." ;
82+ CHECK_EQ (buffer->Size (), bytes) << " Input size differs across allreduce calls." ;
7983 // Apply the reduce_operation to the input and the buffer.
80- Accumulate (input, bytes / n_bytes_type, & buffer-> front () );
84+ Accumulate (input, bytes, buffer);
8185 }
8286 }
8387
@@ -128,39 +132,41 @@ class AllreduceFunctor {
128132 }
129133 }
130134
131- void Accumulate (char const * input, std::size_t size, char * buffer) const {
135+ void Accumulate (char const * input, std::size_t bytes, AlignedByteBuffer * buffer) const {
132136 using Type = ArrayInterfaceHandler::Type;
137+ auto data = buffer->Data ();
138+ auto size = bytes / DispatchDType (data_type_, [](auto t) { return sizeof (t); });
133139 switch (data_type_) {
134140 case Type::kI1 :
135- Accumulate (reinterpret_cast <std::int8_t *>(buffer ),
141+ Accumulate (reinterpret_cast <std::int8_t *>(data ),
136142 reinterpret_cast <std::int8_t const *>(input), size, operation_);
137143 break ;
138144 case Type::kU1 :
139- Accumulate (reinterpret_cast <std::uint8_t *>(buffer ),
145+ Accumulate (reinterpret_cast <std::uint8_t *>(data ),
140146 reinterpret_cast <std::uint8_t const *>(input), size, operation_);
141147 break ;
142148 case Type::kI4 :
143- Accumulate (reinterpret_cast <std::int32_t *>(buffer ),
149+ Accumulate (reinterpret_cast <std::int32_t *>(data ),
144150 reinterpret_cast <std::int32_t const *>(input), size, operation_);
145151 break ;
146152 case Type::kU4 :
147- Accumulate (reinterpret_cast <std::uint32_t *>(buffer ),
153+ Accumulate (reinterpret_cast <std::uint32_t *>(data ),
148154 reinterpret_cast <std::uint32_t const *>(input), size, operation_);
149155 break ;
150156 case Type::kI8 :
151- Accumulate (reinterpret_cast <std::int64_t *>(buffer ),
157+ Accumulate (reinterpret_cast <std::int64_t *>(data ),
152158 reinterpret_cast <std::int64_t const *>(input), size, operation_);
153159 break ;
154160 case Type::kU8 :
155- Accumulate (reinterpret_cast <std::uint64_t *>(buffer ),
161+ Accumulate (reinterpret_cast <std::uint64_t *>(data ),
156162 reinterpret_cast <std::uint64_t const *>(input), size, operation_);
157163 break ;
158164 case Type::kF4 :
159- Accumulate (reinterpret_cast <float *>(buffer ), reinterpret_cast <float const *>(input), size,
165+ Accumulate (reinterpret_cast <float *>(data ), reinterpret_cast <float const *>(input), size,
160166 operation_);
161167 break ;
162168 case Type::kF8 :
163- Accumulate (reinterpret_cast <double *>(buffer ), reinterpret_cast <double const *>(input), size,
169+ Accumulate (reinterpret_cast <double *>(data ), reinterpret_cast <double const *>(input), size,
164170 operation_);
165171 break ;
166172 default :
@@ -182,10 +188,10 @@ class BroadcastFunctor {
182188
183189 BroadcastFunctor (std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
184190
185- void operator ()(char const * input, std::size_t bytes, std::string * buffer) const {
191+ void operator ()(char const * input, std::size_t bytes, AlignedByteBuffer * buffer) const {
186192 if (rank_ == root_) {
187193 // Copy the input if this is the root.
188- buffer->assign (input, bytes);
194+ buffer->Assign (input, bytes);
189195 }
190196 }
191197
@@ -246,9 +252,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
246252 HandlerFunctor const & functor) {
247253 // Pass through if there is only 1 client.
248254 if (world_size_ == 1 ) {
249- if (input != output->data ()) {
250- output->assign (input, bytes);
251- }
255+ output->assign (input, bytes);
252256 return ;
253257 }
254258
@@ -263,7 +267,7 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
263267
264268 if (received_ == world_size_) {
265269 LOG (DEBUG ) << functor.name << " rank " << rank << " : all requests received" ;
266- output->assign (buffer_);
270+ output->assign (buffer_. Data (), buffer_. Size () );
267271 sent_++;
268272 lock.unlock ();
269273 cv_.notify_all ();
@@ -274,14 +278,14 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
274278 cv_.wait (lock, [this ] { return received_ == world_size_; });
275279
276280 LOG (DEBUG ) << functor.name << " rank " << rank << " : sending reply" ;
277- output->assign (buffer_);
281+ output->assign (buffer_. Data (), buffer_. Size () );
278282 sent_++;
279283
280284 if (sent_ == world_size_) {
281285 LOG (DEBUG ) << functor.name << " rank " << rank << " : all replies sent" ;
282286 sent_ = 0 ;
283287 received_ = 0 ;
284- buffer_.clear ();
288+ buffer_.Clear ();
285289 sequence_number_++;
286290 lock.unlock ();
287291 cv_.notify_all ();
0 commit comments