Skip to content

Commit 80ea313

Browse files
[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) (#2596)
* [PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [QA] Add CP deterministic tests to L3 and support TE_PATH in FA test Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [PyTorch] Fix FA3 deterministic gate to match upstream backward constraint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [QA] Fix cutlass-dsl utils shadow in FA versions test FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * skip tests which OOM in deterministic+backward+hopper+large_configs as its a known cudnn issue Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make cp det and nondet tests run in parallel whenever possible Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [QA] L3: gate CP tests per-arch to avoid CI timeout PR 2596 added deterministic CP runs to the L3 FA-versions matrix, multiplying CP wall time across every FA version and causing CI timeouts (pipeline 50243000). Run CP tests once per arch instead, picking the FA version each arch's CP code path actually supports: - sm90 (H100): FA3 3.0.0b1 - context_parallel.py is FA3-only on Hopper (use_flash_attn_3 threaded throughout, FA4 not wired in; pad_between_seqs gated on use_flash_attn_3 at lines 1038, 1366) - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 Non-CP test_attention.py still runs for every FA version in the array. Also drop FA 2.7.3 from the sm90 list (no longer maintained as a target) and bump the FA4 pin from 4.0.0b8 to 4.0.0b11. b8 has an SM90 backward kernel bug fixed by upstream PR #2513 in b11 (get_smem_store_C() got multiple values for argument 'transpose'). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [QA] L3: skip pre-installed FA3 build, per-FA junit XMLs Three follow-ups on top of 13ba004 (L3 per-arch CP gating): 1. Skip the inline FA3 source build when flash_attn_interface is already importable. This makes the script a no-op on FA3 install when the base image has FA3 baked in (companion to TE !573 on te_ci, which auto-sets INSTALL_FA3=${RUN_L3_TESTS} so FA3 is preinstalled for L3 pipelines). Saves ~20 min of L3 H100 wall time once both land. Falls back to the existing inline build when FA3 is not pre-installed. 2. Suffix junit XMLs with the FA version (pytest_test_attention_fa2_8_3.xml etc.) so per-iteration results are preserved instead of overwritten. Pipeline 50348672 had no per-FA timing visibility because pytest.xml was clobbered by each loop iteration. 3. Include FA version in test_fail messages so CI dashboards show which FA iteration caused a failure (was "test_attention.py", now "test_attention.py (FA 2.8.3)"). Also fold the CP_FA_VERSION assignment into the same if-block as FA_versions (was a separate if-block immediately after) since the two are arch-keyed in lockstep. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * b200 shouldnt run FA3 even if present Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check Address two pending review comments: 1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3 preinstall is no longer accurate; drop it so readers don't grep for a coupling that doesn't exist. 2. `flash_attn_interface` reads like a generic FA API even though the top-level shim is only created by the FA3 install. Switching to `import flash_attn_3` makes the FA3-specific intent unambiguous and matches the FA3 package layout produced by the source build. Local validation on H100 (sm90) with FA3 active, TE worktree resolving to the editable install (verified via three-layer import check from /tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is exercised; 5 det OOM cases skip cleanly via the existing inline guard. Same test scope is exercised by L1_pytorch_distributed_unittest (parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test; the changes here are L3-only documentation/detection tweaks and do not alter the Python test code, but the L1+L3 CP execution was re-run on the cleaned PR head end-to-end as proof. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics 1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN workspace is proportional to bHSS, so a future config with H >= 20 but small b or S would be needlessly skipped under the old guard, while a config with H < 20 but large b*S that hit the same OOM wouldn't be caught. Threshold 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/ cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running. 2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1" reference. The detection check is the contract; mentioning a specific image variable couples this script to an out-of-tree provisioning detail that may evolve independently. Local validation on H100 (sm90) with FA3 active and TE worktree resolving to editable (verified via /tmp-cwd three-layer import check after reinstall — the /usr/local TE shadow had reappeared between sessions): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the new bHSS gate — same cases as the old num_heads-only gate. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Name the OOM-skip threshold and explain the 128*bHSS workspace observation Address review nits on the deterministic THD-backward OOM guard: 1. Replace the magic number 1_000_000_000 with the named constant SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable and labeled. 2. Replace the prefatory comment with a short note tying the number to cuDNN's actual workspace request (~128 * bHSS bytes, measured on cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is 128 GiB, which doesn't fit on H100's 80 GB. 3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up internally so workspace grows super-linearly past b=2 (b=4 asks for 4x the b=2 workspace, not 2x). The current fused-essential matrix is all b=2, so the threshold stays correct for what the test exercises; the note is there so the next person doesn't have to rediscover it. Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Reword OOM-skip comment as observations, not cuDNN-internal claims We measured the workspace request from outside cuDNN, so the comment should say "observed" rather than asserting what cuDNN does. Reframes the ~128 * bHSS bytes formula and the super-linear b>=3 behavior as empirical observations from our sweep. No code change. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dc9af4a commit 80ea313

9 files changed

Lines changed: 371 additions & 107 deletions

File tree

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,31 @@ mkdir -p "$XML_LOG_DIR"
2222

2323
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
2424

25+
# Run CP tests (deterministic + non-deterministic) first so they can be parallelized.
26+
# Each needs 4 GPUs, so >=8 GPUs allows them to run concurrently on disjoint GPU sets.
27+
NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())")
28+
echo "Detected $NUM_GPUS GPU(s)"
29+
if [ "$NUM_GPUS" -ge 8 ]; then
30+
echo "Running CP tests in parallel: non-deterministic on GPUs 0-3, deterministic on GPUs 4-7"
31+
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
32+
PID_CP_NONDET=$!
33+
CUDA_VISIBLE_DEVICES=4,5,6,7 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
34+
PID_CP_DET=$!
35+
wait $PID_CP_NONDET || test_fail "test_attention_with_cp.py"
36+
wait $PID_CP_DET || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
37+
else
38+
echo "Running CP tests sequentially: need >=8 GPUs for parallel execution"
39+
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
40+
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
41+
fi
42+
2543
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
2644
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
2745
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
2846
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
2947
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
3048
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
3149
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
32-
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
3350
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
3451
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
3552
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"

qa/L3_pytorch_FA_versions_test/test.sh

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,44 @@
22
#
33
# See LICENSE for license information.
44

5-
set -e
5+
function error_exit() {
6+
echo "Error: $1"
7+
exit 1
8+
}
9+
10+
function test_fail() {
11+
RET=1
12+
FAILED_CASES="$FAILED_CASES $1"
13+
echo "Error: sub-test failed: $1"
14+
}
15+
16+
RET=0
17+
FAILED_CASES=""
618

719
: ${TE_PATH:=/opt/transformerengine}
820
: ${XML_LOG_DIR:=/logs}
921
mkdir -p "$XML_LOG_DIR"
1022

11-
pip3 install pytest==8.2.1
23+
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
1224

1325
# Limit parallel build jobs to avoid overwhelming system resources
1426
export MAX_JOBS=32
1527

1628
# Iterate over Flash Attention versions
1729
sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
1830
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
31+
# CP tests are expensive and run only once per arch:
32+
# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper
33+
# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90
34+
# Non-CP tests still run for every FA version in the array.
1935
if [ $sm_arch -gt 90 ]
2036
then
21-
FA_versions=(2.8.3 4.0.0b8)
37+
FA_versions=(2.8.3 4.0.0b11)
38+
CP_FA_VERSION="${FA_versions[-1]}"
2239
elif [ $sm_arch -eq 90 ]
2340
then
24-
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8)
41+
FA_versions=(2.8.3 3.0.0b1 4.0.0b11)
42+
CP_FA_VERSION="3.0.0b1"
2543
fi
2644

