Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
10e4cfc
[PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
sudhakarsingh27 Apr 24, 2026
2a49dee
[PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
sudhakarsingh27 Apr 24, 2026
34e3d62
[QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
sudhakarsingh27 Apr 24, 2026
4745f98
[PyTorch] Fix FA3 deterministic gate to match upstream backward const…
sudhakarsingh27 Apr 24, 2026
4be004f
[PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
sudhakarsingh27 Apr 24, 2026
c476f15
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 24, 2026
a2b0f1b
[QA] Fix cutlass-dsl utils shadow in FA versions test
sudhakarsingh27 Apr 25, 2026
b94e175
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 28, 2026
fc9182f
skip tests which OOM in deterministic+backward+hopper+large_configs a…
sudhakarsingh27 Apr 29, 2026
636666f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 29, 2026
7928bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
1585ebb
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 29, 2026
2464f43
make cp det and nondet tests run in parallel whenever possible
sudhakarsingh27 Apr 30, 2026
789ccf0
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 1, 2026
0a32185
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 4, 2026
c33cf2d
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 May 4, 2026
13ba004
[QA] L3: gate CP tests per-arch to avoid CI timeout
sudhakarsingh27 May 5, 2026
e41bb96
[QA] L3: skip pre-installed FA3 build, per-FA junit XMLs
sudhakarsingh27 May 5, 2026
7b8ca1e
b200 shouldnt run FA3 even if present
sudhakarsingh27 May 6, 2026
e02b658
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 6, 2026
9389309
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 9, 2026
77941e0
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 15, 2026
d8e8ba4
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 May 22, 2026
8794aa8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
908ca2b
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 22, 2026
c4b6e07
L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check
sudhakarsingh27 May 22, 2026
d3bd4e4
Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics
sudhakarsingh27 May 22, 2026
0638d58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
3b1e4ce
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 22, 2026
1563b10
Name the OOM-skip threshold and explain the 128*bHSS workspace observ…
sudhakarsingh27 May 23, 2026
a27e301
Reword OOM-skip comment as observations, not cuDNN-internal claims
sudhakarsingh27 May 23, 2026
2b05809
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
Expand Down
50 changes: 47 additions & 3 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@
#
# See LICENSE for license information.

set -e
function error_exit() {
echo "Error: $1"
exit 1
}

function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

# Limit parallel build jobs to avoid overwhelming system resources
export MAX_JOBS=32
Expand Down Expand Up @@ -40,7 +52,39 @@ do
cd ../../
fi

# Ensure local test utils is found before nvidia-cutlass-dsl's utils package
export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-}

# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
NUM_GPUS=$(nvidia-smi -L | wc -l)
echo "Detected $NUM_GPUS GPU(s)"
if [ "$NUM_GPUS" -ge 5 ]; then
CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 ))
CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS)
echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)"

CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest.xml \
$TE_PATH/tests/pytorch/attention/test_attention.py &
PID_ATTN=$!

CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \
$TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
Comment thread
sudhakarsingh27 marked this conversation as resolved.
Outdated
PID_CP=$!

