@@ -86,7 +86,7 @@ void ParamAndGradBucket::ScaleGradients(float scaling_factor) {
8686
8787 // FIXME(zbl): should perform in-place multiply
8888 // grad_data_ *= scaling_factor;
89- LOG (FATAL) << " ParamAndGradBucket: Should not arrive here " ;
89+ LOG (FATAL) << " ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet. " ;
9090}
9191
9292ParamAndGradBucketGroup::ParamAndGradBucketGroup (const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets,
@@ -107,8 +107,7 @@ ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector<std::shared_p
107107 }
108108 if (rank_in_collective_pg_ == -1 ) {
109109 auto param = *params_.begin ();
110- // FIXME(zbl): get correct rank in multi-node settings
111- rank_in_collective_pg_ = collective_pg_->GetGroupRank (param->GetDevice ().Rank ().thread_rank ());
110+ rank_in_collective_pg_ = collective_pg_->GetGroupRank (param->GetDevice ().Rank ().GlobalRank ());
112111 }
113112
114113 param_buffer_shard_list_.resize (buckets_.size ());
@@ -168,9 +167,7 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr<Tensor> &p
168167 // TODO(zbl): check this if sync is only done in last mircobatch
169168 // if (!inserted) {
170169 // LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a
171- // "
172- // "bucket group.";
173- // return;
170+ // bucket group."; return;
174171 // }
175172
176173 if (params_with_grad_.size () == params_.size ()) {
@@ -304,7 +301,7 @@ void ParamAndGradBucketGroup::StartGradSync() {
304301 }
305302
306303 grad_reduce_dispatched_ = true ;
307- // FIXME (zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch
304+ // TODO (zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch
308305 params_with_grad_.clear ();
309306}
310307
@@ -637,7 +634,7 @@ void ParamAndGradBuffer::ScaleGradients(float scaling_factor) {
637634
638635 // FIXME(zbl): should perform in-place multiply
639636 // grad_data_ *= scaling_factor;
640- LOG (FATAL) << " Should not arrive here " ;
637+ LOG (FATAL) << " ParamAndGradBuffer::ScaleGradients(): Inplace multiply not implemented yet. " ;
641638}
642639
643640void ParamAndGradBuffer::Reset (bool need_rebind) {
0 commit comments