@@ -248,45 +248,52 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
248248 }
249249}
250250
251- std::vector<std::shared_ptr<Tensor>>
252- ProcessGroup::BroadCast (const std::vector<std::shared_ptr<Tensor>> &input_tensors) const {
253- std::vector<std::shared_ptr<Tensor>> outputs;
251+ void ProcessGroup::BroadCast (const std::vector<std::shared_ptr<Tensor>> &tensors, int root_group_rank) const {
252+ CHECK_GT (tensors.size (), 0 );
253+ CHECK_GT (devices_.size (), 0 );
254+ CHECK_EQ (tensors.size () % devices_.size (), 0 )
255+ << " BroadCast: tensors must be grouped by local device with the same tensor count per device" ;
256+ const size_t num_tensors_per_device = tensors.size () / devices_.size ();
257+
254258 std::vector<core::Stream *> streams;
255259 std::vector<core::CclComm *> comms;
256- std::vector<Device> devices;
260+ std::vector<int > local_group_ranks;
261+ streams.reserve (devices_.size ());
262+ comms.reserve (devices_.size ());
263+ local_group_ranks.reserve (devices_.size ());
257264
258- CHECK_EQ (world_size_, comms_.size ());
259- for (size_t i = 0 ; i < world_size_; ++i) {
260- auto device = devices_[i];
261- for (const auto &input_tensor : input_tensors) {
262- outputs.push_back (std::make_shared<Tensor>(input_tensor->Dims (), input_tensor->Dtype (), device));
263- }
264- devices.push_back (device);
265+ for (const auto &device : devices_) {
265266 streams.push_back (runtime_impl_->GetStream (device));
266267 comms.push_back (device_comm_map_.at (device.index ()));
268+ local_group_ranks.push_back (global_group_rank_map_.at (device.Rank ().GlobalRank ()));
267269 }
268270
269- int root = -1 ;
270- for (size_t i = 0 ; i < devices.size (); ++i) {
271- if (devices[i] == input_tensors[0 ]->GetDevice ()) {
272- root = static_cast <int >(i);
273- break ;
274- }
271+ // Determine NCCL root (= group rank of the source). In single-process mode the caller may
272+ // omit it and we infer from tensors[0]->GetDevice(); in multi-process mode the source
273+ // may not be on this process, so the caller must provide the group rank explicitly.
274+ int root = root_group_rank;
275+ if (root < 0 ) {
276+ auto it = global_group_rank_map_.find (tensors[0 ]->GetDevice ().Rank ().GlobalRank ());
277+ CHECK (it != global_group_rank_map_.end ())
278+ << " BroadCast: root device not found in group and root_group_rank was not provided" ;
279+ root = it->second ;
275280 }
276- CHECK_NE (root, -1 ) << " Root not found in input devices" ;
277-
278- core::CclGroupGuard ccl_group_guard (devices[0 ].type ());
279- for (size_t i = 0 ; i < devices.size (); ++i) {
280- core::DeviceGuard guard (devices[i]);
281- for (size_t j = 0 ; j < input_tensors.size (); ++j) {
282- const auto &input_tensor = input_tensors[j];
283- const void *send_buffer = (static_cast <int >(i) == root ? input_tensor->DataPtr () : nullptr );
284- ccl_impl_->Broadcast (send_buffer, outputs[i * input_tensors.size () + j]->DataPtr (),
285- input_tensor->NumElements (), input_tensor->Dtype (), root, comms[i], streams[i]);
281+ CHECK_GE (root, 0 );
282+ CHECK_LT (root, world_size_);
283+
284+ core::CclGroupGuard ccl_group_guard (devices_[0 ].type ());
285+ for (size_t i = 0 ; i < devices_.size (); ++i) {
286+ core::DeviceGuard guard (devices_[i]);
287+ const int local_group_rank = local_group_ranks[i];
288+ for (size_t j = 0 ; j < num_tensors_per_device; ++j) {
289+ const auto &tensor = tensors[i * num_tensors_per_device + j];
290+ CHECK (tensor != nullptr ) << " BroadCast: null tensor" ;
291+ CHECK_EQ (tensor->GetDevice (), devices_[i]) << " BroadCast: tensors must match local device grouping" ;
292+ const void *send_buffer = (local_group_rank == root ? tensor->DataPtr () : nullptr );
293+ ccl_impl_->Broadcast (send_buffer, tensor->DataPtr (), tensor->NumElements (), tensor->Dtype (), root, comms[i],
294+ streams[i]);
286295 }
287296 }
288-
289- return outputs;
290297}
291298
292299std::vector<std::shared_ptr<Tensor>>
@@ -358,6 +365,72 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr
358365 return outputs;
359366}
360367
368+ void ProcessGroup::ScatterFromRank (const std::vector<std::shared_ptr<Tensor>> &outputs,
369+ const std::shared_ptr<Tensor> &tensor, int64_t dim, int src_group_rank) const {
370+ CHECK (tensor != nullptr ) << " Scatter: tensor carrying full shape/dtype must be provided on every process" ;
371+ CHECK_GE (src_group_rank, 0 );
372+ CHECK_LT (src_group_rank, world_size_);
373+ CHECK_GT (devices_.size (), 0 );
374+ CHECK_EQ (outputs.size (), devices_.size ()) << " ScatterFromRank: expects one output per local group device" ;
375+ const int src_rank = src_group_rank;
376+
377+ // Identify local group ranks (in the same order as devices_).
378+ std::vector<int > local_group_ranks;
379+ local_group_ranks.reserve (devices_.size ());
380+ for (const auto &d : devices_) { local_group_ranks.push_back (global_group_rank_map_.at (d.Rank ().GlobalRank ())); }
381+ const auto src_local_it = std::find (local_group_ranks.begin (), local_group_ranks.end (), src_rank);
382+ const bool src_is_local = src_local_it != local_group_ranks.end ();
383+
384+ CHECK_EQ (tensor->Dims ()[dim] % static_cast <int64_t >(world_size_), 0 )
385+ << " Scatter: dim size must be divisible by world size" ;
386+ const int64_t shard_size = tensor->Dims ()[dim] / static_cast <int64_t >(world_size_);
387+ std::vector<std::shared_ptr<Tensor>> split_tensors;
388+ if (src_is_local) {
389+ split_tensors = tensor->Split (shard_size, dim);
390+ CHECK_EQ (split_tensors.size (), static_cast <size_t >(world_size_));
391+ }
392+
393+ std::vector<int64_t > shard_dims = tensor->Dims ();
394+ shard_dims[dim] = shard_size;
395+ const DataType shard_dtype = tensor->Dtype ();
396+ for (size_t i = 0 ; i < outputs.size (); ++i) {
397+ CHECK (outputs[i] != nullptr ) << " ScatterFromRank: null output tensor" ;
398+ CHECK_EQ (outputs[i]->GetDevice (), devices_[i]) << " ScatterFromRank: output device mismatch" ;
399+ CHECK (outputs[i]->Dims () == shard_dims) << " ScatterFromRank: output shape mismatch" ;
400+ CHECK (outputs[i]->Dtype () == shard_dtype) << " ScatterFromRank: output dtype mismatch" ;
401+ }
402+
403+ core::CclGroupGuard ccl_group_guard (devices_[0 ].type ());
404+
405+ if (src_is_local) {
406+ const size_t src_local_idx = static_cast <size_t >(src_local_it - local_group_ranks.begin ());
407+ const auto &src_device = devices_[src_local_idx];
408+ core::DeviceGuard guard (src_device);
409+ auto *stream = runtime_impl_->GetStream (src_device);
410+ auto *comm = device_comm_map_.at (src_device.index ());
411+ for (int dst = 0 ; dst < world_size_; ++dst) {
412+ if (dst == src_rank) {
413+ continue ;
414+ }
415+ ccl_impl_->Send (split_tensors[dst]->DataPtr (), split_tensors[dst]->NumElements (), shard_dtype, dst, comm,
416+ stream);
417+ }
418+ }
419+
420+ for (size_t i = 0 ; i < devices_.size (); ++i) {
421+ const auto &local_device = devices_[i];
422+ const int local_rank = local_group_ranks[i];
423+ if (src_is_local && local_rank == src_rank) {
424+ outputs[i]->CopyFrom (split_tensors[src_rank]);
425+ continue ;
426+ }
427+ core::DeviceGuard guard (local_device);
428+ auto *stream = runtime_impl_->GetStream (local_device);
429+ auto *comm = device_comm_map_.at (local_device.index ());
430+ ccl_impl_->Recv (outputs[i]->DataPtr (), outputs[i]->NumElements (), shard_dtype, src_rank, comm, stream);
431+ }
432+ }
433+
361434std::shared_ptr<Tensor> ProcessGroup::Gather (const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
362435 int64_t dim) const {
363436 int64_t num_devices = static_cast <int64_t >(tensors.size ());
0 commit comments