2745
for fa_version in "${FA_versions[@]}"
@@ -35,12 +53,63 @@ do
3553
then
3654
pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation
3755
else
38-
git clone https://github.com/Dao-AILab/flash-attention.git
39-
cd flash-attention/hopper && python setup.py install
40-
cd ../../
56+
# FA3 source build (~20 min). Skip if FA3 is already installed.
57+
if python3 -c "import flash_attn_3" 2>/dev/null; then
58+
echo "FA3 already installed (from base image); skipping source build"
59+
else
60+
git clone https://github.com/Dao-AILab/flash-attention.git
61+
cd flash-attention/hopper && python setup.py install
62+
cd ../../
63+
fi
4164
fi
4265

66+
# Ensure local test utils is found before nvidia-cutlass-dsl's utils package
67+
export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-}
68+
4369
# Run tests
44-
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
70+
NUM_GPUS=$(nvidia-smi -L | wc -l)
71+
echo "Detected $NUM_GPUS GPU(s)"
72+
73+
# Suffix junit XMLs with the FA version so per-iteration results are preserved
74+
# (otherwise pytest.xml is overwritten on each loop iteration and we lose timing
75+
# data for all but the last FA version).
76+
fa_tag="${fa_version//./_}"
77+
XML_ATTN="$XML_LOG_DIR/pytest_test_attention_fa${fa_tag}.xml"
78+
XML_CP="$XML_LOG_DIR/pytest_test_attention_with_cp_fa${fa_tag}.xml"
79+
80+
if [ "$fa_version" = "$CP_FA_VERSION" ]; then
81+
echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)"
82+
if [ "$NUM_GPUS" -ge 5 ]; then
83+
CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 ))
84+
CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS)
85+
echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)"
86+
87+
CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
88+
--junitxml=$XML_ATTN \
89+
$TE_PATH/tests/pytorch/attention/test_attention.py &
90+
PID_ATTN=$!
4591

