Skip to content

Commit 8459156

Browse files
committed
refactor: add multi-process Scatter overload and use it for LoRA lora_A init
Add ProcessGroup::Scatter(tensor, dim, src_group_rank) overload where each process only materializes shards for its own local devices. Use it in LoRARowParallelLinear to replace broadcast+slice, avoiding tp_size-fold communication volume during init.
1 parent c614ec6 commit 8459156

3 files changed

Lines changed: 223 additions & 49 deletions

File tree

infini_train/include/nn/parallel/process_group.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,25 @@ class ProcessGroup {
5959
bool async_op = false) const;
6060

6161
// Legacy communication APIs (Single-stream)
62-
virtual std::vector<std::shared_ptr<Tensor>>
63-
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
62+
// If root_group_rank is -1, infer root from input_tensors[0]'s device (single-process mode).
63+
// In multi-process mode, the caller must pass the source's group rank on every rank.
64+
virtual std::vector<std::shared_ptr<Tensor>> BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
65+
int root_group_rank = -1) const;
6466

6567
virtual std::vector<std::shared_ptr<Tensor>>
6668
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, Device destination) const;
6769

70+
// Single-process / DataParallel form: `devices` enumerates all target devices (must be local
71+
// to this process). Source is inferred from `tensor->GetDevice()` when `src_group_rank` is -1.
6872
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
69-
std::vector<Device> devices, int64_t dim) const;
73+
std::vector<Device> devices, int64_t dim,
74+
int src_group_rank = -1) const;
75+
76+
// Multi-process-friendly form (TP init etc.): each process only materializes shard(s) for
77+
// its own local device(s) in this group. `tensor` must carry the full shape/dtype on every
78+
// process; data is only read on the src process.
79+
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor, int64_t dim,
80+
int src_group_rank) const;
7081

7182
virtual std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
7283
int64_t dim) const;

infini_train/src/nn/lora/lora_parallel_linear.cc

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,16 @@ void LoRAColumnParallelLinear::InitLoRAWeights() {
102102
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
103103
const int tp_rank = tp_group->GetGroupRank(global_rank);
104104

105-
// Only TP rank 0 generates random values; others zero-init.
106-
// AllReduce(sum) then broadcasts rank-0's values to all TP ranks.
105+
// TP rank 0 generates random values; Broadcast replicates to other ranks.
107106
if (tp_rank == 0) {
108107
if (config_.use_kaiming_a) {
109108
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
110109
} else {
111110
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
112111
}
113-
} else {
114-
init::Zeros(parameters_[kParamLoraAName]);
115112
}
116-
tp_group->AllReduce(parameters_[kParamLoraAName]);
113+
auto broadcasted = tp_group->BroadCast({parameters_[kParamLoraAName]}, /*root_group_rank=*/0);
114+
parameters_[kParamLoraAName]->CopyFrom(broadcasted[0]);
117115
} else {
118116
if (config_.use_kaiming_a) {
119117
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
@@ -303,17 +301,42 @@ void LoRARowParallelLinear::InitLoRAWeights() {
303301
// lora_B: [out_features, rank] - replicated
304302

305303
// lora_A: [rank, in_features_per_partition]
304+
// TP rank 0 generates full [lora_rank, in_features], broadcasts to all TP ranks,
305+
// then each rank slices its own shard along dim=1.
306306
parameters_[kParamLoraAName]
307307
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_per_partition_}, DataType::kFLOAT32,
308308
device_)
309309
->RequiresGrad();
310-
if (config_.use_kaiming_a) {
311-
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
310+
311+
if (parallel::global::GetTensorParallelSize() > 1) {
312+
const auto global_rank = device_.Rank().GlobalRank();
313+
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
314+
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
315+
const int tp_rank = tp_group->GetGroupRank(global_rank);
316+
const int tp_size = parallel::global::GetTensorParallelSize();
317+
318+
// TP rank 0 generates full [lora_rank, in_features]; scatter shards along dim=1 to all ranks.
319+
// Non-src processes pass a tensor carrying only shape/dtype (contents unread).
320+
auto full_lora_A = std::make_shared<Tensor>(
321+
std::vector<int64_t>{config_.rank, in_features_per_partition_ * tp_size}, DataType::kFLOAT32, device_);
322+
if (tp_rank == 0) {
323+
if (config_.use_kaiming_a) {
324+
init::KaimingUniform(full_lora_A, config_.kaiming_a_param);
325+
} else {
326+
init::Normal(full_lora_A, 0.0f, 0.02f);
327+
}
328+
}
329+
auto shards = tp_group->Scatter(full_lora_A, /*dim=*/1, /*src_group_rank=*/0);
330+
parameters_[kParamLoraAName]->CopyFrom(shards[0]);
312331
} else {
313-
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
332+
if (config_.use_kaiming_a) {
333+
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
334+
} else {
335+
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
336+
}
314337
}
315338

316-
// lora_B: [out_features, rank]
339+
// lora_B: [out_features, rank] - replicated, zeros
317340
parameters_[kParamLoraBName]
318341
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_, config_.rank}, DataType::kFLOAT32, device_)
319342
->RequiresGrad();

