|
53 | 53 | ) |
54 | 54 |
|
55 | 55 | _current_file = pathlib.Path(__file__).resolve() |
56 | | -sys.path.append(str(_current_file.parent.parent)) |
| 56 | +sys.path = [str(_current_file.parent.parent)] + sys.path |
57 | 57 | from utils import ( |
58 | 58 | reset_rng_states, |
59 | 59 | compare_and_assert, |
@@ -362,6 +362,139 @@ def test_dpa_num_splits(dtype, model_configs, model): |
362 | 362 | ) |
363 | 363 |
|
364 | 364 |
|
| 365 | +# ============================== |
| 366 | +# Flash Attention 4 (FA4) tests |
| 367 | +# ============================== |
| 368 | + |
| 369 | +model_configs_fa4_base = { |
| 370 | + # test: ModelConfig(b, sq, hq, dqk) |
| 371 | + # Standard head dims |
| 372 | + "fa4_base_1": ModelConfig(4, 128, 16, 64), |
| 373 | + "fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), |
| 374 | + "fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"), |
| 375 | + # GQA |
| 376 | + "fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), |
| 377 | + "fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"), |
| 378 | + # num_splits |
| 379 | + "fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2), |
| 380 | + "fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4), |
| 381 | +} |
| 382 | + |
| 383 | + |
| 384 | +@pytest.mark.skipif( |
| 385 | + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." |
| 386 | +) |
| 387 | +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
| 388 | +@pytest.mark.parametrize("dtype", param_types_lean) |
| 389 | +@pytest.mark.parametrize("model_configs", [model_configs_fa4_base]) |
| 390 | +@pytest.mark.parametrize("model", model_configs_fa4_base.keys()) |
| 391 | +def test_dpa_fa4_base(dtype, model_configs, model): |
| 392 | + """Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits""" |
| 393 | + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) |
| 394 | + |
| 395 | + |
| 396 | +model_configs_fa4_mla = { |
| 397 | + # test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv) |
| 398 | + "fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64), |
| 399 | + "fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), |
| 400 | + "fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"), |
| 401 | + # dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV; |
| 402 | + # the backend filter should reject FA4 and fall back to another backend. |
| 403 | + "fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"), |
| 404 | + # DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case) |
| 405 | + "fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"), |
| 406 | +} |
| 407 | + |
| 408 | + |
| 409 | +@pytest.mark.skipif( |
| 410 | + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." |
| 411 | +) |
| 412 | +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
| 413 | +@pytest.mark.parametrize("dtype", param_types_lean) |
| 414 | +@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla]) |
| 415 | +@pytest.mark.parametrize("model", model_configs_fa4_mla.keys()) |
| 416 | +def test_dpa_fa4_mla(dtype, model_configs, model): |
| 417 | + """Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)""" |
| 418 | + test_dot_product_attention( |
| 419 | + dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False |
| 420 | + ) |
| 421 | + |
| 422 | + |
| 423 | +model_configs_fa4_swa = { |
| 424 | + # test: ModelConfig(b, sq, hq, dqk, window_size=(left, right)) |
| 425 | + "fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)), |
| 426 | + "fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)), |
| 427 | + "fa4_swa_3": ModelConfig( |
| 428 | + 2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0) |
| 429 | + ), |
| 430 | + "fa4_swa_4": ModelConfig( |
| 431 | + 2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0) |
| 432 | + ), |
| 433 | +} |
| 434 | + |
| 435 | + |
| 436 | +@pytest.mark.skipif( |
| 437 | + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." |
| 438 | +) |
| 439 | +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
| 440 | +@pytest.mark.parametrize("dtype", param_types_lean) |
| 441 | +@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa]) |
| 442 | +@pytest.mark.parametrize("model", model_configs_fa4_swa.keys()) |
| 443 | +@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"]) |
| 444 | +def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout): |
| 445 | + """Test DotProductAttention with FA4: sliding window attention""" |
| 446 | + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) |
| 447 | + |
| 448 | + |
| 449 | +model_configs_fa4_varlen = { |
| 450 | + # test: ModelConfig(b, sq, hq, dqk) |
| 451 | + "fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"), |
| 452 | + "fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"), |
| 453 | + "fa4_varlen_3": ModelConfig( |
| 454 | + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal" |
| 455 | + ), |
| 456 | + "fa4_varlen_4": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), |
| 457 | +} |
| 458 | + |
| 459 | + |
| 460 | +@pytest.mark.skipif( |
| 461 | + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." |
| 462 | +) |
| 463 | +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
| 464 | +@pytest.mark.parametrize("dtype", param_types_lean) |
| 465 | +@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen]) |
| 466 | +@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys()) |
| 467 | +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"]) |
| 468 | +def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout): |
| 469 | + """Test DotProductAttention with FA4: variable-length sequences (varlen/thd)""" |
| 470 | + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) |
| 471 | + |
| 472 | + |
| 473 | +model_configs_fa4_mask = { |
| 474 | + # test: ModelConfig(b, sq, hq, dqk) |
| 475 | + "fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128), |
| 476 | + "fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"), |
| 477 | + "fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"), |
| 478 | + "fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"), |
| 479 | + "fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"), |
| 480 | + "fa4_mask_padding_causal_br": ModelConfig( |
| 481 | + 2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right" |
| 482 | + ), |
| 483 | +} |
| 484 | + |
| 485 | + |
| 486 | +@pytest.mark.skipif( |
| 487 | + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." |
| 488 | +) |
| 489 | +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
| 490 | +@pytest.mark.parametrize("dtype", param_types_lean) |
| 491 | +@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask]) |
| 492 | +@pytest.mark.parametrize("model", model_configs_fa4_mask.keys()) |
| 493 | +def test_dpa_fa4_mask(dtype, model_configs, model): |
| 494 | + """Test DotProductAttention with FA4: various attention mask types""" |
| 495 | + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) |
| 496 | + |
| 497 | + |
365 | 498 | model_configs_softmax = { |
366 | 499 | # test: ModelConfig(b, sq, hq, dqk) |
367 | 500 | "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), |
|
0 commit comments