@@ -92,6 +92,7 @@ void DiTCollectiveCommunicator::create_process_groups(
9292 auto local_rank = tp_parallel_info.rank ();
9393 auto & rank_per_group = tp_parallel_info.rank_per_group ()[group_id];
9494 int port_offset = group_id + 1 ;
95+ #if defined(USE_NPU) || defined(USE_MLU)
9596 dit_tp_group_ = create_process_group (global_rank,
9697 local_rank,
9798 rank_per_group,
@@ -101,6 +102,7 @@ void DiTCollectiveCommunicator::create_process_groups(
101102 host,
102103 " tp_group" ,
103104 device);
105+ #endif
104106 parallel_args_->dit_tp_group_ = dit_tp_group_.get ();
105107 port += num_group;
106108 }
@@ -112,6 +114,7 @@ void DiTCollectiveCommunicator::create_process_groups(
112114 auto local_rank = sp_parallel_info.rank ();
113115 auto & rank_per_group = sp_parallel_info.rank_per_group ()[group_id];
114116 int port_offset = group_id + 1 ;
117+ #if defined(USE_NPU) || defined(USE_MLU)
115118 dit_sp_group_ = create_process_group (global_rank,
116119 local_rank,
117120 rank_per_group,
@@ -121,6 +124,7 @@ void DiTCollectiveCommunicator::create_process_groups(
121124 host,
122125 " sp_group" ,
123126 device);
127+ #endif
124128 parallel_args_->dit_sp_group_ = dit_sp_group_.get ();
125129 port += num_group;
126130 }
@@ -132,6 +136,7 @@ void DiTCollectiveCommunicator::create_process_groups(
132136 auto local_rank = cfg_parallel_info.rank ();
133137 auto & rank_per_group = cfg_parallel_info.rank_per_group ()[group_id];
134138 int port_offset = group_id + 1 ;
139+ #if defined(USE_NPU) || defined(USE_MLU)
135140 dit_cfg_group_ = create_process_group (global_rank,
136141 local_rank,
137142 rank_per_group,
@@ -141,6 +146,7 @@ void DiTCollectiveCommunicator::create_process_groups(
141146 host,
142147 " cfg_group" ,
143148 device);
149+ #endif
144150 parallel_args_->dit_cfg_group_ = dit_cfg_group_.get ();
145151 port += num_group;
146152 }
@@ -152,6 +158,7 @@ void DiTCollectiveCommunicator::create_process_groups(
152158 auto local_rank = dp_parallel_info.rank ();
153159 auto & rank_per_group = dp_parallel_info.rank_per_group ()[group_id];
154160 int port_offset = group_id + 1 ;
161+ #if defined(USE_NPU) || defined(USE_MLU)
155162 dit_dp_group_ = create_process_group (global_rank,
156163 local_rank,
157164 rank_per_group,
@@ -161,6 +168,7 @@ void DiTCollectiveCommunicator::create_process_groups(
161168 host,
162169 " dp_group" ,
163170 device);
171+ #endif
164172 parallel_args_->dit_dp_group_ = dit_dp_group_.get ();
165173 port += num_group;
166174 }
0 commit comments