Skip to content

Commit afdda57

Browse files
zhtmikesayakpaul
andauthored
Fix the attention mask in ulysses SP for QwenImage (#13278)
* fix mask in SP * change the modification to qwen specific * drop xfail since qwen-image mask is fixed --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 5fc2bd2 commit afdda57

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ def forward(
933933
batch_size, image_seq_len = hidden_states.shape[:2]
934934
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
935935
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
936+
joint_attention_mask = joint_attention_mask[:, None, None, :]
936937
block_attention_kwargs["attention_mask"] = joint_attention_mask
937938

938939
for index_block, block in enumerate(self.transformer_blocks):

tests/models/testing_utils/parallelism.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
200200
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
201201
)
202202

203-
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
204203
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
205204
def test_context_parallel_batch_inputs(self, cp_type):
206205
self.test_context_parallel_inference(cp_type, batch_size=2)

0 commit comments

Comments
 (0)