File tree Expand file tree Collapse file tree
pytorch/attention/dot_product_attention Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -455,6 +455,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
455455 (cudnn_runtime_version >= 91301 ||
456456 (cudnn_runtime_version < 91301 &&
457457 softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX )) &&
458+ // max_logit
459+ // pre-9.21: no (the composite softmax node rejects the Stats + Max output combination)
460+ // 9.21+: yes (Stats + Max via the unified softmax node)
461+ (!return_max_logit || cudnn_runtime_version >= 92100 ) &&
458462 // determinism on Blackwell
459463 // pre-9.18.1: fwd: deterministic; bwd: non-deterministic
460464 // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic
Original file line number Diff line number Diff line change @@ -672,6 +672,13 @@ def _disable_all_flash_attention() -> None:
672672 if use_flash_attention :
673673 use_flash_attention = False
674674 logger .debug ("Disabling FlashAttention for max_logit" )
675+ # FusedAttention emits max_logit alongside the softmax stats, which cuDNN only
676+ # supports through the unified softmax node introduced in cuDNN 9.21.0. On older
677+ # cuDNN the composite softmax node rejects the stats+max combination, so fall back
678+ # to UnfusedDotProductAttention.
679+ if use_fused_attention and cudnn_version < (9 , 21 , 0 ):
680+ use_fused_attention = False
681+ logger .debug ("Disabling FusedAttention for max_logit for cuDNN < 9.21.0" )
675682 if fp8 and fp8_meta ["recipe" ].fp8_dpa :
676683 use_flash_attention = False
677684 use_fused_attention = False
You can’t perform that action at this time.
0 commit comments