File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -198,15 +198,17 @@ def test_dot_product_attention(
198198 flash_attn_supported , fused_attn_supported , unfused_attn_supported = available_backends
199199
200200 # Flash Attention requires bottom-right-diagonal causal mask for cross-attention
201- cross_attn_causal = (
202- config . max_seqlen_q != config . max_seqlen_kv
203- and config . attn_mask_type in [ "causal" , " padding_causal"]
204- )
201+ cross_attn_causal = config . max_seqlen_q != config . max_seqlen_kv and config . attn_mask_type in [
202+ "causal" ,
203+ " padding_causal",
204+ ]
205205 sm = get_device_compute_capability ()
206206 # FA3 natively supports pad_between_seqs via seqused_q/seqused_k (SM90 only).
207207 # FA2 does not support pad_between_seqs and is not available on SM >= 100.
208208 if not cross_attn_causal and (
209- pad_between_seqs and FlashAttentionUtils .v3_is_installed and sm == (9 , 0 )
209+ pad_between_seqs
210+ and FlashAttentionUtils .v3_is_installed
211+ and sm == (9 , 0 )
210212 or not pad_between_seqs
211213 and FlashAttentionUtils .is_installed
212214 and (config .window_size [0 ] == - 1 or FlashAttentionUtils .v2_3_plus )
You can’t perform that action at this time.
0 commit comments