Skip to content

Commit 477a198

Browse files
fix: fix some descriptions in comments
1 parent e9610d9 commit 477a198

2 files changed

Lines changed: 9 additions & 15 deletions

File tree

infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
7979
std::vector<int64_t>{static_cast<int64_t>(piece_numel)});
8080

8181
param_piece->set_grad(grad_piece);
82-
// if (use_grad_shard) {
83-
// // NOTE(zbl): Under ZeRO-2, param->grad() is the shard of grad, not the full grad.
84-
// // The binding is done in the construnctor of DistributedOptimizer.
85-
// // Not until backward is finished, the value of param->grad() will be updated.
86-
// param->set_grad(grad_piece);
87-
// }
82+
// NOTE(zbl): Do not call `param->set_grad(grad_piece);` under ZeRO-2.
83+
// The base optimizer updates param_piece views only; original param->grad()
84+
// would be a partial flattened shard and does not represent the full parameter grad.
8885
shard_params_.push_back(param_piece);
8986
}
9087
}
@@ -135,7 +132,7 @@ void DistributedOptimizer::Step() {
135132

136133
// 3. Gather updated param shards back to full params
137134
StartParamSync(/*force_sync=*/false);
138-
// FIXME(zbl): Call sync before param is actually used in next step
135+
// TODO(zbl): Delay sync call until param is actually used in next step
139136
FinishParamSync(/*skip_next_bucket_dispatch=*/true);
140137
}
141138

infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void ParamAndGradBucket::ScaleGradients(float scaling_factor) {
8686

8787
// FIXME(zbl): should perform in-place multiply
8888
// grad_data_ *= scaling_factor;
89-
LOG(FATAL) << "ParamAndGradBucket: Should not arrive here";
89+
LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet.";
9090
}
9191

9292
ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets,
@@ -107,8 +107,7 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector<std::shared_p
107107
}
108108
if (rank_in_collective_pg_ == -1) {
109109
auto param = *params_.begin();
110-
// FIXME(zbl): get correct rank in multi-node settings
111-
rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().thread_rank());
110+
rank_in_collective_pg_ = collective_pg_->GetGroupRank(param->GetDevice().Rank().GlobalRank());
112111
}
113112

114113
param_buffer_shard_list_.resize(buckets_.size());
@@ -168,9 +167,7 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr<Tensor> &p
168167
// TODO(zbl): check this if sync is only done in last mircobatch
169168
// if (!inserted) {
170169
// LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a
171-
// "
172-
// "bucket group.";
173-
// return;
170+
// bucket group."; return;
174171
// }
175172

176173
if (params_with_grad_.size() == params_.size()) {
@@ -304,7 +301,7 @@ void ParamAndGradBucketGroup::StartGradSync() {
304301
}
305302

306303
grad_reduce_dispatched_ = true;
307-
// FIXME(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch
304+
// TODO(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch
308305
params_with_grad_.clear();
309306
}
310307

@@ -637,7 +634,7 @@ void ParamAndGradBuffer::ScaleGradients(float scaling_factor) {
637634

638635
// FIXME(zbl): should perform in-place multiply
639636
// grad_data_ *= scaling_factor;
640-
LOG(FATAL) << "Should not arrive here";
637+
LOG(FATAL) << "ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet.";
641638
}
642639

643640
void ParamAndGradBuffer::Reset(bool need_rebind) {

0 commit comments

Comments
 (0)