infini_train/src/nn/parallel/process_group.cc

Lines changed: 177 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -248,39 +248,40 @@ std::shared_ptr<Work> ProcessGroup::Recv(std::vector<std::shared_ptr<Tensor>> te
248248
}
249249
}
250250

251-
std::vector<std::shared_ptr<Tensor>>
252-
ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const {
251+
std::vector<std::shared_ptr<Tensor>> ProcessGroup::BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
252+
int root_group_rank) const {
253253
std::vector<std::shared_ptr<Tensor>> outputs;
254254
std::vector<core::Stream *> streams;
255255
std::vector<core::CclComm *> comms;
256-
std::vector<Device> devices;
257256

258-
CHECK_EQ(world_size_, comms_.size());
259-
for (size_t i = 0; i < world_size_; ++i) {
260-
auto device = devices_[i];
257+
// Only iterate over this process's devices (in single-process mode this equals world_size_;
258+
// in multi-process mode it is a strict subset).
259+
for (const auto &device : devices_) {
261260
for (const auto &input_tensor : input_tensors) {
262261
outputs.push_back(std::make_shared<Tensor>(input_tensor->Dims(), input_tensor->Dtype(), device));
263262
}
264-
devices.push_back(device);
265263
streams.push_back(runtime_impl_->GetStream(device));
266264
comms.push_back(device_comm_map_.at(device.index()));
267265
}
268266

269-
int root = -1;
270-
for (size_t i = 0; i < devices.size(); ++i) {
271-
if (devices[i] == input_tensors[0]->GetDevice()) {
272-
root = static_cast<int>(i);
273-
break;
274-
}
267+
// Determine NCCL root (= group rank of the source). In single-process mode the caller may
268+
// omit it and we infer from input_tensors[0]->GetDevice(); in multi-process mode the source
269+
// may not be on this process, so the caller must provide the group rank explicitly.
270+
int root = root_group_rank;
271+
if (root < 0) {
272+
auto it = global_group_rank_map_.find(input_tensors[0]->GetDevice().Rank().GlobalRank());
273+
CHECK(it != global_group_rank_map_.end())
274+
<< "BroadCast: root device not found in group and root_group_rank was not provided";
275+
root = it->second;
275276
}
276-
CHECK_NE(root, -1) << "Root not found in input devices";
277277

278-
core::CclGroupGuard ccl_group_guard(devices[0].type());
279-
for (size_t i = 0; i < devices.size(); ++i) {
280-
core::DeviceGuard guard(devices[i]);
278+
core::CclGroupGuard ccl_group_guard(devices_[0].type());
279+
for (size_t i = 0; i < devices_.size(); ++i) {
280+
core::DeviceGuard guard(devices_[i]);
281+
const int local_group_rank = global_group_rank_map_.at(devices_[i].Rank().GlobalRank());
281282
for (size_t j = 0; j < input_tensors.size(); ++j) {
282283
const auto &input_tensor = input_tensors[j];
283-
const void *send_buffer = (static_cast<int>(i) == root ? input_tensor->DataPtr() : nullptr);
284+
const void *send_buffer = (local_group_rank == root ? input_tensor->DataPtr() : nullptr);
284285
ccl_impl_->Broadcast(send_buffer, outputs[i * input_tensors.size() + j]->DataPtr(),
285286
input_tensor->NumElements(), input_tensor->Dtype(), root, comms[i], streams[i]);
286287
}
@@ -330,30 +331,169 @@ ProcessGroup::ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<T
330331
}
331332

332333
std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr<Tensor> &tensor,
333-
std::vector<Device> devices, int64_t dim) const {
334+
std::vector<Device> devices, int64_t dim,
335+
int src_group_rank) const {
336+
CHECK_EQ(devices.size(), static_cast<size_t>(world_size_)) << "Scatter expects one device per group rank";
337+
CHECK_GT(devices.size(), 0);
338+
CHECK(tensor != nullptr) << "Scatter: tensor carrying full shape/dtype must be provided on every process";
339+
340+
// Resolve src rank: explicit overrides inference from tensor device.
341+
int src_rank = src_group_rank;
342+
if (src_rank < 0) {
343+
for (size_t i = 0; i < devices.size(); ++i) {
344+
if (tensor->GetDevice() == devices[i]) {
345+
src_rank = static_cast<int>(i);
346+
break;
347+
}
348+
}
349+
CHECK_NE(src_rank, -1) << "Source device not found in input devices";
350+
}
351+
CHECK_GE(src_rank, 0);
352+
CHECK_LT(src_rank, world_size_);
353+
354+
// Identify local group ranks (in the same order as devices_).
355+
std::vector<int> local_group_ranks;
356+
local_group_ranks.reserve(devices_.size());
357+
for (const auto &d : devices_) { local_group_ranks.push_back(global_group_rank_map_.at(d.Rank().GlobalRank())); }
358+
const auto src_local_it = std::find(local_group_ranks.begin(), local_group_ranks.end(), src_rank);
359+
const bool src_is_local = src_local_it != local_group_ranks.end();
360+
361+
// Source splits only when it owns the full tensor. Shard shape is identical for all ranks
362+
// when the dim is evenly divisible; we rely on that for preallocation on non-src processes.
363+
CHECK_EQ(tensor->Dims()[dim] % static_cast<int64_t>(devices.size()), 0)
364+
<< "Scatter: dim size must be divisible by world size";
365+
const int64_t shard_size = tensor->Dims()[dim] / static_cast<int64_t>(devices.size());
366+
std::vector<std::shared_ptr<Tensor>> split_tensors;
367+
if (src_is_local) {
368+
split_tensors = tensor->Split(shard_size, dim);
369+
CHECK_EQ(split_tensors.size(), devices.size());
370+
}
371+
372+
std::vector<int64_t> shard_dims = tensor->Dims();
373+
shard_dims[dim] = shard_size;
374+
const DataType shard_dtype = tensor->Dtype();
375+
376+
// Preallocate output shards for this process's local devices.
334377
std::vector<std::shared_ptr<Tensor>> outputs;
335-
auto split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim);
336-
std::vector<core::Stream *> streams;
337-
std::vector<core::CclComm *> comms;
338-
int src_rank = -1;
378+
outputs.reserve(devices_.size());
379+
for (const auto &d : devices_) { outputs.push_back(std::make_shared<Tensor>(shard_dims, shard_dtype, d)); }
339380

340-
for (size_t i = 0; i < devices.size(); ++i) {
341-
if (tensor->GetDevice() == devices[i]) {
342-
src_rank = static_cast<int>(i);
381+
// Single-process mode: all devices live here, keep the symmetric Send/Recv loop for clarity.
382+
if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) {
383+
std::vector<core::Stream *> streams;
384+
std::vector<core::CclComm *> comms;
385+
streams.reserve(devices.size());
386+
comms.reserve(devices.size());
387+
for (const auto &d : devices) {
388+
streams.push_back(runtime_impl_->GetStream(d));
389+
comms.push_back(device_comm_map_.at(d.index()));
343390
}
344-
outputs.push_back(std::make_shared<Tensor>(split_tensors[i]->Dims(), split_tensors[i]->Dtype(), devices[i]));
345-
streams.push_back(runtime_impl_->GetStream(devices[i]));
346-
comms.push_back(device_comm_map_.at(devices[i].index()));
391+
core::CclGroupGuard ccl_group_guard(devices[0].type());
392+
for (size_t i = 0; i < devices.size(); ++i) {
393+
core::DeviceGuard guard(devices[i]);
394+
ccl_impl_->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), shard_dtype,
395+
static_cast<int>(i), comms[src_rank], streams[src_rank]);
396+
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comms[i],
397+
streams[i]);
398+
}
399+
return outputs;
347400
}
348-
CHECK_NE(src_rank, -1) << "Source device not found in input devices";
349401

350-
core::CclGroupGuard ccl_group_guard(devices[0].type());
351-
for (size_t i = 0; i < devices.size(); ++i) {
352-
core::DeviceGuard guard(devices[i]);
353-
ccl_impl_->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), tensor->Dtype(), i,
354-
comms[src_rank], streams[src_rank]);
355-
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), tensor->Dtype(), src_rank, comms[i],
356-
streams[i]);
402+
// Multi-process mode: each process handles only its local device(s).
403+
core::CclGroupGuard ccl_group_guard(devices_[0].type());
404+
405+
// Src issues a Send to every non-src group rank (including group ranks hosted in other processes).
406+
if (src_is_local) {
407+
const size_t src_local_idx = static_cast<size_t>(src_local_it - local_group_ranks.begin());
408+
const auto &src_device = devices_[src_local_idx];
409+
core::DeviceGuard guard(src_device);
410+
auto *stream = runtime_impl_->GetStream(src_device);
411+
auto *comm = device_comm_map_.at(src_device.index());
412+
for (int dst = 0; dst < world_size_; ++dst) {
413+
if (dst == src_rank) {
414+
continue;
415+
}
416+
ccl_impl_->Send(split_tensors[dst]->DataPtr(), split_tensors[dst]->NumElements(), shard_dtype, dst, comm,
417+
stream);
418+
}
419+
}
420+
421+
// Every local device posts either a local copy (if it is src) or a Recv from src.
422+
for (size_t i = 0; i < devices_.size(); ++i) {
423+
const auto &local_device = devices_[i];
424+
const int local_rank = local_group_ranks[i];
425+
if (src_is_local && local_rank == src_rank) {
426+
outputs[i]->CopyFrom(split_tensors[src_rank]);
427+
continue;
428+
}
429+
core::DeviceGuard guard(local_device);
430+
auto *stream = runtime_impl_->GetStream(local_device);
431+
auto *comm = device_comm_map_.at(local_device.index());
432+
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comm, stream);
433+
}
434+
return outputs;
435+
}
436+
437+
std::vector<std::shared_ptr<Tensor>> ProcessGroup::Scatter(const std::shared_ptr<Tensor> &tensor, int64_t dim,
438+
int src_group_rank) const {
439+
CHECK(tensor != nullptr) << "Scatter: tensor carrying full shape/dtype must be provided on every process";
440+
CHECK_GE(src_group_rank, 0);
441+
CHECK_LT(src_group_rank, world_size_);
442+
CHECK_GT(devices_.size(), 0);
443+
const int src_rank = src_group_rank;
444+
445+
// Identify local group ranks (in the same order as devices_).
446+
std::vector<int> local_group_ranks;
447+
local_group_ranks.reserve(devices_.size());
448+
for (const auto &d : devices_) { local_group_ranks.push_back(global_group_rank_map_.at(d.Rank().GlobalRank())); }
449+
const auto src_local_it = std::find(local_group_ranks.begin(), local_group_ranks.end(), src_rank);
450+
const bool src_is_local = src_local_it != local_group_ranks.end();
451+
452+
CHECK_EQ(tensor->Dims()[dim] % static_cast<int64_t>(world_size_), 0)
453+
<< "Scatter: dim size must be divisible by world size";
454+
const int64_t shard_size = tensor->Dims()[dim] / static_cast<int64_t>(world_size_);
455+
std::vector<std::shared_ptr<Tensor>> split_tensors;
456+
if (src_is_local) {
457+
split_tensors = tensor->Split(shard_size, dim);
458+
CHECK_EQ(split_tensors.size(), static_cast<size_t>(world_size_));
459+
}
460+
461+
std::vector<int64_t> shard_dims = tensor->Dims();
462+
shard_dims[dim] = shard_size;
463+
const DataType shard_dtype = tensor->Dtype();
464+
465+
std::vector<std::shared_ptr<Tensor>> outputs;
466+
outputs.reserve(devices_.size());
467+
for (const auto &d : devices_) { outputs.push_back(std::make_shared<Tensor>(shard_dims, shard_dtype, d)); }
468+
469+
core::CclGroupGuard ccl_group_guard(devices_[0].type());
470+
471+
if (src_is_local) {
472+
const size_t src_local_idx = static_cast<size_t>(src_local_it - local_group_ranks.begin());
473+
const auto &src_device = devices_[src_local_idx];
474+
core::DeviceGuard guard(src_device);
475+
auto *stream = runtime_impl_->GetStream(src_device);
476+
auto *comm = device_comm_map_.at(src_device.index());
477+
for (int dst = 0; dst < world_size_; ++dst) {
478+
if (dst == src_rank) {
479+
continue;
480+
}
481+
ccl_impl_->Send(split_tensors[dst]->DataPtr(), split_tensors[dst]->NumElements(), shard_dtype, dst, comm,
482+
stream);
483+
}
484+
}
485+
486+
for (size_t i = 0; i < devices_.size(); ++i) {
487+
const auto &local_device = devices_[i];
488+
const int local_rank = local_group_ranks[i];
489+
if (src_is_local && local_rank == src_rank) {
490+
outputs[i]->CopyFrom(split_tensors[src_rank]);
491+
continue;
492+
}
493+
core::DeviceGuard guard(local_device);
494+
auto *stream = runtime_impl_->GetStream(local_device);
495+
auto *comm = device_comm_map_.at(local_device.index());
496+
ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comm, stream);
357497
}
358498
return outputs;
359499
}

0 commit comments

Comments
 (0)