@@ -344,12 +344,36 @@ def test_dpa_num_splits(dtype, model_configs, model):
344344@pytest .mark .skipif (
345345 not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
346346)
347- @pytest .mark .skipif (get_cudnn_version () < (8 , 9 , 1 ), reason = "cuDNN 8.9.1+ is required." )
348347@pytest .mark .parametrize ("dtype" , param_types_lean )
349348@pytest .mark .parametrize ("model_configs" , [model_configs_fa4_base ])
350349@pytest .mark .parametrize ("model" , model_configs_fa4_base .keys ())
351350def test_dpa_fa4_base (dtype , model_configs , model ):
352- """Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
351+ """Test DotProductAttention with FA4: base configs, GQA, num_splits"""
352+ test_dot_product_attention (dtype , model_configs , model , False , True , None , False , False )
353+
354+
355+ # head_dim=256 is supported only on SM100 via FA4's dedicated kernel
356+ # (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10.
357+ # On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and
358+ # the test would silently fall back to another backend — defeating the purpose. Gate
359+ # explicitly so the CI signal is unambiguous.
360+ model_configs_fa4_hdim256 = {
361+ "fa4_hdim256" : ModelConfig (2 , 1024 , 8 , 256 , attn_mask_type = "causal" ),
362+ }
363+
364+
365+ @pytest .mark .skipif (
366+ not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
367+ )
368+ @pytest .mark .skipif (
369+ device_compute_capability not in ((10 , 0 ), (10 , 3 )),
370+ reason = "FA4 head_dim=256 dedicated kernel is SM100/103-only." ,
371+ )
372+ @pytest .mark .parametrize ("dtype" , param_types_lean )
373+ @pytest .mark .parametrize ("model_configs" , [model_configs_fa4_hdim256 ])
374+ @pytest .mark .parametrize ("model" , model_configs_fa4_hdim256 .keys ())
375+ def test_dpa_fa4_hdim256 (dtype , model_configs , model ):
376+ """Test DotProductAttention with FA4: head_dim=256 dedicated kernel on SM100"""
353377 test_dot_product_attention (dtype , model_configs , model , False , True , None , False , False )
354378
355379
@@ -369,7 +393,6 @@ def test_dpa_fa4_base(dtype, model_configs, model):
369393@pytest .mark .skipif (
370394 not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
371395)
372- @pytest .mark .skipif (get_cudnn_version () < (8 , 9 , 1 ), reason = "cuDNN 8.9.1+ is required." )
373396@pytest .mark .parametrize ("dtype" , param_types_lean )
374397@pytest .mark .parametrize ("model_configs" , [model_configs_fa4_mla ])
375398@pytest .mark .parametrize ("model" , model_configs_fa4_mla .keys ())
@@ -396,7 +419,6 @@ def test_dpa_fa4_mla(dtype, model_configs, model):
396419@pytest .mark .skipif (
397420 not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
398421)
399- @pytest .mark .skipif (get_cudnn_version () < (8 , 9 , 1 ), reason = "cuDNN 8.9.1+ is required." )
400422@pytest .mark .parametrize ("dtype" , param_types_lean )
401423@pytest .mark .parametrize ("model_configs" , [model_configs_fa4_swa ])
402424@pytest .mark .parametrize ("model" , model_configs_fa4_swa .keys ())
@@ -420,7 +442,6 @@ def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
420442@pytest .mark .skipif (
421443 not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
422444)
423- @pytest .mark .skipif (get_cudnn_version () < (8 , 9 , 1 ), reason = "cuDNN 8.9.1+ is required." )
424445@pytest .mark .parametrize ("dtype" , param_types_lean )
425446@pytest .mark .parametrize ("model_configs" , [model_configs_fa4_varlen ])
426447@pytest .mark .parametrize ("model" , model_configs_fa4_varlen .keys ())
@@ -446,7 +467,6 @@ def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
446467@pytest .mark .skipif (
447468 not FlashAttentionUtils .v4_is_installed , reason = "Flash-attn v4 (flash-attn-4) is required."
448469)
449- @pytest .mark .skipif (get_cudnn_version () < (8 , 9 , 1 ), reason = "cuDNN 8.9.1+ is required." )
450470@pytest .mark .parametrize ("dtype" , param_types_lean )
451471@pytest .mark .parametrize ("model_configs" , [model_configs_fa4_mask ])
452472@pytest .mark .parametrize ("model" , model_configs_fa4_mask .keys ())
0 commit comments