Commit 4745f98
committed
[PyTorch] Fix FA3 deterministic gate to match upstream backward constraint
The previous check disabled FA3 for deterministic mode whenever
head_dim_qk > 128, which was overly conservative — FA3 forward supports
deterministic execution at any head dim. The actual constraint from
flash_api.cpp is that the backward pass does not support deterministic
mode when max(head_size, head_size_v) >= 256.
Narrow the gate to only disable FA3 during training (backward) and
raise the threshold to >= 256, checking both head_dim_qk and head_dim_v
to handle MLA configs with asymmetric head dimensions.
Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>1 parent 34e3d62 commit 4745f98
1 file changed
Lines changed: 5 additions & 2 deletions
Lines changed: 5 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1315 | 1315 | | |
1316 | 1316 | | |
1317 | 1317 | | |
1318 | | - | |
| 1318 | + | |
1319 | 1319 | | |
1320 | | - | |
| 1320 | + | |
| 1321 | + | |
| 1322 | + | |
| 1323 | + | |
1321 | 1324 | | |
1322 | 1325 | | |
1323 | 1326 | | |
| |||
0 commit comments