@@ -42,22 +42,20 @@ DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank,
4242 int32_t dit_sp_size,
4343 int32_t dit_cfg_size)
4444 : CollectiveCommunicatorBase(global_rank, world_size) {
45- #if defined(USE_NPU)
46- DiTMappingNPU::Options dit_mapping_options;
47- dit_mapping_options.dit_tp_size (dit_tp_size)
48- .dit_sp_size (dit_sp_size)
49- .dit_cfg_size (dit_cfg_size)
50- .dit_dp_size (dit_dp_size);
51- dit_mapping_npu_ = std::make_unique<DiTMappingNPU>(
52- world_size, global_rank, dit_mapping_options);
5345 parallel_args_ = std::make_unique<ParallelArgs>(global_rank,
5446 world_size,
5547 dit_dp_size,
5648 dit_tp_size,
5749 dit_sp_size,
5850 dit_cfg_size,
59- nullptr );
60- #endif
51+ /* process_group=*/ nullptr );
52+ DiTMapping::Options dit_mapping_options;
53+ dit_mapping_options.dit_tp_size (dit_tp_size)
54+ .dit_sp_size (dit_sp_size)
55+ .dit_cfg_size (dit_cfg_size)
56+ .dit_dp_size (dit_dp_size);
57+ dit_mapping_ = std::make_unique<DiTMapping>(
58+ world_size, global_rank, dit_mapping_options);
6159}
6260
6361void DiTCollectiveCommunicator::create_process_groups (
@@ -87,14 +85,13 @@ void DiTCollectiveCommunicator::create_process_groups(
8785
8886 parallel_args_->process_group_ = process_group_.get ();
8987
90- if (tp_size > 1 ) {
91- auto tp_parallel_info = dit_mapping_npu_ ->get_parallel_info (" tp" );
88+ if (tp_size > 1 && dit_mapping_ ) {
89+ auto tp_parallel_info = dit_mapping_ ->get_parallel_info (" tp" );
9290 auto group_id = tp_parallel_info.current_group_id ();
9391 auto num_group = tp_parallel_info.num_group ();
9492 auto local_rank = tp_parallel_info.rank ();
9593 auto & rank_per_group = tp_parallel_info.rank_per_group ()[group_id];
9694 int port_offset = group_id + 1 ;
97- #if defined(USE_NPU)
9895 dit_tp_group_ = create_process_group (global_rank,
9996 local_rank,
10097 rank_per_group,
@@ -105,18 +102,16 @@ void DiTCollectiveCommunicator::create_process_groups(
105102 " tp_group" ,
106103 device);
107104 parallel_args_->dit_tp_group_ = dit_tp_group_.get ();
108- #endif
109105 port += num_group;
110106 }
111107
112- if (sp_size > 1 ) {
113- auto sp_parallel_info = dit_mapping_npu_ ->get_parallel_info (" sp" );
108+ if (sp_size > 1 && dit_mapping_ ) {
109+ auto sp_parallel_info = dit_mapping_ ->get_parallel_info (" sp" );
114110 auto group_id = sp_parallel_info.current_group_id ();
115111 auto num_group = sp_parallel_info.num_group ();
116112 auto local_rank = sp_parallel_info.rank ();
117113 auto & rank_per_group = sp_parallel_info.rank_per_group ()[group_id];
118114 int port_offset = group_id + 1 ;
119- #if defined(USE_NPU)
120115 dit_sp_group_ = create_process_group (global_rank,
121116 local_rank,
122117 rank_per_group,
@@ -127,18 +122,16 @@ void DiTCollectiveCommunicator::create_process_groups(
127122 " sp_group" ,
128123 device);
129124 parallel_args_->dit_sp_group_ = dit_sp_group_.get ();
130- #endif
131125 port += num_group;
132126 }
133127
134- if (cfg_size > 1 ) {
135- auto cfg_parallel_info = dit_mapping_npu_ ->get_parallel_info (" cfg" );
128+ if (cfg_size > 1 && dit_mapping_ ) {
129+ auto cfg_parallel_info = dit_mapping_ ->get_parallel_info (" cfg" );
136130 auto group_id = cfg_parallel_info.current_group_id ();
137131 auto num_group = cfg_parallel_info.num_group ();
138132 auto local_rank = cfg_parallel_info.rank ();
139133 auto & rank_per_group = cfg_parallel_info.rank_per_group ()[group_id];
140134 int port_offset = group_id + 1 ;
141- #if defined(USE_NPU)
142135 dit_cfg_group_ = create_process_group (global_rank,
143136 local_rank,
144137 rank_per_group,
@@ -149,18 +142,16 @@ void DiTCollectiveCommunicator::create_process_groups(
149142 " cfg_group" ,
150143 device);
151144 parallel_args_->dit_cfg_group_ = dit_cfg_group_.get ();
152- #endif
153145 port += num_group;
154146 }
155147
156- if (dp_size > 1 ) {
157- auto dp_parallel_info = dit_mapping_npu_ ->get_parallel_info (" dp" );
148+ if (dp_size > 1 && dit_mapping_ ) {
149+ auto dp_parallel_info = dit_mapping_ ->get_parallel_info (" dp" );
158150 auto group_id = dp_parallel_info.current_group_id ();
159151 auto num_group = dp_parallel_info.num_group ();
160152 auto local_rank = dp_parallel_info.rank ();
161153 auto & rank_per_group = dp_parallel_info.rank_per_group ()[group_id];
162154 int port_offset = group_id + 1 ;
163- #if defined(USE_NPU)
164155 dit_dp_group_ = create_process_group (global_rank,
165156 local_rank,
166157 rank_per_group,
@@ -171,7 +162,6 @@ void DiTCollectiveCommunicator::create_process_groups(
171162 " dp_group" ,
172163 device);
173164 parallel_args_->dit_dp_group_ = dit_dp_group_.get ();
174- #endif
175165 port += num_group;
176166 }
177167}
0 commit comments