Skip to content

Commit 0f7d258

Browse files
feat: support flux model on mlu device. (#1138)
Co-authored-by: a120092009 <zhaoty0121@gmail.com>
1 parent 535b3bc commit 0f7d258

14 files changed

Lines changed: 85 additions & 53 deletions

xllm/core/framework/dit_model_context.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,21 @@ const QuantArgs& DiTModelContext::get_quant_args(
7474
}
7575
}
7676

77-
#if defined(USE_NPU)
77+
#if defined(USE_NPU) || defined(USE_CUDA) || defined(USE_MLU)
7878
ModelContext DiTModelContext::get_model_context(
7979
const std::string& component) const {
80+
#if defined(USE_NPU)
8081
return ModelContext(parallel_args_,
8182
get_model_args(component),
8283
get_quant_args(component),
8384
tensor_options_,
8485
context_);
86+
#else
87+
return ModelContext(parallel_args_,
88+
get_model_args(component),
89+
get_quant_args(component),
90+
tensor_options_);
91+
#endif
8592
}
8693
#endif
8794

xllm/core/framework/dit_model_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class DiTModelContext {
4444

4545
const QuantArgs& get_quant_args(const std::string& component) const;
4646

47-
#if defined(USE_NPU)
47+
#if defined(USE_NPU) || defined(USE_CUDA) || defined(USE_MLU)
4848
ModelContext get_model_context(const std::string& component) const;
4949
#endif
5050

xllm/core/framework/parallel_state/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ cc_library(
66
parallel_state
77
HDRS
88
mapping_npu.h
9-
dit_mapping_npu.h
9+
dit_mapping.h
1010
rank_generator.h
1111
parallel_args.h
1212
parallel_state.h
@@ -21,7 +21,7 @@ cc_library(
2121
dit_collective_communicator.h
2222
SRCS
2323
mapping_npu.cpp
24-
dit_mapping_npu.cpp
24+
dit_mapping.cpp
2525
parallel_state.cpp
2626
parallel_state_async.cpp
2727
process_group.cpp

xllm/core/framework/parallel_state/dit_collective_communicator.cpp

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,20 @@ DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank,
4242
int32_t dit_sp_size,
4343
int32_t dit_cfg_size)
4444
: CollectiveCommunicatorBase(global_rank, world_size) {
45-
#if defined(USE_NPU)
46-
DiTMappingNPU::Options dit_mapping_options;
47-
dit_mapping_options.dit_tp_size(dit_tp_size)
48-
.dit_sp_size(dit_sp_size)
49-
.dit_cfg_size(dit_cfg_size)
50-
.dit_dp_size(dit_dp_size);
51-
dit_mapping_npu_ = std::make_unique<DiTMappingNPU>(
52-
world_size, global_rank, dit_mapping_options);
5345
parallel_args_ = std::make_unique<ParallelArgs>(global_rank,
5446
world_size,
5547
dit_dp_size,
5648
dit_tp_size,
5749
dit_sp_size,
5850
dit_cfg_size,
59-
nullptr);
60-
#endif
51+
/*process_group=*/nullptr);
52+
DiTMapping::Options dit_mapping_options;
53+
dit_mapping_options.dit_tp_size(dit_tp_size)
54+
.dit_sp_size(dit_sp_size)
55+
.dit_cfg_size(dit_cfg_size)
56+
.dit_dp_size(dit_dp_size);
57+
dit_mapping_ = std::make_unique<DiTMapping>(
58+
world_size, global_rank, dit_mapping_options);
6159
}
6260

6361
void DiTCollectiveCommunicator::create_process_groups(
@@ -87,14 +85,13 @@ void DiTCollectiveCommunicator::create_process_groups(
8785

8886
parallel_args_->process_group_ = process_group_.get();
8987

90-
if (tp_size > 1) {
91-
auto tp_parallel_info = dit_mapping_npu_->get_parallel_info("tp");
88+
if (tp_size > 1 && dit_mapping_) {
89+
auto tp_parallel_info = dit_mapping_->get_parallel_info("tp");
9290
auto group_id = tp_parallel_info.current_group_id();
9391
auto num_group = tp_parallel_info.num_group();
9492
auto local_rank = tp_parallel_info.rank();
9593
auto& rank_per_group = tp_parallel_info.rank_per_group()[group_id];
9694
int port_offset = group_id + 1;
97-
#if defined(USE_NPU)
9895
dit_tp_group_ = create_process_group(global_rank,
9996
local_rank,
10097
rank_per_group,
@@ -105,18 +102,16 @@ void DiTCollectiveCommunicator::create_process_groups(
105102
"tp_group",
106103
device);
107104
parallel_args_->dit_tp_group_ = dit_tp_group_.get();
108-
#endif
109105
port += num_group;
110106
}
111107

112-
if (sp_size > 1) {
113-
auto sp_parallel_info = dit_mapping_npu_->get_parallel_info("sp");
108+
if (sp_size > 1 && dit_mapping_) {
109+
auto sp_parallel_info = dit_mapping_->get_parallel_info("sp");
114110
auto group_id = sp_parallel_info.current_group_id();
115111
auto num_group = sp_parallel_info.num_group();
116112
auto local_rank = sp_parallel_info.rank();
117113
auto& rank_per_group = sp_parallel_info.rank_per_group()[group_id];
118114
int port_offset = group_id + 1;
119-
#if defined(USE_NPU)
120115
dit_sp_group_ = create_process_group(global_rank,
121116
local_rank,
122117
rank_per_group,
@@ -127,18 +122,16 @@ void DiTCollectiveCommunicator::create_process_groups(
127122
"sp_group",
128123
device);
129124
parallel_args_->dit_sp_group_ = dit_sp_group_.get();
130-
#endif
131125
port += num_group;
132126
}
133127

134-
if (cfg_size > 1) {
135-
auto cfg_parallel_info = dit_mapping_npu_->get_parallel_info("cfg");
128+
if (cfg_size > 1 && dit_mapping_) {
129+
auto cfg_parallel_info = dit_mapping_->get_parallel_info("cfg");
136130
auto group_id = cfg_parallel_info.current_group_id();
137131
auto num_group = cfg_parallel_info.num_group();
138132
auto local_rank = cfg_parallel_info.rank();
139133
auto& rank_per_group = cfg_parallel_info.rank_per_group()[group_id];
140134
int port_offset = group_id + 1;
141-
#if defined(USE_NPU)
142135
dit_cfg_group_ = create_process_group(global_rank,
143136
local_rank,
144137
rank_per_group,
@@ -149,18 +142,16 @@ void DiTCollectiveCommunicator::create_process_groups(
149142
"cfg_group",
150143
device);
151144
parallel_args_->dit_cfg_group_ = dit_cfg_group_.get();
152-
#endif
153145
port += num_group;
154146
}
155147

156-
if (dp_size > 1) {
157-
auto dp_parallel_info = dit_mapping_npu_->get_parallel_info("dp");
148+
if (dp_size > 1 && dit_mapping_) {
149+
auto dp_parallel_info = dit_mapping_->get_parallel_info("dp");
158150
auto group_id = dp_parallel_info.current_group_id();
159151
auto num_group = dp_parallel_info.num_group();
160152
auto local_rank = dp_parallel_info.rank();
161153
auto& rank_per_group = dp_parallel_info.rank_per_group()[group_id];
162154
int port_offset = group_id + 1;
163-
#if defined(USE_NPU)
164155
dit_dp_group_ = create_process_group(global_rank,
165156
local_rank,
166157
rank_per_group,
@@ -171,7 +162,6 @@ void DiTCollectiveCommunicator::create_process_groups(
171162
"dp_group",
172163
device);
173164
parallel_args_->dit_dp_group_ = dit_dp_group_.get();
174-
#endif
175165
port += num_group;
176166
}
177167
}

xllm/core/framework/parallel_state/dit_collective_communicator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License.
1616
#pragma once
1717

1818
#include "collective_communicator_base.h"
19-
#include "dit_mapping_npu.h"
19+
#include "dit_mapping.h"
2020

2121
namespace xllm {
2222

@@ -38,7 +38,7 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase {
3838
const ParallelArgs* parallel_args() override;
3939

4040
private:
41-
std::unique_ptr<DiTMappingNPU> dit_mapping_npu_{nullptr};
41+
std::unique_ptr<DiTMapping> dit_mapping_{nullptr};
4242
std::unique_ptr<ParallelArgs> parallel_args_;
4343
std::unique_ptr<ProcessGroup> process_group_;
4444
std::unique_ptr<ProcessGroup> dit_tp_group_;

xllm/core/framework/parallel_state/dit_mapping_npu.cpp renamed to xllm/core/framework/parallel_state/dit_mapping.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "dit_mapping_npu.h"
16+
#include "dit_mapping.h"
1717

1818
#include <glog/logging.h>
1919

2020
namespace xllm {
2121

22-
DiTMappingNPU::DiTMappingNPU(const int32_t world_size,
23-
const int32_t rank,
24-
const Options& options)
22+
DiTMapping::DiTMapping(const int32_t world_size,
23+
const int32_t rank,
24+
const Options& options)
2525
: rank_(rank), options_(options), world_size_(world_size) {
2626
tp_.backend("hccl");
2727
sp_.backend("hccl");
@@ -41,7 +41,7 @@ DiTMappingNPU::DiTMappingNPU(const int32_t world_size,
4141
set_group_by_type(dp_, "dp");
4242
}
4343

44-
void DiTMappingNPU::parse_parallel_info() {
44+
void DiTMapping::parse_parallel_info() {
4545
if (options_.dit_tp_size() != -1) {
4646
tp_.group_size(options_.dit_tp_size());
4747
}
@@ -56,7 +56,7 @@ void DiTMappingNPU::parse_parallel_info() {
5656
}
5757
}
5858

59-
void DiTMappingNPU::validate() {
59+
void DiTMapping::validate() {
6060
CHECK(cfg_.group_size() * tp_.group_size() * sp_.group_size() *
6161
dp_.group_size() ==
6262
world_size_)
@@ -84,8 +84,8 @@ void DiTMappingNPU::validate() {
8484
". Please check `cfg` .";
8585
}
8686

87-
void DiTMappingNPU::set_group_by_type(ParallelInfo& parallel_info,
88-
const std::string& group_type) {
87+
void DiTMapping::set_group_by_type(ParallelInfo& parallel_info,
88+
const std::string& group_type) {
8989
auto rank_per_group = rank_generator_->get_ranks(group_type);
9090
parallel_info.rank_per_group(rank_per_group);
9191
auto group_size = rank_per_group[0].size();
@@ -99,7 +99,7 @@ void DiTMappingNPU::set_group_by_type(ParallelInfo& parallel_info,
9999
parallel_info.rank(local_rank);
100100
}
101101

102-
std::tuple<int32_t, int32_t> DiTMappingNPU::get_current_group_id(
102+
std::tuple<int32_t, int32_t> DiTMapping::get_current_group_id(
103103
const std::vector<std::vector<int32_t>>& rank_per_group,
104104
int32_t target_rank_id) {
105105
for (int32_t idx = 0; idx < rank_per_group.size(); ++idx) {
@@ -112,7 +112,7 @@ std::tuple<int32_t, int32_t> DiTMappingNPU::get_current_group_id(
112112
return std::make_tuple(-1, -1);
113113
}
114114

115-
const ParallelInfo& DiTMappingNPU::get_parallel_info(
115+
const ParallelInfo& DiTMapping::get_parallel_info(
116116
const std::string& group_type) const {
117117
if (group_type == "tp") {
118118
return tp_;
@@ -127,7 +127,7 @@ const ParallelInfo& DiTMappingNPU::get_parallel_info(
127127
}
128128
}
129129

130-
nlohmann::json DiTMappingNPU::to_json() {
130+
nlohmann::json DiTMapping::to_json() {
131131
nlohmann::json data;
132132

133133
data["SpSize"] = options_.dit_sp_size();

xllm/core/framework/parallel_state/dit_mapping_npu.h renamed to xllm/core/framework/parallel_state/dit_mapping.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License.
2323
#include "rank_generator.h"
2424
namespace xllm {
2525

26-
class DiTMappingNPU final {
26+
class DiTMapping final {
2727
public:
2828
struct Options {
2929
// cfg size
@@ -36,9 +36,9 @@ class DiTMappingNPU final {
3636
PROPERTY(int32_t, dit_dp_size) = -1;
3737
};
3838

39-
DiTMappingNPU(const int32_t world_size,
40-
const int32_t rank,
41-
const Options& options);
39+
DiTMapping(const int32_t world_size,
40+
const int32_t rank,
41+
const Options& options);
4242

4343
int32_t get_num_nodes();
4444

xllm/core/framework/parallel_state/mlu_process_group.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ class ProcessGroupImpl : public ProcessGroup {
5555
pg_ = std::make_unique<torch_mlu::ProcessGroupCNCL>(
5656
store, rank, rank_size, pg_options);
5757
}
58+
59+
ProcessGroupImpl(int32_t global_rank,
60+
int32_t local_rank,
61+
const std::vector<int32_t>& group_ranks,
62+
int32_t world_size,
63+
int32_t rank_size,
64+
int32_t port,
65+
const std::string& host,
66+
const std::string& group_name,
67+
const torch::Device& device)
68+
: ProcessGroup(global_rank, world_size, device) {
69+
c10::intrusive_ptr<torch_mlu::ProcessGroupCNCL::Options> pg_options =
70+
torch_mlu::ProcessGroupCNCL::Options::create();
71+
pg_options->group_name = group_name;
72+
std::vector<size_t> ranks_unsigned;
73+
ranks_unsigned.reserve(group_ranks.size());
74+
for (int32_t rank : group_ranks) {
75+
ranks_unsigned.push_back(static_cast<size_t>(rank));
76+
}
77+
pg_options->global_ranks_in_group = ranks_unsigned;
78+
auto store = create_tcp_store(host, port, local_rank);
79+
pg_ = std::make_unique<torch_mlu::ProcessGroupCNCL>(
80+
store, local_rank, rank_size, pg_options);
81+
}
5882
};
5983

6084
} // namespace xllm

xllm/core/framework/parallel_state/process_group.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ std::unique_ptr<ProcessGroup> create_process_group(
198198
rank, world_size, rank_size, port, trans, host, group_name, device);
199199
}
200200

201-
#if defined(USE_NPU)
202201
// TODO: This function is used by DiT models, since the DiT communication group
203202
// info have already been calculated by rank_generator, we only need to pass the
204203
// info to create the process groups. For any device that want to reuse the
@@ -224,5 +223,4 @@ std::unique_ptr<ProcessGroup> create_process_group(
224223
group_name,
225224
device);
226225
}
227-
#endif
228226
} // namespace xllm

xllm/models/dit/clip_text_model.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ limitations under the License.
1616

1717
#pragma once
1818

19+
#if defined(USE_NPU)
1920
#include <atb/atb_infer.h>
21+
#endif
2022
#include <c10/core/ScalarType.h>
2123
#include <torch/torch.h>
2224

@@ -27,13 +29,17 @@ limitations under the License.
2729
#include "core/framework/kv_cache/kv_cache.h"
2830
#include "core/framework/model/model_input_params.h"
2931
#include "core/framework/model_context.h"
32+
#if defined(USE_NPU)
3033
#include "core/layers/npu/npu_siglip_encoder_layer_impl.h"
34+
#endif
3135
#include "models/model_registry.h"
3236
#include "processors/clip_image_processor.h"
3337
#include "processors/clip_input_processor.h"
3438
#include "processors/pywarpper_image_processor.h"
3539
#include "xllm/core/layers/common/add_matmul.h"
40+
#if defined(USE_NPU)
3641
#include "xllm_atb_layers/core/include/atb_speed/log.h"
42+
#endif
3743

3844
namespace xllm {
3945
// clip_text_model compatible with huggingface weights

0 commit comments

Comments
 (0)