diff --git a/litgpt/model.py b/litgpt/model.py index b60b0506b6..a66b592673 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -343,6 +343,12 @@ def scaled_dot_product_attention( scores = torch.nn.functional.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v else: + # 修复 #2220: 在单步解码 (decoding) 阶段,q 的序列长度为 1, + # 此时不需要因果掩码,丢弃 mask 可以让 PyTorch 启用 Flash Attention 加速。 + if mask is not None and q.size(2) == 1: + mask = None + + # 注意这里!把 F. 换成了完整的 torch.nn.functional. y = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None )