Skip to content

Commit 18c802a

Browse files
fix flash_attn_supported override for large head_dim configs
The override in test_dot_product_attention unconditionally forced flash_attn_supported=True for pad_between_seqs=False configs, including base_5_*/base_6_* (head_dim 512/1024) where both FA2 and FA3 reject head_dim > 256. This caused 960 "no backend available" failures across A100, H100, and L40 in pipeline 48086204. Restrict the override to pad_between_seqs=True only, which is the case where FA3 supports the feature via seqused_q/seqused_k but the backend checker doesn't know about it yet. For pad_between_seqs=False, trust get_available_attention_backends() as-is. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1 parent 5fceae2 commit 18c802a

1 file changed

Lines changed: 11 additions & 17 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,18 @@ def test_dot_product_attention(
197197
)
198198
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
199199

200-
# Flash Attention requires bottom-right-diagonal causal mask for cross-attention
201-
cross_attn_causal = config.max_seqlen_q != config.max_seqlen_kv and config.attn_mask_type in [
202-
"causal",
203-
"padding_causal",
204-
]
205-
sm = get_device_compute_capability()
206200
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only).
207-
# FA2 does not support pad_between_seqs and is not available on SM >= 100.
208-
if not cross_attn_causal and (
209-
pad_between_seqs
210-
and FlashAttentionUtils.v3_is_installed
211-
and sm == (9, 0)
212-
or not pad_between_seqs
213-
and FlashAttentionUtils.is_installed
214-
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
215-
and sm < (10, 0)
216-
):
217-
flash_attn_supported = True
201+
# Override flash_attn_supported only for pad_between_seqs=True because
202+
# get_available_attention_backends doesn't know about FA3's seqused support yet.
203+
# For pad_between_seqs=False, trust the backend checker's result as-is.
204+
if pad_between_seqs:
205+
cross_attn_causal = (
206+
config.max_seqlen_q != config.max_seqlen_kv
207+
and config.attn_mask_type in ["causal", "padding_causal"]
208+
)
209+
sm = get_device_compute_capability()
210+
if not cross_attn_causal and FlashAttentionUtils.v3_is_installed and sm == (9, 0):
211+
flash_attn_supported = True
218212

219213
# Skip if only unfused backend is supported
220214
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:

0 commit comments

Comments
 (0)