@@ -9,10 +9,10 @@ namespace infini_train::nn::parallel {
99DistributedOptimizer::DistributedOptimizer (OptimizerCreator creator,
1010 const std::vector<std::shared_ptr<Tensor>> &full_params,
1111 const std::vector<std::shared_ptr<Module>> &model_chunks,
12- size_t dp_world_size , size_t dp_rank )
13- : Optimizer(full_params), dp_world_size_(dp_world_size ), dp_rank_(dp_rank ) {
12+ size_t ddp_world_size , size_t ddp_rank )
13+ : Optimizer(full_params), ddp_world_size_(ddp_world_size ), ddp_rank_(ddp_rank ) {
1414
15- CHECK (dp_world_size_ > 1 ) << " DistributedOptimizer: dp_world_size must be greater than 1." ;
15+ CHECK (ddp_world_size_ > 1 ) << " DistributedOptimizer: ddp_world_size must be greater than 1." ;
1616
1717 for (size_t i = 0 ; i < model_chunks.size (); ++i) {
1818 auto ddp_chunk = std::dynamic_pointer_cast<DistributedDataParallel>(model_chunks[i]);
@@ -43,9 +43,9 @@ void DistributedOptimizer::BuildShardParamsAndBindGrads() {
4343 CHECK (bucket_param) << " DistributedOptimizer requires param buffer." ;
4444 CHECK (bucket_grad) << " DistributedOptimizer requires grad buffer." ;
4545
46- CHECK_EQ (bucket_param->NumElements () % dp_world_size_ , 0 );
47- const size_t bucket_shard_numel = bucket_param->NumElements () / dp_world_size_ ;
48- const size_t bucket_shard_start = dp_rank_ * bucket_shard_numel;
46+ CHECK_EQ (bucket_param->NumElements () % ddp_world_size_ , 0 );
47+ const size_t bucket_shard_numel = bucket_param->NumElements () / ddp_world_size_ ;
48+ const size_t bucket_shard_start = ddp_rank_ * bucket_shard_numel;
4949 const size_t bucket_shard_end = bucket_shard_start + bucket_shard_numel;
5050
5151 // Iterate param in bucket, build each param(or param_shard) seperately
0 commit comments