@@ -248,39 +248,40 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
248248 }
249249}
250250
251- std::vector<std::shared_ptr<Tensor>>
252- ProcessGroup::BroadCast ( const std::vector<std::shared_ptr<Tensor>> &input_tensors ) const {
251+ std::vector<std::shared_ptr<Tensor>> ProcessGroup::BroadCast ( const std::vector<std::shared_ptr<Tensor>> &input_tensors,
252+ int root_group_rank ) const {
253253 std::vector<std::shared_ptr<Tensor>> outputs;
254254 std::vector<core::Stream *> streams;
255255 std::vector<core::CclComm *> comms;
256- std::vector<Device> devices;
257256
258- CHECK_EQ (world_size_, comms_. size ()) ;
259- for ( size_t i = 0 ; i < world_size_; ++i) {
260- auto device = devices_[i];
257+ // Only iterate over this process's devices (in single-process mode this equals world_size_ ;
258+ // in multi-process mode it is a strict subset).
259+ for ( const auto & device : devices_) {
261260 for (const auto &input_tensor : input_tensors) {
262261 outputs.push_back (std::make_shared<Tensor>(input_tensor->Dims (), input_tensor->Dtype (), device));
263262 }
264- devices.push_back (device);
265263 streams.push_back (runtime_impl_->GetStream (device));
266264 comms.push_back (device_comm_map_.at (device.index ()));
267265 }
268266
269- int root = -1 ;
270- for (size_t i = 0 ; i < devices.size (); ++i) {
271- if (devices[i] == input_tensors[0 ]->GetDevice ()) {
272- root = static_cast <int >(i);
273- break ;
274- }
267+ // Determine NCCL root (= group rank of the source). In single-process mode the caller may
268+ // omit it and we infer from input_tensors[0]->GetDevice(); in multi-process mode the source
269+ // may not be on this process, so the caller must provide the group rank explicitly.
270+ int root = root_group_rank;
271+ if (root < 0 ) {
272+ auto it = global_group_rank_map_.find (input_tensors[0 ]->GetDevice ().Rank ().GlobalRank ());
273+ CHECK (it != global_group_rank_map_.end ())
274+ << " BroadCast: root device not found in group and root_group_rank was not provided" ;
275+ root = it->second ;
275276 }
276- CHECK_NE (root, -1 ) << " Root not found in input devices" ;
277277
278- core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
279- for (size_t i = 0 ; i < devices.size (); ++i) {
280- core::DeviceGuard guard (devices[i]);
278+ core::CclGroupGuard ccl_group_guard (devices_[0 ].type ());
279+ for (size_t i = 0 ; i < devices_.size (); ++i) {
280+ core::DeviceGuard guard (devices_[i]);
281+ const int local_group_rank = global_group_rank_map_.at (devices_[i].Rank ().GlobalRank ());
281282 for (size_t j = 0 ; j < input_tensors.size (); ++j) {
282283 const auto &input_tensor = input_tensors[j];
283- const void *send_buffer = (static_cast < int >(i) == root ? input_tensor->DataPtr () : nullptr );
284+ const void *send_buffer = (local_group_rank == root ? input_tensor->DataPtr () : nullptr );
284285 ccl_impl_->Broadcast (send_buffer, outputs[i * input_tensors.size () + j]->DataPtr (),
285286 input_tensor->NumElements (), input_tensor->Dtype (), root, comms[i], streams[i]);
286287 }
@@ -330,30 +331,169 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
330331}
331332
332333std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter (const std::shared_ptr<Tensor> &tensor,
333- std::vector<Device> devices, int64_t dim) const {
334+ std::vector<Device> devices, int64_t dim,
335+ int src_group_rank) const {
336+ CHECK_EQ (devices.size (), static_cast <size_t >(world_size_)) << " Scatter expects one device per group rank" ;
337+ CHECK_GT (devices.size (), 0 );
338+ CHECK (tensor != nullptr ) << " Scatter: tensor carrying full shape/dtype must be provided on every process" ;
339+
340+ // Resolve src rank: explicit overrides inference from tensor device.
341+ int src_rank = src_group_rank;
342+ if (src_rank < 0 ) {
343+ for (size_t i = 0 ; i < devices.size (); ++i) {
344+ if (tensor->GetDevice () == devices[i]) {
345+ src_rank = static_cast <int >(i);
346+ break ;
347+ }
348+ }
349+ CHECK_NE (src_rank, -1 ) << " Source device not found in input devices" ;
350+ }
351+ CHECK_GE (src_rank, 0 );
352+ CHECK_LT (src_rank, world_size_);
353+
354+ // Identify local group ranks (in the same order as devices_).
355+ std::vector<int > local_group_ranks;
356+ local_group_ranks.reserve (devices_.size ());
357+ for (const auto &d : devices_) { local_group_ranks.push_back (global_group_rank_map_.at (d.Rank ().GlobalRank ())); }
358+ const auto src_local_it = std::find (local_group_ranks.begin (), local_group_ranks.end (), src_rank);
359+ const bool src_is_local = src_local_it != local_group_ranks.end ();
360+
361+ // Source splits only when it owns the full tensor. Shard shape is identical for all ranks
362+ // when the dim is evenly divisible; we rely on that for preallocation on non-src processes.
363+ CHECK_EQ (tensor->Dims ()[dim] % static_cast <int64_t >(devices.size ()), 0 )
364+ << " Scatter: dim size must be divisible by world size" ;
365+ const int64_t shard_size = tensor->Dims ()[dim] / static_cast <int64_t >(devices.size ());
366+ std::vector<std::shared_ptr<Tensor>> split_tensors;
367+ if (src_is_local) {
368+ split_tensors = tensor->Split (shard_size, dim);
369+ CHECK_EQ (split_tensors.size (), devices.size ());
370+ }
371+
372+ std::vector<int64_t > shard_dims = tensor->Dims ();
373+ shard_dims[dim] = shard_size;
374+ const DataType shard_dtype = tensor->Dtype ();
375+
376+ // Preallocate output shards for this process's local devices.
334377 std::vector<std::shared_ptr<Tensor>> outputs;
335- auto split_tensors = tensor->Split (tensor->Dims ()[dim] / devices.size (), dim);
336- std::vector<core::Stream *> streams;
337- std::vector<core::CclComm *> comms;
338- int src_rank = -1 ;
378+ outputs.reserve (devices_.size ());
379+ for (const auto &d : devices_) { outputs.push_back (std::make_shared<Tensor>(shard_dims, shard_dtype, d)); }
339380
340- for (size_t i = 0 ; i < devices.size (); ++i) {
341- if (tensor->GetDevice () == devices[i]) {
342- src_rank = static_cast <int >(i);
381+ // Single-process mode: all devices live here, keep the symmetric Send/Recv loop for clarity.
382+ if (global::GetNnodes () == 1 && global::GetNprocPerNode () == 1 ) {
383+ std::vector<core::Stream *> streams;
384+ std::vector<core::CclComm *> comms;
385+ streams.reserve (devices.size ());
386+ comms.reserve (devices.size ());
387+ for (const auto &d : devices) {
388+ streams.push_back (runtime_impl_->GetStream (d));
389+ comms.push_back (device_comm_map_.at (d.index ()));
343390 }
344- outputs.push_back (std::make_shared<Tensor>(split_tensors[i]->Dims (), split_tensors[i]->Dtype (), devices[i]));
345- streams.push_back (runtime_impl_->GetStream (devices[i]));
346- comms.push_back (device_comm_map_.at (devices[i].index ()));
391+ core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
392+ for (size_t i = 0 ; i < devices.size (); ++i) {
393+ core::DeviceGuard guard (devices[i]);
394+ ccl_impl_->Send (split_tensors[i]->DataPtr (), split_tensors[i]->NumElements (), shard_dtype,
395+ static_cast <int >(i), comms[src_rank], streams[src_rank]);
396+ ccl_impl_->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), shard_dtype, src_rank, comms[i],
397+ streams[i]);
398+ }
399+ return outputs;
347400 }
348- CHECK_NE (src_rank, -1 ) << " Source device not found in input devices" ;
349401
350- core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
351- for (size_t i = 0 ; i < devices.size (); ++i) {
352- core::DeviceGuard guard (devices[i]);
353- ccl_impl_->Send (split_tensors[i]->DataPtr (), split_tensors[i]->NumElements (), tensor->Dtype (), i,
354- comms[src_rank], streams[src_rank]);
355- ccl_impl_->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), tensor->Dtype (), src_rank, comms[i],
356- streams[i]);
402+ // Multi-process mode: each process handles only its local device(s).
403+ core::CclGroupGuard ccl_group_guard (devices_[0 ].type ());
404+
405+ // Src issues a Send to every non-src group rank (including group ranks hosted in other processes).
406+ if (src_is_local) {
407+ const size_t src_local_idx = static_cast <size_t >(src_local_it - local_group_ranks.begin ());
408+ const auto &src_device = devices_[src_local_idx];
409+ core::DeviceGuard guard (src_device);
410+ auto *stream = runtime_impl_->GetStream (src_device);
411+ auto *comm = device_comm_map_.at (src_device.index ());
412+ for (int dst = 0 ; dst < world_size_; ++dst) {
413+ if (dst == src_rank) {
414+ continue ;
415+ }
416+ ccl_impl_->Send (split_tensors[dst]->DataPtr (), split_tensors[dst]->NumElements (), shard_dtype, dst, comm,
417+ stream);
418+ }
419+ }
420+
421+ // Every local device posts either a local copy (if it is src) or a Recv from src.
422+ for (size_t i = 0 ; i < devices_.size (); ++i) {
423+ const auto &local_device = devices_[i];
424+ const int local_rank = local_group_ranks[i];
425+ if (src_is_local && local_rank == src_rank) {
426+ outputs[i]->CopyFrom (split_tensors[src_rank]);
427+ continue ;
428+ }
429+ core::DeviceGuard guard (local_device);
430+ auto *stream = runtime_impl_->GetStream (local_device);
431+ auto *comm = device_comm_map_.at (local_device.index ());
432+ ccl_impl_->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), shard_dtype, src_rank, comm, stream);
433+ }
434+ return outputs;
435+ }
436+
437+ std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter (const std::shared_ptr<Tensor> &tensor, int64_t dim,
438+ int src_group_rank) const {
439+ CHECK (tensor != nullptr ) << " Scatter: tensor carrying full shape/dtype must be provided on every process" ;
440+ CHECK_GE (src_group_rank, 0 );
441+ CHECK_LT (src_group_rank, world_size_);
442+ CHECK_GT (devices_.size (), 0 );
443+ const int src_rank = src_group_rank;
444+
445+ // Identify local group ranks (in the same order as devices_).
446+ std::vector<int > local_group_ranks;
447+ local_group_ranks.reserve (devices_.size ());
448+ for (const auto &d : devices_) { local_group_ranks.push_back (global_group_rank_map_.at (d.Rank ().GlobalRank ())); }
449+ const auto src_local_it = std::find (local_group_ranks.begin (), local_group_ranks.end (), src_rank);
450+ const bool src_is_local = src_local_it != local_group_ranks.end ();
451+
452+ CHECK_EQ (tensor->Dims ()[dim] % static_cast <int64_t >(world_size_), 0 )
453+ << " Scatter: dim size must be divisible by world size" ;
454+ const int64_t shard_size = tensor->Dims ()[dim] / static_cast <int64_t >(world_size_);
455+ std::vector<std::shared_ptr<Tensor>> split_tensors;
456+ if (src_is_local) {
457+ split_tensors = tensor->Split (shard_size, dim);
458+ CHECK_EQ (split_tensors.size (), static_cast <size_t >(world_size_));
459+ }
460+
461+ std::vector<int64_t > shard_dims = tensor->Dims ();
462+ shard_dims[dim] = shard_size;
463+ const DataType shard_dtype = tensor->Dtype ();
464+
465+ std::vector<std::shared_ptr<Tensor>> outputs;
466+ outputs.reserve (devices_.size ());
467+ for (const auto &d : devices_) { outputs.push_back (std::make_shared<Tensor>(shard_dims, shard_dtype, d)); }
468+
469+ core::CclGroupGuard ccl_group_guard (devices_[0 ].type ());
470+
471+ if (src_is_local) {
472+ const size_t src_local_idx = static_cast <size_t >(src_local_it - local_group_ranks.begin ());
473+ const auto &src_device = devices_[src_local_idx];
474+ core::DeviceGuard guard (src_device);
475+ auto *stream = runtime_impl_->GetStream (src_device);
476+ auto *comm = device_comm_map_.at (src_device.index ());
477+ for (int dst = 0 ; dst < world_size_; ++dst) {
478+ if (dst == src_rank) {
479+ continue ;
480+ }
481+ ccl_impl_->Send (split_tensors[dst]->DataPtr (), split_tensors[dst]->NumElements (), shard_dtype, dst, comm,
482+ stream);
483+ }
484+ }
485+
486+ for (size_t i = 0 ; i < devices_.size (); ++i) {
487+ const auto &local_device = devices_[i];
488+ const int local_rank = local_group_ranks[i];
489+ if (src_is_local && local_rank == src_rank) {
490+ outputs[i]->CopyFrom (split_tensors[src_rank]);
491+ continue ;
492+ }
493+ core::DeviceGuard guard (local_device);
494+ auto *stream = runtime_impl_->GetStream (local_device);
495+ auto *comm = device_comm_map_.at (local_device.index ());
496+ ccl_impl_->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), shard_dtype, src_rank, comm, stream);
357497 }
358498 return outputs;
359499}
0 commit comments