Skip to content

Commit c8c8401

Browse files
add _native_npu_attention support mask shape like [B,1,1,S] (#13490)
* add _native_npu_attention support mask shape like [B,1,1,S] * add _native_npu_attention support mask shape like [B,1,1,S] * fix style --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 77f8cf8 commit c8c8401

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,17 +1521,16 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15211521
if attn_mask is not None and torch.all(attn_mask != 0):
15221522
attn_mask = None
15231523

1524-
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
1524+
# Reshape Attention Mask: [batch_size, seq_len_k] or [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
15251525
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1526-
if (
1527-
attn_mask is not None
1528-
and attn_mask.ndim == 2
1529-
and attn_mask.shape[0] == query.shape[0]
1530-
and attn_mask.shape[1] == key.shape[1]
1531-
):
1532-
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
1526+
if attn_mask is not None:
1527+
if attn_mask.ndim == 2 and attn_mask.shape[0] == query.shape[0] and attn_mask.shape[1] == key.shape[1]:
1528+
batch_size, seq_len_q, seq_len_kv = attn_mask.shape[0], query.shape[1], key.shape[1]
1529+
attn_mask = attn_mask.unsqueeze(1).expand(batch_size, seq_len_q, seq_len_kv).unsqueeze(1).contiguous()
1530+
elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1):
1531+
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1).contiguous()
1532+
15331533
attn_mask = ~attn_mask.to(torch.bool)
1534-
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
15351534

15361535
return attn_mask
15371536

0 commit comments

Comments
 (0)