File tree Expand file tree Collapse file tree
transformer_engine/pytorch/attention/dot_product_attention Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -473,9 +473,14 @@ def get_attention_backend(
473473 # On SM90, prefer FA3 over FA4 when FA3 is available.
474474 # FA3 is more mature on Hopper; FA4's SM90 backward has limitations
475475 # (MLA, non-standard head dims, SplitKV).
476- if use_flash_attention_4 and use_flash_attention_3 and device_compute_capability == (9 , 0 ):
477- if FlashAttentionUtils .v4_is_installed :
478- logger .debug ("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90" )
476+ if (
477+ device_compute_capability == (9 , 0 )
478+ and use_flash_attention_3
479+ and FlashAttentionUtils .v3_is_installed
480+ and use_flash_attention_4
481+ and FlashAttentionUtils .v4_is_installed
482+ ):
483+ logger .debug ("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90" )
479484 use_flash_attention_4 = False
480485
481486 # Filter: Data type
You can’t perform that action at this time.
0 commit comments