92+
CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
93+
--junitxml=$XML_CP \
94+
$TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
95+
PID_CP=$!
96+
97+
wait $PID_ATTN || test_fail "test_attention.py (FA $fa_version)"
98+
wait $PID_CP || test_fail "test_attention_with_cp.py (FA $fa_version)"
99+
else
100+
echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)"
101+
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)"
102+
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_CP $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py (FA $fa_version)"
103+
fi
104+
else
105+
echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)"
106+
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)"
107+
fi
46108
done
109+
110+
if [ "$RET" -ne 0 ]; then
111+
echo "Error in the following test cases:$FAILED_CASES"
112+
exit 1
113+
fi
114+
echo "All tests passed"
115+
exit 0

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def generate_input_shapes(
4747
config: ModelConfig,
4848
world_size: int,
4949
kernel_backend: str,
50+
fa_pad_between_seqs: str = "False",
5051
):
5152
if qkv_format == "bshd":
5253
q_input_shape = (
@@ -115,9 +116,12 @@ def generate_input_shapes(
115116
).cuda()
116117
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
117118

118-
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
119-
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
120-
if kernel_backend == "FusedAttention":
119+
# Generate padded data (cu_seqlens_q reflects non-padded lengths, so it
120+
# differs from cu_seqlens_q_padded) for FusedAttention always, and for
121+
# FlashAttention only when its test param requests it. DPA auto-detects
122+
# pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded
123+
# mismatch.
124+
if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True":
121125
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
122126

123127
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
@@ -196,6 +200,7 @@ def run_dpa_with_cp(
196200
scaling_mode="delayed",
197201
f16_O="False",
198202
is_training="True",
203+
fa_pad_between_seqs="False",
199204
deterministic="False",
200205
log_level=logging.WARNING,
201206
):
@@ -314,7 +319,7 @@ def run_dpa_with_cp(
314319
cu_seqlens_kv,
315320
cu_seqlens_q_padded,
316321
cu_seqlens_kv_padded,
317-
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
322+
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs)
318323
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
319324
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
320325
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
@@ -557,11 +562,11 @@ def run_dpa_with_cp(
557562
tensors_to_deq[i] = tensor.dequantize()
558563
if not fp8_bwd:
559564
tensors[0], tensors[5] = tensors_to_deq
560-
for i, tensor in enumerate(tensors):
565+
for tensor, name in zip(tensors, names):
561566
# dbias/dbias_ could be None, so skip check for it
562567
if tensor is not None:
563-
assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN"
564-
assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf"
568+
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
569+
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
565570
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors
566571

567572
############ compare results between CP and no-CP ############
@@ -617,49 +622,60 @@ def run_dpa_with_cp(
617622
if is_training:
618623
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
619624
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
620-
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
621625
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
622626
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
623627
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
624628
)
625-
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
626-
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
627-
for x in [dq, out, dq_, out_]:
628-
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
629-
for b in range(config.batch_size):
630-
assert (
631-
num_pads_q[b] == 0
632-
or torch.count_nonzero(
633-
x[
634-
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
635-
b + 1
636-
]
637-
]
638-
).item()
639-
== 0
640-
)
629+
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
630+
cu_seqlens_q_padded - cu_seqlens_q
631+
)[:-1]
641632
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
642633
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
643634
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
644635
)
645-
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
646-
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
647-
for x in [dk, dv, dk_, dv_]:
648-
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
649-
for b in range(config.batch_size):
650-
assert (
651-
num_pads_kv[b] == 0
652-
or torch.count_nonzero(
653-
x[
654-
(
655-
cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]
656-
) : cu_seqlens_kv_padded[b + 1]
657-
]
658-
).item()
659-
== 0
636+
num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - (
637+
cu_seqlens_kv_padded - cu_seqlens_kv
638+
)[:-1]
639+
# FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover).
640+
# Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an
641+
# output rather than mutating it in-place, triggering PyTorch's aliasing constraint.
642+
# Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs.
643+
if fa_pad_between_seqs == "True":
644+
# out_ is a view inside the CP custom autograd Function, so in-place
645+
# zeroing is blocked by PyTorch. Clone to break the view relationship.
646+
out_ = out_.clone()
647+
for x in [out, out_, dq]:
648+
for b in range(config.batch_size):
649+
x[
650+
cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1]
651+
] = 0.0
652+
x[cu_seqlens_q_padded[-1] :] = 0.0
653+
for x in [dk, dv]:
654+
for b in range(config.batch_size):
655+
x[
656+
cu_seqlens_kv_padded[b + 1]
657+
- num_pads_kv[b] : cu_seqlens_kv_padded[b + 1]
658+
] = 0.0
659+
x[cu_seqlens_kv_padded[-1] :] = 0.0
660+
# Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py).
661+
for xname, x, cu, np_ in [
662+
("dq_", dq_, cu_seqlens_q_padded, num_pads_q),
663+
("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv),
664+
("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv),
665+
]:
666+
nnz = torch.count_nonzero(x[cu[-1] :]).item()
667+
assert nnz == 0, (
668+
f"{xname} has {nnz} nonzero values in tail padding — "
669+
"context_parallel.py should zero padding positions"
660670
)
671+
for b in range(config.batch_size):
672+
if np_[b] > 0:
673+
nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item()
674+
assert nnz == 0, (
675+
f"{xname} has {nnz} nonzero values in batch {b} padding — "
676+
"context_parallel.py should zero padding positions"
677+
)
661678
else:
662-
# Forward-only: reshape only out/out_ for comparison
663679
out = out.index_select(0, seq_idx_q).contiguous()
664680
out_ = out_
665681

tests/pytorch/attention/test_attention.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def reset_global_fp8_state():
124124
@pytest.mark.parametrize("workspace_opt", [True, False])
125125
@pytest.mark.parametrize("qkv_layout", [None])
126126
@pytest.mark.parametrize("swa", [False])
127-
@pytest.mark.parametrize("pad_between_seqs", [False])
127+
@pytest.mark.parametrize("pad_between_seqs", [False, True])
128128
def test_dot_product_attention(
129129
dtype,
130130
model_configs,
@@ -157,6 +157,8 @@ def test_dot_product_attention(
157157

158158
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
159159
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
160+
if pad_between_seqs and qkv_format != "thd":
161+
pytest.skip("pad_between_seqs only applies to THD format!")
160162
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
161163
config.attn_mask_type = (
162164
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
@@ -195,19 +197,6 @@ def test_dot_product_attention(
195197
)
196198
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
197199

198-
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
199-
# mannually pads and unpads the input and output of FlashAttention for testing purposes
200-
if (
201-
pad_between_seqs
202-
and FlashAttentionUtils.is_installed
203-
and not (
204-
config.max_seqlen_q != config.max_seqlen_kv
205-
and config.attn_mask_type in ["causal", "padding_causal"]
206-
)
207-
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
208-
):
209-
flash_attn_supported = True
210-
211200
# Skip if only unfused backend is supported
212201
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
213202
pytest.skip("Less than two backends to compare.")
@@ -1301,12 +1290,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
13011290
block.softmax_offset.requires_grad = True
13021291

13031292
# Run a forward and backward pass
1304-
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
1293+
if backend in ["UnfusedDotProductAttention"]:
13051294
q = inp_orig[0]
13061295
k = inp_orig[1]
13071296
v = inp_orig[2]
13081297
d_out = out_grad_orig
1309-
if backend == "FusedAttention":
1298+
if backend in ["FusedAttention", "FlashAttention"]:
13101299
q = inp[0]
13111300
k = inp[1]
13121301
v = inp[2]
@@ -1322,14 +1311,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
13221311
max_seqlen_kv=config.max_seqlen_kv,
13231312
cu_seqlens_q=cu_seqlens_q,
13241313
cu_seqlens_kv=cu_seqlens_kv,
1325-
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
1326-
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
1314+
cu_seqlens_q_padded=(
1315+
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
1316+
),
1317+
cu_seqlens_kv_padded=(
1318+
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
1319+
),
13271320
attn_mask_type=config.attn_mask_type,
13281321
checkpoint_core_attention=ckpt_attn,
13291322
core_attention_bias_type=config.attn_bias_type,
13301323
core_attention_bias=bias,
13311324
alibi_slopes=alibi_slopes,
13321325
fast_zero_fill=True,
1326+
pad_between_seqs=pad_between_seqs,
13331327
# Only pass num_splits when exercising the FlashAttention path
13341328
num_splits=config.num_splits if backend == "FlashAttention" else 1,
13351329
)
@@ -1343,12 +1337,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
13431337
if is_training and config.softmax_type != "vanilla":
13441338
d_softmax_offset = block.softmax_offset.grad
13451339

1346-
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
1340+
if backend in ["UnfusedDotProductAttention"]:
13471341
if is_training:
13481342
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
13491343
else:
13501344
return out, max_logit, (None, None, None, d_softmax_offset)
1351-
if backend == "FusedAttention":
1345+
if backend in ["FusedAttention", "FlashAttention"]:
13521346
if qkv_format == "thd" and pad_between_seqs:
13531347
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
13541348
if is_training:

0 commit comments

Comments
 (0)