Skip to content

Commit 4745f98

Browse files
[PyTorch] Fix FA3 deterministic gate to match upstream backward constraint
The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1 parent 34e3d62 commit 4745f98

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

  • transformer_engine/pytorch/attention/dot_product_attention

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,9 +1315,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
13151315
)
13161316
use_flash_attention_2 = False
13171317
if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed:
1318-
if head_dim_qk > 128:
1318+
if is_training and max(head_dim_qk, head_dim_v) >= 256:
13191319
logger.debug(
1320-
"Disabling FlashAttention 3 for deterministic execution with head_dim_qk > 128."
1320+
"Disabling FlashAttention 3 for deterministic backward with"
1321+
" max(head_dim_qk, head_dim_v) >= 256. Found: head_dim_qk = %s, head_dim_v = %s.",
1322+
head_dim_qk,
1323+
head_dim_v,
13211324
)
13221325
use_flash_attention_3 = False
13231326
if use_fused_attention and deterministic:

0 commit comments

Comments
 (0)