Skip to content

Commit 439caa9

Browse files
fix test skips for FA3 pad_between_seqs and deterministic CP tests
- test_attention.py: Guard flash_attn_supported override for thd+pad_between_seqs to require FA3 installed + SM90. FA2 path retained for non-pad_between_seqs. - test_attention_with_cp.py: Skip fused attention CP tests in deterministic mode for post_scale_bias (requires_grad) and non-vanilla softmax configs, which have no deterministic cuDNN backend available.
1 parent fd24692 commit 439caa9

2 files changed

Lines changed: 18 additions & 3 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,13 @@ def test_dot_product_attention(
198198
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
199199

200200
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k.
201-
# FA2 does not support pad_between_seqs, but _run_dot_product_attention
202-
# manually pads and unpads the input and output of FlashAttention for testing purposes.
203-
# Flash Attention is not supported on SM100+
201+
# FA2 does not support pad_between_seqs
202+
# Flash Attention is not supported on SM > 90
204203
if (
205204
pad_between_seqs
205+
and FlashAttentionUtils.v3_is_installed
206+
and get_device_compute_capability() == (9, 0)
207+
or not pad_between_seqs
206208
and FlashAttentionUtils.is_installed
207209
and not (
208210
config.max_seqlen_q != config.max_seqlen_kv

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,22 @@ def test_cp_with_fused_attention(
374374
is_training=is_training,
375375
)
376376
_, fused_attn_supported, _ = available_backends
377+
378+
# Skip any tests if not supported by the configs
377379
if not fused_attn_supported:
378380
pytest.skip("No attention backend available.")
379381

382+
deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
383+
if deterministic:
384+
if config.softmax_type != "vanilla":
385+
pytest.skip(
386+
"Deterministic mode does not support non-vanilla softmax with FusedAttention"
387+
)
388+
if config.attn_bias_type == "post_scale_bias" and is_training:
389+
pytest.skip(
390+
"Deterministic mode does not support post_scale_bias with requires_grad"
391+
)
392+
380393
run_distributed(
381394
get_bash_arguments(
382395
num_gpus_per_node=num_gpus,

0 commit comments

Comments
 (0)