Skip to content

Commit 8a4cc1a

Browse files
authored
bugfix: fix build error on cuda and ilu device. (#1252)
1 parent 47acd9f commit 8a4cc1a

3 files changed

Lines changed: 14 additions & 0 deletions

File tree

xllm/core/framework/parallel_state/dit_collective_communicator.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ void DiTCollectiveCommunicator::create_process_groups(
9292
auto local_rank = tp_parallel_info.rank();
9393
auto& rank_per_group = tp_parallel_info.rank_per_group()[group_id];
9494
int port_offset = group_id + 1;
95+
#if defined(USE_NPU) || defined(USE_MLU)
9596
dit_tp_group_ = create_process_group(global_rank,
9697
local_rank,
9798
rank_per_group,
@@ -101,6 +102,7 @@ void DiTCollectiveCommunicator::create_process_groups(
101102
host,
102103
"tp_group",
103104
device);
105+
#endif
104106
parallel_args_->dit_tp_group_ = dit_tp_group_.get();
105107
port += num_group;
106108
}
@@ -112,6 +114,7 @@ void DiTCollectiveCommunicator::create_process_groups(
112114
auto local_rank = sp_parallel_info.rank();
113115
auto& rank_per_group = sp_parallel_info.rank_per_group()[group_id];
114116
int port_offset = group_id + 1;
117+
#if defined(USE_NPU) || defined(USE_MLU)
115118
dit_sp_group_ = create_process_group(global_rank,
116119
local_rank,
117120
rank_per_group,
@@ -121,6 +124,7 @@ void DiTCollectiveCommunicator::create_process_groups(
121124
host,
122125
"sp_group",
123126
device);
127+
#endif
124128
parallel_args_->dit_sp_group_ = dit_sp_group_.get();
125129
port += num_group;
126130
}
@@ -132,6 +136,7 @@ void DiTCollectiveCommunicator::create_process_groups(
132136
auto local_rank = cfg_parallel_info.rank();
133137
auto& rank_per_group = cfg_parallel_info.rank_per_group()[group_id];
134138
int port_offset = group_id + 1;
139+
#if defined(USE_NPU) || defined(USE_MLU)
135140
dit_cfg_group_ = create_process_group(global_rank,
136141
local_rank,
137142
rank_per_group,
@@ -141,6 +146,7 @@ void DiTCollectiveCommunicator::create_process_groups(
141146
host,
142147
"cfg_group",
143148
device);
149+
#endif
144150
parallel_args_->dit_cfg_group_ = dit_cfg_group_.get();
145151
port += num_group;
146152
}
@@ -152,6 +158,7 @@ void DiTCollectiveCommunicator::create_process_groups(
152158
auto local_rank = dp_parallel_info.rank();
153159
auto& rank_per_group = dp_parallel_info.rank_per_group()[group_id];
154160
int port_offset = group_id + 1;
161+
#if defined(USE_NPU) || defined(USE_MLU)
155162
dit_dp_group_ = create_process_group(global_rank,
156163
local_rank,
157164
rank_per_group,
@@ -161,6 +168,7 @@ void DiTCollectiveCommunicator::create_process_groups(
161168
host,
162169
"dp_group",
163170
device);
171+
#endif
164172
parallel_args_->dit_dp_group_ = dit_dp_group_.get();
165173
port += num_group;
166174
}

xllm/core/framework/parallel_state/process_group.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ std::unique_ptr<ProcessGroup> create_process_group(
198198
rank, world_size, rank_size, port, trans, host, group_name, device);
199199
}
200200

201+
#if defined(USE_NPU) || defined(USE_MLU)
202+
// we only support DiT models onNPU and MLU for now.
201203
// TODO: This function is used by DiT models, since the DiT communication group
202204
// info have already been calculated by rank_generator, we only need to pass the
203205
// info to create the process groups. For any device that want to reuse the
@@ -223,4 +225,5 @@ std::unique_ptr<ProcessGroup> create_process_group(
223225
group_name,
224226
device);
225227
}
228+
#endif
226229
} // namespace xllm

xllm/core/framework/parallel_state/process_group.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ std::unique_ptr<xllm::ProcessGroup> create_process_group(
133133
const std::string& group_name,
134134
const torch::Device& device);
135135

136+
#if defined(USE_NPU) || defined(USE_MLU)
137+
// for DiT models
136138
std::unique_ptr<xllm::ProcessGroup> create_process_group(
137139
int32_t global_rank,
138140
int32_t local_rank,
@@ -143,5 +145,6 @@ std::unique_ptr<xllm::ProcessGroup> create_process_group(
143145
const std::string& host,
144146
const std::string& group_name,
145147
const torch::Device& device);
148+
#endif
146149

147150
} // namespace xllm

0 commit comments

Comments
 (0)