diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 837d573d8c4d..b3bd55db48dd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1521,17 +1521,16 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas if attn_mask is not None and torch.all(attn_mask != 0): attn_mask = None - # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] + # Reshape Attention Mask: [batch_size, seq_len_k] or [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + if attn_mask is not None: + if attn_mask.ndim == 2 and attn_mask.shape[0] == query.shape[0] and attn_mask.shape[1] == key.shape[1]: + batch_size, seq_len_q, seq_len_kv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = attn_mask.unsqueeze(1).expand(batch_size, seq_len_q, seq_len_kv).unsqueeze(1).contiguous() + elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1): + attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1).contiguous() + attn_mask = ~attn_mask.to(torch.bool) - attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() return attn_mask