diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c5419b9f107e..d88aef4dcf2a 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -933,6 +933,7 @@ def forward( batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + joint_attention_mask = joint_attention_mask[:, None, None, :] block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 2b6aab59a662..bea832904041 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -200,7 +200,6 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) - @pytest.mark.xfail(reason="Context parallel may not support batch_size > 1") @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2)