Commit 80ea313
[PyTorch] Add
* [PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
Add support for padding between sequences (pad_between_seqs) in the
FlashAttention 3 backend when used with context parallelism (CP).
Key changes:
- backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward
- context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths,
zero FA3 padding garbage in CP forward, fix a2a backward alignment
- dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens
- utils.py: Gate FA3 deterministic backward for hdim>=256, fix
flash_attn_supported override for cross-attention and large head_dim,
disable UnfusedDotProductAttention for pad_between_seqs, add SM100+
FA3 skip
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
Add test parametrization for pad_between_seqs in flash attention tests.
Update run_attention_with_cp.py to support the new parameter and fix
batch boundary alignment in the non-CP FA3 path. Run tests in parallel
when multiple GPUs are available.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH
positional arg and fix GPU threshold for parallel test execution.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [PyTorch] Fix FA3 deterministic gate to match upstream backward constraint
The previous check disabled FA3 for deterministic mode whenever
head_dim_qk > 128, which was overly conservative — FA3 forward supports
deterministic execution at any head dim. The actual constraint from
flash_api.cpp is that the backward pass does not support deterministic
mode when max(head_size, head_size_v) >= 256.
Narrow the gate to only disable FA3 during training (backward) and
raise the threshold to >= 256, checking both head_dim_qk and head_dim_v
to handle MLA configs with asymmetric head dimensions.
Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
The pad_between_seqs gate in get_attention_backend only disabled
FlashAttention 2, letting FA4 leak through to the test-time
fused-vs-flash comparison. On B200 runners that install flash-attn-4,
this caused test_dpa_qkv_layout_thd to compare FusedAttention against
an FA4 output whose padded positions contain garbage, producing 48
numerics failures in L3_pytorch_FA_versions_test--B200_1GPU.
The log message already claimed FA4 would be disabled — this change
makes the code match the message: set use_flash_attention_4 = False
alongside use_flash_attention_2 when pad_between_seqs is True. FA3
continues to support pad_between_seqs via seqused_q/seqused_k.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [QA] Fix cutlass-dsl utils shadow in FA versions test
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass`
adds cutlass/base_dsl/ to sys.path. That directory contains a utils/
package that shadows tests/pytorch/utils.py, breaking collection of
test_attention_with_cp.py with:
ImportError: cannot import name 'ModelConfig' from 'utils'
Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py
is always resolved first, regardless of what FA4 dependencies install.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* skip tests which OOM in deterministic+backward+hopper+large_configs as its a known cudnn issue
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* make cp det and nondet tests run in parallel whenever possible
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [QA] L3: gate CP tests per-arch to avoid CI timeout
PR 2596 added deterministic CP runs to the L3 FA-versions matrix, multiplying
CP wall time across every FA version and causing CI timeouts (pipeline
50243000). Run CP tests once per arch instead, picking the FA version each
arch's CP code path actually supports:
- sm90 (H100): FA3 3.0.0b1 - context_parallel.py is FA3-only on Hopper
(use_flash_attn_3 threaded throughout, FA4
not wired in; pad_between_seqs gated on
use_flash_attn_3 at lines 1038, 1366)
- sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90
Non-CP test_attention.py still runs for every FA version in the array.
Also drop FA 2.7.3 from the sm90 list (no longer maintained as a target)
and bump the FA4 pin from 4.0.0b8 to 4.0.0b11. b8 has an SM90 backward
kernel bug fixed by upstream PR #2513 in b11
(get_smem_store_C() got multiple values for argument 'transpose').
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [QA] L3: skip pre-installed FA3 build, per-FA junit XMLs
Three follow-ups on top of 13ba004 (L3 per-arch CP gating):
1. Skip the inline FA3 source build when flash_attn_interface is already
importable. This makes the script a no-op on FA3 install when the base
image has FA3 baked in (companion to TE !573 on te_ci, which auto-sets
INSTALL_FA3=${RUN_L3_TESTS} so FA3 is preinstalled for L3 pipelines).
Saves ~20 min of L3 H100 wall time once both land. Falls back to the
existing inline build when FA3 is not pre-installed.
2. Suffix junit XMLs with the FA version (pytest_test_attention_fa2_8_3.xml
etc.) so per-iteration results are preserved instead of overwritten.
Pipeline 50348672 had no per-FA timing visibility because pytest.xml
was clobbered by each loop iteration.
3. Include FA version in test_fail messages so CI dashboards show which
FA iteration caused a failure (was "test_attention.py", now
"test_attention.py (FA 2.8.3)").
Also fold the CP_FA_VERSION assignment into the same if-block as
FA_versions (was a separate if-block immediately after) since the two
are arch-keyed in lockstep.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* b200 shouldnt run FA3 even if present
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check
Address two pending review comments:
1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3
preinstall is no longer accurate; drop it so readers don't grep for a
coupling that doesn't exist.
2. `flash_attn_interface` reads like a generic FA API even though the
top-level shim is only created by the FA3 install. Switching to
`import flash_attn_3` makes the FA3-specific intent unambiguous and
matches the FA3 package layout produced by the source build.
Local validation on H100 (sm90) with FA3 active, TE worktree resolving
to the editable install (verified via three-layer import check from
/tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0
failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True
nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is
exercised; 5 det OOM cases skip cleanly via the existing inline guard.
Same test scope is exercised by L1_pytorch_distributed_unittest
(parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test;
the changes here are L3-only documentation/detection tweaks and do not
alter the Python test code, but the L1+L3 CP execution was re-run on
the cleaned PR head end-to-end as proof.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics
1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual
memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN
workspace is proportional to bHSS, so a future config with H >= 20
but small b or S would be needlessly skipped under the old guard,
while a config with H < 20 but large b*S that hit the same OOM
wouldn't be caught. Threshold 1e9 empirically matches the existing
5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2,
cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/
cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running.
2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1"
reference. The detection check is the contract; mentioning a
specific image variable couples this script to an out-of-tree
provisioning detail that may evolve independently.
Local validation on H100 (sm90) with FA3 active and TE worktree
resolving to editable (verified via /tmp-cwd three-layer import check
after reinstall — the /usr/local TE shadow had reappeared between
sessions): test_attention_with_cp.py parallel det+nondet — 45 passed /
0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True
nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the
new bHSS gate — same cases as the old num_heads-only gate.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Name the OOM-skip threshold and explain the 128*bHSS workspace observation
Address review nits on the deterministic THD-backward OOM guard:
1. Replace the magic number 1_000_000_000 with the named constant
SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable
and labeled.
2. Replace the prefatory comment with a short note tying the number to
cuDNN's actual workspace request (~128 * bHSS bytes, measured on
cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is
128 GiB, which doesn't fit on H100's 80 GB.
3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up
internally so workspace grows super-linearly past b=2 (b=4 asks for
4x the b=2 workspace, not 2x). The current fused-essential matrix is
all b=2, so the threshold stays correct for what the test exercises;
the note is there so the next person doesn't have to rediscover it.
Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* Reword OOM-skip comment as observations, not cuDNN-internal claims
We measured the workspace request from outside cuDNN, so the comment should
say "observed" rather than asserting what cuDNN does. Reframes the ~128 *
bHSS bytes formula and the super-linear b>=3 behavior as empirical
observations from our sweep.
No code change.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
---------
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) (#2596)1 parent dc9af4a commit 80ea313
9 files changed
Lines changed: 371 additions & 107 deletions
File tree
- qa
- L1_pytorch_distributed_unittest
- L3_pytorch_FA_versions_test
- tests/pytorch/attention
- transformer_engine/pytorch/attention/dot_product_attention
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
25 | 43 | | |
26 | 44 | | |
27 | 45 | | |
28 | 46 | | |
29 | 47 | | |
30 | 48 | | |
31 | 49 | | |
32 | | - | |
33 | 50 | | |
34 | 51 | | |
35 | 52 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
6 | 18 | | |
7 | 19 | | |
8 | 20 | | |
9 | 21 | | |
10 | 22 | | |
11 | | - | |
| 23 | + | |
12 | 24 | | |
13 | 25 | | |
14 | 26 | | |
15 | 27 | | |
16 | 28 | | |
17 | 29 | | |
18 | 30 | | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
19 | 35 | | |
20 | 36 | | |
21 | | - | |
| 37 | + | |
| 38 | + | |
22 | 39 | | |
23 | 40 | | |
24 | | - | |
| 41 | + | |
| 42 | + | |
25 | 43 | | |
26 | 44 | | |
27 | 45 | | |
| |||
35 | 53 | | |
36 | 54 | | |
37 | 55 | | |
38 | | - | |
39 | | - | |
40 | | - | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
41 | 64 | | |
42 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
43 | 69 | | |
44 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
45 | 91 | | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
46 | 108 | | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
115 | 116 | | |
116 | 117 | | |
117 | 118 | | |
118 | | - | |
119 | | - | |
120 | | - | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
121 | 125 | | |
122 | 126 | | |
123 | 127 | | |
| |||
196 | 200 | | |
197 | 201 | | |
198 | 202 | | |
| 203 | + | |
199 | 204 | | |
200 | 205 | | |
201 | 206 | | |
| |||
314 | 319 | | |
315 | 320 | | |
316 | 321 | | |
317 | | - | |
| 322 | + | |
318 | 323 | | |
319 | 324 | | |
320 | 325 | | |
| |||
557 | 562 | | |
558 | 563 | | |
559 | 564 | | |
560 | | - | |
| 565 | + | |
561 | 566 | | |
562 | 567 | | |
563 | | - | |
564 | | - | |
| 568 | + | |
| 569 | + | |
565 | 570 | | |
566 | 571 | | |
567 | 572 | | |
| |||
617 | 622 | | |
618 | 623 | | |
619 | 624 | | |
620 | | - | |
621 | 625 | | |
622 | 626 | | |
623 | 627 | | |
624 | 628 | | |
625 | | - | |
626 | | - | |
627 | | - | |
628 | | - | |
629 | | - | |
630 | | - | |
631 | | - | |
632 | | - | |
633 | | - | |
634 | | - | |
635 | | - | |
636 | | - | |
637 | | - | |
638 | | - | |
639 | | - | |
640 | | - | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
641 | 632 | | |
642 | 633 | | |
643 | 634 | | |
644 | 635 | | |
645 | | - | |
646 | | - | |
647 | | - | |
648 | | - | |
649 | | - | |
650 | | - | |
651 | | - | |
652 | | - | |
653 | | - | |
654 | | - | |
655 | | - | |
656 | | - | |
657 | | - | |
658 | | - | |
659 | | - | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
660 | 670 | | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
661 | 678 | | |
662 | | - | |
663 | 679 | | |
664 | 680 | | |
665 | 681 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
124 | 124 | | |
125 | 125 | | |
126 | 126 | | |
127 | | - | |
| 127 | + | |
128 | 128 | | |
129 | 129 | | |
130 | 130 | | |
| |||
157 | 157 | | |
158 | 158 | | |
159 | 159 | | |
| 160 | + | |
| 161 | + | |
160 | 162 | | |
161 | 163 | | |
162 | 164 | | |
| |||
195 | 197 | | |
196 | 198 | | |
197 | 199 | | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | 200 | | |
212 | 201 | | |
213 | 202 | | |
| |||
1301 | 1290 | | |
1302 | 1291 | | |
1303 | 1292 | | |
1304 | | - | |
| 1293 | + | |
1305 | 1294 | | |
1306 | 1295 | | |
1307 | 1296 | | |
1308 | 1297 | | |
1309 | | - | |
| 1298 | + | |
1310 | 1299 | | |
1311 | 1300 | | |
1312 | 1301 | | |
| |||
1322 | 1311 | | |
1323 | 1312 | | |
1324 | 1313 | | |
1325 | | - | |
1326 | | - | |
| 1314 | + | |
| 1315 | + | |
| 1316 | + | |
| 1317 | + | |
| 1318 | + | |
| 1319 | + | |
1327 | 1320 | | |
1328 | 1321 | | |
1329 | 1322 | | |
1330 | 1323 | | |
1331 | 1324 | | |
1332 | 1325 | | |
| 1326 | + | |
1333 | 1327 | | |
1334 | 1328 | | |
1335 | 1329 | | |
| |||
1343 | 1337 | | |
1344 | 1338 | | |
1345 | 1339 | | |
1346 | | - | |
| 1340 | + | |
1347 | 1341 | | |
1348 | 1342 | | |
1349 | 1343 | | |
1350 | 1344 | | |
1351 | | - | |
| 1345 | + | |
1352 | 1346 | | |
1353 | 1347 | | |
1354 | 1348 | | |
| |||
0 commit comments