Skip to content

Commit bd0a8b0

Browse files
feat: move rope calculation from QwenImageEditPlus pipeline to transformer.
1 parent 974685c commit bd0a8b0

2 files changed

Lines changed: 63 additions & 75 deletions

File tree

xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,30 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
6868
register_module("scheduler", scheduler_);
6969
register_module("transformer", transformer_);
7070
register_module("vae_image_processor", vae_image_processor_);
71+
72+
use_layer3d_rope_ = context.get_model_context("transformer")
73+
.get_model_args()
74+
.use_layer3d_rope();
75+
std::vector<int64_t> axes_dims_rope =
76+
context.get_model_context("transformer")
77+
.get_model_args()
78+
.axes_dims_rope();
79+
// Positional embedding
80+
if (use_layer3d_rope_) {
81+
pos_embed_3d_rope_ = register_module(
82+
"pos_embed",
83+
QwenEmbedLayer3DRope(context.get_model_context("transformer"),
84+
/*theta=*/10000,
85+
axes_dims_rope,
86+
true));
87+
} else {
88+
pos_embed_ = register_module(
89+
"pos_embed",
90+
QwenEmbedRope(context.get_model_context("transformer"),
91+
/*theta=*/10000,
92+
axes_dims_rope,
93+
true));
94+
}
7195
}
7296

7397
std::vector<torch::Tensor> _extract_masked_hidden(
@@ -461,51 +485,31 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
461485
if (do_true_cfg && negative_prompt_embeds_mask.defined()) {
462486
negative_txt_seq_lens = negative_prompt_embeds_mask.sum(1);
463487
}
464-
/*
465-
if (prompt_embeds.size(1) % FLAGS_sp_size != 0) {
466-
int64_t pad_len =
467-
FLAGS_sp_size - prompt_embeds.size(1) % FLAGS_sp_size;
468-
std::vector<int64_t> pad_with = {
469-
0,
470-
0, // 第3维�~Hhe ight�~I� ~Mpad
471-
0,
472-
pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
473-
0,
474-
0}; // 第1维�~Hbatch�~I�~Mpad
475-
std::vector<int64_t> pad_with_mask = {
476-
// 第3维�~Hhe ight�~I� ~Mpad
477-
0,
478-
pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
479-
0,
480-
0}; // 第1维�~Hbatch�~I�~Mpad
481-
prompt_embeds = torch::pad(prompt_embeds, pad_with, "constant", 0);
482-
prompt_embeds_mask =
483-
torch::pad(prompt_embeds_mask, pad_with_mask, "constant", 0);
484-
}
485488

486-
if (negative_prompt_embeds.size(1) % FLAGS_sp_size != 0) {
487-
int64_t pad_len = FLAGS_sp_size -
488-
negative_prompt_embeds.size(1) % FLAGS_sp_size;
489-
std::vector<int64_t> pad_with = {
490-
0,
491-
0, // 第3维�~Hhe ight�~I� ~Mpad
492-
0,
493-
pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
494-
0,
495-
0}; // 第1维�~Hbatch�~I�~Mpad
496-
std::vector<int64_t> pad_with_mask = {
497-
// 第3维�~Hhe ight�~I� ~Mpad
498-
0,
499-
pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
500-
0,
501-
0};
502-
negative_prompt_embeds =
503-
torch::pad(negative_prompt_embeds, pad_with, "constant", 0);
504-
negative_prompt_embeds_mask =
505-
torch::pad(negative_prompt_embeds_mask, pad_with_mask, "constant", 0);
506-
}
507-
*/
508489
scheduler_->set_begin_index(0);
490+
491+
int64_t origin_text_seq_len = prompt_embeds.size(1);
492+
int64_t origin_neg_text_seq_len = negative_prompt_embeds.size(1);
493+
std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_pos;
494+
std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_neg;
495+
if (use_layer3d_rope_) {
496+
image_rotary_emb_pos = pos_embed_3d_rope_->forward(
497+
main_shape, origin_text_seq_len, prompt_embeds.device());
498+
image_rotary_emb_neg = pos_embed_3d_rope_->forward(
499+
main_shape, origin_neg_text_seq_len, prompt_embeds.device());
500+
} else {
501+
image_rotary_emb_pos =
502+
pos_embed_->forward(main_shape,
503+
origin_text_seq_len,
504+
prompt_embeds.device(),
505+
/*max_txt_seq_len=*/std::nullopt);
506+
image_rotary_emb_neg =
507+
pos_embed_->forward(main_shape,
508+
origin_neg_text_seq_len,
509+
prompt_embeds.device(),
510+
/*max_txt_seq_len=*/std::nullopt);
511+
}
512+
509513
for (int64_t i = 0; i < timesteps.size(0); ++i) {
510514
auto t = timesteps[i];
511515
current_timestep_ = t;
@@ -530,6 +534,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
530534
timestep_expanded / 1000.0,
531535
main_shape,
532536
txt_seq_lens,
537+
image_rotary_emb_pos,
533538
/*use_cfg=*/false,
534539
/*step_index=*/i);
535540
noise_pred = noise_pred.slice(1, 0, final_latents.size(1));
@@ -544,6 +549,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
544549
timestep_expanded / 1000.0,
545550
main_shape,
546551
negative_txt_seq_lens,
552+
image_rotary_emb_neg,
547553
/*use_cfg=*/true,
548554
/*step_index=*/i);
549555

