@@ -1521,17 +1521,16 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15211521 if attn_mask is not None and torch .all (attn_mask != 0 ):
15221522 attn_mask = None
15231523
1524- # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
1524+ # Reshape Attention Mask: [batch_size, seq_len_k] or [batch_size, 1, 1, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
15251525 # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1526- if (
1527- attn_mask is not None
1528- and attn_mask .ndim == 2
1529- and attn_mask . shape [ 0 ] == query . shape [ 0 ]
1530- and attn_mask .shape [1 ] == key . shape [ 1 ]
1531- ):
1532- B , Sq , Skv = attn_mask . shape [ 0 ], query . shape [ 1 ], key . shape [ 1 ]
1526+ if attn_mask is not None :
1527+ if attn_mask . ndim == 2 and attn_mask . shape [ 0 ] == query . shape [ 0 ] and attn_mask . shape [ 1 ] == key . shape [ 1 ]:
1528+ batch_size , seq_len_q , seq_len_kv = attn_mask .shape [ 0 ], query . shape [ 1 ], key . shape [ 1 ]
1529+ attn_mask = attn_mask . unsqueeze ( 1 ). expand ( batch_size , seq_len_q , seq_len_kv ). unsqueeze ( 1 ). contiguous ()
1530+ elif attn_mask . ndim == 4 and attn_mask .shape [1 : 3 ] == ( 1 , 1 ):
1531+ attn_mask = attn_mask . expand ( - 1 , - 1 , query . shape [ 1 ], - 1 ). contiguous ()
1532+
15331533 attn_mask = ~ attn_mask .to (torch .bool )
1534- attn_mask = attn_mask .unsqueeze (1 ).expand (B , Sq , Skv ).unsqueeze (1 ).contiguous ()
15351534
15361535 return attn_mask
15371536
0 commit comments