Skip to content

Commit 6c94b36

Browse files
fix flash_attn_supported override for cross-attention causal mask
Factor out cross_attn_causal check to avoid no-backend errors when FA3 is installed but flash attention doesn't support non-bottom-right causal mask with different Q/KV sequence lengths.
1 parent fc26d1c commit 6c94b36

1 file changed

Lines changed: 11 additions & 12 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,21 +197,20 @@ def test_dot_product_attention(
197197
)
198198
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
199199

200-
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k.
201-
# FA2 does not support pad_between_seqs
202-
# Flash Attention is not supported on SM > 90
203-
if (
204-
pad_between_seqs
205-
and FlashAttentionUtils.v3_is_installed
206-
and get_device_compute_capability() == (9, 0)
200+
# Flash Attention requires bottom-right-diagonal causal mask for cross-attention
201+
cross_attn_causal = (
202+
config.max_seqlen_q != config.max_seqlen_kv
203+
and config.attn_mask_type in ["causal", "padding_causal"]
204+
)
205+
sm = get_device_compute_capability()
206+
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only).
207+
# FA2 does not support pad_between_seqs and is not available on SM >= 100.
208+
if not cross_attn_causal and (
209+
pad_between_seqs and FlashAttentionUtils.v3_is_installed and sm == (9, 0)
207210
or not pad_between_seqs
208211
and FlashAttentionUtils.is_installed
209-
and not (
210-
config.max_seqlen_q != config.max_seqlen_kv
211-
and config.attn_mask_type in ["causal", "padding_causal"]
212-
)
213212
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
214-
and get_device_compute_capability() < (10, 0)
213+
and sm < (10, 0)
215214
):
216215
flash_attn_supported = True
217216

0 commit comments

Comments
 (0)