@@ -567,6 +573,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
567573
timestep_expanded / 1000.0,
568574
main_shape,
569575
txt_seq_lens,
576+
image_rotary_emb_pos,
570577
/*use_cfg=*/false,
571578
/*step_index=*/i);
572579
noise_pred = noise_pred.slice(1, 0, final_latents.size(1));
@@ -577,6 +584,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
577584
timestep_expanded / 1000.0,
578585
main_shape,
579586
negative_txt_seq_lens,
587+
image_rotary_emb_neg,
580588
/*use_cfg=*/true,
581589
/*step_index=*/i);
582590

@@ -653,6 +661,9 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
653661
torch::Tensor current_timestep_;
654662
string prompt_template_encode_;
655663
const ModelArgs& vae_model_args_;
664+
bool use_layer3d_rope_;
665+
QwenEmbedRope pos_embed_{nullptr};
666+
QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr};
656667
};
657668

658669
REGISTER_MODEL_ARGS(Qwen2Tokenizer, [&] {});

xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,8 @@ class QwenDoubleStreamAttnProcessorBase : public torch::nn::Module {
13901390
const torch::Tensor& encoder_hidden_states, // Text stream
13911391
const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
13921392
const torch::Tensor& attention_mask = torch::Tensor(),
1393-
const std::tuple<at::Tensor, at::Tensor>& image_rotary_emb = {}) = 0;
1393+
const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb =
1394+
{}) = 0;
13941395

13951396
virtual void load_state_dict(const StateDict& state_dict) {
13961397
attn_->load_state_dict(state_dict);
@@ -1405,7 +1406,7 @@ class QwenDoubleStreamAttnProcessorBase : public torch::nn::Module {
14051406
};
14061407

14071408
// Implementation of attention forward with communication & computation overlap
1408-
class QwenDoubleStreamAttnProcessorCMO2_0Impl
1409+
class QwenDoubleStreamAttnProcessorCMO2_0Impl final
14091410
: public QwenDoubleStreamAttnProcessorBase {
14101411
public:
14111412
QwenDoubleStreamAttnProcessorCMO2_0Impl(Attention&& attn_module,
@@ -1426,7 +1427,7 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
14261427
const torch::Tensor& encoder_hidden_states, // Text stream
14271428
const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
14281429
const torch::Tensor& attention_mask = torch::Tensor(),
1429-
const std::tuple<at::Tensor, at::Tensor>& image_rotary_emb = {})
1430+
const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb = {})
14301431
override {
14311432
// Compute QKV for image stream (sample projections)
14321433
// auto reshape_dims = std::vector<int64_t>{heads / FLAGS_sp_size,
@@ -1562,8 +1563,8 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
15621563
/*atten_mask*/ torch::nullopt,
15631564
/*scale=*/pow(joint_query.size(3), -0.5),
15641565
/*keep_prob=*/1.0,
1565-
/*pre_tockens=*/65535,
1566-
/*next_tockens=*/65535);
1566+
/*pre_tokens=*/65535,
1567+
/*next_tokens=*/65535);
15671568

