We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d5bf1af commit 717fb6eCopy full SHA for 717fb6e
1 file changed
transformer_engine/pytorch/attention/dot_product_attention/utils.py
@@ -464,11 +464,7 @@ def get_attention_backend(
464
# On SM90, prefer FA3 over FA4 when FA3 is available.
465
# FA3 is more mature on Hopper; FA4's SM90 backward has limitations
466
# (MLA, non-standard head dims, SplitKV).
467
- if (
468
- use_flash_attention_4
469
- and use_flash_attention_3
470
- and device_compute_capability == (9, 0)
471
- ):
+ if use_flash_attention_4 and use_flash_attention_3 and device_compute_capability == (9, 0):
472
if FlashAttentionUtils.v4_is_installed:
473
logger.debug("Disabling FlashAttention 4 to prefer FlashAttention 3 on SM90")
474
use_flash_attention_4 = False
0 commit comments