Skip to content

Commit b0a3c64

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6650d86 commit b0a3c64

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,17 @@ def test_dot_product_attention(
198198
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
199199

200200
# Flash Attention requires bottom-right-diagonal causal mask for cross-attention
201-
cross_attn_causal = (
202-
config.max_seqlen_q != config.max_seqlen_kv
203-
and config.attn_mask_type in ["causal", "padding_causal"]
204-
)
201+
cross_attn_causal = config.max_seqlen_q != config.max_seqlen_kv and config.attn_mask_type in [
202+
"causal",
203+
"padding_causal",
204+
]
205205
sm = get_device_compute_capability()
206206
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only).
207207
# FA2 does not support pad_between_seqs and is not available on SM >= 100.
208208
if not cross_attn_causal and (
209-
pad_between_seqs and FlashAttentionUtils.v3_is_installed and sm == (9, 0)
209+
pad_between_seqs
210+
and FlashAttentionUtils.v3_is_installed
211+
and sm == (9, 0)
210212
or not pad_between_seqs
211213
and FlashAttentionUtils.is_installed
212214
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)

0 commit comments

Comments
 (0)