Skip to content

Commit 94e90b9

Browse files
committed
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-cublasmp
2 parents 1f2710b + f8bda5d commit 94e90b9

52 files changed

Lines changed: 3661 additions & 1892 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# See LICENSE for license information.
44

55
import argparse
6+
import os
67
import torch
78
import torch.utils.benchmark as benchmark
89
import pandas as pd
@@ -185,6 +186,8 @@ def run_benchmark_linear(
185186
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
186187
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
187188
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
189+
if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))):
190+
m_splits = torch.tensor(m_splits, dtype=torch.int64, device=device)
188191
# Bias is not supported for GroupedLinear benchmark
189192
bias = None
190193

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_P
2929
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py"
3030
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
3131
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
32+
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py"
3233
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
3334
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
3435
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"

tests/pytorch/attention/test_attention.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,36 @@ def test_dpa_num_splits(dtype, model_configs, model):
344344
@pytest.mark.skipif(
345345
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
346346
)
347-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
348347
@pytest.mark.parametrize("dtype", param_types_lean)
349348
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
350349
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
351350
def test_dpa_fa4_base(dtype, model_configs, model):
352-
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
351+
"""Test DotProductAttention with FA4: base configs, GQA, num_splits"""
352+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
353+
354+
355+
# head_dim=256 is supported only on SM100 via FA4's dedicated kernel
356+
# (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10.
357+
# On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and
358+
# the test would silently fall back to another backend — defeating the purpose. Gate
359+
# explicitly so the CI signal is unambiguous.
360+
model_configs_fa4_hdim256 = {
361+
"fa4_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"),
362+
}
363+
364+
365+
@pytest.mark.skipif(
366+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
367+
)
368+
@pytest.mark.skipif(
369+
device_compute_capability not in ((10, 0), (10, 3)),
370+
reason="FA4 head_dim=256 dedicated kernel is SM100/103-only.",
371+
)
372+
@pytest.mark.parametrize("dtype", param_types_lean)
373+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_hdim256])
374+
@pytest.mark.parametrize("model", model_configs_fa4_hdim256.keys())
375+
def test_dpa_fa4_hdim256(dtype, model_configs, model):
376+
"""Test DotProductAttention with FA4: head_dim=256 dedicated kernel on SM100"""
353377
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
354378

355379

@@ -369,7 +393,6 @@ def test_dpa_fa4_base(dtype, model_configs, model):
369393
@pytest.mark.skipif(
370394
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
371395
)
372-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
373396
@pytest.mark.parametrize("dtype", param_types_lean)
374397
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
375398
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
@@ -396,7 +419,6 @@ def test_dpa_fa4_mla(dtype, model_configs, model):
396419
@pytest.mark.skipif(
397420
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
398421
)
399-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
400422
@pytest.mark.parametrize("dtype", param_types_lean)
401423
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
402424
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
@@ -420,7 +442,6 @@ def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
420442
@pytest.mark.skipif(
421443
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
422444
)
423-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
424445
@pytest.mark.parametrize("dtype", param_types_lean)
425446
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
426447
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
@@ -446,7 +467,6 @@ def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
446467
@pytest.mark.skipif(
447468
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
448469
)
449-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
450470
@pytest.mark.parametrize("dtype", param_types_lean)
451471
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
452472
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())

0 commit comments

Comments
 (0)