Skip to content

Commit 791c75e

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: fix rank argument in ddp multi-node training
1 parent 733ad19 commit 791c75e

4 files changed

Lines changed: 14 additions & 11 deletions

File tree

example/gpt2/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ void Train(const nn::parallel::Rank &rank) {
220220
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
221221
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
222222
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
223-
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
224-
rank.thread_rank(), ddp_config);
223+
(*mutable_chunks)[chunk_id]
224+
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
225225
}
226226
}
227227
} else if (ddp_world_size > 1) {
@@ -230,7 +230,7 @@ void Train(const nn::parallel::Rank &rank) {
230230
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
231231
// are created during the conversion.
232232
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
233-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
233+
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
234234
}
235235

236236
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),

example/llama3/main.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ void Train(const nn::parallel::Rank &rank) {
199199
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
200200
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
201201
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
202-
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
203-
rank.thread_rank(), ddp_config);
202+
(*mutable_chunks)[chunk_id]
203+
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
204204
}
205205
}
206206
} else if (ddp_world_size > 1) {
@@ -210,7 +210,7 @@ void Train(const nn::parallel::Rank &rank) {
210210
// are created during the conversion.
211211

212212
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
213-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
213+
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
214214
}
215215

216216
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),

infini_train/include/nn/parallel/ddp/distributed_data_parallel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ class Tensor;
1111
class Device;
1212
namespace nn::parallel {
1313
class DistributedDataParallelConfig;
14+
class Rank;
1415
} // namespace nn::parallel
1516
} // namespace infini_train
1617

1718
namespace infini_train::nn::parallel {
1819

1920
class DistributedDataParallel : public nn::Module {
2021
public:
21-
DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
22+
DistributedDataParallel(std::shared_ptr<nn::Module> module, const Rank &rank,
2223
DistributedDataParallelConfig ddp_config);
2324

2425
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "infini_train/include/nn/modules/module.h"
1212
#include "infini_train/include/nn/parallel/parallel_functional.h"
1313
#include "infini_train/include/nn/parallel/process_group.h"
14+
#include "infini_train/include/nn/parallel/rank.h"
1415
#include "infini_train/include/nn/parallel/utils.h"
1516
#include "infini_train/include/tensor.h"
1617

@@ -19,21 +20,22 @@ namespace {
1920
constexpr char kModuleName[] = "module";
2021
} // namespace
2122

22-
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
23+
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, const Rank &rank,
2324
const DistributedDataParallelConfig ddp_config)
2425
: ddp_config_(ddp_config),
25-
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) {
26+
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) {
2627
for (auto &param : module->Parameters()) {
2728
auto device = param->GetDevice();
28-
CHECK_EQ(device.index(), thread_rank) << "All parameters must be on the same device as the module";
29+
CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module";
2930
if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) {
3031
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(
3132
function::ReduceOpType::kAvg, ddp_pg_);
3233
param->RegisterPostAccumulateGradHook(std::move(hook));
3334
}
3435
}
3536
for (auto &buffer : module->Buffers()) {
36-
CHECK_EQ(buffer->GetDevice().index(), thread_rank) << "All buffers must be on the same device as the module";
37+
CHECK_EQ(buffer->GetDevice().index(), rank.thread_rank())
38+
<< "All buffers must be on the same device as the module";
3739
}
3840
modules_[kModuleName] = std::move(module);
3941

0 commit comments

Comments
 (0)