Skip to content

Commit ae53b5b

Browse files
guarding max_logits fused attention for cudnn < 9.21.0
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
1 parent 720ec27 commit ae53b5b

2 files changed

Lines changed: 11 additions & 0 deletions

File tree

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
455455
(cudnn_runtime_version >= 91301 ||
456456
(cudnn_runtime_version < 91301 &&
457457
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) &&
458+
// max_logit
459+
// pre-9.21: no (the composite softmax node rejects the Stats + Max output combination)
460+
// 9.21+: yes (Stats + Max via the unified softmax node)
461+
(!return_max_logit || cudnn_runtime_version >= 92100) &&
458462
// determinism on Blackwell
459463
// pre-9.18.1: fwd: deterministic; bwd: non-deterministic
460464
// 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,13 @@ def _disable_all_flash_attention() -> None:
672672
if use_flash_attention:
673673
use_flash_attention = False
674674
logger.debug("Disabling FlashAttention for max_logit")
675+
# FusedAttention emits max_logit alongside the softmax stats, which cuDNN only
676+
# supports through the unified softmax node introduced in cuDNN 9.21.0. On older
677+
# cuDNN the composite softmax node rejects the stats+max combination, so fall back
678+
# to UnfusedDotProductAttention.
679+
if use_fused_attention and cudnn_version < (9, 21, 0):
680+
use_fused_attention = False
681+
logger.debug("Disabling FusedAttention for max_logit for cuDNN < 9.21.0")
675682
if fp8 and fp8_meta["recipe"].fp8_dpa:
676683
use_flash_attention = False
677684
use_fused_attention = False

0 commit comments

Comments
 (0)