@@ -1390,7 +1390,8 @@ class QwenDoubleStreamAttnProcessorBase : public torch::nn::Module {
13901390 const torch::Tensor& encoder_hidden_states, // Text stream
13911391 const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
13921392 const torch::Tensor& attention_mask = torch::Tensor(),
1393- const std::tuple<at::Tensor, at::Tensor>& image_rotary_emb = {}) = 0 ;
1393+ const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb =
1394+ {}) = 0 ;
13941395
13951396 virtual void load_state_dict (const StateDict& state_dict) {
13961397 attn_->load_state_dict (state_dict);
@@ -1405,7 +1406,7 @@ class QwenDoubleStreamAttnProcessorBase : public torch::nn::Module {
14051406};
14061407
14071408// Implementation of attention forward with communication & computation overlap
1408- class QwenDoubleStreamAttnProcessorCMO2_0Impl
1409+ class QwenDoubleStreamAttnProcessorCMO2_0Impl final
14091410 : public QwenDoubleStreamAttnProcessorBase {
14101411 public:
14111412 QwenDoubleStreamAttnProcessorCMO2_0Impl (Attention&& attn_module,
@@ -1426,7 +1427,7 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
14261427 const torch::Tensor& encoder_hidden_states, // Text stream
14271428 const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
14281429 const torch::Tensor& attention_mask = torch::Tensor(),
1429- const std::tuple<at ::Tensor, at ::Tensor>& image_rotary_emb = {})
1430+ const std::tuple<torch ::Tensor, torch ::Tensor>& image_rotary_emb = {})
14301431 override {
14311432 // Compute QKV for image stream (sample projections)
14321433 // auto reshape_dims = std::vector<int64_t>{heads / FLAGS_sp_size,
@@ -1562,8 +1563,8 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
15621563 /* atten_mask*/ torch::nullopt ,
15631564 /* scale=*/ pow (joint_query.size (3 ), -0.5 ),
15641565 /* keep_prob=*/ 1.0 ,
1565- /* pre_tockens =*/ 65535 ,
1566- /* next_tockens =*/ 65535 );
1566+ /* pre_tokens =*/ 65535 ,
1567+ /* next_tokens =*/ 65535 );
15671568
15681569 auto joint_hidden_states = std::get<0 >(results);
15691570 // Reshape back
@@ -1634,7 +1635,7 @@ class QwenDoubleStreamAttnProcessorCMO2_0Impl
16341635TORCH_MODULE (QwenDoubleStreamAttnProcessorCMO2_0);
16351636
16361637// Implementation of attention forward
1637- class QwenDoubleStreamAttnProcessor2_0Impl
1638+ class QwenDoubleStreamAttnProcessor2_0Impl final
16381639 : public QwenDoubleStreamAttnProcessorBase {
16391640 public:
16401641 QwenDoubleStreamAttnProcessor2_0Impl (Attention&& attn_module,
@@ -1647,7 +1648,7 @@ class QwenDoubleStreamAttnProcessor2_0Impl
16471648 const torch::Tensor& encoder_hidden_states, // Text stream
16481649 const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(),
16491650 const torch::Tensor& attention_mask = torch::Tensor(),
1650- const std::tuple<at ::Tensor, at ::Tensor>& image_rotary_emb = {})
1651+ const std::tuple<torch ::Tensor, torch ::Tensor>& image_rotary_emb = {})
16511652 override {
16521653 // int64_t seq_txt = encoder_hidden_states.size(1);
16531654 // int64_t seq_img = hidden_states.size(1);
@@ -2115,17 +2116,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
21152116 out_channels = (out_channels > 0 ) ? out_channels : in_channels;
21162117 auto inner_dim = num_attention_heads * attention_head_dim;
21172118
2118- // Positional embedding
2119- if (use_layer3d_rope_) {
2120- pos_embed_3d_rope_ = register_module (
2121- " pos_embed" ,
2122- QwenEmbedLayer3DRope (context, /* theta=*/ 10000 , axes_dims_rope, true ));
2123- } else {
2124- pos_embed_ = register_module (
2125- " pos_embed" ,
2126- QwenEmbedRope (context, /* theta=*/ 10000 , axes_dims_rope, true ));
2127- }
2128-
21292119 // Time-text embedding
21302120 time_text_embed_ = register_module (
21312121 " time_text_embed" ,
@@ -2178,6 +2168,7 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
21782168 torch::Tensor timestep = torch::Tensor(),
21792169 std::vector<std::vector<int64_t>> img_shapes = {},
21802170 torch::Tensor txt_seq_lens = torch::Tensor(),
2171+ const std::tuple<torch::Tensor, torch::Tensor>& image_rotary_emb = {},
21812172 bool use_cfg = false ,
21822173 int64_t step_idx = 0 ,
21832174 torch::Tensor addition_t_cond = torch::Tensor(),
@@ -2212,8 +2203,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
22122203 modulate_index = torch::Tensor ();
22132204 }
22142205
2215- auto origin_text_seq_len = encoder_hidden_states.size (1 );
2216-
22172206 // padding mask for sequence parallel scene
22182207 auto padded_encoder_hidden_states_mask =
22192208 xllm::dit::SequenceParallelPadManager::getInstance ().pad_tensor (
@@ -2245,16 +2234,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
22452234 padded_encoder_hidden_states_mask);
22462235 auto temb = time_text_embed_->forward (
22472236 new_timestep, new_hidden_states, addition_t_cond);
2248- std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb;
2249- if (use_layer3d_rope_) {
2250- image_rotary_emb = pos_embed_3d_rope_->forward (
2251- img_shapes, origin_text_seq_len, new_hidden_states.device ());
2252- } else {
2253- image_rotary_emb = pos_embed_->forward (img_shapes,
2254- origin_text_seq_len,
2255- new_hidden_states.device (),
2256- /* max_txt_seq_len=*/ std::nullopt );
2257- }
22582237
22592238 std::unordered_map<std::string, torch::Tensor> block_attention_kwargs;
22602239 if (new_encoder_hidden_states_mask.has_value () &&
@@ -2398,8 +2377,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module {
23982377
23992378 private:
24002379 torch::TensorOptions options_;
2401- QwenEmbedRope pos_embed_{nullptr };
2402- QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr };
24032380 QwenTimestepProjEmbeddings time_text_embed_{nullptr };
24042381 RMSNorm txt_norm_{nullptr };
24052382 layer::AddMatmulWeightTransposed img_in_{nullptr };
0 commit comments