Skip to content

Commit 4a2527f

Browse files
feat: Support ZeRO-2 based on DistributedOptimizer
1 parent 45c11cb commit 4a2527f

File tree

11 files changed

+391
-32
lines changed

11 files changed

+391
-32
lines changed

example/gpt2/main.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5252
// optimization
5353
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
5454
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
55+
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
5556
// evaluation
5657
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
5758
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
@@ -106,6 +107,7 @@ const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
106107
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
107108
DEFINE_validator(device,
108109
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
110+
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });
109111

110112
void Train(const nn::parallel::Rank &rank) {
111113
using namespace nn::parallel;
@@ -211,8 +213,8 @@ void Train(const nn::parallel::Rank &rank) {
211213
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
212214
std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
213215
if (ddp_world_size > 1) {
214-
auto ddp_config
215-
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
216+
auto ddp_config = DistributedDataParallelConfig{
217+
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
216218
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
217219
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
218220
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
@@ -224,7 +226,8 @@ void Train(const nn::parallel::Rank &rank) {
224226
// before wrapping the model with DistributedDataParallel (DDP).
225227
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
226228
// are created during the conversion.
227-
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
229+
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
230+
.zero_stage = FLAGS_zero_stage};
228231
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
229232
}
230233

example/llama3/main.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5151
// optimization
5252
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
5353
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
54+
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
5455
// evaluation
5556
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
5657
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
@@ -88,6 +89,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
8889
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
8990
DEFINE_validator(device,
9091
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
92+
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });
9193

9294
void Train(const nn::parallel::Rank &rank) {
9395
using namespace nn::parallel;
@@ -190,8 +192,8 @@ void Train(const nn::parallel::Rank &rank) {
190192
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
191193
std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
192194
if (ddp_world_size > 1) {
193-
auto ddp_config
194-
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
195+
auto ddp_config = DistributedDataParallelConfig{
196+
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
195197
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
196198
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
197199
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
@@ -204,7 +206,8 @@ void Train(const nn::parallel::Rank &rank) {
204206
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
205207
// are created during the conversion.
206208

207-
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
209+
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
210+
.zero_stage = FLAGS_zero_stage};
208211
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
209212
}
210213

infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class DistributedDataParallelConfig {
4040
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
4141
bool overlap_grad_reduce = true;
4242

43+
// ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true)
44+
// ZeRO-1: Optimizer states partitioning, by default
45+
// ZeRO-2: Gradients partitioning
46+
// ZeRO-3: Parameters partitioning
47+
int zero_stage = 1;
48+
4349
// Whether to overlap parameter all-gather with forward compute.
4450
bool overlap_param_gather = true;
4551

@@ -59,7 +65,7 @@ class DistributedDataParallelConfig {
5965
// Maximum number of parameters in each ParamAndGradBucket.
6066
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
6167
// TODO(zbl): To unify the definition of bucket_size argument for users
62-
size_t bucket_size_in_elements = 40000000;
68+
size_t bucket_size_in_elements = 1000000;
6369

6470
// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
6571
bool pad_buckets_for_high_nccl_busbw = false;

infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ namespace infini_train::nn::parallel {
2222
class ParamAndGradBucket {
2323
public:
2424
ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> &params, const std::shared_ptr<Tensor> &param_data,
25-
const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded,
26-
float gradient_scaling_factor, size_t bucket_id);
25+
DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype,
26+
size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id);
2727

2828
size_t bucket_id() const { return bucket_id_; }
2929

@@ -33,6 +33,10 @@ class ParamAndGradBucket {
3333

3434
const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; }
3535

36+
DataType param_dtype() const { return param_dtype_; }
37+
38+
DataType grad_dtype() const { return grad_dtype_; }
39+
3640
size_t offset() const { return offset_; }
3741

3842
size_t num_elements_unpadded() const { return num_elements_unpadded_; }
@@ -49,6 +53,8 @@ class ParamAndGradBucket {
4953
std::vector<std::shared_ptr<Tensor>> params_;
5054
std::shared_ptr<Tensor> param_data_;
5155
std::shared_ptr<Tensor> grad_data_;
56+
DataType param_dtype_;
57+
DataType grad_dtype_;
5258

5359
size_t offset_ = 0;
5460
size_t num_elements_unpadded_ = 0;
@@ -73,6 +79,11 @@ class ParamAndGradBucketGroup {
7379
// Start grad reduce
7480
void StartGradSync();
7581

82+
// Accumulate a parameter grad into bucket buffer
83+
// ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward
84+
void AccumulateParamGrad(const std::shared_ptr<Tensor> &parameter, const std::shared_ptr<Tensor> &grad,
85+
bool overwrite, float learning_rate);
86+
7687
// Wait for gradient reduce to complete
7788
void FinishGradSync();
7889

@@ -87,6 +98,9 @@ class ParamAndGradBucketGroup {
8798

8899
const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; }
89100

101+
// ZeRO-2: Get a bucket's local grad shard buffer
102+
std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const;
103+
90104
const DistributedDataParallelConfig &config() const { return ddp_config_; }
91105

92106
private:
@@ -98,12 +112,20 @@ class ParamAndGradBucketGroup {
98112

99113
std::unordered_set<Tensor *> params_;
100114
std::unordered_set<Tensor *> params_with_grad_;
115+
// Tensor -> (Bucket, Bucket Index)
116+
std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_;
101117

102118
// TODO(zbl): Implement CoalescedWork for aggregate works
103119
// According to Megatron-LM's _coalescing_manager
104120
std::vector<std::shared_ptr<Work>> grad_reduce_work_list_;
121+
std::vector<size_t> grad_reduce_bucket_indices_;
105122
std::vector<std::shared_ptr<Work>> param_gather_work_list_;
106123

124+
// ZeRO-2: persistent grad shard buffers and temporary full grad buffers
125+
std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_;
126+
std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_;
127+
std::vector<bool> temp_full_grad_initialized_;
128+
107129
std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr;
108130

109131
std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_;

infini_train/include/tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
227227
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator();
228228
void ResetAccumulator();
229229

230+
// ZeRO-2: Use this function to take over AccumulateGrad::Backward
231+
using GradAccumulateBypass
232+
= std::function<bool(const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate)>;
233+
GradAccumulateBypass grad_accumulate_bypass();
234+
void SetGradAccumulateBypass(GradAccumulateBypass);
235+
230236
void RegisterPostAccumulateGradHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook);
231237

232238
autograd::PostAccumulateGradHook *post_accumulate_grad_hook() const;
@@ -241,6 +247,8 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
241247
// a strong reference to the accumulator to manage its lifetime.
242248
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator_ = nullptr;
243249
std::shared_ptr<autograd::PostAccumulateGradHook> post_accumulate_grad_hook_ = nullptr;
250+
// ZeRO-2: Use this function to take over AccumulateGrad::Backward
251+
GradAccumulateBypass grad_accumulate_bypass_ = nullptr;
244252

245253
bool grad_overwrite_once_ = false;
246254
};

infini_train/src/autograd/accumulate.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,16 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
2525
device->SetDevice();
2626

2727
if (grad_output) {
28+
const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
29+
// ZeRO-2: Use a bypass function to perform grad accumulation in temp full grad buffer
30+
auto bypass = tensor_->grad_accumulate_bypass();
31+
if (bypass && bypass(grad_output, overwrite, learning_rate_)) {
32+
tensor_->ResetAccumulator();
33+
return {};
34+
}
35+
2836
if (grad) {
29-
if (tensor_->ConsumeGradOverwriteFlag()) {
37+
if (overwrite) {
3038
// If the tensor is marked to overrite its current grad on next grad update
3139
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
3240
// NOTE(zbl): must copy, cannot change grad buffer address

infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
2424
const DistributedDataParallelConfig ddp_config)
2525
: ddp_config_(ddp_config),
2626
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) {
27+
CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3)
28+
<< "DistributedDataParallel: zero_stage must be in 1/2/3.";
29+
if (ddp_config_.zero_stage >= 3) {
30+
LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet.";
31+
}
32+
if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) {
33+
LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because "
34+
"use_distributed_optimizer is false.";
35+
ddp_config_.zero_stage = 1;
36+
}
37+
2738
for (auto &param : module->Parameters()) {
2839
auto device = param->GetDevice();
2940
CHECK_EQ(device->Index(), thread_rank) << "All parameters must be on the same device as the module";
@@ -79,6 +90,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
7990
continue;
8091
}
8192

93+
// At the point, zero_stage is already aligned with use_distributed_optimizer.
8294
auto buffer = std::make_shared<ParamAndGradBuffer>(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_);
8395

8496
param_grad_buffers_.push_back(buffer);
@@ -112,6 +124,32 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
112124
}
113125

114126
void DistributedDataParallel::RegisterBackwardHooks() {
127+
if (ddp_config_.zero_stage >= 2) {
128+
auto &module = modules_.at(kModuleName);
129+
for (auto &param : module->Parameters()) {
130+
if (!param->requires_grad()) {
131+
continue;
132+
}
133+
auto it = param_to_bucket_group_.find(param.get());
134+
if (it == param_to_bucket_group_.end()) {
135+
continue;
136+
}
137+
std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second;
138+
param->SetGradAccumulateBypass(
139+
[weak_group, param](const std::shared_ptr<Tensor> &grad_output, bool overwrite, float learning_rate) {
140+
if (auto group = weak_group.lock()) {
141+
group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate);
142+
if (group->config().overlap_grad_reduce) {
143+
group->RegisterGradReady(param);
144+
}
145+
return true;
146+
}
147+
return false;
148+
});
149+
}
150+
return;
151+
}
152+
115153
class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook {
116154
public:
117155
DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr<Tensor> param)
@@ -143,7 +181,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr<Tensor> &param)
143181
auto it = param_to_bucket_group_.find(param.get());
144182
if (it != param_to_bucket_group_.end()) {
145183
CHECK(param->requires_grad());
146-
if (ddp_config_.overlap_grad_reduce) {
184+
if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) {
147185
CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True";
148186
}
149187

infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
3535
shard_params_.clear();
3636

3737
for (const auto &group : bucket_groups_) {
38-
for (const auto &bucket : group->buckets()) {
38+
const bool use_grad_shard = group->config().zero_stage >= 2;
39+
const auto &buckets = group->buckets();
40+
for (size_t bucket_idx = 0; bucket_idx < buckets.size(); ++bucket_idx) {
41+
const auto &bucket = buckets[bucket_idx];
3942

4043
auto bucket_param = bucket->param_data();
41-
auto bucket_grad = bucket->grad_data();
44+
auto bucket_grad = use_grad_shard ? group->GetLocalGradShardBuffer(bucket_idx) : bucket->grad_data();
4245

4346
CHECK(bucket_param) << "DistributedOptimizer requires param buffer.";
4447
CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer.";
@@ -65,7 +68,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
6568
CHECK_GT(piece_numel, 0);
6669

6770
const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype());
68-
const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype());
71+
// Adjust the offset since bucket_grad is already the shard of grad under ZeRO-2.
72+
auto offset = use_grad_shard ? (local_start - bucket_shard_start) : local_start;
73+
size_t grad_piece_offset_bytes = offset * kDataTypeToSize.at(bucket_grad->Dtype());
6974

7075
auto param_piece = std::make_shared<Tensor>(*bucket_param, param_piece_offset_bytes,
7176
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});
@@ -74,6 +79,12 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
7479
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});
7580

7681
param_piece->set_grad(grad_piece);
82+
// if (use_grad_shard) {
83+
// // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad.
84+
// // The binding is done in the construnctor of DistributedOptimizer.
85+
// // Not until backward is finished, the value of param->grad() will be updated.
86+
// param->set_grad(grad_piece);
87+
// }
7788
shard_params_.push_back(param_piece);
7889
}
7990
}

0 commit comments

Comments
 (0)