From 832b4d26dd5480d7c11d42d622fb902ef7df33db Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 17 Mar 2026 16:08:55 +0800 Subject: [PATCH 1/3] fix mask in SP --- src/diffusers/models/attention_dispatch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5b1f831ed060..cfb796d05af6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -813,6 +813,9 @@ def _native_attention_forward_op( if return_lse: raise ValueError("Native attention does not support return_lse=True") + if attn_mask is not None and attn_mask.dim() == 2: + attn_mask = attn_mask[:, None, None, :] + # used for backward pass if _save_ctx: ctx.save_for_backward(query, key, value) From 1e535f0d6de63debb7569cce70117ae2a6969f2d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 17 Mar 2026 18:03:34 +0800 Subject: [PATCH 2/3] change the modification to qwen specific --- src/diffusers/models/attention_dispatch.py | 3 --- src/diffusers/models/transformers/transformer_qwenimage.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index cfb796d05af6..5b1f831ed060 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -813,9 +813,6 @@ def _native_attention_forward_op( if return_lse: raise ValueError("Native attention does not support return_lse=True") - if attn_mask is not None and attn_mask.dim() == 2: - attn_mask = attn_mask[:, None, None, :] - # used for backward pass if _save_ctx: ctx.save_for_backward(query, key, value) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a54cb3b8e092..a76e4dbc93b3 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -934,6 +934,7 @@ def forward( batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + joint_attention_mask = joint_attention_mask[:, None, None, :] block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): From ea4330947b7602ec4dea968cf13f912aebd75926 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 24 Mar 2026 11:42:42 +0800 Subject: [PATCH 3/3] drop xfail since qwen-image mask is fixed --- tests/models/testing_utils/parallelism.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 2b6aab59a662..bea832904041 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -200,7 +200,6 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) - @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2)