Skip to content

Commit 9e55a25

Browse files
[PyTorch] Fix FA4 selection when FA3 is unavailable. (#2909)
Fix FA4 selection when FA3 is unavailable. Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
1 parent ab60f4c commit 9e55a25

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

  • transformer_engine/pytorch/attention/dot_product_attention

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,14 @@ def get_attention_backend(
473473
# On SM90, prefer FA3 over FA4 when FA3 is available.
474474
# FA3 is more mature on Hopper; FA4's SM90 backward has limitations
475475
# (MLA, non-standard head dims, SplitKV).
476-
if use_flash_attention_4 and use_flash_attention_3 and device_compute_capability == (9, 0):
477-
if FlashAttentionUtils.v4_is_installed:
478-
logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90")
476+
if (
477+
device_compute_capability == (9, 0)
478+
and use_flash_attention_3
479+
and FlashAttentionUtils.v3_is_installed
480+
and use_flash_attention_4
481+
and FlashAttentionUtils.v4_is_installed
482+
):
483+
logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90")
479484
use_flash_attention_4 = False
480485

481486
# Filter: Data type

0 commit comments

Comments
 (0)