15681569
auto joint_hidden_states = std::get<0>(results);
15691570
// Reshape back
@@ -1634,7 +1635,7 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
16341635
TORCH_MODULE(QwenDoubleStreamAttnProcessorCMO2_0);
16351636

16361637
// Implementation of attention forward
1637-
class QwenDoubleStreamAttnProcessor2_0Impl
1638+
class QwenDoubleStreamAttnProcessor2_0Impl final
16381639
: public QwenDoubleStreamAttnProcessorBase {
16391640
public:
16401641
QwenDoubleStreamAttnProcessor2_0Impl(Attention&& attn_module,
@@ -1647,7 +1648,7 @@ class QwenDoubleStreamAttnProcessor2_0Impl
16471648
const torch::Tensor& encoder_hidden_states, // Text stream
16481649
const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
16491650
const torch::Tensor& attention_mask = torch::Tensor(),
1650-
const std::tuple<at::Tensor, at::Tensor>& image_rotary_emb = {})
1651+
const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb = {})
16511652
override {
16521653
// int64_t seq_txt = encoder_hidden_states.size(1);
16531654
// int64_t seq_img = hidden_states.size(1);
@@ -2115,17 +2116,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
21152116
out_channels = (out_channels > 0) ? out_channels : in_channels;
21162117
auto inner_dim = num_attention_heads * attention_head_dim;
21172118

2118-
// Positional embedding
2119-
if (use_layer3d_rope_) {
2120-
pos_embed_3d_rope_ = register_module(
2121-
"pos_embed",
2122-
QwenEmbedLayer3DRope(context, /*theta=*/10000, axes_dims_rope, true));
2123-
} else {
2124-
pos_embed_ = register_module(
2125-
"pos_embed",
2126-
QwenEmbedRope(context, /*theta=*/10000, axes_dims_rope, true));
2127-
}
2128-
21292119
// Time-text embedding
21302120
time_text_embed_ = register_module(
21312121
"time_text_embed",
@@ -2178,6 +2168,7 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
21782168
torch::Tensor timestep = torch::Tensor(),
21792169
std::vector<std::vector<int64_t>> img_shapes = {},
21802170
torch::Tensor txt_seq_lens = torch::Tensor(),
2171+
const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb = {},
21812172
bool use_cfg = false,
21822173
int64_t step_idx = 0,
21832174
torch::Tensor addition_t_cond = torch::Tensor(),
@@ -2212,8 +2203,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
22122203
modulate_index = torch::Tensor();
22132204
}
22142205

2215-
auto origin_text_seq_len = encoder_hidden_states.size(1);
2216-
22172206
// padding mask for sequence parallel scene
22182207
auto padded_encoder_hidden_states_mask =
22192208
xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor(
@@ -2245,16 +2234,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
22452234
padded_encoder_hidden_states_mask);
22462235
auto temb = time_text_embed_->forward(
22472236
new_timestep, new_hidden_states, addition_t_cond);
2248-
std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb;
2249-
if (use_layer3d_rope_) {
2250-
image_rotary_emb = pos_embed_3d_rope_->forward(
2251-
img_shapes, origin_text_seq_len, new_hidden_states.device());
2252-
} else {
2253-
image_rotary_emb = pos_embed_->forward(img_shapes,
2254-
origin_text_seq_len,
2255-
new_hidden_states.device(),
2256-
/*max_txt_seq_len=*/std::nullopt);
2257-
}
22582237

22592238
std::unordered_map<std::string, torch::Tensor> block_attention_kwargs;
22602239
if (new_encoder_hidden_states_mask.has_value() &&
@@ -2398,8 +2377,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
23982377

23992378
private:
24002379
torch::TensorOptions options_;
2401-
QwenEmbedRope pos_embed_{nullptr};
2402-
QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr};
24032380
QwenTimestepProjEmbeddings time_text_embed_{nullptr};
24042381
RMSNorm txt_norm_{nullptr};
24052382
layer::AddMatmulWeightTransposed img_in_{nullptr};

0 commit comments

Comments
 (0)