Skip to content

Commit f918ae3

Browse files
committed
refactor: in-place BroadCast + ScatterFromRank for TP-aware init
Replace BroadCast's allocate-then-return signature with an in-place form (void return) that takes pre-grouped tensors per local device. Lets root ranks broadcast directly out of the source tensor with no self-copy and no extra allocation. Add ScatterFromRank as the multi-process counterpart to Scatter for the same reason. Use both in LoRA*ParallelLinear so TP rank-0 init no longer pays a tp_size-fold communication or scratch cost.
1 parent c614ec6 commit f918ae3

4 files changed

Lines changed: 154 additions & 41 deletions

File tree

infini_train/include/nn/parallel/process_group.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,21 @@ class ProcessGroup {
5959
bool async_op = false) const;
6060

6161
// Legacy communication APIs (Single-stream)
62-
virtual std::vector<std::shared_ptr<Tensor>>
63-
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
62+
// In-place broadcast of tensors grouped as [device0 tensors..., device1 tensors...].
63+
// Pass root_group_rank in multi-process mode; -1 infers it from tensors[0].
64+
virtual void BroadCast(const std::vector<std::shared_ptr<Tensor>> &tensors, int root_group_rank = -1) const;
6465

6566
virtual std::vector<std::shared_ptr<Tensor>>
6667
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, Device destination) const;
6768

6869
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
6970
std::vector<Device> devices, int64_t dim) const;
7071

72+
// Multi-process-friendly in-place scatter. Outputs are this process's local shard(s);
73+
// full tensor data is read only on src_group_rank.
74+
virtual void ScatterFromRank(const std::vector<std::shared_ptr<Tensor>> &outputs,
75+
const std::shared_ptr<Tensor> &tensor, int64_t dim, int src_group_rank) const;
76+
7177
virtual std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
7278
int64_t dim) const;
7379

infini_train/src/autograd/comm.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,21 @@ std::vector<std::shared_ptr<Tensor>> Broadcast::Forward(const std::vector<std::s
8282
<< "Broadcast function not implemented for tensors on different device type";
8383
}
8484

85+
std::vector<std::shared_ptr<Tensor>> outputs;
86+
outputs.reserve(target_gpus_.size() * input_tensors.size());
87+
for (const auto &device : target_gpus_) {
88+
for (const auto &tensor : input_tensors) {
89+
if (device == input_device_) {
90+
outputs.push_back(tensor);
91+
} else {
92+
outputs.push_back(std::make_shared<Tensor>(tensor->Dims(), tensor->Dtype(), device));
93+
}
94+
}
95+
}
96+
8597
// TODO(dcj): mark non differentiable
86-
return pg_->BroadCast(input_tensors);
98+
pg_->BroadCast(outputs, pg_->GetGroupRank(input_device_.Rank().GlobalRank()));
99+
return outputs;
87100
}
88101

89102
void Broadcast::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,

infini_train/src/nn/lora/lora_parallel_linear.cc

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,15 @@ void LoRAColumnParallelLinear::InitLoRAWeights() {
102102
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
103103
const int tp_rank = tp_group->GetGroupRank(global_rank);
104104

105-
// Only TP rank 0 generates random values; others zero-init.
106-
// AllReduce(sum) then broadcasts rank-0's values to all TP ranks.
105+
// TP rank 0 generates random values; broadcast replicates to other ranks in-place.
107106
if (tp_rank == 0) {
108107
if (config_.use_kaiming_a) {
109108
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
110109
} else {
111110
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
112111
}
113-
} else {
114-
init::Zeros(parameters_[kParamLoraAName]);
115112
}
116-
tp_group->AllReduce(parameters_[kParamLoraAName]);
113+
tp_group->BroadCast({parameters_[kParamLoraAName]}, /*root_group_rank=*/0);
117114
} else {
118115
if (config_.use_kaiming_a) {
119116
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
@@ -303,17 +300,41 @@ void LoRARowParallelLinear::InitLoRAWeights() {
303300
// lora_B: [out_features, rank] - replicated
304301

305302
// lora_A: [rank, in_features_per_partition]
303+
// TP rank 0 generates full [lora_rank, in_features], broadcasts to all TP ranks,
304+
// then each rank slices its own shard along dim=1.
306305
parameters_[kParamLoraAName]
307306
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_per_partition_}, DataType::kFLOAT32,
308307
device_)
309308
->RequiresGrad();
310-
if (config_.use_kaiming_a) {
311-
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
309+
310+
if (parallel::global::GetTensorParallelSize() > 1) {
311+
const auto global_rank = device_.Rank().GlobalRank();
312+
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
313+
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
314+
const int tp_rank = tp_group->GetGroupRank(global_rank);
315+
const int tp_size = parallel::global::GetTensorParallelSize();
316+
317+
// TP rank 0 generates full [lora_rank, in_features]; scatter shards along dim=1 to all ranks.
318+
// Non-src processes pass a tensor carrying only shape/dtype (contents unread).
319+
auto full_lora_A = std::make_shared<Tensor>(
320+
std::vector<int64_t>{config_.rank, in_features_per_partition_ * tp_size}, DataType::kFLOAT32, device_);
321+
if (tp_rank == 0) {
322+
if (config_.use_kaiming_a) {
323+
init::KaimingUniform(full_lora_A, config_.kaiming_a_param);
324+
} else {
325+
init::Normal(full_lora_A, 0.0f, 0.02f);
326+
}
327+
}
328+
tp_group->ScatterFromRank({parameters_[kParamLoraAName]}, full_lora_A, /*dim=*/1, /*src_group_rank=*/0);
312329
} else {
313-
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
330+
if (config_.use_kaiming_a) {
331+
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
332+
} else {
333+
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
334+
}
314335
}
315336

316-
// lora_B: [out_features, rank]
337+
// lora_B: [out_features, rank] - replicated, zeros
317338
parameters_[kParamLoraBName]
318339
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_, config_.rank}, DataType::kFLOAT32, device_)
319340
->RequiresGrad();

