Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
10e4cfc
[PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
sudhakarsingh27 Apr 24, 2026
2a49dee
[PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
sudhakarsingh27 Apr 24, 2026
34e3d62
[QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
sudhakarsingh27 Apr 24, 2026
4745f98
[PyTorch] Fix FA3 deterministic gate to match upstream backward const…
sudhakarsingh27 Apr 24, 2026
4be004f
[PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
sudhakarsingh27 Apr 24, 2026
c476f15
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 24, 2026
a2b0f1b
[QA] Fix cutlass-dsl utils shadow in FA versions test
sudhakarsingh27 Apr 25, 2026
b94e175
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 28, 2026
fc9182f
skip tests which OOM in deterministic+backward+hopper+large_configs a…
sudhakarsingh27 Apr 29, 2026
636666f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 29, 2026
7928bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
1585ebb
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 29, 2026
2464f43
make cp det and nondet tests run in parallel whenever possible
sudhakarsingh27 Apr 30, 2026
789ccf0
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 1, 2026
0a32185
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 4, 2026
c33cf2d
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 May 4, 2026
13ba004
[QA] L3: gate CP tests per-arch to avoid CI timeout
sudhakarsingh27 May 5, 2026
e41bb96
[QA] L3: skip pre-installed FA3 build, per-FA junit XMLs
sudhakarsingh27 May 5, 2026
7b8ca1e
b200 shouldnt run FA3 even if present
sudhakarsingh27 May 6, 2026
e02b658
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 6, 2026
9389309
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 9, 2026
77941e0
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 15, 2026
d8e8ba4
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 May 22, 2026
8794aa8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
908ca2b
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 22, 2026
c4b6e07
L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check
sudhakarsingh27 May 22, 2026
d3bd4e4
Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics
sudhakarsingh27 May 22, 2026
0638d58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
3b1e4ce
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 22, 2026
1563b10
Name the OOM-skip threshold and explain the 128*bHSS workspace observ…
sudhakarsingh27 May 23, 2026
a27e301
Reword OOM-skip comment as observations, not cuDNN-internal claims
sudhakarsingh27 May 23, 2026
2b05809
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,31 @@ mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

# Run CP tests (deterministic + non-deterministic) first so they can be parallelized.
# Each needs 4 GPUs, so >=8 GPUs allows them to run concurrently on disjoint GPU sets.
NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())")
echo "Detected $NUM_GPUS GPU(s)"
if [ "$NUM_GPUS" -ge 8 ]; then
echo "Running CP tests in parallel: non-deterministic on GPUs 0-3, deterministic on GPUs 4-7"
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
PID_CP_NONDET=$!
CUDA_VISIBLE_DEVICES=4,5,6,7 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
PID_CP_DET=$!
Comment thread
sudhakarsingh27 marked this conversation as resolved.
wait $PID_CP_NONDET || test_fail "test_attention_with_cp.py"
wait $PID_CP_DET || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
else
echo "Running CP tests sequentially: need >=8 GPUs for parallel execution"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
fi

python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
Expand Down
85 changes: 77 additions & 8 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,44 @@
#
# See LICENSE for license information.

set -e
function error_exit() {
echo "Error: $1"
exit 1
}

function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

# Limit parallel build jobs to avoid overwhelming system resources
export MAX_JOBS=32

# Iterate over Flash Attention versions
sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
# CP tests are expensive and run only once per arch:
# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper
# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90
# Non-CP tests still run for every FA version in the array.
if [ $sm_arch -gt 90 ]
then
FA_versions=(2.8.3 4.0.0b8)
FA_versions=(2.8.3 4.0.0b11)
CP_FA_VERSION="${FA_versions[-1]}"
elif [ $sm_arch -eq 90 ]
then
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8)
FA_versions=(2.8.3 3.0.0b1 4.0.0b11)
CP_FA_VERSION="3.0.0b1"
fi

for fa_version in "${FA_versions[@]}"
Expand All @@ -35,12 +53,63 @@ do
then
pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper && python setup.py install
cd ../../
# FA3 source build (~20 min). Skip if FA3 is already installed.
if python3 -c "import flash_attn_3" 2>/dev/null; then
echo "FA3 already installed (from base image); skipping source build"
else
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper && python setup.py install
cd ../../
fi
fi

# Ensure local test utils is found before nvidia-cutlass-dsl's utils package
export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-}

# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
NUM_GPUS=$(nvidia-smi -L | wc -l)
echo "Detected $NUM_GPUS GPU(s)"

# Suffix junit XMLs with the FA version so per-iteration results are preserved
# (otherwise pytest.xml is overwritten on each loop iteration and we lose timing
# data for all but the last FA version).
fa_tag="${fa_version//./_}"
XML_ATTN="$XML_LOG_DIR/pytest_test_attention_fa${fa_tag}.xml"
XML_CP="$XML_LOG_DIR/pytest_test_attention_with_cp_fa${fa_tag}.xml"

if [ "$fa_version" = "$CP_FA_VERSION" ]; then
echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)"
if [ "$NUM_GPUS" -ge 5 ]; then
CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 ))
CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS)
echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)"

CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_ATTN \
$TE_PATH/tests/pytorch/attention/test_attention.py &
PID_ATTN=$!

CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_CP \
$TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
PID_CP=$!

wait $PID_ATTN || test_fail "test_attention.py (FA $fa_version)"
wait $PID_CP || test_fail "test_attention_with_cp.py (FA $fa_version)"
else
echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_CP $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py (FA $fa_version)"
fi
else
echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)"
fi
done

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
96 changes: 56 additions & 40 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def generate_input_shapes(
config: ModelConfig,
world_size: int,
kernel_backend: str,
fa_pad_between_seqs: str = "False",
):
if qkv_format == "bshd":
q_input_shape = (
Expand Down Expand Up @@ -115,9 +116,12 @@ def generate_input_shapes(
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
# Generate padded data (cu_seqlens_q reflects non-padded lengths, so it
# differs from cu_seqlens_q_padded) for FusedAttention always, and for
# FlashAttention only when its test param requests it. DPA auto-detects
# pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded
# mismatch.
if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
Expand Down Expand Up @@ -196,6 +200,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
fa_pad_between_seqs="False",
deterministic="False",
log_level=logging.WARNING,
):
Expand Down Expand Up @@ -314,7 +319,7 @@ def run_dpa_with_cp(
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
Expand Down Expand Up @@ -557,11 +562,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for i, tensor in enumerate(tensors):
for tensor, name in zip(tensors, names):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN"
assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf"
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down Expand Up @@ -617,49 +622,60 @@ def run_dpa_with_cp(
if is_training:
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
b + 1
]
]
).item()
== 0
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(
cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]
) : cu_seqlens_kv_padded[b + 1]
]
).item()
== 0
num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - (
cu_seqlens_kv_padded - cu_seqlens_kv
)[:-1]
# FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover).
# Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an
# output rather than mutating it in-place, triggering PyTorch's aliasing constraint.
# Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs.
if fa_pad_between_seqs == "True":
# out_ is a view inside the CP custom autograd Function, so in-place
# zeroing is blocked by PyTorch. Clone to break the view relationship.
out_ = out_.clone()
for x in [out, out_, dq]:
for b in range(config.batch_size):
x[
cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1]
] = 0.0
x[cu_seqlens_q_padded[-1] :] = 0.0
for x in [dk, dv]:
for b in range(config.batch_size):
x[
cu_seqlens_kv_padded[b + 1]
- num_pads_kv[b] : cu_seqlens_kv_padded[b + 1]
] = 0.0
x[cu_seqlens_kv_padded[-1] :] = 0.0
# Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py).
for xname, x, cu, np_ in [
("dq_", dq_, cu_seqlens_q_padded, num_pads_q),
("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv),
("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv),
]:
nnz = torch.count_nonzero(x[cu[-1] :]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in tail padding — "
"context_parallel.py should zero padding positions"
)
for b in range(config.batch_size):
if np_[b] > 0:
nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in batch {b} padding — "
"context_parallel.py should zero padding positions"
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down
34 changes: 14 additions & 20 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_dot_product_attention(
dtype,
model_configs,
Expand Down Expand Up @@ -157,6 +157,8 @@ def test_dot_product_attention(

config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if pad_between_seqs and qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
Expand Down Expand Up @@ -195,19 +197,6 @@ def test_dot_product_attention(
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
Expand Down Expand Up @@ -1301,12 +1290,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1322,14 +1311,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1343,12 +1337,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
Loading
Loading