wait $PID_ATTN || test_fail "test_attention.py"
wait $PID_CP || test_fail "test_attention_with_cp.py"
else
echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
fi
done

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
96 changes: 56 additions & 40 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def generate_input_shapes(
config: ModelConfig,
world_size: int,
kernel_backend: str,
fa_pad_between_seqs: str = "False",
):
if qkv_format == "bshd":
q_input_shape = (
Expand Down Expand Up @@ -105,9 +106,12 @@ def generate_input_shapes(
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
# Generate padded data (cu_seqlens_q reflects non-padded lengths, so it
# differs from cu_seqlens_q_padded) for FusedAttention always, and for
# FlashAttention only when its test param requests it. DPA auto-detects
# pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded
# mismatch.
if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
Expand Down Expand Up @@ -186,6 +190,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
fa_pad_between_seqs="False",
deterministic="False",
log_level=logging.WARNING,
):
Expand Down Expand Up @@ -288,7 +293,7 @@ def run_dpa_with_cp(
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
Expand Down Expand Up @@ -531,11 +536,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for i, tensor in enumerate(tensors):
for tensor, name in zip(tensors, names):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN"
assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf"
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down Expand Up @@ -588,49 +593,60 @@ def run_dpa_with_cp(
if is_training:
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
b + 1
]
]
).item()
== 0
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(
cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]
) : cu_seqlens_kv_padded[b + 1]
]
).item()
== 0
num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - (
cu_seqlens_kv_padded - cu_seqlens_kv
)[:-1]
# FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover).
# Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an
# output rather than mutating it in-place, triggering PyTorch's aliasing constraint.
# Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs.
if fa_pad_between_seqs == "True":
# out_ is a view inside the CP custom autograd Function, so in-place
# zeroing is blocked by PyTorch. Clone to break the view relationship.
out_ = out_.clone()
for x in [out, out_, dq]:
for b in range(config.batch_size):
x[
cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1]
] = 0.0
x[cu_seqlens_q_padded[-1] :] = 0.0
for x in [dk, dv]:
for b in range(config.batch_size):
x[
cu_seqlens_kv_padded[b + 1]
- num_pads_kv[b] : cu_seqlens_kv_padded[b + 1]
] = 0.0
x[cu_seqlens_kv_padded[-1] :] = 0.0
# Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py).
for xname, x, cu, np_ in [
("dq_", dq_, cu_seqlens_q_padded, num_pads_q),
("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv),
("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv),
]:
nnz = torch.count_nonzero(x[cu[-1] :]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in tail padding — "
"context_parallel.py should zero padding positions"
)
for b in range(config.batch_size):
if np_[b] > 0:
nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in batch {b} padding — "
"context_parallel.py should zero padding positions"
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down
34 changes: 14 additions & 20 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_dot_product_attention(
dtype,
model_configs,
Expand Down Expand Up @@ -157,6 +157,8 @@ def test_dot_product_attention(

config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if pad_between_seqs and qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
Expand Down Expand Up @@ -195,19 +197,6 @@ def test_dot_product_attention(
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
Expand Down Expand Up @@ -1330,12 +1319,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1351,14 +1340,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1372,12 +1366,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
28 changes: 27 additions & 1 deletion tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about AG?


config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand Down Expand Up @@ -148,6 +157,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
fa_pad_between_seqs=pad_between_seqs,
log_level=pytest_logging_level,
),
)
Expand Down Expand Up @@ -386,6 +396,7 @@ def test_cp_with_fused_attention(
is_training=is_training,
deterministic=_deterministic,
)

_, fused_attn_supported, _ = available_backends
if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]:
config_copy = copy.deepcopy(config)
Expand All @@ -404,6 +415,21 @@ def test_cp_with_fused_attention(
if not fused_attn_supported:
pytest.skip("No attention backend available.")

deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
if deterministic:
if config.softmax_type != "vanilla":
pytest.skip(
"Deterministic mode does not support non-vanilla softmax with FusedAttention"
)
if config.attn_bias_type == "post_scale_bias" and is_training:
pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad")
if qkv_format == "thd" and config.num_heads >= 20 and get_device_compute_capability() == (9, 0):
pytest.skip(
"Deterministic FusedAttention backward with THD format OOMs on sm90"
" for this particular test config since cuDNN reserves memory"
" proportional to bHSS (known cuDNN issue)."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation for this makes sense to me but seems like the way we are skipping the test is viewing it from a slightly narrow lens. What I mean by that is the main issue is total memory (bhSS) but we seem to be guarding on head dims only

This skip guard would not be correct if tomorrow someone were to add a test with small b,S and H>20 (IIUC) - it almost makes it seem that the issue is the num_heads rather than the total memory

Is there a better way to do this ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — gated on the actual b*H*S*S product instead of num_heads in d3bd4e4. Threshold of 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS 1.07B–4.29B) and lets the smaller configs (cp_1_0/cp_2_1/cp_2_4/cp_3_2/cp_3_4, all ~0.40B) keep running. Local det+nondet still 33/0 + 45/0 with 5 OOM skips fired by the new gate.


run_distributed(
get_bash_arguments(
num_gpus_per_node=num_gpus,
Expand Down
Loading
Loading