infini_train/src/nn/parallel/process_group.cc

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -248,45 +248,52 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
248248
}
249249
}
250250

251-
std::vector<std::shared_ptr<Tensor>>
252-
ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const {
253-
std::vector<std::shared_ptr<Tensor>> outputs;
251+
void ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &tensors, int root_group_rank) const {
252+
CHECK_GT(tensors.size(), 0);
253+
CHECK_GT(devices_.size(), 0);
254+
CHECK_EQ(tensors.size() % devices_.size(), 0)
255+
<< "BroadCast: tensors must be grouped by local device with the same tensor count per device";
256+
const size_t num_tensors_per_device = tensors.size() / devices_.size();
257+
254258
std::vector<core::Stream *> streams;
255259
std::vector<core::CclComm *> comms;
256-
std::vector<Device> devices;
260+
std::vector<int> local_group_ranks;
261+
streams.reserve(devices_.size());
262+
comms.reserve(devices_.size());
263+
local_group_ranks.reserve(devices_.size());
257264

258-
CHECK_EQ(world_size_, comms_.size());
259-
for (size_t i = 0; i < world_size_; ++i) {
260-
auto device = devices_[i];
261-
for (const auto &input_tensor : input_tensors) {
262-
outputs.push_back(std::make_shared<Tensor>(input_tensor->Dims(), input_tensor->Dtype(), device));
263-
}
264-
devices.push_back(device);
265+
for (const auto &device : devices_) {
265266
streams.push_back(runtime_impl_->GetStream(device));
266267
comms.push_back(device_comm_map_.at(device.index()));
268+
local_group_ranks.push_back(global_group_rank_map_.at(device.Rank().GlobalRank()));
267269
}
268270

269-
int root = -1;
270-
for (size_t i = 0; i < devices.size(); ++i) {
271-
if (devices[i] == input_tensors[0]->GetDevice()) {
272-
root = static_cast<int>(i);
273-
break;
274-
}
271+
// Determine NCCL root (= group rank of the source). In single-process mode the caller may
272+
// omit it and we infer from tensors[0]->GetDevice(); in multi-process mode the source
273+
// may not be on this process, so the caller must provide the group rank explicitly.
274+
int root = root_group_rank;
275+
if (root < 0) {
276+
auto it = global_group_rank_map_.find(tensors[0]->GetDevice().Rank().GlobalRank());
277+
CHECK(it != global_group_rank_map_.end())
278+
<< "BroadCast: root device not found in group and root_group_rank was not provided";
279+
root = it->second;
275280
}
276-
CHECK_NE(root, -1) << "Root not found in input devices";
277-
278-
core::CclGroupGuard ccl_group_guard(devices[0].type());
279-
for (size_t i = 0; i < devices.size(); ++i) {
280-
core::DeviceGuard guard(devices[i]);
281-
for (size_t j = 0; j < input_tensors.size(); ++j) {
282-
const auto &input_tensor = input_tensors[j];
283-
const void *send_buffer = (static_cast<int>(i) == root ? input_tensor->DataPtr() : nullptr);
284-
ccl_impl_->Broadcast(send_buffer, outputs[i * input_tensors.size() + j]->DataPtr(),
285-
input_tensor->NumElements(), input_tensor->Dtype(), root, comms[i], streams[i]);
281+
CHECK_GE(root, 0);
282+
CHECK_LT(root, world_size_);
283+
284+
core::CclGroupGuard ccl_group_guard(devices_[0].type());
285+
for (size_t i = 0; i < devices_.size(); ++i) {
286+
core::DeviceGuard guard(devices_[i]);
287+
const int local_group_rank = local_group_ranks[i];
288+
for (size_t j = 0; j < num_tensors_per_device; ++j) {
289+
const auto &tensor = tensors[i * num_tensors_per_device + j];
290+
CHECK(tensor != nullptr) << "BroadCast: null tensor";
291+
CHECK_EQ(tensor->GetDevice(), devices_[i]) << "BroadCast: tensors must match local device grouping";
292+
const void *send_buffer = (local_group_rank == root ? tensor->DataPtr() : nullptr);
293+
ccl_impl_->Broadcast(send_buffer, tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), root, comms[i],
294+
streams[i]);
286295
}
287296
}
288-
289-
return outputs;
290297
}
291298

292299
std::vector<std::shared_ptr<Tensor>>
@@ -358,6 +365,72 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr
358365
return outputs;
359366
}
360367

368+
void ProcessGroup::ScatterFromRank(const std::vector<std::shared_ptr<Tensor>> &outputs,
369+
const std::shared_ptr<Tensor> &tensor, int64_t dim, int src_group_rank) const {
370+
CHECK(tensor != nullptr) << "Scatter: tensor carrying full shape/dtype must be provided on every process";
371+
CHECK_GE(src_group_rank, 0);
372+
CHECK_LT(src_group_rank, world_size_);
373+
CHECK_GT(devices_.size(), 0);
374+
CHECK_EQ(outputs.size(), devices_.size()) << "ScatterFromRank: expects one output per local group device";
375+
const int src_rank = src_group_rank;
376+
377+
// Identify local group ranks (in the same order as devices_).
378+
std::vector<int> local_group_ranks;
379+
local_group_ranks.reserve(devices_.size());
380+
for (const auto &d : devices_) { local_group_ranks.push_back(global_group_rank_map_.at(d.Rank().GlobalRank())); }
381+
const auto src_local_it = std::find(local_group_ranks.begin(), local_group_ranks.end(), src_rank);
382+
const bool src_is_local = src_local_it != local_group_ranks.end();
383+
384+
CHECK_EQ(tensor->Dims()[dim] % static_cast<int64_t>(world_size_), 0)
385+
<< "Scatter: dim size must be divisible by world size";
386+
const int64_t shard_size = tensor->Dims()[dim] / static_cast<int64_t>(world_size_);
387+
std::vector<std::shared_ptr<Tensor>> split_tensors;
388+
if (src_is_local) {
389+
split_tensors = tensor->Split(shard_size, dim);
390+
CHECK_EQ(split_tensors.size(), static_cast<size_t>(world_size_));
391+
}
392+
393+
std::vector<int64_t> shard_dims = tensor->Dims();
394+
shard_dims[dim] = shard_size;
395+
const DataType shard_dtype = tensor->Dtype();
396+
for (size_t i = 0; i < outputs.size(); ++i) {
397+
CHECK(outputs[i] != nullptr) << "ScatterFromRank: null output tensor";
398+
CHECK_EQ(outputs[i]->GetDevice(), devices_[i]) << "ScatterFromRank: output device mismatch";
399+
CHECK(outputs[i]->Dims() == shard_dims) << "ScatterFromRank: output shape mismatch";
400+
CHECK(outputs[i]->Dtype() == shard_dtype) << "ScatterFromRank: output dtype mismatch";
401+
}
402+
403+
core::CclGroupGuard ccl_group_guard(devices_[0].type());
404+
405+
if (src_is_local) {
406+
const size_t src_local_idx = static_cast<size_t>(src_local_it - local_group_ranks.begin());
407+
const auto &src_device = devices_[src_local_idx];
408+
core::DeviceGuard guard(src_device);
409+
auto *stream = runtime_impl_->GetStream(src_device);
410+
auto *comm = device_comm_map_.at(src_device.index());
411+
for (int dst = 0; dst < world_size_; ++dst) {
412+
if (dst == src_rank) {
413+
continue;
414+
}
415+
ccl_impl_->Send(split_tensors[dst]->DataPtr(), split_tensors[dst]->NumElements(), shard_dtype, dst, comm,
416+
stream);
417+
}
418+
}
419+
420+
for (size_t i = 0; i < devices_.size(); ++i) {
421+
const auto &local_device = devices_[i];
422+
const int local_rank = local_group_ranks[i];
423+
if (src_is_local && local_rank == src_rank) {
424+
outputs[i]->CopyFrom(split_tensors[src_rank]);
425+
continue;
426+
}
427+
core::DeviceGuard guard(local_device);
428+
auto *stream = runtime_impl_->GetStream(local_device);
429+
auto *comm = device_comm_map_.at(local_device.index());
430+
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comm, stream);
431+
}
432+
}
433+
361434
std::shared_ptr<Tensor> ProcessGroup::Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
362435
int64_t dim) const {
363436
int64_t num_devices = static_cast<int64_t>(tensors.size());

0 commit comments

Comments
 (0)