File tree Expand file tree Collapse file tree
infini_train/src/nn/parallel/ddp Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -157,18 +157,20 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr<Tensor> &p
157157 return ;
158158 }
159159
160- // Only register grads as ready when processing the last microbatch
160+ // TODO(zbl): Only register grads as ready and trigger grad sync when processing the last microbatch
161+ // For now, is_last_microbatch_ is always true
161162 if (is_last_microbatch_) {
162163 if (!parameter || params_.find (parameter.get ()) == params_.end ()) {
163164 return ;
164165 }
165166
166167 const bool inserted = params_with_grad_.insert (parameter.get ()).second ;
167- if (!inserted) {
168- LOG (FATAL ) << " ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a "
169- " bucket group." ;
170- return ;
171- }
168+ // TODO(zbl): check this if sync is only done in last mircobatch
169+ // if (!inserted) {
170+ // LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a "
171+ // "bucket group.";
172+ // return;
173+ // }
172174
173175 if (params_with_grad_.size () == params_.size ()) {
174176 // All param grads are ready in this group, trigger grad sync
@@ -301,6 +303,8 @@ void ParamAndGradBucketGroup::StartGradSync() {
301303 }
302304
303305 grad_reduce_dispatched_ = true ;
306+ // FIXME(zbl): no need to clear params_with_grad_ here if grad sync is only done on last microbatch
307+ params_with_grad_.clear ();
304308}
305309
306310void ParamAndGradBucketGroup::FinishGradSync () {
Original file line number Diff line number Diff line change @@ -267,8 +267,8 @@ for ((id=0; id<num_builds; ++id)); do
267267 arg_str=" $( args_string_for_test " $gi " " $ti " ) "
268268
269269 # gpt2
270- gpt2_cmd=" ${prefix} ./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${arg_str} "
271- run_and_log " $gpt2_cmd " " gpt2_${test_id}${log_suffix} " " $profile_flag " " $group_tag "
270+ # gpt2_cmd="${prefix}./gpt2 --input_bin ${GPT2_INPUT_BIN} --llmc_filepath ${GPT2_LLMC_FILEPATH} --device cuda ${arg_str}"
271+ # run_and_log "$gpt2_cmd" "gpt2_${test_id}${log_suffix}" "$profile_flag" "$group_tag"
272272
273273 # llama3
274274 llama3_cmd=" ${prefix} ./llama3 --input_bin ${LLAMA3_INPUT_BIN} --llmc_filepath ${LLAMA3_LLMC_FILEPATH} --device cuda ${arg_str} "
Original file line number Diff line number Diff line change 1414 "id" : " build_1" ,
1515 "profile" : false ,
1616 "cmd" : " cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j"
17- },
18- {
19- "id" : " build_2" ,
20- "profile" : true ,
21- "cmd" : " cmake -DUSE_CUDA=ON -DUSE_NCCL=ON -DPROFILE_MODE=ON .. && make -j"
2217 }
2318 ],
2419 "test_groups" : [
You can’t perform that action at this time.
0 commit comments