Skip to content

Commit 7c2ad36

Browse files
committed
add _native_npu_attention support mask shape like [B,1,1,S]
1 parent 26b046c commit 7c2ad36

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15281528
batch_size, seq_len_q, seq_len_kv = attn_mask.shape[0], query.shape[1], key.shape[1]
15291529
attn_mask = attn_mask.unsqueeze(1).expand(batch_size, seq_len_q, seq_len_kv).unsqueeze(1).contiguous()
15301530
elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1):
1531-
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1)
1531+
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1).contiguous()
15321532

15331533
attn_mask = ~attn_mask.to(torch.bool)
15341534

0 commit comments

Comments
 (0)