@@ -68,6 +68,30 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
6868 register_module (" scheduler" , scheduler_);
6969 register_module (" transformer" , transformer_);
7070 register_module (" vae_image_processor" , vae_image_processor_);
71+
72+ use_layer3d_rope_ = context.get_model_context (" transformer" )
73+ .get_model_args ()
74+ .use_layer3d_rope ();
75+ std::vector<int64_t > axes_dims_rope =
76+ context.get_model_context (" transformer" )
77+ .get_model_args ()
78+ .axes_dims_rope ();
79+ // Positional embedding
80+ if (use_layer3d_rope_) {
81+ pos_embed_3d_rope_ = register_module (
82+ " pos_embed" ,
83+ QwenEmbedLayer3DRope (context.get_model_context (" transformer" ),
84+ /* theta=*/ 10000 ,
85+ axes_dims_rope,
86+ true ));
87+ } else {
88+ pos_embed_ = register_module (
89+ " pos_embed" ,
90+ QwenEmbedRope (context.get_model_context (" transformer" ),
91+ /* theta=*/ 10000 ,
92+ axes_dims_rope,
93+ true ));
94+ }
7195 }
7296
7397 std::vector<torch::Tensor> _extract_masked_hidden (
@@ -461,51 +485,31 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
461485 if (do_true_cfg && negative_prompt_embeds_mask.defined ()) {
462486 negative_txt_seq_lens = negative_prompt_embeds_mask.sum (1 );
463487 }
464- /*
465- if (prompt_embeds.size(1) % FLAGS_sp_size != 0) {
466- int64_t pad_len =
467- FLAGS_sp_size - prompt_embeds.size(1) % FLAGS_sp_size;
468- std::vector<int64_t> pad_with = {
469- 0,
470- 0, // 第3维�~Hhe ight�~I� ~Mpad
471- 0,
472- pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
473- 0,
474- 0}; // 第1维�~Hbatch�~I�~Mpad
475- std::vector<int64_t> pad_with_mask = {
476- // 第3维�~Hhe ight�~I� ~Mpad
477- 0,
478- pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
479- 0,
480- 0}; // 第1维�~Hbatch�~I�~Mpad
481- prompt_embeds = torch::pad(prompt_embeds, pad_with, "constant", 0);
482- prompt_embeds_mask =
483- torch::pad(prompt_embeds_mask, pad_with_mask, "constant", 0);
484- }
485488
486- if (negative_prompt_embeds.size(1) % FLAGS_sp_size != 0) {
487- int64_t pad_len = FLAGS_sp_size -
488- negative_prompt_embeds.size(1) % FLAGS_sp_size;
489- std::vector<int64_t> pad_with = {
490- 0,
491- 0, // 第3维�~Hhe ight�~I� ~Mpad
492- 0,
493- pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
494- 0,
495- 0}; // 第1维�~Hbatch�~I�~Mpad
496- std::vector<int64_t> pad_with_mask = {
497- // 第3维�~Hhe ight�~I� ~Mpad
498- 0,
499- pad_len, // 第 2维�~Hchannels�~I�~I~M�~P~Npad
500- 0,
501- 0};
502- negative_prompt_embeds =
503- torch::pad(negative_prompt_embeds, pad_with, "constant", 0);
504- negative_prompt_embeds_mask =
505- torch::pad(negative_prompt_embeds_mask, pad_with_mask, "constant", 0);
506- }
507- */
508489 scheduler_->set_begin_index (0 );
490+
491+ auto origin_text_seq_len = prompt_embeds.size (1 );
492+ auto origin_neg_text_seq_len = negative_prompt_embeds.size (1 );
493+ std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_pos;
494+ std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_neg;
495+ if (use_layer3d_rope_) {
496+ image_rotary_emb_pos = pos_embed_3d_rope_->forward (
497+ main_shape, origin_text_seq_len, prompt_embeds.device ());
498+ image_rotary_emb_neg = pos_embed_3d_rope_->forward (
499+ main_shape, origin_neg_text_seq_len, prompt_embeds.device ());
500+ } else {
501+ image_rotary_emb_pos =
502+ pos_embed_->forward (main_shape,
503+ origin_text_seq_len,
504+ prompt_embeds.device (),
505+ /* max_txt_seq_len=*/ std::nullopt );
506+ image_rotary_emb_neg =
507+ pos_embed_->forward (main_shape,
508+ origin_neg_text_seq_len,
509+ prompt_embeds.device (),
510+ /* max_txt_seq_len=*/ std::nullopt );
511+ }
512+
509513 for (int64_t i = 0 ; i < timesteps.size (0 ); ++i) {
510514 auto t = timesteps[i];
511515 current_timestep_ = t;
@@ -530,6 +534,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
530534 timestep_expanded / 1000.0 ,
531535 main_shape,
532536 txt_seq_lens,
537+ image_rotary_emb_pos,
533538 /* use_cfg=*/ false ,
534539 /* step_index=*/ i);
535540 noise_pred = noise_pred.slice (1 , 0 , final_latents.size (1 ));
@@ -544,6 +549,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
544549 timestep_expanded / 1000.0 ,
545550 main_shape,
546551 negative_txt_seq_lens,
552+ image_rotary_emb_neg,
547553 /* use_cfg=*/ true ,
548554 /* step_index=*/ i);
549555
@@ -567,6 +573,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
567573 timestep_expanded / 1000.0 ,
568574 main_shape,
569575 txt_seq_lens,
576+ image_rotary_emb_pos,
570577 /* use_cfg=*/ false ,
571578 /* step_index=*/ i);
572579 noise_pred = noise_pred.slice (1 , 0 , final_latents.size (1 ));
@@ -577,6 +584,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
577584 timestep_expanded / 1000.0 ,
578585 main_shape,
579586 negative_txt_seq_lens,
587+ image_rotary_emb_neg,
580588 /* use_cfg=*/ true ,
581589 /* step_index=*/ i);
582590
@@ -653,6 +661,9 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
653661 torch::Tensor current_timestep_;
654662 string prompt_template_encode_;
655663 const ModelArgs& vae_model_args_;
664+ bool use_layer3d_rope_;
665+ QwenEmbedRope pos_embed_{nullptr };
666+ QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr };
656667};
657668
658669REGISTER_MODEL_ARGS (Qwen2Tokenizer, [&] {});
0 commit comments