From 0465d676032b7a3e751342e89a2464c4b3efde50 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 8 Jun 2026 22:00:07 +0000 Subject: [PATCH 01/11] [PyTorch] Move advanced grouped MLP tests from test_fusible_ops to test_grouped_linear test_fusible_ops.py is the general op-fuser test suite; detailed tests for grouped-linear-specific features belong in test_grouped_linear.py. Moved to test_grouped_linear.py: - test_grouped_linear_cuda_graph_safe (CUDA graph capture) - test_grouped_mlp_single_weight_numerics (single_grouped_weight equivalence) - test_grouped_mlp_overwrite_main_grad (MegatronFSDP overwrite convention) - test_grouped_mlp_cuda_graph_safe_mxfp8 (CUDA graph + MXFP8) Kept in test_fusible_ops.py with reduced parametrization: - test_grouped_linear: dropped single_grouped_weight/bias and delay_wgrad_compute axes (hardcoded to False) - test_grouped_mlp: dropped single_grouped_weight/bias, accumulate_into_main_grad, and delay_wgrad_compute axes; reduced hidden_size to a single value Also adds NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 and NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 to the test_grouped_linear.py invocation in the QA script, matching the env vars already set for test_fusible_ops.py and required by the moved tests. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_fusible_ops.py | 794 +-------------------------- tests/pytorch/test_grouped_linear.py | 771 ++++++++++++++++++++++++++ 3 files changed, 780 insertions(+), 787 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 2485e84b5e..08530be3ca 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -29,7 +29,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_P python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 14a52249f3..3c7b2c90d2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2079,9 +2079,6 @@ def test_dropout( @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) - @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) - @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("single_grouped_bias", (False, True)) def test_grouped_linear( self, *, @@ -2096,9 +2093,9 @@ def test_grouped_linear( quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, - delay_wgrad_compute: bool, - single_grouped_weight: bool, - single_grouped_bias: bool, + delay_wgrad_compute: bool = False, + single_grouped_weight: bool = False, + single_grouped_bias: bool = False, ) -> None: """Grouped GEMM""" if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( @@ -2280,191 +2277,6 @@ def test_grouped_linear( else: assert b_test.grad is None - @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) - @pytest.mark.parametrize( - "quantization", - [None] + (["mxfp8"] if mxfp8_available else []), - ) - @pytest.mark.parametrize("quantized_weight", (False, True)) - @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("single_grouped_bias", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) - def test_grouped_linear_cuda_graph_safe( - self, - *, - dtype: torch.dtype, - quantization: Optional[str], - quantized_weight: bool, - bias: bool, - single_grouped_weight: bool, - single_grouped_bias: bool, - accumulate_into_main_grad: bool, - device: torch.device = "cuda", - group_size: int = 4, - in_features: int = 128, - out_features: int = 128, - split_alignment: int = 128, - token_padding: int = 256, - ) -> None: - """GroupedLinear forward+backward should be CUDA graph capturable. - - Exercises the grouped-tensor / cublas-grouped-gemm path which uses - GPU-resident split offsets and is the only flow safe to capture. - """ - if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( - single_grouped_weight or single_grouped_bias - ): - pytest.skip( - "single_grouped_weight/single_grouped_bias requires" - " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" - ) - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") - # Skip invalid configurations - if quantization is None and quantized_weight: - pytest.skip("quantized_weight requires a quantization recipe") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - - # Split sizes (statically pinned for graph capture) - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) - # Pad input tokens to validate the sync-free flow - in_shape = (split_sizes.sum().item() + token_padding, in_features) - out_shape = (in_shape[0], out_features) - - recipe = make_recipe(quantization) - with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): - op = te_ops.GroupedLinear( - group_size, - in_features, - out_features, - bias=bias, - device=device, - dtype=dtype, - accumulate_into_main_grad=accumulate_into_main_grad, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - ) - - def _weight_params() -> list[torch.nn.Parameter]: - if single_grouped_weight: - return [op.weight] - return [getattr(op, f"weight{i}") for i in range(group_size)] - - def _bias_params() -> list[torch.nn.Parameter]: - if not bias: - return [] - if single_grouped_bias: - return [op.bias] - return [getattr(op, f"bias{i}") for i in range(group_size)] - - def _init_main_grads(value: float = 0.0) -> None: - if not accumulate_into_main_grad: - return - with torch.no_grad(): - for w in _weight_params(): - if getattr(w, "main_grad", None) is None: - w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) - w.main_grad.fill_(value) - - def _collect_main_grads() -> list[torch.Tensor]: - return [w.main_grad.detach().clone() for w in _weight_params()] - - def _zero_param_grads() -> None: - for param in op.parameters(): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.zero_() - - static_split_sizes = split_sizes.clone() - - def train_step( - x: torch.Tensor, - dy: torch.Tensor, - out_buf: torch.Tensor, - *, - use_graphed: bool, - ) -> torch.Tensor: - with te.autocast(enabled=quantization is not None, recipe=recipe): - out = ( - graphed_module(x, static_split_sizes) - if use_graphed - else op(x, static_split_sizes) - ) - out.backward(dy) - out_buf.copy_(out) - return out_buf - - _init_main_grads(0.0) - - static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) - static_dy = torch.randn(out_shape, device=device, dtype=dtype) - static_out_buf = torch.empty(out_shape, device=device, dtype=dtype) - - graphed_module = te.make_graphed_callables( - op, - (static_x, static_split_sizes), - num_warmup_iters=3, - enabled=quantization is not None, - recipe=recipe, - ) - - # Replace static buffers with fresh data (graph captures must replay - # against new inputs without re-recording). - fresh_x = torch.randn_like(static_x) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_dy.copy_(fresh_dy) - - # Reset grads & main_grads so the captured iteration starts fresh. - _zero_param_grads() - _init_main_grads(0.5) - if static_x.grad is not None: - static_x.grad.zero_() - - # Replay the graph - graph_out = ( - train_step(static_x, static_dy, static_out_buf, use_graphed=True).detach().clone() - ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - if accumulate_into_main_grad: - graph_main_grads = _collect_main_grads() - graph_param_grads: list[torch.Tensor] = [] - else: - graph_main_grads = [] - graph_param_grads = [param.grad.detach().clone() for param in op.parameters()] - - # Reference: same op invoked eagerly with the same fresh inputs and - # the same starting grad/main_grad state. - _zero_param_grads() - _init_main_grads(0.5) - static_x.grad.zero_() - - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with te.autocast(enabled=quantization is not None, recipe=recipe): - expected_out = op(expected_x, static_split_sizes) - expected_out.backward(expected_dy) - - tols = dtype_tols(dtype) - if quantization is not None: - tols = quantization_tols(quantization) - - assert_close(graph_out, expected_out, **tols) - assert_close(graph_dx, expected_x.grad, **tols) - if accumulate_into_main_grad: - for g, w in zip(graph_main_grads, _weight_params()): - assert_close(g, w.main_grad, **tols) - else: - for g, param in zip(graph_param_grads, op.parameters()): - assert_close(g, param.grad, **tols) - @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("scales_requires_grad", (False, True)) @@ -3696,12 +3508,8 @@ def test_layernorm_mlp( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) - @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("single_grouped_bias", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) - @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) - @pytest.mark.parametrize("hidden_size", (128, 256)) + @pytest.mark.parametrize("hidden_size", (128,)) @pytest.mark.parametrize( "activation", ( @@ -3719,13 +3527,13 @@ def test_grouped_mlp( hidden_size: int, dtype: torch.dtype, quantization: Optional[str], - single_grouped_weight: bool, - single_grouped_bias: bool, - accumulate_into_main_grad: bool, + single_grouped_weight: bool = False, + single_grouped_bias: bool = False, + accumulate_into_main_grad: bool = False, device: torch.device = "cuda", split_alignment: int = 256, glu_interleave_size: Optional[int], - delay_wgrad_compute: bool, + delay_wgrad_compute: bool = False, activation: str, ) -> None: """GroupedLinear + scaled activation + GroupedLinear""" @@ -4109,592 +3917,6 @@ def _make_module(): assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) - @pytest.mark.parametrize( - "dtype", - tuple(dtype for dtype in _dtypes if dtype in (torch.float16, torch.bfloat16)), - ) - @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_grouped_mlp_single_weight_numerics( - self, - *, - dtype: torch.dtype, - bias: bool, - activation: str, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, - ) -> None: - """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" - - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - in_shape = (split_sizes.sum().item(), hidden_size) - recipe = make_recipe("mxfp8") - - x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) - dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - fc1_ws_base = [ - torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc2_ws_base = [ - torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc1_bs_base = ( - [ - torch.empty((2 * hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) - for _ in range(group_size) - ] - if bias - else None - ) - fc2_bs_base = ( - [ - torch.empty((hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) - for _ in range(group_size) - ] - if bias - else None - ) - - def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: - with te.quantized_model_init(enabled=True, recipe=recipe): - scaled_act = ( - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - ) - fc1 = te_ops.GroupedLinear( - group_size, - hidden_size, - 2 * hidden_size, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - ) - fc2 = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - scale_bias=bias, - ) - module = te_ops.Sequential(fc1, scaled_act, fc2) - - with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() - for group_idx in range(group_size): - if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) - if bias: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_base[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_base[group_idx]) - - x = x_base.detach().clone().requires_grad_(True) - probs = probs_base.detach().clone().requires_grad_(True) - dy = dy_base.detach().clone() - - with te.autocast(enabled=True, recipe=recipe): - fc2_extra = (split_sizes, probs) if bias else (split_sizes,) - y = module(x, split_sizes, probs, *fc2_extra) - y.backward(dy) - - forward_ops = module._module_groups[0]._forward_ops - backward_ops = module._module_groups[0]._backward_ops - assert len(forward_ops) == 1 - assert isinstance( - forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, - ) - assert len(backward_ops) == 1 - assert isinstance( - backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, - ) - - if single_grouped_weight: - fc1_dw = fc1.weight.grad.detach().clone() - fc2_dw = fc2.weight.grad.detach().clone() - else: - fc1_dw = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_dw = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - - fc1_db = None - fc2_db = None - if bias: - fc1_db = torch.stack( - [ - getattr(fc1, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_db = torch.stack( - [ - getattr(fc2, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - - return ( - y.detach().clone(), - x.grad.detach().clone(), - probs.grad.detach().clone(), - fc1_dw, - fc2_dw, - fc1_db, - fc2_db, - ) - - ( - y_false, - dx_false, - dprobs_false, - fc1_dw_false, - fc2_dw_false, - fc1_db_false, - fc2_db_false, - ) = _run_case(False) - ( - y_true, - dx_true, - dprobs_true, - fc1_dw_true, - fc2_dw_true, - fc1_db_true, - fc2_db_true, - ) = _run_case(True) - - torch.testing.assert_close(y_false, y_true, rtol=0, atol=0) - torch.testing.assert_close(dx_false, dx_true, rtol=0, atol=0) - torch.testing.assert_close(dprobs_false, dprobs_true, rtol=0, atol=0) - torch.testing.assert_close(fc1_dw_false, fc1_dw_true, rtol=0, atol=0) - torch.testing.assert_close(fc2_dw_false, fc2_dw_true, rtol=0, atol=0) - if bias: - bias_tols = {"rtol": 0.05, "atol": 0.015625} - torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) - torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) - - @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) - @pytest.mark.parametrize("zero_out_wgrad", (False, True)) - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_grouped_mlp_overwrite_main_grad( - self, - *, - single_grouped_weight: bool, - delay_wgrad_compute: bool, - zero_out_wgrad: bool, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, - ) -> None: - """End-to-end check that the fused grouped-MLP backward writes the - wgrad into ``weight.main_grad`` correctly under the MegatronFSDP - ``overwrite_main_grad=True`` convention. - ``test_grouped_mlp`` already covers the standard Megatron-LM - ``fuse_wgrad_accumulation`` (DDP) path where the wgrad GEMM - *accumulates* into ``main_grad``. This test focuses exclusively on - the MegatronFSDP variant where the wgrad GEMM must *overwrite* - ``main_grad`` (because FSDP has already ReduceScattered the previous - accumulation), so ``main_grad`` after backward equals ``wgrad`` - regardless of the prior contents. - - Also exercises the MegatronFSDP ``zero_out_wgrad`` flag, which is - independent of ``main_grad`` and only controls whether the dummy - ``param.grad`` returned to autograd is zeroed (so downstream hooks - that read ``.grad`` don't see stale bytes from the cached dummy). - """ - - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - - recipe = make_recipe("mxfp8") - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - in_shape = (split_sizes.sum().item(), hidden_size) - x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) - dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - fc1_ws_base = [ - torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc2_ws_base = [ - torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - - def _build_module(*, accumulate_into_main_grad: bool): - with te.quantized_model_init(enabled=True, recipe=recipe): - fc1 = te_ops.GroupedLinear( - group_size, - hidden_size, - 2 * hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) - fc2 = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) - scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - module = te_ops.Sequential(fc1, scaled_act, fc2) - - with torch.no_grad(): - if single_grouped_weight: - fc1_weights = ( - fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() - ) - fc2_weights = ( - fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() - ) - for group_idx in range(group_size): - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) - else: - for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) - return module, fc1, fc2 - - def _weight_params(fc): - if single_grouped_weight: - return [fc.weight] - return [getattr(fc, f"weight{i}") for i in range(group_size)] - - def _run_backward(module, fc1, fc2): - x = x_base.detach().clone().requires_grad_(True) - probs = probs_base.detach().clone().requires_grad_(True) - with te.autocast(enabled=True, recipe=recipe): - y = module(x, split_sizes, probs, split_sizes) - y.backward(dy_base) - if delay_wgrad_compute: - fc1.backward_dw() - fc2.backward_dw() - - # Reference run: vanilla autograd, no Megatron protocol. - ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) - _run_backward(ref_module, ref_fc1, ref_fc2) - ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] - ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] - - # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). - # NaN sentinel makes a missed write loud (would surface as NaN diff). - test_module, test_fc1, test_fc2 = _build_module(accumulate_into_main_grad=True) - for fc in (test_fc1, test_fc2): - MegatronTrainingHelper.init_main_grad_buffers( - _weight_params(fc), - fill_value=float("nan"), - overwrite_main_grad=True, - zero_out_wgrad=zero_out_wgrad, - ) - _run_backward(test_module, test_fc1, test_fc2) - - # main_grad must be overwritten to exactly the ref wgrad (bitwise: - # the wgrad GEMM is deterministic across the two runs because the - # quantized weights and inputs are identical). - MegatronTrainingHelper.verify_main_grad_accumulation( - _weight_params(test_fc1), expected_main_grads=ref_fc1_grads - ) - MegatronTrainingHelper.verify_main_grad_accumulation( - _weight_params(test_fc2), expected_main_grads=ref_fc2_grads - ) - - @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) - @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_grouped_mlp_cuda_graph_safe_mxfp8( - self, - *, - dtype: torch.dtype, - single_grouped_weight: bool, - accumulate_into_main_grad: bool, - activation: str, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, - token_padding: int = 2048, - ) -> None: - """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" - - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP is not supported on this system") - if dtype not in (torch.bfloat16, torch.float16): - pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") - - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - # Pad the input tokens to validate the sync-free MOE - in_shape = (split_sizes.sum().item() + token_padding, hidden_size) - recipe = make_recipe("mxfp8") - with te.quantized_model_init(enabled=True, recipe=recipe): - fc1 = te_ops.GroupedLinear( - group_size, - hidden_size, - 2 * hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, - ) - fc2 = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, - ) - scaled_act = ( - te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - ) - module = te_ops.Sequential( - fc1, - scaled_act, - fc2, - ) - - def _init_main_grads(value: float = 0.0) -> None: - if not accumulate_into_main_grad: - return - with torch.no_grad(): - if single_grouped_weight: - if getattr(fc1.weight, "main_grad", None) is None: - fc1.weight.main_grad = torch.empty( - fc1.weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2.weight, "main_grad", None) is None: - fc2.weight.main_grad = torch.empty( - fc2.weight.size(), - device=device, - dtype=torch.float32, - ) - fc1.weight.main_grad.fill_(value) - fc2.weight.main_grad.fill_(value) - else: - for group_idx in range(group_size): - fc1_weight = getattr(fc1, f"weight{group_idx}") - fc2_weight = getattr(fc2, f"weight{group_idx}") - if getattr(fc1_weight, "main_grad", None) is None: - fc1_weight.main_grad = torch.empty( - fc1_weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2_weight, "main_grad", None) is None: - fc2_weight.main_grad = torch.empty( - fc2_weight.size(), - device=device, - dtype=torch.float32, - ) - fc1_weight.main_grad.fill_(value) - fc2_weight.main_grad.fill_(value) - - def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: - if single_grouped_weight: - fc1_main_grad = fc1.weight.main_grad.detach().clone() - fc2_main_grad = fc2.weight.main_grad.detach().clone() - else: - fc1_main_grad = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_main_grad = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - return fc1_main_grad, fc2_main_grad - - static_split_sizes = split_sizes.clone() - - def train_step( - x: torch.Tensor, - probs: torch.Tensor, - dy: torch.Tensor, - out_buf: torch.Tensor, - *, - use_graphed: bool, - ) -> torch.Tensor: - with te.autocast(enabled=True, recipe=recipe): - out = ( - graphed_module(x, static_split_sizes, probs, static_split_sizes) - if use_graphed - else module(x, static_split_sizes, probs, static_split_sizes) - ) - out.backward(dy) - out_buf.copy_(out) - return out_buf - - _init_main_grads(0.0) - - static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) - static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) - static_dy = torch.randn(in_shape, device=device, dtype=dtype) - static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) - - graphed_module = te.make_graphed_callables( - module, - (static_x, static_split_sizes, static_probs, static_split_sizes), - num_warmup_iters=3, - enabled=True, - recipe=recipe, - ) - - forward_ops = module._module_groups[0]._forward_ops - backward_ops = module._module_groups[0]._backward_ops - assert len(forward_ops) == 1 - assert isinstance( - forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, - ) - assert len(backward_ops) == 1 - assert isinstance( - backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, - ) - - fresh_x = torch.randn_like(static_x) - fresh_probs = torch.randn_like(static_probs) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_probs.copy_(fresh_probs) - static_dy.copy_(fresh_dy) - - for param in module.parameters(): - param.grad = torch.zeros_like(param) - _init_main_grads(0.5) - if static_x.grad is not None: - static_x.grad.zero_() - if static_probs.grad is not None: - static_probs.grad.zero_() - - graph_out = ( - train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) - .detach() - .clone() - ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - graph_dprobs = static_probs.grad.detach().clone() - if accumulate_into_main_grad: - graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() - else: - graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] - - for param in module.parameters(): - param.grad.zero_() - _init_main_grads(0.5) - static_x.grad.zero_() - static_probs.grad.zero_() - - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_probs = fresh_probs.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with te.autocast(enabled=True, recipe=recipe): - expected_out = module( - expected_x, - static_split_sizes, - expected_probs, - static_split_sizes, - ) - expected_out.backward(expected_dy) - - tols = dtype_tols(dtype) - assert_close(graph_out, expected_out, **tols) - assert_close(graph_dx, expected_x.grad, **tols) - assert_close(graph_dprobs, expected_probs.grad, **tols) - if accumulate_into_main_grad: - expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() - assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) - assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) - else: - for graph_grad, param in zip(graph_param_grads, module.parameters()): - assert_close(graph_grad, param.grad, **tols) - class TestCustomOps: """Test with ops that are defined externally""" diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index c1a9e0a407..e2d107bc96 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -37,8 +37,11 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import transformer_engine_torch as tex from utils import ( + MegatronTrainingHelper, ModelConfig, assert_close, + make_recipe, + quantization_tols, recipe_id, reset_rng_states, skip_unsupported_backward_override, @@ -1912,3 +1915,771 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( swizzled_scales_buffer, expected_swizzled_scales_buffer, ) + + +@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) +@pytest.mark.parametrize( + "quantization", + [None] + (["mxfp8"] if mxfp8_available else []), +) +@pytest.mark.parametrize("quantized_weight", (False, True)) +@pytest.mark.parametrize("bias", (False, True)) +@pytest.mark.parametrize("single_grouped_weight", (False, True)) +@pytest.mark.parametrize("single_grouped_bias", (False, True)) +@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) +def test_grouped_linear_cuda_graph_safe( + *, + dtype: torch.dtype, + quantization: Optional[str], + quantized_weight: bool, + bias: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + group_size: int = 4, + in_features: int = 128, + out_features: int = 128, + split_alignment: int = 128, + token_padding: int = 256, +) -> None: + """GroupedLinear forward+backward should be CUDA graph capturable. + + Exercises the grouped-tensor / cublas-grouped-gemm path which uses + GPU-resident split offsets and is the only flow safe to capture. + """ + if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( + single_grouped_weight or single_grouped_bias + ): + pytest.skip( + "single_grouped_weight/single_grouped_bias requires" + " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" + ) + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") + # Skip invalid configurations + if quantization is None and quantized_weight: + pytest.skip("quantized_weight requires a quantization recipe") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + + # Split sizes (statically pinned for graph capture) + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + # Pad input tokens to validate the sync-free flow + in_shape = (split_sizes.sum().item() + token_padding, in_features) + out_shape = (in_shape[0], out_features) + + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te.ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) + + def _weight_params() -> list[torch.nn.Parameter]: + if single_grouped_weight: + return [op.weight] + return [getattr(op, f"weight{i}") for i in range(group_size)] + + def _bias_params() -> list[torch.nn.Parameter]: + if not bias: + return [] + if single_grouped_bias: + return [op.bias] + return [getattr(op, f"bias{i}") for i in range(group_size)] + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + for w in _weight_params(): + if getattr(w, "main_grad", None) is None: + w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) + w.main_grad.fill_(value) + + def _collect_main_grads() -> list[torch.Tensor]: + return [w.main_grad.detach().clone() for w in _weight_params()] + + def _zero_param_grads() -> None: + for param in op.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.zero_() + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=quantization is not None, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes) + if use_graphed + else op(x, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(out_shape, device=device, dtype=dtype) + static_out_buf = torch.empty(out_shape, device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + op, + (static_x, static_split_sizes), + num_warmup_iters=3, + enabled=quantization is not None, + recipe=recipe, + ) + + # Replace static buffers with fresh data (graph captures must replay + # against new inputs without re-recording). + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + # Reset grads & main_grads so the captured iteration starts fresh. + _zero_param_grads() + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + + # Replay the graph + graph_out = ( + train_step(static_x, static_dy, static_out_buf, use_graphed=True).detach().clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + if accumulate_into_main_grad: + graph_main_grads = _collect_main_grads() + graph_param_grads: list[torch.Tensor] = [] + else: + graph_main_grads = [] + graph_param_grads = [param.grad.detach().clone() for param in op.parameters()] + + # Reference: same op invoked eagerly with the same fresh inputs and + # the same starting grad/main_grad state. + _zero_param_grads() + _init_main_grads(0.5) + static_x.grad.zero_() + + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=quantization is not None, recipe=recipe): + expected_out = op(expected_x, static_split_sizes) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + if quantization is not None: + tols = quantization_tols(quantization) + + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + if accumulate_into_main_grad: + for g, w in zip(graph_main_grads, _weight_params()): + assert_close(g, w.main_grad, **tols) + else: + for g, param in zip(graph_param_grads, op.parameters()): + assert_close(g, param.grad, **tols) + + +@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) +@pytest.mark.parametrize("bias", (False, True)) +@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_grouped_mlp_single_weight_numerics( + *, + dtype: torch.dtype, + bias: bool, + activation: str, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, +) -> None: + """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") + if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + recipe = make_recipe("mxfp8") + + x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) + dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + fc1_ws_base = [ + torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc2_ws_base = [ + torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc1_bs_base = ( + [ + torch.empty((2 * hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) + for _ in range(group_size) + ] + if bias + else None + ) + fc2_bs_base = ( + [ + torch.empty((hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) + for _ in range(group_size) + ] + if bias + else None + ) + + def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: + with te.quantized_model_init(enabled=True, recipe=recipe): + scaled_act = ( + te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) + fc1 = te.ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + ) + fc2 = te.ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + scale_bias=bias, + ) + module = te.ops.Sequential(fc1, scaled_act, fc2) + + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_base[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_base[group_idx]) + + x = x_base.detach().clone().requires_grad_(True) + probs = probs_base.detach().clone().requires_grad_(True) + dy = dy_base.detach().clone() + + with te.autocast(enabled=True, recipe=recipe): + fc2_extra = (split_sizes, probs) if bias else (split_sizes,) + y = module(x, split_sizes, probs, *fc2_extra) + y.backward(dy) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, + ) + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, + ) + + if single_grouped_weight: + fc1_dw = fc1.weight.grad.detach().clone() + fc2_dw = fc2.weight.grad.detach().clone() + else: + fc1_dw = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_dw = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + + fc1_db = None + fc2_db = None + if bias: + fc1_db = torch.stack( + [ + getattr(fc1, f"bias{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_db = torch.stack( + [ + getattr(fc2, f"bias{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + + return ( + y.detach().clone(), + x.grad.detach().clone(), + probs.grad.detach().clone(), + fc1_dw, + fc2_dw, + fc1_db, + fc2_db, + ) + + ( + y_false, + dx_false, + dprobs_false, + fc1_dw_false, + fc2_dw_false, + fc1_db_false, + fc2_db_false, + ) = _run_case(False) + ( + y_true, + dx_true, + dprobs_true, + fc1_dw_true, + fc2_dw_true, + fc1_db_true, + fc2_db_true, + ) = _run_case(True) + + torch.testing.assert_close(y_false, y_true, rtol=0, atol=0) + torch.testing.assert_close(dx_false, dx_true, rtol=0, atol=0) + torch.testing.assert_close(dprobs_false, dprobs_true, rtol=0, atol=0) + torch.testing.assert_close(fc1_dw_false, fc1_dw_true, rtol=0, atol=0) + torch.testing.assert_close(fc2_dw_false, fc2_dw_true, rtol=0, atol=0) + if bias: + bias_tols = {"rtol": 0.05, "atol": 0.015625} + torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) + torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) + + +@pytest.mark.parametrize("single_grouped_weight", (False, True)) +@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) +@pytest.mark.parametrize("zero_out_wgrad", (False, True)) +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_grouped_mlp_overwrite_main_grad( + *, + single_grouped_weight: bool, + delay_wgrad_compute: bool, + zero_out_wgrad: bool, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, +) -> None: + """End-to-end check that the fused grouped-MLP backward writes the + wgrad into ``weight.main_grad`` correctly under the MegatronFSDP + ``overwrite_main_grad=True`` convention. + ``test_grouped_mlp`` already covers the standard Megatron-LM + ``fuse_wgrad_accumulation`` (DDP) path where the wgrad GEMM + *accumulates* into ``main_grad``. This test focuses exclusively on + the MegatronFSDP variant where the wgrad GEMM must *overwrite* + ``main_grad`` (because FSDP has already ReduceScattered the previous + accumulation), so ``main_grad`` after backward equals ``wgrad`` + regardless of the prior contents. + + Also exercises the MegatronFSDP ``zero_out_wgrad`` flag, which is + independent of ``main_grad`` and only controls whether the dummy + ``param.grad`` returned to autograd is zeroed (so downstream hooks + that read ``.grad`` don't see stale bytes from the cached dummy). + """ + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") + if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") + + recipe = make_recipe("mxfp8") + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) + dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + fc1_ws_base = [ + torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc2_ws_base = [ + torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + + def _build_module(*, accumulate_into_main_grad: bool): + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te.ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + fc2 = te.ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + scaled_act = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + module = te.ops.Sequential(fc1, scaled_act, fc2) + + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = ( + fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() + ) + fc2_weights = ( + fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() + ) + for group_idx in range(group_size): + fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + else: + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + return module, fc1, fc2 + + def _weight_params(fc): + if single_grouped_weight: + return [fc.weight] + return [getattr(fc, f"weight{i}") for i in range(group_size)] + + def _run_backward(module, fc1, fc2): + x = x_base.detach().clone().requires_grad_(True) + probs = probs_base.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=recipe): + y = module(x, split_sizes, probs, split_sizes) + y.backward(dy_base) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() + + # Reference run: vanilla autograd, no Megatron protocol. + ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) + _run_backward(ref_module, ref_fc1, ref_fc2) + ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] + ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] + + # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). + # NaN sentinel makes a missed write loud (would surface as NaN diff). + test_module, test_fc1, test_fc2 = _build_module(accumulate_into_main_grad=True) + for fc in (test_fc1, test_fc2): + MegatronTrainingHelper.init_main_grad_buffers( + _weight_params(fc), + fill_value=float("nan"), + overwrite_main_grad=True, + zero_out_wgrad=zero_out_wgrad, + ) + _run_backward(test_module, test_fc1, test_fc2) + + # main_grad must be overwritten to exactly the ref wgrad (bitwise: + # the wgrad GEMM is deterministic across the two runs because the + # quantized weights and inputs are identical). + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc1), expected_main_grads=ref_fc1_grads + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc2), expected_main_grads=ref_fc2_grads + ) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) +@pytest.mark.parametrize("single_grouped_weight", (False, True)) +@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) +@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +def test_grouped_mlp_cuda_graph_safe_mxfp8( + *, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + activation: str, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + token_padding: int = 2048, +) -> None: + """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") + if dtype not in (torch.bfloat16, torch.float16): + pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + # Pad the input tokens to validate the sync-free MOE + in_shape = (split_sizes.sum().item() + token_padding, hidden_size) + recipe = make_recipe("mxfp8") + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te.ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te.ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + scaled_act = ( + te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) + module = te.ops.Sequential( + fc1, + scaled_act, + fc2, + ) + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + if single_grouped_weight: + if getattr(fc1.weight, "main_grad", None) is None: + fc1.weight.main_grad = torch.empty( + fc1.weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2.weight, "main_grad", None) is None: + fc2.weight.main_grad = torch.empty( + fc2.weight.size(), + device=device, + dtype=torch.float32, + ) + fc1.weight.main_grad.fill_(value) + fc2.weight.main_grad.fill_(value) + else: + for group_idx in range(group_size): + fc1_weight = getattr(fc1, f"weight{group_idx}") + fc2_weight = getattr(fc2, f"weight{group_idx}") + if getattr(fc1_weight, "main_grad", None) is None: + fc1_weight.main_grad = torch.empty( + fc1_weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2_weight, "main_grad", None) is None: + fc2_weight.main_grad = torch.empty( + fc2_weight.size(), + device=device, + dtype=torch.float32, + ) + fc1_weight.main_grad.fill_(value) + fc2_weight.main_grad.fill_(value) + + def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: + if single_grouped_weight: + fc1_main_grad = fc1.weight.main_grad.detach().clone() + fc2_main_grad = fc2.weight.main_grad.detach().clone() + else: + fc1_main_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_main_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + return fc1_main_grad, fc2_main_grad + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + probs: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=True, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes, probs, static_split_sizes) + if use_graphed + else module(x, static_split_sizes, probs, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(in_shape, device=device, dtype=dtype) + static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + module, + (static_x, static_split_sizes, static_probs, static_split_sizes), + num_warmup_iters=3, + enabled=True, + recipe=recipe, + ) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, + ) + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, + ) + + fresh_x = torch.randn_like(static_x) + fresh_probs = torch.randn_like(static_probs) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_probs.copy_(fresh_probs) + static_dy.copy_(fresh_dy) + + for param in module.parameters(): + param.grad = torch.zeros_like(param) + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + if static_probs.grad is not None: + static_probs.grad.zero_() + + graph_out = ( + train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_dprobs = static_probs.grad.detach().clone() + if accumulate_into_main_grad: + graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() + else: + graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] + + for param in module.parameters(): + param.grad.zero_() + _init_main_grads(0.5) + static_x.grad.zero_() + static_probs.grad.zero_() + + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_probs = fresh_probs.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=True, recipe=recipe): + expected_out = module( + expected_x, + static_split_sizes, + expected_probs, + static_split_sizes, + ) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + assert_close(graph_dprobs, expected_probs.grad, **tols) + if accumulate_into_main_grad: + expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() + assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) + assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) + else: + for graph_grad, param in zip(graph_param_grads, module.parameters()): + assert_close(graph_grad, param.grad, **tols) From aada4aecffc8329c5e1004522065165338292e3c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 8 Jun 2026 23:43:50 +0000 Subject: [PATCH 02/11] [PyTorch] Clean up grouped MLP test in test_fusible_ops - Simplify test_grouped_mlp to a focused sanity test: - Hardcode ScaledSwiGLU (drop 4-way activation + glu_interleave_size axes) - Drop single_grouped_weight/bias, accumulate_into_main_grad, delay_wgrad_compute branches (all were defaulted False and never parametrized) - Remove fused-op dispatch assertions (ForwardGroupedMLP_CuTeGEMMGLU etc.) that required NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 - Switch quantization list from _grouped_mlp_quantization_list to _quantization_list (drops nvfp4_rht which was always skipped for SwiGLU) - Clean up test_grouped_linear: - Remove env-var skip for single_grouped_weight/bias (dead code: those params are not parametrized here, so the condition never triggered) - Switch assertions from torch.testing.assert_close with manual .to(float64) to assert_close / assert_close_grads utilities - Remove now-unused imports: _cudnn_frontend_*, is_glu_activation, MegatronTrainingHelper, _grouped_mlp_quantization_list - QA script: drop NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 and NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 from test_fusible_ops.py invocation Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_fusible_ops.py | 354 ++++-------------------------- 2 files changed, 47 insertions(+), 309 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 08530be3ca..93490d6f81 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -43,7 +43,7 @@ NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_L python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3c7b2c90d2..d38a8e2325 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -20,11 +20,6 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.ops._common import ( - _cudnn_frontend_supports_grouped_gemm_srelu, - _cudnn_frontend_version_supported, - is_glu_activation, -) from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, @@ -53,7 +48,6 @@ assert_close_grads, dtype_tols, make_recipe, - MegatronTrainingHelper, quantization_tols, reset_rng_states, ) @@ -80,10 +74,6 @@ if nvfp4_available: _quantization_list.append("nvfp4") _quantization_list.append("nvfp4_4over6") -_grouped_mlp_quantization_list = list(_quantization_list) -if nvfp4_available: - _grouped_mlp_quantization_list.append("nvfp4_rht") - @pytest.fixture(autouse=True, scope="function") def _reset_rng_states_per_test(): @@ -2098,13 +2088,7 @@ def test_grouped_linear( single_grouped_bias: bool = False, ) -> None: """Grouped GEMM""" - if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( - single_grouped_weight or single_grouped_bias - ): - pytest.skip( - "single_grouped_weight/single_grouped_bias requires" - " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" - ) + # Split sizes split_sizes = [split_alignment * i for i in range(group_size)] random.shuffle(split_sizes) @@ -2124,9 +2108,6 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") if ( @@ -2238,44 +2219,29 @@ def test_grouped_linear( tols = quantization_tols(quantization) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - if input_requires_grad: - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(dx_test, x_ref.grad, **tols) - else: - assert x_test.grad is None + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) if single_grouped_weight: if weight_requires_grad: - dw_test_all = op.weight.grad.to(dtype=torch.float64, device="cpu") w_ref_grad = torch.stack([w.grad for w in ws_ref], dim=0) - torch.testing.assert_close(dw_test_all, w_ref_grad, **tols) + assert_close(op.weight.grad, w_ref_grad, **tols) else: assert op.weight.grad is None else: for group_idx in range(group_size): w_test = getattr(op, f"weight{group_idx}") - if weight_requires_grad: - dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) - else: - assert w_test.grad is None + assert_close_grads(w_test, ws_ref[group_idx], **tols) if bias: if single_grouped_bias: if weight_requires_grad: - db_test_all = op.bias.grad.to(dtype=torch.float64, device="cpu") b_ref_grad = torch.stack([b.grad for b in bs_ref], dim=0) - torch.testing.assert_close(db_test_all, b_ref_grad, **tols) + assert_close(op.bias.grad, b_ref_grad, **tols) else: assert op.bias.grad is None else: for group_idx in range(group_size): b_test = getattr(op, f"bias{group_idx}") - if weight_requires_grad: - db_test = b_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) - else: - assert b_test.grad is None + assert_close_grads(b_test, bs_ref[group_idx], **tols) @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("input_requires_grad", (False, True)) @@ -3507,95 +3473,35 @@ def test_layernorm_mlp( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) - @pytest.mark.parametrize("glu_interleave_size", (None, 32)) - @pytest.mark.parametrize("hidden_size", (128,)) - @pytest.mark.parametrize( - "activation", - ( - "scaled_swiglu", - "scaled_clamped_qgeglu", - "scaled_clamped_qgeglu_custom", - "scaled_srelu", - ), - ) + @pytest.mark.parametrize("quantization", _quantization_list) def test_grouped_mlp( self, *, group_size: int = 4, + hidden_size: int = 128, bias: bool, - hidden_size: int, dtype: torch.dtype, quantization: Optional[str], - single_grouped_weight: bool = False, - single_grouped_bias: bool = False, - accumulate_into_main_grad: bool = False, device: torch.device = "cuda", split_alignment: int = 256, - glu_interleave_size: Optional[int], - delay_wgrad_compute: bool = False, - activation: str, ) -> None: - """GroupedLinear + scaled activation + GroupedLinear""" + """GroupedLinear + ScaledSwiGLU + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] random.shuffle(split_sizes) split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) - # Make input shape + # Tensor dimensions in_shape = (split_sizes.sum().item(), hidden_size) out_shape = in_shape + fc1_out_features = 2 * hidden_size # Skip invalid configurations with_quantization = quantization is not None - if activation == "scaled_swiglu": - scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - elif activation.startswith("scaled_clamped_qgeglu"): - scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - elif activation == "scaled_srelu": - scaled_act = te_ops.ScaledSReLU() - else: - raise ValueError(f"Unexpected grouped MLP activation ({activation})") - activation_is_glu = is_glu_activation(scaled_act) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) - if single_grouped_weight and quantization != "mxfp8": - pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): - pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") - if not activation_is_glu and glu_interleave_size is not None: - pytest.skip("Unary activations do not use GLU interleaving") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: - pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") - if quantization == "nvfp4_rht": - if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): - pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") - if activation not in ("scaled_swiglu", "scaled_srelu"): - pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") - if ( - with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") - and activation.startswith("scaled_clamped_qgeglu") - and bias - ): - # TODO: ksivaman: Need to debug numerics for this case. - pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") - fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size - # Activation parameters for clamped QGeGLU variants - if activation == "scaled_clamped_qgeglu_custom": - geglu_limit = 5.0 - geglu_alpha = 1.5 - geglu_offset = 0.5 - else: - geglu_limit = 7.0 - geglu_alpha = 1.702 - geglu_offset = 1.0 # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3669,29 +3575,6 @@ def test_grouped_mlp( fc2_ws_test.append(fc2_w_test) fc2_bs_test.append(fc2_b_test) - def _apply_activation(x: torch.Tensor) -> torch.Tensor: - if activation_is_glu and glu_interleave_size is not None: - x = x.reshape( - -1, - 2 * hidden_size // (2 * glu_interleave_size), - 2, - glu_interleave_size, - ) - x = x.transpose(1, 2) - x = x.reshape(-1, 2 * hidden_size) - if activation == "scaled_swiglu": - x1, x2 = x.chunk(2, dim=-1) - return torch.nn.functional.silu(x1) * x2 - if activation.startswith("scaled_clamped_qgeglu"): - x1, x2 = x.chunk(2, dim=-1) - lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) - x1c = torch.minimum(x1, lim) - x2c = torch.clamp(x2, -lim, lim) - return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) - if activation == "scaled_srelu": - return torch.nn.functional.relu(x).square() - raise ValueError(f"Unexpected grouped MLP activation ({activation})") - # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) probs = torch.split(probs_ref, split_sizes.tolist()) @@ -3701,7 +3584,8 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: fc1_out = torch.nn.functional.linear( x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] ) - fc2_in = _apply_activation(fc1_out) + x1, x2 = fc1_out.chunk(2, dim=-1) + fc2_in = torch.nn.functional.silu(x1) * x2 fc2_in = fc2_in * probs[group_idx].unsqueeze(-1) y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: @@ -3712,154 +3596,45 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: # Construct operations recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + fc1_out_features, + bias=bias, + device=device, + dtype=dtype, + ) + act = te_ops.ScaledSwiGLU() + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + scale_bias=bias, + ) + module = te_ops.Sequential(fc1, act, fc2) - def _make_scaled_act(): - if activation == "scaled_swiglu": - return te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_clamped_qgeglu_custom": - return te_ops.ScaledClampedQGeGLU( - glu_interleave_size=glu_interleave_size, - limit=geglu_limit, - alpha=geglu_alpha, - glu_linear_offset=geglu_offset, - ) - if activation.startswith("scaled_clamped_qgeglu"): - return te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_srelu": - return te_ops.ScaledSReLU() - raise ValueError(f"Unexpected grouped MLP activation ({activation})") - - def _make_module(): - with te.quantized_model_init(enabled=with_quantization, recipe=recipe): - fc1_op = te_ops.GroupedLinear( - group_size, - hidden_size, - fc1_out_features, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) - - fc2_op = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - scale_bias=bias, - ) - return te_ops.Sequential(fc1_op, _make_scaled_act(), fc2_op), fc1_op, fc2_op - - module, fc1, fc2 = _make_module() - - # Copy weights + # Initialize params with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() for group_idx in range(group_size): - if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) if bias: - if single_grouped_bias: - fc1_bparts = fc1.bias.split_into_quantized_tensors() - fc2_bparts = fc2.bias.split_into_quantized_tensors() - fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) - fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) - else: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) - if accumulate_into_main_grad: - # 0.5 sentinel lets us reconstruct ``expected = ref_grad + 0.5`` - # below and detect a missed accumulation. - main_grad_sentinel = 0.5 - if single_grouped_weight: - weight_params_for_main_grad = [fc1.weight, fc2.weight] - else: - weight_params_for_main_grad = [ - getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) - ] - MegatronTrainingHelper.init_main_grad_buffers( - weight_params_for_main_grad, - fill_value=main_grad_sentinel, - overwrite_main_grad=False, - ) + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test - # Fuse ops and perform forward and backward pass + # Forward and backward pass with te.autocast(enabled=with_quantization, recipe=recipe): fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) y_test = module(x_test, split_sizes, probs_test, *fc2_extra) y_test.backward(dy_test) - if delay_wgrad_compute: - fc1.backward_dw() - fc2.backward_dw() - - # Check for expected fusions - cudnn_frontend_supports_grouped_mlp = ( - _cudnn_frontend_supports_grouped_gemm_srelu() - if activation == "scaled_srelu" - else _cudnn_frontend_version_supported() - ) - expected_grouped_mlp_fusion = cudnn_frontend_supports_grouped_mlp and ( - ( - quantization == "mxfp8" - and dtype in (torch.bfloat16, torch.float16) - and ( - (not activation_is_glu and glu_interleave_size is None) - or (activation_is_glu and glu_interleave_size == 32) - ) - ) - or ( - quantization == "nvfp4_rht" - and dtype == torch.bfloat16 - and activation == "scaled_srelu" - and glu_interleave_size is None - ) - ) - if expected_grouped_mlp_fusion: - if activation_is_glu: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU - else: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary - if forward_cls.is_supported(): - forward_ops = module._module_groups[0]._forward_ops - assert len(forward_ops) == 1 - assert isinstance( - forward_ops[0][0], - forward_cls, - ) - if backward_cls is not None and backward_cls.is_supported(): - backward_ops = module._module_groups[0]._backward_ops - assert len(backward_ops) == 1 - assert isinstance( - backward_ops[0][0], - backward_cls, - ) # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): - tols = {"rtol": 0.25, "atol": 0.5} # Check values assert_close(y_test, y_ref, **tols) @@ -3867,55 +3642,18 @@ def _make_module(): assert_close_grads(probs_test, probs_ref, **tols) for group_idx in range(group_size): if bias: - if single_grouped_bias: - assert_close( - fc2.bias.grad[group_idx], - fc2_bs_ref[group_idx].grad, - **tols, - ) - assert_close( - fc1.bias.grad[group_idx], - fc1_bs_ref[group_idx].grad, - **tols, - ) - else: - assert_close_grads( - getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols - ) - assert_close_grads( - getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols - ) - if not single_grouped_weight and not accumulate_into_main_grad: assert_close_grads( - getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols + getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols ) assert_close_grads( - getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols + getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols ) - fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) - fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) - if accumulate_into_main_grad: - # main_grad should accumulate the ref wgrad onto the 0.5 sentinel. - # Per-param expected views must line up with - # ``weight_params_for_main_grad`` registered above. - fc1_expected = ( - [fc1_w_ref_grad + main_grad_sentinel] - if single_grouped_weight - else [g + main_grad_sentinel for g in fc1_w_ref_grad] + assert_close_grads( + getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols ) - fc2_expected = ( - [fc2_w_ref_grad + main_grad_sentinel] - if single_grouped_weight - else [g + main_grad_sentinel for g in fc2_w_ref_grad] - ) - MegatronTrainingHelper.verify_main_grad_accumulation( - weight_params_for_main_grad, - expected_main_grads=fc1_expected + fc2_expected, - **tols, + assert_close_grads( + getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols ) - elif single_grouped_weight: - assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) - assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) class TestCustomOps: From b2ce792f44c01fbd5db0edc98c904ea4a11e9842 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 8 Jun 2026 23:50:22 +0000 Subject: [PATCH 03/11] [PyTorch] Move test_grouped_gemm_quant_cute to test_grouped_linear The test validates that grouped_gemm_quant_wrapper_sm100 (a cuTE DSL kernel internal to the grouped MLP fusion) matches MXFP8 quantizer output. It does not exercise the op-fuser infrastructure at all, so it belongs in test_grouped_linear.py alongside the other grouped-MLP- specific tests. Also removes the now-unused `import transformer_engine_torch as tex` from test_fusible_ops.py. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 142 --------------------------- tests/pytorch/test_grouped_linear.py | 141 ++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 142 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d38a8e2325..8a5bc792ee 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -40,7 +40,6 @@ is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor -import transformer_engine_torch as tex # Import utility functions from utils import ( @@ -4302,144 +4301,3 @@ def test_linear_inference_loop( quantization=quantization, recipe=recipe, ) - - -def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: - if not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Requires SM100+ for grouped GEMM quant kernel.") - - try: - from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module - except ImportError as exc: - pytest.skip(f"grouped_gemm_quant_wrapper_sm100 unavailable: {exc}") - - device = torch.device("cuda") - dtype = torch.bfloat16 if is_bf16_available() else torch.float16 - num_groups = 4 - m = 256 - n = 512 - k = 512 - total_m = num_groups * m - split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) - - q = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3, rowwise=True, columnwise=False) - q.optimize_for_gemm = False - - torch.manual_seed(0) - a_full = torch.randn(total_m, k, device=device, dtype=dtype) - weights = [torch.randn(n, k, device=device, dtype=dtype) for _ in range(num_groups)] - - grouped_a = tex.group_quantize(a_full, q, num_groups, split_sizes) - a_groups = grouped_a.split_into_quantized_tensors() - b_groups = [q(w) for w in weights] - - # Reference GEMM on dequantized tensors. - ref = torch.empty((total_m, n), device=device, dtype=torch.float32) - start = 0 - for group_idx in range(num_groups): - end = start + m - a_deq = a_groups[group_idx].dequantize(dtype=torch.float32) - b_deq = b_groups[group_idx].dequantize(dtype=torch.float32) - ref[start:end, :] = a_deq @ b_deq.t() - start = end - ref = ref.to(dtype=torch.bfloat16).to(torch.float32) - - # Allocate empty input tensors needed for cuTE DSL kernel - padded_offsets = torch.tensor( - [m * (i + 1) for i in range(num_groups)], - dtype=torch.int32, - device=device, - ) - inputs = { - "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), - "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), - "sfa_tensor": torch.empty( - 1, - total_m // 128, - k // 128, - 32, - 4, - 4, - dtype=torch.float8_e8m0fnu, - device=device, - ).permute(3, 4, 1, 5, 2, 0), - "sfb_tensor": torch.empty( - num_groups, - n // 128, - k // 128, - 32, - 4, - 4, - dtype=torch.float8_e8m0fnu, - device=device, - ).permute(3, 4, 1, 5, 2, 0), - "alpha_tensor": torch.empty(num_groups, dtype=torch.float32, device=device), - "prob_tensor": torch.empty(total_m, 1, 1, dtype=torch.float32, device=device), - "padded_offsets_tensor": padded_offsets, - } - # Overwrite inputs with quantized data/scales from MXFP8 quantizer. - a_data = grouped_a.rowwise_data.view(total_m, k).view(dtype=torch.float8_e4m3fn) - a_data = a_data.unsqueeze(0).permute(1, 2, 0).contiguous() - inputs["a_tensor"].copy_(a_data) - - a_scales = grouped_a.scale_inv.view(dtype=torch.float8_e8m0fnu) - a_scales = a_scales.view(1, total_m // 128, 4, 32, k // 128, 4) - a_scales = a_scales.permute(0, 1, 4, 3, 2, 5).contiguous() - a_scales = a_scales.permute(3, 4, 1, 5, 2, 0).contiguous() - inputs["sfa_tensor"].copy_(a_scales) - - b_data = torch.cat([w._rowwise_data.reshape(-1) for w in b_groups]) - b_data = b_data.view(dtype=torch.float8_e4m3fn) - b_data = b_data.view(num_groups, n, k).permute(1, 2, 0).contiguous() - inputs["b_tensor"].copy_(b_data) - - b_scales = torch.cat([w._rowwise_scale_inv for w in b_groups]) - b_scales = b_scales.view(dtype=torch.float8_e8m0fnu) - b_scales = b_scales.view(num_groups, n // 128, 4, 32, k // 128, 4) - b_scales = b_scales.permute(0, 1, 4, 3, 2, 5).contiguous() - b_scales = b_scales.permute(3, 4, 1, 5, 2, 0).contiguous() - inputs["sfb_tensor"].copy_(b_scales) - - inputs["alpha_tensor"].fill_(1.0) - inputs["prob_tensor"].fill_(1.0) - - cute_out = grouped_gemm_quant_wrapper_sm100( - a_tensor=inputs["a_tensor"], - b_tensor=inputs["b_tensor"], - sfa_tensor=inputs["sfa_tensor"], - sfb_tensor=inputs["sfb_tensor"], - padded_offsets=inputs["padded_offsets_tensor"], - alpha_tensor=inputs["alpha_tensor"], - norm_const_tensor=None, - prob_tensor=inputs["prob_tensor"], - acc_dtype=torch.float32, - d_dtype=torch.bfloat16, - cd_major="n", - sf_vec_size=32, - discrete_col_sfd=True, - current_stream=None, - ) - - if isinstance(cute_out, dict): - outputs = cute_out - else: - d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor = cute_out - outputs = { - "d_tensor": d_tensor, - "d_col_tensor": d_col_tensor, - "amax_tensor": amax_tensor, - "sfd_row_tensor": sfd_row_tensor, - "sfd_col_tensor": sfd_col_tensor, - } - - d_cute = outputs["d_tensor"] - if d_cute.dim() == 3: - d_cute = d_cute.squeeze(-1) - tols = dtype_tols(torch.bfloat16) - assert_close(d_cute[:total_m].float(), ref, **tols) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index e2d107bc96..906203b54f 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -2683,3 +2683,144 @@ def train_step( else: for graph_grad, param in zip(graph_param_grads, module.parameters()): assert_close(graph_grad, param.grad, **tols) + + +def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: + if not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Requires SM100+ for grouped GEMM quant kernel.") + + try: + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + except ImportError as exc: + pytest.skip(f"grouped_gemm_quant_wrapper_sm100 unavailable: {exc}") + + device = torch.device("cuda") + dtype = torch.bfloat16 if is_bf16_available() else torch.float16 + num_groups = 4 + m = 256 + n = 512 + k = 512 + total_m = num_groups * m + split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) + + q = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3, rowwise=True, columnwise=False) + q.optimize_for_gemm = False + + torch.manual_seed(0) + a_full = torch.randn(total_m, k, device=device, dtype=dtype) + weights = [torch.randn(n, k, device=device, dtype=dtype) for _ in range(num_groups)] + + grouped_a = tex.group_quantize(a_full, q, num_groups, split_sizes) + a_groups = grouped_a.split_into_quantized_tensors() + b_groups = [q(w) for w in weights] + + # Reference GEMM on dequantized tensors. + ref = torch.empty((total_m, n), device=device, dtype=torch.float32) + start = 0 + for group_idx in range(num_groups): + end = start + m + a_deq = a_groups[group_idx].dequantize(dtype=torch.float32) + b_deq = b_groups[group_idx].dequantize(dtype=torch.float32) + ref[start:end, :] = a_deq @ b_deq.t() + start = end + ref = ref.to(dtype=torch.bfloat16).to(torch.float32) + + # Allocate empty input tensors needed for cuTE DSL kernel + padded_offsets = torch.tensor( + [m * (i + 1) for i in range(num_groups)], + dtype=torch.int32, + device=device, + ) + inputs = { + "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "sfa_tensor": torch.empty( + 1, + total_m // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "sfb_tensor": torch.empty( + num_groups, + n // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "alpha_tensor": torch.empty(num_groups, dtype=torch.float32, device=device), + "prob_tensor": torch.empty(total_m, 1, 1, dtype=torch.float32, device=device), + "padded_offsets_tensor": padded_offsets, + } + # Overwrite inputs with quantized data/scales from MXFP8 quantizer. + a_data = grouped_a.rowwise_data.view(total_m, k).view(dtype=torch.float8_e4m3fn) + a_data = a_data.unsqueeze(0).permute(1, 2, 0).contiguous() + inputs["a_tensor"].copy_(a_data) + + a_scales = grouped_a.scale_inv.view(dtype=torch.float8_e8m0fnu) + a_scales = a_scales.view(1, total_m // 128, 4, 32, k // 128, 4) + a_scales = a_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + a_scales = a_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfa_tensor"].copy_(a_scales) + + b_data = torch.cat([w._rowwise_data.reshape(-1) for w in b_groups]) + b_data = b_data.view(dtype=torch.float8_e4m3fn) + b_data = b_data.view(num_groups, n, k).permute(1, 2, 0).contiguous() + inputs["b_tensor"].copy_(b_data) + + b_scales = torch.cat([w._rowwise_scale_inv for w in b_groups]) + b_scales = b_scales.view(dtype=torch.float8_e8m0fnu) + b_scales = b_scales.view(num_groups, n // 128, 4, 32, k // 128, 4) + b_scales = b_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + b_scales = b_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfb_tensor"].copy_(b_scales) + + inputs["alpha_tensor"].fill_(1.0) + inputs["prob_tensor"].fill_(1.0) + + cute_out = grouped_gemm_quant_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=None, + prob_tensor=inputs["prob_tensor"], + acc_dtype=torch.float32, + d_dtype=torch.bfloat16, + cd_major="n", + sf_vec_size=32, + discrete_col_sfd=True, + current_stream=None, + ) + + if isinstance(cute_out, dict): + outputs = cute_out + else: + d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor = cute_out + outputs = { + "d_tensor": d_tensor, + "d_col_tensor": d_col_tensor, + "amax_tensor": amax_tensor, + "sfd_row_tensor": sfd_row_tensor, + "sfd_col_tensor": sfd_col_tensor, + } + + d_cute = outputs["d_tensor"] + if d_cute.dim() == 3: + d_cute = d_cute.squeeze(-1) + tols = dtype_tols(torch.bfloat16) + assert_close(d_cute[:total_m].float(), ref, **tols) From bbfde9ce92e7ab697cd59dfbe6a5f4cc472a8b83 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 9 Jun 2026 01:21:19 +0000 Subject: [PATCH 04/11] [PyTorch] Port te.ops.GroupedLinear test to test_grouped_linear MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_grouped_linear.py only tested te.GroupedLinear (the high-level module API). The test_grouped_linear test in test_fusible_ops.py covers te.ops.GroupedLinear (the fuser/ops API) — a different thing. This commit brings that coverage into test_grouped_linear.py. Additions: - maybe_skip_quantization: skip helper for hardware/dim/dtype checks - make_reference_and_test_tensors: paired float64/CPU reference and target-dtype/CUDA test tensor construction (with quantization) - _ops_quantization_list: quantization parameter list for ops tests - test_ops_grouped_linear: full port of test_grouped_linear from test_fusible_ops.py, including the three axes that were stripped during cleanup (delay_wgrad_compute, single_grouped_weight, single_grouped_bias); uses te.ops.GroupedLinear throughout Also adds Float8CurrentScalingQuantizer, QuantizedTensor, QuantizerRole to imports, and import math. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- tests/pytorch/test_grouped_linear.py | 652 +++++++++++++++++++++++++++ 1 file changed, 652 insertions(+) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 906203b54f..72800f17e2 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import math import os import random from typing import Dict, List, Optional, Sequence @@ -14,6 +15,7 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch import ( + Float8CurrentScalingQuantizer, Float8Quantizer, Fp8Padding, Fp8Unpadding, @@ -21,6 +23,8 @@ Linear, MXFP8Quantizer, NVFP4Quantizer, + QuantizedTensor, + QuantizerRole, autocast, is_bf16_available, quantized_model_init, @@ -36,10 +40,16 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import transformer_engine_torch as tex +from transformer_engine.pytorch.ops._common import ( + _cudnn_frontend_supports_grouped_gemm_srelu, + _cudnn_frontend_version_supported, + is_glu_activation, +) from utils import ( MegatronTrainingHelper, ModelConfig, assert_close, + assert_close_grads, make_recipe, quantization_tols, recipe_id, @@ -166,6 +176,22 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: if torch.cuda.get_device_capability() == (9, 0): use_cutlass_grouped_gemm.append(True) +_grouped_mlp_quantization_list: list = [None] +if fp8_available: + _grouped_mlp_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _grouped_mlp_quantization_list.append("mxfp8") +if nvfp4_available: + _grouped_mlp_quantization_list.extend(("nvfp4", "nvfp4_4over6", "nvfp4_rht")) + +_ops_quantization_list: list = [None] +if fp8_available: + _ops_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) +if mxfp8_available: + _ops_quantization_list.append("mxfp8") +if nvfp4_available: + _ops_quantization_list.extend(("nvfp4", "nvfp4_4over6")) + class TorchGroupedLinearWithPadding(nn.Module): @@ -269,6 +295,123 @@ def _test_grouped_linear_accuracy( return outputs +def maybe_skip_quantization( + quantization: Optional[str], + *, + dims=None, + device=None, + dtype: Optional[torch.dtype] = None, +) -> None: + """Skip test case if a quantization scheme is not supported on this hardware/config.""" + if quantization is None: + return + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + if dims is not None: + if not hasattr(dims, "__iter__"): + dims = (dims,) + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("FP8 GEMMs require dims divisible by 16") + elif quantization == "mxfp8": + if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: + pytest.skip("MXFP8 GEMMs require dims divisible by 32") + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims divisible by 16") + if dtype is not None: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and dtype != torch.bfloat16: + pytest.skip("NVFP4 quantization is only supported with BF16 data") + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape, + *, + min: float = 0.0, + max: float = 1.0, + quantization: Optional[str] = None, + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_quantized: bool = False, + quantizer_role: Optional[QuantizerRole] = None, + requires_grad: bool = True, +): + """Paired tensors: float64/CPU reference for PyTorch ops, target-dtype/CUDA for TE ops.""" + ref = torch.empty(shape, dtype=ref_dtype, device=ref_device) + ref.uniform_(min, max) + test = ref.to(device=test_device, dtype=test_dtype) + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=te.DType.kFloat8E4M3, + ) + test = quantizer(test) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=te.DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3)(test) + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" + test = NVFP4Quantizer( + with_rht=with_rht, + with_post_rht_amax=with_rht, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) + elif quantization == "nvfp4_4over6": + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + nvfp4_use_4over6 = False + with_2d_quantization = False + nvfp4_e4m3_max = 448 + if tensor_type not in ("grad_output", "grad_input"): + nvfp4_use_4over6 = True + nvfp4_e4m3_max = 256 + if tensor_type == "weight": + with_2d_quantization = True + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + stochastic_rounding=False, + with_random_sign_mask=False, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + )(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + ref.copy_(test.to(dtype=ref.dtype)) + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @@ -2102,6 +2245,515 @@ def train_step( assert_close(g, param.grad, **tols) +@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) +@pytest.mark.parametrize("single_grouped_weight", (False, True)) +@pytest.mark.parametrize("single_grouped_bias", (False, True)) +@pytest.mark.parametrize("bias", (False, True)) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("quantization", _ops_quantization_list) +@pytest.mark.parametrize("quantized_compute", (False, True)) +@pytest.mark.parametrize("quantized_weight", (False, True)) +@pytest.mark.parametrize("input_requires_grad", (False, True)) +@pytest.mark.parametrize("weight_requires_grad", (False, True)) +def test_ops_grouped_linear( + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, + delay_wgrad_compute: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, +) -> None: + """te.ops.GroupedLinear forward+backward accuracy""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if ( + single_grouped_weight + and quantized_weight + and quantization in ("fp8_delayed_scaling", "fp8_current_scaling") + ): + pytest.skip( + "single_grouped_weight does not support FP8 delayed/current scaling " + "with quantized_model_init" + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch reference implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct te.ops.GroupedLinear + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te.ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) + with torch.no_grad(): + if single_grouped_weight: + op_weights = op.weight.quantized_tensors + if op_weights is None: + op_weights = op.weight.split_into_quantized_tensors() + if single_grouped_bias: + op_bias_parts = op.bias.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + op_weights[group_idx].copy_(ws_test[group_idx]) + else: + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + if single_grouped_bias: + op_bias_parts[group_idx].reshape(-1).copy_(bs_test[group_idx]) + else: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + if delay_wgrad_compute and weight_requires_grad: + op.backward_dw() + + # Expected numerical tolerances + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + if single_grouped_weight: + if weight_requires_grad: + w_ref_grad = torch.stack([w.grad for w in ws_ref], dim=0) + assert_close(op.weight.grad, w_ref_grad, **tols) + else: + assert op.weight.grad is None + else: + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + assert_close_grads(w_test, ws_ref[group_idx], **tols) + if bias: + if single_grouped_bias: + if weight_requires_grad: + b_ref_grad = torch.stack([b.grad for b in bs_ref], dim=0) + assert_close(op.bias.grad, b_ref_grad, **tols) + else: + assert op.bias.grad is None + else: + for group_idx in range(group_size): + b_test = getattr(op, f"bias{group_idx}") + assert_close_grads(b_test, bs_ref[group_idx], **tols) + + +@pytest.mark.parametrize("bias", (False, True)) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) +@pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) +@pytest.mark.parametrize("glu_interleave_size", (None, 32)) +@pytest.mark.parametrize("single_grouped_weight", (False, True)) +@pytest.mark.parametrize("single_grouped_bias", (False, True)) +@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) +@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) +@pytest.mark.parametrize( + "activation", + ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom", "scaled_srelu"), +) +def test_grouped_mlp( + *, + group_size: int = 4, + hidden_size: int = 128, + bias: bool, + dtype: torch.dtype, + quantization: Optional[str], + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + split_alignment: int = 256, + glu_interleave_size: Optional[int], + delay_wgrad_compute: bool, + activation: str, +) -> None: + """GroupedLinear + scaled activation + GroupedLinear""" + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("BF16 not available") + + # Build activation op to determine GLU vs unary + if activation == "scaled_swiglu": + scaled_act_ref = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation.startswith("scaled_clamped_qgeglu"): + scaled_act_ref = te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_srelu": + scaled_act_ref = te.ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected activation ({activation})") + activation_is_glu = is_glu_activation(scaled_act_ref) + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device, dtype=dtype) + if single_grouped_weight and quantization != "mxfp8": + pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): + pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") + if not activation_is_glu and glu_interleave_size is not None: + pytest.skip("Unary activations do not use GLU interleaving") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: + pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") + if quantization == "nvfp4_rht": + if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): + pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") + if activation not in ("scaled_swiglu", "scaled_srelu"): + pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and activation.startswith("scaled_clamped_qgeglu") + and bias + ): + # TODO: ksivaman: Need to debug numerics for this case. + pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + + fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size + if activation == "scaled_clamped_qgeglu_custom": + geglu_limit, geglu_alpha, geglu_offset = 5.0, 1.5, 0.5 + else: + geglu_limit, geglu_alpha, geglu_offset = 7.0, 1.702, 1.0 + + # Split sizes (one group intentionally empty to test the zero-token case) + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Random data: float32 for reference path, target dtype for test path + x_base = torch.empty(in_shape, device=device).uniform_(-0.25, 0.25) + x_ref = x_base.clone().requires_grad_() + x_test = x_base.to(dtype).requires_grad_() + + dy_base = torch.empty(out_shape, device=device).uniform_(-0.25, 0.25) + dy_ref = dy_base.clone() + dy_test = dy_base.to(dtype) + + probs_base = torch.empty((in_shape[0],), device=device).uniform_(0.1, 1.0) + probs_ref = probs_base.clone().requires_grad_() + probs_test = probs_base.to(dtype).requires_grad_() + + fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] + for _ in range(group_size): + w1 = torch.empty((fc1_out_features, hidden_size), device=device).uniform_(-0.25, 0.25) + fc1_ws_ref.append(w1.clone().requires_grad_()) + fc1_ws_test.append(w1.to(dtype)) + w2 = torch.empty((hidden_size, hidden_size), device=device).uniform_(-0.25, 0.25) + fc2_ws_ref.append(w2.clone().requires_grad_()) + fc2_ws_test.append(w2.to(dtype)) + if bias: + b1 = torch.empty((fc1_out_features,), device=device).uniform_(-0.5, 0.5) + fc1_bs_ref.append(b1.clone().requires_grad_()) + fc1_bs_test.append(b1.to(dtype)) + b2 = torch.empty((hidden_size,), device=device).uniform_(-0.5, 0.5) + fc2_bs_ref.append(b2.clone().requires_grad_()) + fc2_bs_test.append(b2.to(dtype)) + else: + fc1_bs_ref.append(None) + fc1_bs_test.append(None) + fc2_bs_ref.append(None) + fc2_bs_test.append(None) + + def _apply_activation(x: torch.Tensor) -> torch.Tensor: + if activation_is_glu and glu_interleave_size is not None: + x = x.reshape(-1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size) + x = x.transpose(1, 2).reshape(-1, 2 * hidden_size) + if activation == "scaled_swiglu": + x1, x2 = x.chunk(2, dim=-1) + return torch.nn.functional.silu(x1) * x2 + if activation.startswith("scaled_clamped_qgeglu"): + x1, x2 = x.chunk(2, dim=-1) + lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) + x1c = torch.minimum(x1, lim) + x2c = torch.clamp(x2, -lim, lim) + return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + if activation == "scaled_srelu": + return torch.nn.functional.relu(x).square() + raise ValueError(f"Unexpected activation ({activation})") + + # Reference implementation (float32 PyTorch) + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + fc1_out = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + fc2_in = _apply_activation(fc1_out) * probs[group_idx].unsqueeze(-1) + y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) + if bias: + y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) + ys.append(y) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct TE module + recipe = make_recipe(quantization) + + def _make_scaled_act(): + if activation == "scaled_swiglu": + return te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_clamped_qgeglu_custom": + return te.ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=geglu_limit, + alpha=geglu_alpha, + glu_linear_offset=geglu_offset, + ) + if activation.startswith("scaled_clamped_qgeglu"): + return te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_srelu": + return te.ops.ScaledSReLU() + raise ValueError(f"Unexpected activation ({activation})") + + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1 = te.ops.GroupedLinear( + group_size, hidden_size, fc1_out_features, + bias=bias, device=device, dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + fc2 = te.ops.GroupedLinear( + group_size, hidden_size, hidden_size, + bias=bias, device=device, dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + scale_bias=bias, + ) + module = te.ops.Sequential(fc1, _make_scaled_act(), fc2) + + # Copy weights + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + if single_grouped_bias: + fc1_bparts = fc1.bias.split_into_quantized_tensors() + fc2_bparts = fc2.bias.split_into_quantized_tensors() + fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) + fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + else: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + if accumulate_into_main_grad: + main_grad_sentinel = 0.5 + if single_grouped_weight: + weight_params_for_main_grad = [fc1.weight, fc2.weight] + else: + weight_params_for_main_grad = [ + getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) + ] + MegatronTrainingHelper.init_main_grad_buffers( + weight_params_for_main_grad, + fill_value=main_grad_sentinel, + overwrite_main_grad=False, + ) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test + + # Forward and backward pass + with te.autocast(enabled=with_quantization, recipe=recipe): + fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) + y_test = module(x_test, split_sizes, probs_test, *fc2_extra) + y_test.backward(dy_test) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() + + # Check for expected fusions + cudnn_frontend_supports_grouped_mlp = ( + _cudnn_frontend_supports_grouped_gemm_srelu() + if activation == "scaled_srelu" + else _cudnn_frontend_version_supported() + ) + expected_grouped_mlp_fusion = cudnn_frontend_supports_grouped_mlp and ( + ( + quantization == "mxfp8" + and dtype in (torch.bfloat16, torch.float16) + and ( + (not activation_is_glu and glu_interleave_size is None) + or (activation_is_glu and glu_interleave_size == 32) + ) + ) + or ( + quantization == "nvfp4_rht" + and dtype == torch.bfloat16 + and activation == "scaled_srelu" + and glu_interleave_size is None + ) + ) + if expected_grouped_mlp_fusion: + if activation_is_glu: + forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU + backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU + else: + forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMUnary + backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary + if forward_cls.is_supported(): + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], forward_cls) + if backward_cls is not None and backward_cls.is_supported(): + backward_ops = module._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], backward_cls) + + # Loose tols for sanity checking + tols = {"rtol": 0.125, "atol": 0.25} + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): + tols = {"rtol": 0.25, "atol": 0.5} + + # Check values + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(probs_test, probs_ref, **tols) + for group_idx in range(group_size): + if bias: + if single_grouped_bias: + assert_close(fc2.bias.grad[group_idx], fc2_bs_ref[group_idx].grad, **tols) + assert_close(fc1.bias.grad[group_idx], fc1_bs_ref[group_idx].grad, **tols) + else: + assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + if not single_grouped_weight and not accumulate_into_main_grad: + assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) + fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) + fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) + if accumulate_into_main_grad: + fc1_expected = ( + [fc1_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc1_w_ref_grad] + ) + fc2_expected = ( + [fc2_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc2_w_ref_grad] + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + weight_params_for_main_grad, + expected_main_grads=fc1_expected + fc2_expected, + **tols, + ) + elif single_grouped_weight: + assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) + assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) + + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) From 28da34080582fcf3577a9a228dc599d99b617719 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 9 Jun 2026 21:07:32 +0000 Subject: [PATCH 05/11] [PyTorch] Rationalize test_grouped_linear.py structure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Structural changes: - Organize into three labeled sections: te.GroupedLinear (module API), raw GEMM kernels (cpp_extensions), te.ops.GroupedLinear (ops/fuser API) - Move fused-path helpers and tests (_reset_fp8_state, _run_grouped_linear_path, test_grouped_linear_grouped_tensor_path_*, test_grouped_linear_fused_path_*) before the GEMM section so all te.GroupedLinear tests are contiguous Fold redundant tests: - test_grouped_linear_accuracy_single_gemm → num_gemms=[1,3,6] - test_padding_grouped_linear_accuracy_save_original_input → save_original_input parametrize axis on test_padding_grouped_linear_accuracy - test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad → single_grouped_bias parametrize axis on test_grouped_linear_grouped_tensor_path_matches_legacy Fix test_grouped_mlp reference precision and tolerances: - Replace float32 reference tensors with make_reference_and_test_tensors (float64 CPU reference, target dtype CUDA test) - Replace loose rtol=0.125/0.25 with dtype_tols()/quantization_tols() Promote shared helpers to utils.py: - maybe_skip_quantization: was duplicated in test_fusible_ops.py and test_grouped_linear.py; now in utils.py - make_reference_and_test_tensors: same - dtype_tols: was duplicated locally in test_grouped_linear.py; now imported from utils.py (where it already existed) Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 152 +--- tests/pytorch/test_grouped_linear.py | 1235 +++++++++++--------------- tests/pytorch/utils.py | 166 ++++ 3 files changed, 691 insertions(+), 862 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 8a5bc792ee..9134749fb4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -32,7 +32,6 @@ ) from transformer_engine.pytorch import ( QuantizedTensor, - Float8CurrentScalingQuantizer, Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer, @@ -47,6 +46,8 @@ assert_close_grads, dtype_tols, make_recipe, + make_reference_and_test_tensors, + maybe_skip_quantization, quantization_tols, reset_rng_states, ) @@ -81,155 +82,6 @@ def _reset_rng_states_per_test(): yield -def maybe_skip_quantization( - quantization: Optional[str], - *, - dims: Optional[Iterable[int] | int] = None, - device: Optional[torch.device | str] = None, - dtype: Optional[torch.dtype] = None, -) -> None: - """Skip test case if a quantization scheme is not supported""" - - # Don't skip if there is no quantization - if quantization is None: - return - - # Check if quantization scheme is supported on device - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") - if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: - pytest.skip(reason_for_no_fp8) - if quantization == "mxfp8" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if ( - quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") - and not nvfp4_available - ): - pytest.skip(reason_for_no_nvfp4) - - # Check dims - if dims is not None: - if not isinstance(dims, Iterable): - dims = (dims,) - if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): - if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - elif quantization == "mxfp8": - if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: - pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): - if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: - pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") - - # Check dtype - if dtype is not None: - if ( - quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") - and dtype != torch.bfloat16 - ): - pytest.skip("NVFP4 quantization is only supported with BF16 data") - - -@torch.no_grad() -def make_reference_and_test_tensors( - shape: int | Iterable[int], - *, - min: float = 0.0, - max: float = 1.0, - quantization: Optional[str] = None, - ref_dtype: torch.dtype = torch.float64, - ref_device: torch.device = "cpu", - test_dtype: torch.dtype = torch.float32, - test_device: torch.device = "cuda", - test_is_quantized: bool = False, - quantizer_role: Optional[QuantizerRole] = None, - requires_grad: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: - """Construct tensors with the same values - - The reference tensor is intended for use in plain PyTorch - operations in high precision. The test tensor is intended for use - in Transformer Engine operations. - - If a quantization scheme is provided, the tensor values are - quantized so that they are representable. - - """ - - # Random reference tensor - ref = torch.empty(shape, dtype=ref_dtype, device=ref_device) - ref.uniform_(min, max) - - # Construct test tensor from reference tensor - test = ref.to(device=test_device, dtype=test_dtype) - if quantization is None: - if test_is_quantized: - raise ValueError("Quantization scheme not provided") - if test.data_ptr() == ref.data_ptr(): - test = test.clone() - elif quantization in ("fp8", "fp8_delayed_scaling"): - quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), - amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=te.DType.kFloat8E4M3, - ) - test = quantizer(test) - elif quantization == "fp8_current_scaling": - quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=te.DType.kFloat8E4M3, - device=test_device, - ) - test = quantizer(test) - elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3)(test) - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): - tensor_type = "input" - if quantizer_role is not None: - tensor_type = quantizer_role.tensor_type - with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" - test = NVFP4Quantizer( - with_rht=with_rht, - with_post_rht_amax=with_rht, - with_2d_quantization=False, - stochastic_rounding=False, - with_random_sign_mask=False, - )(test) - elif quantization == "nvfp4_4over6": - tensor_type = "input" - if quantizer_role is not None: - tensor_type = quantizer_role.tensor_type - - nvfp4_use_4over6 = False - with_2d_quantization = False - nvfp4_e4m3_max = 448 - if tensor_type not in ("grad_output", "grad_input"): - nvfp4_use_4over6 = True - nvfp4_e4m3_max = 256 - if tensor_type == "weight": - with_2d_quantization = True - - test = NVFP4Quantizer( - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - stochastic_rounding=False, - with_random_sign_mask=False, - nvfp4_use_4over6=nvfp4_use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - )(test) - else: - raise ValueError(f"Unsupported quantization scheme ({quantization})") - if isinstance(test, QuantizedTensor) and not test_is_quantized: - test = test.dequantize() - - # Make sure reference and test tensors match each other - ref.copy_(test.to(dtype=ref.dtype)) - - ref.requires_grad_(requires_grad) - test.requires_grad_(requires_grad) - return ref, test - - def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: """Convert to an FP64 CPU tensor""" if tensor is None: diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 72800f17e2..2cb0dcc420 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. -import math import os import random from typing import Dict, List, Optional, Sequence @@ -15,7 +14,6 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch import ( - Float8CurrentScalingQuantizer, Float8Quantizer, Fp8Padding, Fp8Unpadding, @@ -23,7 +21,6 @@ Linear, MXFP8Quantizer, NVFP4Quantizer, - QuantizedTensor, QuantizerRole, autocast, is_bf16_available, @@ -50,7 +47,10 @@ ModelConfig, assert_close, assert_close_grads, + dtype_tols, make_recipe, + make_reference_and_test_tensors, + maybe_skip_quantization, quantization_tols, recipe_id, reset_rng_states, @@ -142,16 +142,6 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> return supported_input_dtypes -def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: - if dtype == torch.float32: - return dict(rtol=1.3e-6, atol=1e-5) - if dtype == torch.float16: - return dict(rtol=1e-3, atol=1e-5) - if dtype == torch.bfloat16: - return dict(rtol=1.6e-2, atol=1e-5) - raise ValueError(f"Unsupported dtype ({dtype})") - - param_types = [torch.float32, torch.float16] if is_bf16_available(): param_types.append(torch.bfloat16) @@ -295,125 +285,16 @@ def _test_grouped_linear_accuracy( return outputs -def maybe_skip_quantization( - quantization: Optional[str], - *, - dims=None, - device=None, - dtype: Optional[torch.dtype] = None, -) -> None: - """Skip test case if a quantization scheme is not supported on this hardware/config.""" - if quantization is None: - return - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") - if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: - pytest.skip(reason_for_no_fp8) - if quantization == "mxfp8" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and not nvfp4_available: - pytest.skip(reason_for_no_nvfp4) - if dims is not None: - if not hasattr(dims, "__iter__"): - dims = (dims,) - if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): - if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: - pytest.skip("FP8 GEMMs require dims divisible by 16") - elif quantization == "mxfp8": - if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: - pytest.skip("MXFP8 GEMMs require dims divisible by 32") - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): - if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: - pytest.skip("NVFP4 GEMMs require dims divisible by 16") - if dtype is not None: - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and dtype != torch.bfloat16: - pytest.skip("NVFP4 quantization is only supported with BF16 data") - - -@torch.no_grad() -def make_reference_and_test_tensors( - shape, - *, - min: float = 0.0, - max: float = 1.0, - quantization: Optional[str] = None, - ref_dtype: torch.dtype = torch.float64, - ref_device: torch.device = "cpu", - test_dtype: torch.dtype = torch.float32, - test_device: torch.device = "cuda", - test_is_quantized: bool = False, - quantizer_role: Optional[QuantizerRole] = None, - requires_grad: bool = True, -): - """Paired tensors: float64/CPU reference for PyTorch ops, target-dtype/CUDA for TE ops.""" - ref = torch.empty(shape, dtype=ref_dtype, device=ref_device) - ref.uniform_(min, max) - test = ref.to(device=test_device, dtype=test_dtype) - if quantization is None: - if test_is_quantized: - raise ValueError("Quantization scheme not provided") - if test.data_ptr() == ref.data_ptr(): - test = test.clone() - elif quantization in ("fp8", "fp8_delayed_scaling"): - quantizer = Float8Quantizer( - scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), - amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=te.DType.kFloat8E4M3, - ) - test = quantizer(test) - elif quantization == "fp8_current_scaling": - quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=te.DType.kFloat8E4M3, - device=test_device, - ) - test = quantizer(test) - elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3)(test) - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): - tensor_type = "input" - if quantizer_role is not None: - tensor_type = quantizer_role.tensor_type - with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" - test = NVFP4Quantizer( - with_rht=with_rht, - with_post_rht_amax=with_rht, - with_2d_quantization=False, - stochastic_rounding=False, - with_random_sign_mask=False, - )(test) - elif quantization == "nvfp4_4over6": - tensor_type = "input" - if quantizer_role is not None: - tensor_type = quantizer_role.tensor_type - nvfp4_use_4over6 = False - with_2d_quantization = False - nvfp4_e4m3_max = 448 - if tensor_type not in ("grad_output", "grad_input"): - nvfp4_use_4over6 = True - nvfp4_e4m3_max = 256 - if tensor_type == "weight": - with_2d_quantization = True - test = NVFP4Quantizer( - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=with_2d_quantization, - stochastic_rounding=False, - with_random_sign_mask=False, - nvfp4_use_4over6=nvfp4_use_4over6, - nvfp4_e4m3_max=nvfp4_e4m3_max, - )(test) - else: - raise ValueError(f"Unsupported quantization scheme ({quantization})") - if isinstance(test, QuantizedTensor) and not test_is_quantized: - test = test.dequantize() - ref.copy_(test.to(dtype=ref.dtype)) - ref.requires_grad_(requires_grad) - test.requires_grad_(requires_grad) - return ref, test +# ============================================================================= +# te.GroupedLinear (module API) +# +# Tests the high-level GroupedLinear module. Reference: sequential te.Linear +# modules with shared weights — bitwise match verifies grouping correctness. +# ============================================================================= @pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("num_gemms", [1, 3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @@ -667,22 +548,6 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) -def test_grouped_linear_accuracy_single_gemm(recipe): - """Split the tests to save CI time""" - test_grouped_linear_accuracy( - dtype=torch.float32, - num_gemms=1, - bs=2, - model="126m", - recipe=recipe, - fp8_model_params=True, - fuse_wgrad_accumulation=True, - bias=True, - delay_wgrad_compute=False, - ) - - def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): @@ -776,6 +641,7 @@ def _generate_random_numbers(n, total_sum): return outputs +@pytest.mark.parametrize("save_original_input", [False, True]) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @@ -791,10 +657,13 @@ def test_padding_grouped_linear_accuracy( fp8, recipe, fp8_model_params, + save_original_input, parallel_mode=None, ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if save_original_input and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -829,7 +698,7 @@ def test_padding_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", - save_original_input=False, + save_original_input=save_original_input, ).eval() # Share params @@ -854,152 +723,358 @@ def test_padding_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("num_gemms", [3]) -@pytest.mark.parametrize("bs", [1]) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) -@pytest.mark.parametrize("fp8_model_params", [False]) -def test_padding_grouped_linear_accuracy_save_original_input( - dtype, - num_gemms, - bs, - model, - fp8, - recipe, - fp8_model_params, - parallel_mode=None, -): - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.delayed(): - pytest.skip("DelayedScaling recipe is not supported with save_original_input") - skip_unsupported_backward_override( - "grouped_linear", recipe, getattr(recipe, "backward_override", None) - ) +_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" +_ALL_BOOLEAN = all_boolean +_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8 - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) +@pytest.fixture(autouse=True) +def _reset_fp8_state(monkeypatch): + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") + yield + FP8GlobalStateManager.reset() + monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = TorchGroupedLinearWithPadding( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - fp8=fp8, - ).eval() - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - ref_grouped_linear = GroupedLinear( +def _clone_outputs(outputs): + return [None if out is None else out.detach().clone() for out in outputs] + + +def _run_grouped_linear_path( + *, + enable_grouped_tensor_path: bool, + fp8_recipe, + bias: bool, + fp8_model_params: bool, + delay_wgrad_compute: bool, + single_grouped_bias: bool = False, + x_base: torch.Tensor, + dy: torch.Tensor, + weights, + biases, + m_splits, + monkeypatch, +): + FP8GlobalStateManager.reset() + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1" if enable_grouped_tensor_path else "0") + + dtype = x_base.dtype + num_gemms = len(m_splits) + in_features = weights[0].size(1) + out_features = weights[0].size(0) + use_fp8 = fp8_recipe is not None + + x = x_base.detach().clone().requires_grad_(True) + with quantized_model_init(enabled=fp8_model_params, recipe=fp8_recipe): + grouped_linear = GroupedLinear( num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, + in_features, + out_features, + bias=bias, params_dtype=dtype, - parallel_mode=parallel_mode, device="cuda", - save_original_input=True, - ).eval() - - # Share params + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + ) with torch.no_grad(): - inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - setattr( - ref_grouped_linear, - f"weight{i}", - Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), - ) + getattr(grouped_linear, f"weight{i}").copy_(weights[i]) + if bias: + getattr(grouped_linear, f"bias{i}").copy_(biases[i]) - outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. + # The legacy path still expects Python split sections in several places. + m_splits_arg = ( + torch.tensor(m_splits, dtype=torch.int64, device="cuda") + if enable_grouped_tensor_path + else m_splits ) + with autocast(enabled=use_fp8, recipe=fp8_recipe): + y = grouped_linear(x, m_splits_arg) + y.backward(dy) + if delay_wgrad_compute: + grouped_linear.backward_dw() - # Should be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + outputs = [y, x.grad] + for i in range(num_gemms): + outputs.append(getattr(grouped_linear, f"weight{i}").grad) + if bias: + outputs.append(getattr(grouped_linear, f"bias{i}").grad) + return _clone_outputs(outputs) @pytest.mark.parametrize( - "shape", + "fp8_recipe", [ - (1, 127, 128, 512), - (8, 15, 128, 512), - (8, 1027, 128, 512), - (16, 10027, 128, 512), + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), + ), ], + ids=["bf16", "mxfp8"], ) -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) -def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch): - torch.manual_seed(0) - z, m, k, n = shape - - dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - assert m_splits.sum() == m and len(m_splits) == z - m_splits = m_splits.tolist() +@pytest.mark.parametrize("single_grouped_bias", _ALL_BOOLEAN) +@pytest.mark.parametrize("bias", _ALL_BOOLEAN) +@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN) +@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN) +def test_grouped_linear_grouped_tensor_path_matches_legacy( + fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, single_grouped_bias, monkeypatch +): + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = False - single_output = True - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = True - single_output = True - else: # layout == "NT" - A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - out_ref = [o.clone() for o in out] - grad = True - single_output = False + use_fp8 = fp8_recipe is not None + if fp8_model_params and not use_fp8: + pytest.skip("fp8_model_params requires FP8") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") - if use_cutlass: - monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + dtype = torch.bfloat16 + num_gemms = 3 + in_features = 64 + out_features = 64 + m_splits = [128, 256, 384] + total_tokens = sum(m_splits) - for i in range(z): - general_gemm( - A[i], - B[i], - dtype, - grad=grad, - accumulate=accumulate, - layout=layout, - out=out_ref[i], - ) - if single_output: - out_ref = [torch.cat(out_ref)] + torch.manual_seed(1234) + x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype) + dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype) + weights = [ + (0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype) + for _ in range(num_gemms) + ] + biases = None + if bias: + biases = [ + (0.1 * torch.randn(out_features, device="cuda")).to(dtype) for _ in range(num_gemms) + ] - general_grouped_gemm( - A, + outputs_legacy = _run_grouped_linear_path( + enable_grouped_tensor_path=False, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) + outputs_grouped_tensor = _run_grouped_linear_path( + enable_grouped_tensor_path=True, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) + + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): + assert grouped_tensor_out is not None + assert legacy_out is not None + torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) + + +@pytest.mark.parametrize( + "fp8_recipe", + [ + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), + ), + ], + ids=["bf16", "mxfp8"], +) +@pytest.mark.parametrize("bias", _ALL_BOOLEAN) +def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch): + """Fused GroupedTensor GEMM path should be CUDA graph capturable.""" + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + FP8GlobalStateManager.reset() + + use_fp8 = fp8_recipe is not None + dtype = torch.bfloat16 + device = "cuda" + num_gemms = 3 + in_features = 128 + out_features = 128 + split_sizes = [128, 256, 384] + total_tokens = sum(split_sizes) + static_m_splits = torch.tensor(split_sizes, dtype=torch.int64, device=device) + + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device=device, + ) + + static_x = torch.randn(total_tokens, in_features, dtype=dtype, device=device) + static_x.requires_grad_(True) + static_dy = torch.randn(total_tokens, out_features, dtype=dtype, device=device) + static_out_buf = torch.empty(total_tokens, out_features, dtype=dtype, device=device) + + def _zero_grads(): + if static_x.grad is not None: + static_x.grad.zero_() + for param in grouped_linear.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.zero_() + + def _clone_param_grads(): + return [param.grad.detach().clone() for param in grouped_linear.parameters()] + + def _train_step(x, dy, out_buf, *, use_graphed): + with autocast(enabled=use_fp8, recipe=fp8_recipe): + out = ( + graphed_grouped_linear(x, static_m_splits) + if use_graphed + else grouped_linear(x, static_m_splits) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + graphed_grouped_linear = te.make_graphed_callables( + grouped_linear, + (static_x, static_m_splits), + num_warmup_iters=3, + enabled=use_fp8, + recipe=fp8_recipe, + ) + + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + _zero_grads() + graph_out = ( + _train_step( + static_x, + static_dy, + static_out_buf, + use_graphed=True, + ) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_param_grads = _clone_param_grads() + + _zero_grads() + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with autocast(enabled=use_fp8, recipe=fp8_recipe): + expected_out = grouped_linear(expected_x, static_m_splits) + expected_out.backward(expected_dy) + + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) + torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) + for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): + assert param.grad is not None + torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + + +# ============================================================================= +# Raw grouped GEMM kernels (cpp_extensions) +# +# Tests general_grouped_gemm and general_grouped_gemm_for_grouped_tensor +# directly. Reference: per-group general_gemm calls. +# ============================================================================= + + +@pytest.mark.parametrize( + "shape", + [ + (1, 127, 128, 512), + (8, 15, 128, 512), + (8, 1027, 128, 512), + (16, 10027, 128, 512), + ], +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch): + torch.manual_seed(0) + z, m, k, n = shape + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + assert m_splits.sum() == m and len(m_splits) == z + m_splits = m_splits.tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = True + single_output = True + else: # layout == "NT" + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] + grad = True + single_output = False + + if use_cutlass: + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + + for i in range(z): + general_gemm( + A[i], + B[i], + dtype, + grad=grad, + accumulate=accumulate, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, B, out, [None] * z, @@ -1500,446 +1575,153 @@ def test_grouped_gemm_grouped_tensor_mxfp8( grad = True else: # layout == "NT" A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - grad = True - - out_ref = [o.clone() for o in out] - - transa = layout[0] == "T" - transb = layout[1] == "T" - a_is_weight = all(t.shape == A[0].shape for t in A) - a_rowwise, a_columnwise = transa, not transa - b_rowwise, b_columnwise = not transb, transb - grouped_A = _make_grouped_tensor_quantized_mxfp8( - A, - rowwise=a_rowwise, - columnwise=a_columnwise, - device="cuda", - is_weight=a_is_weight, - ) - grouped_B = _make_grouped_tensor_quantized_mxfp8( - B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" - ) - A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) - B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) - - general_grouped_gemm( - A_fp8, - B_fp8, - out_ref, - [None] * z, - dtype, - m_splits=m_sizes, - grad=grad, - accumulate=accumulate, - layout=layout, - single_output=False, - ) - - device = A[0].device - - grouped_out = None - if case != "discrete_out": - if layout == "TN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - elif layout == "NN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - else: # layout == "NT" - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_out, out) - - grouped_out_input = out if case == "discrete_out" else grouped_out - grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A - general_grouped_gemm_for_grouped_tensor( - grouped_A_input, - grouped_B, - grouped_out_input, - layout=layout, - accumulate=accumulate, - ) - - out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() - tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance - - for o, o_ref in zip(out_grouped, out_ref): - torch.testing.assert_close(o, o_ref, **tols) - - -@pytest.mark.parametrize( - "shape", - [ - (1, 128, 128, 512), - (8, 1024, 128, 512), - (16, 4096, 128, 512), - ], -) -@pytest.mark.parametrize("accumulate", [False, True]) -def test_fp8_grouped_gemm(shape, accumulate): - if not fp8_available: - pytest.skip(reason_for_no_fp8) - - z, m, k, n = shape - m_splits = [m // z] * z - - dtype = torch.bfloat16 - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output - out_ref = [o.clone() for o in out] - - # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() - amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - - a_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - b_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - - A_fp8 = [] - B_fp8 = [] - - for i in range(z): - A_fp8.append(a_quantizers[i](A[i])) - B_fp8.append(b_quantizers[i](B[i])) - - # baseline - for i in range(z): - general_gemm( - A_fp8[i], - B_fp8[i], - dtype, - out=out_ref[i], - accumulate=accumulate, - ) - general_grouped_gemm( - A_fp8, - B_fp8, - out, - [None] * z, - dtype, - m_splits=m_splits, - accumulate=accumulate, - ) - - # should be bit-wise match - for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - -_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" -_ALL_BOOLEAN = all_boolean -_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8 - - -@pytest.fixture(autouse=True) -def _reset_fp8_state(monkeypatch): - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") - yield - FP8GlobalStateManager.reset() - monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) - - -def _clone_outputs(outputs): - return [None if out is None else out.detach().clone() for out in outputs] - - -def _run_grouped_linear_path( - *, - enable_grouped_tensor_path: bool, - fp8_recipe, - bias: bool, - fp8_model_params: bool, - delay_wgrad_compute: bool, - x_base: torch.Tensor, - dy: torch.Tensor, - weights, - biases, - m_splits, - monkeypatch, -): - FP8GlobalStateManager.reset() - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1" if enable_grouped_tensor_path else "0") - - dtype = x_base.dtype - num_gemms = len(m_splits) - in_features = weights[0].size(1) - out_features = weights[0].size(0) - use_fp8 = fp8_recipe is not None - - x = x_base.detach().clone().requires_grad_(True) - with quantized_model_init(enabled=fp8_model_params, recipe=fp8_recipe): - grouped_linear = GroupedLinear( - num_gemms, - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - delay_wgrad_compute=delay_wgrad_compute, - ) - with torch.no_grad(): - for i in range(num_gemms): - getattr(grouped_linear, f"weight{i}").copy_(weights[i]) - if bias: - getattr(grouped_linear, f"bias{i}").copy_(biases[i]) - - # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. - # The legacy path still expects Python split sections in several places. - m_splits_arg = ( - torch.tensor(m_splits, dtype=torch.int64, device="cuda") - if enable_grouped_tensor_path - else m_splits - ) - with autocast(enabled=use_fp8, recipe=fp8_recipe): - y = grouped_linear(x, m_splits_arg) - y.backward(dy) - if delay_wgrad_compute: - grouped_linear.backward_dw() - - outputs = [y, x.grad] - for i in range(num_gemms): - outputs.append(getattr(grouped_linear, f"weight{i}").grad) - if bias: - outputs.append(getattr(grouped_linear, f"bias{i}").grad) - return _clone_outputs(outputs) - - -@pytest.mark.parametrize( - "fp8_recipe", - [ - None, - pytest.param( - recipe.MXFP8BlockScaling(), - marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), - ), - ], - ids=["bf16", "mxfp8"], -) -@pytest.mark.parametrize("bias", _ALL_BOOLEAN) -@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN) -@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN) -def test_grouped_linear_grouped_tensor_path_matches_legacy( - fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch -): - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("GroupedTensor grouped GEMM path requires SM100+") - - use_fp8 = fp8_recipe is not None - if fp8_model_params and not use_fp8: - pytest.skip("fp8_model_params requires FP8") - - dtype = torch.bfloat16 - num_gemms = 3 - in_features = 64 - out_features = 64 - m_splits = [128, 256, 384] - total_tokens = sum(m_splits) - - torch.manual_seed(1234) - x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype) - dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype) - weights = [ - (0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype) - for _ in range(num_gemms) - ] - biases = None - if bias: - biases = [ - (0.1 * torch.randn(out_features, device="cuda")).to(dtype) for _ in range(num_gemms) - ] - - outputs_legacy = _run_grouped_linear_path( - enable_grouped_tensor_path=False, - fp8_recipe=fp8_recipe, - bias=bias, - fp8_model_params=fp8_model_params, - delay_wgrad_compute=delay_wgrad_compute, - x_base=x_base, - dy=dy, - weights=weights, - biases=biases, - m_splits=m_splits, - monkeypatch=monkeypatch, - ) - outputs_grouped_tensor = _run_grouped_linear_path( - enable_grouped_tensor_path=True, - fp8_recipe=fp8_recipe, - bias=bias, - fp8_model_params=fp8_model_params, - delay_wgrad_compute=delay_wgrad_compute, - x_base=x_base, - dy=dy, - weights=weights, - biases=biases, - m_splits=m_splits, - monkeypatch=monkeypatch, - ) - - tols = dict(rtol=1e-2, atol=5e-3) - if use_fp8: - tols = dict(rtol=0.05, atol=0.05) - for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): - assert grouped_tensor_out is not None - assert legacy_out is not None - torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + out_ref = [o.clone() for o in out] -def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monkeypatch): - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + transa = layout[0] == "T" + transb = layout[1] == "T" + a_is_weight = all(t.shape == A[0].shape for t in A) + a_rowwise, a_columnwise = transa, not transa + b_rowwise, b_columnwise = not transb, transb + grouped_A = _make_grouped_tensor_quantized_mxfp8( + A, + rowwise=a_rowwise, + columnwise=a_columnwise, + device="cuda", + is_weight=a_is_weight, + ) + grouped_B = _make_grouped_tensor_quantized_mxfp8( + B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" + ) + A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) + B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + general_grouped_gemm( + A_fp8, + B_fp8, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) - dtype = torch.bfloat16 - num_gemms = 3 - in_features = 64 - out_features = 64 - total_tokens = 64 + 96 + 128 - m_splits = torch.tensor([64, 96, 128], dtype=torch.int64, device="cuda") - x = torch.randn(total_tokens, in_features, dtype=dtype, device="cuda").requires_grad_() - dy = torch.randn(x.size(0), out_features, dtype=dtype, device="cuda") + device = A[0].device - grouped_linear = GroupedLinear( - num_gemms, - in_features, - out_features, - bias=True, - params_dtype=dtype, - device="cuda", - delay_wgrad_compute=True, - single_grouped_bias=True, + grouped_out = None + if case != "discrete_out": + if layout == "TN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_out, out) + + grouped_out_input = out if case == "discrete_out" else grouped_out + grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A + general_grouped_gemm_for_grouped_tensor( + grouped_A_input, + grouped_B, + grouped_out_input, + layout=layout, + accumulate=accumulate, ) - y = grouped_linear(x, m_splits) - y.backward(dy) - grouped_linear.backward_dw() + out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() + tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance + + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) @pytest.mark.parametrize( - "fp8_recipe", + "shape", [ - None, - pytest.param( - recipe.MXFP8BlockScaling(), - marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), - ), + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), ], - ids=["bf16", "mxfp8"], ) -@pytest.mark.parametrize("bias", _ALL_BOOLEAN) -def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch): - """Fused GroupedTensor GEMM path should be CUDA graph capturable.""" - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("GroupedTensor grouped GEMM path requires SM100+") +@pytest.mark.parametrize("accumulate", [False, True]) +def test_fp8_grouped_gemm(shape, accumulate): + if not fp8_available: + pytest.skip(reason_for_no_fp8) - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") - FP8GlobalStateManager.reset() + z, m, k, n = shape + m_splits = [m // z] * z - use_fp8 = fp8_recipe is not None dtype = torch.bfloat16 - device = "cuda" - num_gemms = 3 - in_features = 128 - out_features = 128 - split_sizes = [128, 256, 384] - total_tokens = sum(split_sizes) - static_m_splits = torch.tensor(split_sizes, dtype=torch.int64, device=device) - - grouped_linear = GroupedLinear( - num_gemms, - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device=device, - ) - - static_x = torch.randn(total_tokens, in_features, dtype=dtype, device=device) - static_x.requires_grad_(True) - static_dy = torch.randn(total_tokens, out_features, dtype=dtype, device=device) - static_out_buf = torch.empty(total_tokens, out_features, dtype=dtype, device=device) - - def _zero_grads(): - if static_x.grad is not None: - static_x.grad.zero_() - for param in grouped_linear.parameters(): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.zero_() + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input + out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + out_ref = [o.clone() for o in out] - def _clone_param_grads(): - return [param.grad.detach().clone() for param in grouped_linear.parameters()] + # fp8 should be robust enough to this fake scale + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - def _train_step(x, dy, out_buf, *, use_graphed): - with autocast(enabled=use_fp8, recipe=fp8_recipe): - out = ( - graphed_grouped_linear(x, static_m_splits) - if use_graphed - else grouped_linear(x, static_m_splits) - ) - out.backward(dy) - out_buf.copy_(out) - return out_buf + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] - graphed_grouped_linear = te.make_graphed_callables( - grouped_linear, - (static_x, static_m_splits), - num_warmup_iters=3, - enabled=use_fp8, - recipe=fp8_recipe, - ) + A_fp8 = [] + B_fp8 = [] - fresh_x = torch.randn_like(static_x) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_dy.copy_(fresh_dy) + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) - _zero_grads() - graph_out = ( - _train_step( - static_x, - static_dy, - static_out_buf, - use_graphed=True, + # baseline + for i in range(z): + general_gemm( + A_fp8[i], + B_fp8[i], + dtype, + out=out_ref[i], + accumulate=accumulate, ) - .detach() - .clone() + general_grouped_gemm( + A_fp8, + B_fp8, + out, + [None] * z, + dtype, + m_splits=m_splits, + accumulate=accumulate, ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - graph_param_grads = _clone_param_grads() - _zero_grads() - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with autocast(enabled=use_fp8, recipe=fp8_recipe): - expected_out = grouped_linear(expected_x, static_m_splits) - expected_out.backward(expected_dy) + # should be bit-wise match + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - tols = dict(rtol=1e-2, atol=5e-3) - if use_fp8: - tols = dict(rtol=0.05, atol=0.05) - torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) - torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) - for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): - assert param.grad is not None - torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + +# ============================================================================= +# te.ops.GroupedLinear (ops/fuser API) +# +# Tests the ops-API GroupedLinear and grouped MLP patterns. Reference: PyTorch +# F.linear per group via make_reference_and_test_tensors (float64 CPU). +# Tolerance: dtype_tols() / quantization_tols(). +# ============================================================================= @pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) @@ -2516,37 +2298,68 @@ def test_grouped_mlp( in_shape = (split_sizes.sum().item(), hidden_size) out_shape = in_shape - # Random data: float32 for reference path, target dtype for test path - x_base = torch.empty(in_shape, device=device).uniform_(-0.25, 0.25) - x_ref = x_base.clone().requires_grad_() - x_test = x_base.to(dtype).requires_grad_() - - dy_base = torch.empty(out_shape, device=device).uniform_(-0.25, 0.25) - dy_ref = dy_base.clone() - dy_test = dy_base.to(dtype) - - probs_base = torch.empty((in_shape[0],), device=device).uniform_(0.1, 1.0) - probs_ref = probs_base.clone().requires_grad_() - probs_test = probs_base.to(dtype).requires_grad_() + # Reference tensors: float64 CPU; test tensors: target dtype on CUDA + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + min=-0.25, max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + min=-0.25, max=0.25, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0],), + min=0.1, max=1.0, + test_dtype=dtype, + test_device=device, + ) fc1_ws_ref, fc1_ws_test = [], [] fc1_bs_ref, fc1_bs_test = [], [] fc2_ws_ref, fc2_ws_test = [], [] fc2_bs_ref, fc2_bs_test = [], [] for _ in range(group_size): - w1 = torch.empty((fc1_out_features, hidden_size), device=device).uniform_(-0.25, 0.25) - fc1_ws_ref.append(w1.clone().requires_grad_()) - fc1_ws_test.append(w1.to(dtype)) - w2 = torch.empty((hidden_size, hidden_size), device=device).uniform_(-0.25, 0.25) - fc2_ws_ref.append(w2.clone().requires_grad_()) - fc2_ws_test.append(w2.to(dtype)) + w1_ref, w1_test = make_reference_and_test_tensors( + (fc1_out_features, hidden_size), + min=-0.25, max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref.append(w1_ref) + fc1_ws_test.append(w1_test) + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + min=-0.25, max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc2_ws_ref.append(w2_ref) + fc2_ws_test.append(w2_test) if bias: - b1 = torch.empty((fc1_out_features,), device=device).uniform_(-0.5, 0.5) - fc1_bs_ref.append(b1.clone().requires_grad_()) - fc1_bs_test.append(b1.to(dtype)) - b2 = torch.empty((hidden_size,), device=device).uniform_(-0.5, 0.5) - fc2_bs_ref.append(b2.clone().requires_grad_()) - fc2_bs_test.append(b2.to(dtype)) + b1_ref, b1_test = make_reference_and_test_tensors( + (fc1_out_features,), + min=-0.5, max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc1_bs_ref.append(b1_ref) + fc1_bs_test.append(b1_test) + b2_ref, b2_test = make_reference_and_test_tensors( + (hidden_size,), + min=-0.5, max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc2_bs_ref.append(b2_ref) + fc2_bs_test.append(b2_test) else: fc1_bs_ref.append(None) fc1_bs_test.append(None) @@ -2570,7 +2383,7 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.relu(x).square() raise ValueError(f"Unexpected activation ({activation})") - # Reference implementation (float32 PyTorch) + # Reference implementation (float64 CPU PyTorch) xs = torch.split(x_ref, split_sizes.tolist()) probs = torch.split(probs_ref, split_sizes.tolist()) ys = [] @@ -2713,8 +2526,6 @@ def _make_scaled_act(): # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): - tols = {"rtol": 0.25, "atol": 0.5} # Check values assert_close(y_test, y_ref, **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index f068f3581e..d76fa6783a 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -18,8 +18,16 @@ import transformer_engine from transformer_engine.common.recipe import Recipe +import math from transformer_engine.pytorch import InferenceParams, QuantizedTensor from transformer_engine.pytorch import DType +from transformer_engine.pytorch import ( + Float8CurrentScalingQuantizer, + Float8Quantizer, + MXFP8Quantizer, + NVFP4Quantizer, + QuantizerRole, +) from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, @@ -31,6 +39,11 @@ from transformer_engine.pytorch.module.base import get_dummy_wgrad +fp8_available, reason_for_no_fp8 = transformer_engine.pytorch.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = transformer_engine.pytorch.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = transformer_engine.pytorch.is_nvfp4_available(return_reason=True) + + def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert type name to PyTorch dtype""" if isinstance(dtype, torch.dtype): @@ -490,6 +503,159 @@ def assert_close_grads( assert_close(actual.grad, expected.grad, **kwargs) +def maybe_skip_quantization( + quantization: Optional[str], + *, + dims: Optional[Iterable[int] | int] = None, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + """Skip test case if a quantization scheme is not supported""" + + # Check if dtype is supported + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("BF16 requires SM 8.0+") + + # No quantization scheme + if quantization is None: + return + + # Check if quantization scheme is supported on device + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and not nvfp4_available + ): + pytest.skip(reason_for_no_nvfp4) + + # Check dims + if dims is not None: + if not isinstance(dims, Iterable): + dims = (dims,) + if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling"): + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + elif quantization == "mxfp8": + if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: + pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") + + # Check dtype + if dtype is not None: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and dtype != torch.bfloat16 + ): + pytest.skip("NVFP4 quantization is only supported with BF16 data") + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + *, + min: float = 0.0, + max: float = 1.0, + quantization: Optional[str] = None, + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_quantized: bool = False, + quantizer_role: Optional[QuantizerRole] = None, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + If a quantization scheme is provided, the tensor values are + quantized so that they are representable. + + """ + + # Random reference tensor + ref = torch.empty(shape, dtype=ref_dtype, device=ref_device) + ref.uniform_(min, max) + + # Construct test tensor from reference tensor + test = ref.to(device=test_device, dtype=test_dtype) + if quantization is None: + if test_is_quantized: + raise ValueError("Quantization scheme not provided") + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + elif quantization in ("fp8", "fp8_delayed_scaling"): + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=DType.kFloat8E4M3, + ) + test = quantizer(test) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=DType.kFloat8E4M3, + device=test_device, + ) + test = quantizer(test) + elif quantization == "mxfp8": + test = MXFP8Quantizer(fp8_dtype=DType.kFloat8E4M3)(test) + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" + test = NVFP4Quantizer( + with_rht=with_rht, + with_post_rht_amax=with_rht, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) + elif quantization == "nvfp4_4over6": + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + + nvfp4_use_4over6 = False + with_2d_quantization = False + nvfp4_e4m3_max = 448 + if tensor_type not in ("grad_output", "grad_input"): + nvfp4_use_4over6 = True + nvfp4_e4m3_max = 256 + if tensor_type == "weight": + with_2d_quantization = True + + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + stochastic_rounding=False, + with_random_sign_mask=False, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + )(test) + else: + raise ValueError(f"Unsupported quantization scheme ({quantization})") + if isinstance(test, QuantizedTensor) and not test_is_quantized: + test = test.dequantize() + + # Make sure reference and test tensors match each other + ref.copy_(test.to(dtype=ref.dtype)) + + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + def run_distributed( args: Sequence[str], *, From 5f62abcd519a8e2bece4c059d9adc6d6f7e6a52e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 10 Jun 2026 00:23:54 +0000 Subject: [PATCH 06/11] [PyTorch] Reorganize test_grouped_linear.py into test classes Group the four logical sections into test classes so structure is enforced by Python rather than relying on comments: - TestGroupedLinearModule: te.GroupedLinear module API tests - TestGroupedGemm: raw grouped GEMM kernel (cpp_extensions) tests - TestOpsGroupedLinear: te.ops.GroupedLinear tests - TestGroupedMLP: grouped MLP pattern tests Each class carries autouse fixtures for the environment variables it needs (NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM, NVTE_GROUPED_LINEAR_SINGLE_PARAM, NVTE_CUTEDSL_FUSED_GROUPED_MLP), replacing runtime os.environ skip checks. Also fix three coverage/correctness issues flagged in review: - _ALL_BOOLEAN / _mxfp8_available were aliases defined after the class that used them in decorators, causing NameError at collection time; replaced with the originals (all_boolean, mxfp8_available, etc.) - test_grouped_mlp was missing hidden_size=256 from parametrization - test_grouped_mlp weight tensors were missing quantizer_role="weight", which matters for NVFP4 RHT quantization behavior Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- tests/pytorch/test_grouped_linear.py | 5438 +++++++++++++------------- tests/pytorch/utils.py | 6 +- 2 files changed, 2733 insertions(+), 2711 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 2cb0dcc420..8cece276c5 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -4,7 +4,7 @@ import os import random -from typing import Dict, List, Optional, Sequence +from typing import List, Optional, Sequence import pytest import torch @@ -57,7 +57,7 @@ skip_unsupported_backward_override, ) -# Only run FP8 tests on supported devices. +# Check supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) @@ -285,269 +285,6 @@ def _test_grouped_linear_accuracy( return outputs -# ============================================================================= -# te.GroupedLinear (module API) -# -# Tests the high-level GroupedLinear module. Reference: sequential te.Linear -# modules with shared weights — bitwise match verifies grouping correctness. -# ============================================================================= - - -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [1, 3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) -@pytest.mark.parametrize("fp8_model_params", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -@pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) -def test_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - recipe, - fp8_model_params, - fuse_wgrad_accumulation, - bias, - delay_wgrad_compute, - parallel_mode=None, - use_cutlass=False, -): - fp8 = recipe is not None - if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: - pytest.skip("Delayed wgrad compute is not supported in debug mode.") - skip_unsupported_backward_override( - "grouped_linear", recipe, getattr(recipe, "backward_override", None) - ) - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - delay_wgrad_compute=delay_wgrad_compute, - save_original_input=False, - ).eval() - sequential_linear = torch.nn.ModuleList( - [ - Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - for _ in range(num_gemms) - ] - ) - - # Share params - with torch.no_grad(): - for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - if fuse_wgrad_accumulation: - weight_i = getattr(grouped_linear, f"weight{i}") - weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) - sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() - - outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - outputs = _test_grouped_linear_accuracy( - grouped_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - - for o, o_ref in zip(outputs, outputs_ref): - if use_cutlass: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) - else: - # cuBLAS implementation should be bit-wise match - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - -@pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), - reason="Only enable CUTLASS grouped gemm on Hopper", -) -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) -def test_grouped_linear_accuracy_cutlass( - dtype, - num_gemms, - bs, - model, - fuse_wgrad_accumulation, - delay_wgrad_compute, - monkeypatch, -): - monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") - test_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - None, - False, - fuse_wgrad_accumulation, - False, - delay_wgrad_compute, - None, - use_cutlass=True, - ) - - -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3]) -@pytest.mark.parametrize("bs", [1]) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) -@pytest.mark.parametrize("fp8_model_params", [False]) -@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) -@pytest.mark.parametrize("bias", [False]) -@pytest.mark.parametrize("delay_wgrad_compute", [True]) -def test_grouped_linear_accuracy_save_original_input( - dtype, - num_gemms, - bs, - model, - recipe, - fp8_model_params, - fuse_wgrad_accumulation, - bias, - delay_wgrad_compute, - parallel_mode=None, -): - fp8 = recipe is not None - if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.delayed(): - pytest.skip("DelayedScaling recipe is not supported with save_original_input") - if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: - pytest.skip("Delayed wgrad compute is not supported in debug mode.") - skip_unsupported_backward_override( - "grouped_linear", recipe, getattr(recipe, "backward_override", None) - ) - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - delay_wgrad_compute=delay_wgrad_compute, - save_original_input=True, - ).eval() - sequential_linear = torch.nn.ModuleList( - [ - Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - for _ in range(num_gemms) - ] - ) - - # Share params - with torch.no_grad(): - for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - if fuse_wgrad_accumulation: - weight_i = getattr(grouped_linear, f"weight{i}") - weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) - sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() - - outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - outputs = _test_grouped_linear_accuracy( - grouped_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - - # Should be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): @@ -641,2018 +378,2081 @@ def _generate_random_numbers(n, total_sum): return outputs -@pytest.mark.parametrize("save_original_input", [False, True]) -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("num_gemms", [3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) -@pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_padding_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - fp8, - recipe, - fp8_model_params, - save_original_input, - parallel_mode=None, -): - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if save_original_input and recipe.delayed(): - pytest.skip("DelayedScaling recipe is not supported with save_original_input") - skip_unsupported_backward_override( - "grouped_linear", recipe, getattr(recipe, "backward_override", None) - ) - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) +_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = TorchGroupedLinearWithPadding( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - fp8=fp8, - ).eval() - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - ref_grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - save_original_input=save_original_input, - ).eval() +class TestGroupedLinearModule: + """Tests for te.GroupedLinear module API. - # Share params - with torch.no_grad(): - inner_grouped_linear = grouped_linear.linear_fn - for i in range(num_gemms): - setattr( - ref_grouped_linear, - f"weight{i}", - Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), - ) + Reference: sequential te.Linear modules with shared weights. + """ - outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) + @pytest.fixture(autouse=True) + def _set_fused_grouped_gemm_env(self, monkeypatch): + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") + yield + monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) + + + @pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("num_gemms", [1, 3, 6]) + @pytest.mark.parametrize("bs", batch_sizes) + @pytest.mark.parametrize("model", ["126m"]) + @pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) + @pytest.mark.parametrize("fp8_model_params", all_boolean) + @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) + @pytest.mark.parametrize("bias", all_boolean) + @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) + def test_grouped_linear_accuracy( + self, + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, + use_cutlass=False, + ): + fp8 = recipe is not None + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: + pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) - # Should be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) -_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" -_ALL_BOOLEAN = all_boolean -_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8 + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=False, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) -@pytest.fixture(autouse=True) -def _reset_fp8_state(monkeypatch): - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") - yield - FP8GlobalStateManager.reset() - monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + + @pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", + ) + @pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("num_gemms", [3, 6]) + @pytest.mark.parametrize("bs", batch_sizes) + @pytest.mark.parametrize("model", ["126m"]) + @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) + @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) + def test_grouped_linear_accuracy_cutlass( + self, + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, + monkeypatch, + ): + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + self.test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) -def _clone_outputs(outputs): - return [None if out is None else out.detach().clone() for out in outputs] + @pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("num_gemms", [3]) + @pytest.mark.parametrize("bs", [1]) + @pytest.mark.parametrize("model", ["126m"]) + @pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) + @pytest.mark.parametrize("fp8_model_params", [False]) + @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) + @pytest.mark.parametrize("bias", [False]) + @pytest.mark.parametrize("delay_wgrad_compute", [True]) + def test_grouped_linear_accuracy_save_original_input( + self, + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, + ): + fp8 = recipe is not None + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: + pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") -def _run_grouped_linear_path( - *, - enable_grouped_tensor_path: bool, - fp8_recipe, - bias: bool, - fp8_model_params: bool, - delay_wgrad_compute: bool, - single_grouped_bias: bool = False, - x_base: torch.Tensor, - dy: torch.Tensor, - weights, - biases, - m_splits, - monkeypatch, -): - FP8GlobalStateManager.reset() - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1" if enable_grouped_tensor_path else "0") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) - dtype = x_base.dtype - num_gemms = len(m_splits) - in_features = weights[0].size(1) - out_features = weights[0].size(0) - use_fp8 = fp8_recipe is not None + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=True, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) - x = x_base.detach().clone().requires_grad_(True) - with quantized_model_init(enabled=fp8_model_params, recipe=fp8_recipe): - grouped_linear = GroupedLinear( + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, num_gemms, - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, ) - with torch.no_grad(): - for i in range(num_gemms): - getattr(grouped_linear, f"weight{i}").copy_(weights[i]) - if bias: - getattr(grouped_linear, f"bias{i}").copy_(biases[i]) - - # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. - # The legacy path still expects Python split sections in several places. - m_splits_arg = ( - torch.tensor(m_splits, dtype=torch.int64, device="cuda") - if enable_grouped_tensor_path - else m_splits - ) - with autocast(enabled=use_fp8, recipe=fp8_recipe): - y = grouped_linear(x, m_splits_arg) - y.backward(dy) - if delay_wgrad_compute: - grouped_linear.backward_dw() - - outputs = [y, x.grad] - for i in range(num_gemms): - outputs.append(getattr(grouped_linear, f"weight{i}").grad) - if bias: - outputs.append(getattr(grouped_linear, f"bias{i}").grad) - return _clone_outputs(outputs) - - -@pytest.mark.parametrize( - "fp8_recipe", - [ - None, - pytest.param( - recipe.MXFP8BlockScaling(), - marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), - ), - ], - ids=["bf16", "mxfp8"], -) -@pytest.mark.parametrize("single_grouped_bias", _ALL_BOOLEAN) -@pytest.mark.parametrize("bias", _ALL_BOOLEAN) -@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN) -@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN) -def test_grouped_linear_grouped_tensor_path_matches_legacy( - fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, single_grouped_bias, monkeypatch -): - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("GroupedTensor grouped GEMM path requires SM100+") - - use_fp8 = fp8_recipe is not None - if fp8_model_params and not use_fp8: - pytest.skip("fp8_model_params requires FP8") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - - dtype = torch.bfloat16 - num_gemms = 3 - in_features = 64 - out_features = 64 - m_splits = [128, 256, 384] - total_tokens = sum(m_splits) - - torch.manual_seed(1234) - x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype) - dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype) - weights = [ - (0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype) - for _ in range(num_gemms) - ] - biases = None - if bias: - biases = [ - (0.1 * torch.randn(out_features, device="cuda")).to(dtype) for _ in range(num_gemms) - ] - - outputs_legacy = _run_grouped_linear_path( - enable_grouped_tensor_path=False, - fp8_recipe=fp8_recipe, - bias=bias, - fp8_model_params=fp8_model_params, - delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, - x_base=x_base, - dy=dy, - weights=weights, - biases=biases, - m_splits=m_splits, - monkeypatch=monkeypatch, - ) - outputs_grouped_tensor = _run_grouped_linear_path( - enable_grouped_tensor_path=True, - fp8_recipe=fp8_recipe, - bias=bias, - fp8_model_params=fp8_model_params, - delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, - x_base=x_base, - dy=dy, - weights=weights, - biases=biases, - m_splits=m_splits, - monkeypatch=monkeypatch, - ) - tols = dict(rtol=1e-2, atol=5e-3) - if use_fp8: - tols = dict(rtol=0.05, atol=0.05) - for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): - assert grouped_tensor_out is not None - assert legacy_out is not None - torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) - - -@pytest.mark.parametrize( - "fp8_recipe", - [ - None, - pytest.param( - recipe.MXFP8BlockScaling(), - marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), - ), - ], - ids=["bf16", "mxfp8"], -) -@pytest.mark.parametrize("bias", _ALL_BOOLEAN) -def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch): - """Fused GroupedTensor GEMM path should be CUDA graph capturable.""" - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + # Should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") - FP8GlobalStateManager.reset() - use_fp8 = fp8_recipe is not None - dtype = torch.bfloat16 - device = "cuda" - num_gemms = 3 - in_features = 128 - out_features = 128 - split_sizes = [128, 256, 384] - total_tokens = sum(split_sizes) - static_m_splits = torch.tensor(split_sizes, dtype=torch.int64, device=device) - - grouped_linear = GroupedLinear( + @pytest.mark.parametrize("save_original_input", [False, True]) + @pytest.mark.parametrize("dtype", param_types) + @pytest.mark.parametrize("num_gemms", [3, 6]) + @pytest.mark.parametrize("bs", batch_sizes) + @pytest.mark.parametrize("model", ["126m"]) + @pytest.mark.parametrize("fp8", [True]) + @pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) + @pytest.mark.parametrize("fp8_model_params", all_boolean) + def test_padding_grouped_linear_accuracy( + self, + dtype, num_gemms, - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device=device, - ) + bs, + model, + fp8, + recipe, + fp8_model_params, + save_original_input, + parallel_mode=None, + ): + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if save_original_input and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) - static_x = torch.randn(total_tokens, in_features, dtype=dtype, device=device) - static_x.requires_grad_(True) - static_dy = torch.randn(total_tokens, out_features, dtype=dtype, device=device) - static_out_buf = torch.empty(total_tokens, out_features, dtype=dtype, device=device) + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") - def _zero_grads(): - if static_x.grad is not None: - static_x.grad.zero_() - for param in grouped_linear.parameters(): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.zero_() + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) - def _clone_param_grads(): - return [param.grad.detach().clone() for param in grouped_linear.parameters()] + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + save_original_input=save_original_input, + ).eval() - def _train_step(x, dy, out_buf, *, use_graphed): - with autocast(enabled=use_fp8, recipe=fp8_recipe): - out = ( - graphed_grouped_linear(x, static_m_splits) - if use_graphed - else grouped_linear(x, static_m_splits) - ) - out.backward(dy) - out_buf.copy_(out) - return out_buf - - graphed_grouped_linear = te.make_graphed_callables( - grouped_linear, - (static_x, static_m_splits), - num_warmup_iters=3, - enabled=use_fp8, - recipe=fp8_recipe, - ) + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) - fresh_x = torch.randn_like(static_x) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_dy.copy_(fresh_dy) - - _zero_grads() - graph_out = ( - _train_step( - static_x, - static_dy, - static_out_buf, - use_graphed=True, - ) - .detach() - .clone() - ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - graph_param_grads = _clone_param_grads() - - _zero_grads() - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with autocast(enabled=use_fp8, recipe=fp8_recipe): - expected_out = grouped_linear(expected_x, static_m_splits) - expected_out.backward(expected_dy) - - tols = dict(rtol=1e-2, atol=5e-3) - if use_fp8: - tols = dict(rtol=0.05, atol=0.05) - torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) - torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) - for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): - assert param.grad is not None - torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) - - -# ============================================================================= -# Raw grouped GEMM kernels (cpp_extensions) -# -# Tests general_grouped_gemm and general_grouped_gemm_for_grouped_tensor -# directly. Reference: per-group general_gemm calls. -# ============================================================================= - - -@pytest.mark.parametrize( - "shape", - [ - (1, 127, 128, 512), - (8, 15, 128, 512), - (8, 1027, 128, 512), - (16, 10027, 128, 512), - ], -) -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) -def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch): - torch.manual_seed(0) - z, m, k, n = shape - - dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - assert m_splits.sum() == m and len(m_splits) == z - m_splits = m_splits.tolist() - - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = False - single_output = True - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = True - single_output = True - else: # layout == "NT" - A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - out_ref = [o.clone() for o in out] - grad = True - single_output = False + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) - if use_cutlass: - monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + # Should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - for i in range(z): - general_gemm( - A[i], - B[i], - dtype, - grad=grad, - accumulate=accumulate, - layout=layout, - out=out_ref[i], + + @staticmethod + def _run_grouped_linear_path( + *, + enable_grouped_tensor_path: bool, + fp8_recipe, + bias: bool, + fp8_model_params: bool, + delay_wgrad_compute: bool, + single_grouped_bias: bool = False, + x_base: torch.Tensor, + dy: torch.Tensor, + weights, + biases, + m_splits, + monkeypatch, + ): + FP8GlobalStateManager.reset() + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1" if enable_grouped_tensor_path else "0") + + dtype = x_base.dtype + num_gemms = len(m_splits) + in_features = weights[0].size(1) + out_features = weights[0].size(0) + use_fp8 = fp8_recipe is not None + + x = x_base.detach().clone().requires_grad_(True) + with quantized_model_init(enabled=fp8_model_params, recipe=fp8_recipe): + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + ) + with torch.no_grad(): + for i in range(num_gemms): + getattr(grouped_linear, f"weight{i}").copy_(weights[i]) + if bias: + getattr(grouped_linear, f"bias{i}").copy_(biases[i]) + + # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. + # The legacy path still expects Python split sections in several places. + m_splits_arg = ( + torch.tensor(m_splits, dtype=torch.int64, device="cuda") + if enable_grouped_tensor_path + else m_splits ) - if single_output: - out_ref = [torch.cat(out_ref)] + with autocast(enabled=use_fp8, recipe=fp8_recipe): + y = grouped_linear(x, m_splits_arg) + y.backward(dy) + if delay_wgrad_compute: + grouped_linear.backward_dw() - general_grouped_gemm( - A, - B, - out, - [None] * z, - dtype, - m_splits=m_splits, - grad=grad, - accumulate=accumulate, - layout=layout, - single_output=single_output, - ) + outputs = [y, x.grad] + for i in range(num_gemms): + outputs.append(getattr(grouped_linear, f"weight{i}").grad) + if bias: + outputs.append(getattr(grouped_linear, f"bias{i}").grad) + return _clone_outputs(outputs) - for o, o_ref in zip(out, out_ref): - if not use_cutlass: - # cublas implementation should be bit-wise match - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - else: - torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + @pytest.mark.parametrize( + "fp8_recipe", + [ + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + ], + ids=["bf16", "mxfp8"], + ) + @pytest.mark.parametrize("single_grouped_bias", all_boolean) + @pytest.mark.parametrize("bias", all_boolean) + @pytest.mark.parametrize("fp8_model_params", all_boolean) + @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) + def test_grouped_linear_grouped_tensor_path_matches_legacy( + self, + fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, single_grouped_bias, monkeypatch + ): + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + use_fp8 = fp8_recipe is not None + if fp8_model_params and not use_fp8: + pytest.skip("fp8_model_params requires FP8") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + + dtype = torch.bfloat16 + num_gemms = 3 + in_features = 64 + out_features = 64 + m_splits = [128, 256, 384] + total_tokens = sum(m_splits) + + torch.manual_seed(1234) + x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype) + dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype) + weights = [ + (0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype) + for _ in range(num_gemms) + ] + biases = None + if bias: + biases = [ + (0.1 * torch.randn(out_features, device="cuda")).to(dtype) for _ in range(num_gemms) + ] -@pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), - reason="Only enable CUTLASS grouped gemm on Hopper", -) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch): - dtype = torch.bfloat16 - z, k, n = 1, 2048, 1536 - m_splits = [0] * z + outputs_legacy = self._run_grouped_linear_path( + enable_grouped_tensor_path=False, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) + outputs_grouped_tensor = self._run_grouped_linear_path( + enable_grouped_tensor_path=True, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_bias=single_grouped_bias, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input - out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output - grad = False - single_output = True - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output - out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad - grad = True - single_output = True - else: # layout == "NT" - A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input - B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - grad = True - single_output = False - - monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") - general_grouped_gemm( - A, - B, - out, - [None] * z, - dtype, - m_splits=m_splits, - grad=grad, - layout=layout, - single_output=single_output, - ) - torch.cuda.synchronize() + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): + assert grouped_tensor_out is not None + assert legacy_out is not None + torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) - for tensor in out: - torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) - - -def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: - data = grouped_tensor.rowwise_data - if data is None: - data = grouped_tensor.columnwise_data - if data is None: - raise ValueError("GroupedTensor has no data buffers to pack.") - offset = 0 - for tensor in tensors: - numel = tensor.numel() - data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - - -def _make_grouped_tensor_from_splits( - m_sizes: List[int], - last_dim: int, - device: torch.device, - dtype: torch.dtype, -) -> GroupedTensor: - first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) - return GroupedTensor.make_grouped_tensor( - num_tensors=len(m_sizes), - first_dims=first_dims, - last_dims=None, - logical_first_dim=sum(m_sizes), - logical_last_dim=last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) + @pytest.mark.parametrize( + "fp8_recipe", + [ + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + ], + ids=["bf16", "mxfp8"], + ) + @pytest.mark.parametrize("bias", all_boolean) + def test_grouped_linear_fused_path_cuda_graph_safe(self, fp8_recipe, bias, monkeypatch): + """Fused GroupedTensor GEMM path should be CUDA graph capturable.""" + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + FP8GlobalStateManager.reset() -def _make_grouped_tensor_uniform( - num_tensors: int, - first_dim: int, - last_dim: int, - device: torch.device, - dtype: torch.dtype, -) -> GroupedTensor: - return GroupedTensor.make_grouped_tensor( - num_tensors=num_tensors, - first_dims=None, - last_dims=None, - logical_first_dim=num_tensors * first_dim, - logical_last_dim=last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) + use_fp8 = fp8_recipe is not None + dtype = torch.bfloat16 + device = "cuda" + num_gemms = 3 + in_features = 128 + out_features = 128 + split_sizes = [128, 256, 384] + total_tokens = sum(split_sizes) + static_m_splits = torch.tensor(split_sizes, dtype=torch.int64, device=device) + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device=device, + ) -def _apply_grouped_bias_ref( - base_outs: List[torch.Tensor], - bias: Optional[List[torch.Tensor]], - bias_scale: Optional[torch.Tensor], - m_sizes: List[int], - dtype: torch.dtype, -) -> List[torch.Tensor]: - """Reference: add (optionally per-row scaled) bias to each group's output, cast to ``dtype``.""" - if bias is None: - return list(base_outs) - if bias_scale is None: - return [(o.float() + b.float()).to(dtype) for o, b in zip(base_outs, bias)] - out = [] - offset = 0 - for i, ms in enumerate(m_sizes): - s = bias_scale[offset : offset + ms].unsqueeze(-1) - out.append((base_outs[i].float() + bias[i].float() * s).to(dtype)) - offset += ms - return out - - -@pytest.mark.parametrize( - "z, m, n, k", - [ - (4, 256, 256, 256), - (4, 512, 256, 512), - (4, 512, 512, 256), - (8, 512, 256, 512), - ], -) -@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("use_bias_scale", [False, True]) -def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: - if torch.cuda.get_device_capability() < (9, 0): - pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.") - if torch.cuda.get_device_capability() < (10, 0): - if tex.get_cublasLt_version() < 130400: - pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.") - if tex.get_cublasLt_version() < 130300: - pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") - if not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - - torch.manual_seed(0) - - dtype = torch.bfloat16 - - split_points = torch.randperm(m - 1)[: z - 1] + 1 - split_points = torch.sort(split_points).values.tolist() - m_sizes = [split_points[0]] - m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] - m_sizes.append(m - split_points[-1]) - assert sum(m_sizes) == m and len(m_sizes) == z - - if layout == "NT": - A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] - else: - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [ - torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") - for ms in m_sizes - ] # TN --> input, NN --> grad_output - out = [ - torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") - for ms in m_sizes - ] # TN --> output, NN --> dgrad - if layout == "NN": - out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] - else: # layout == "TN" - out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] - - if accumulate: - out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] - - # Bias is applied after GEMM (broadcasted along rows) - # Match kernel behavior: GEMM output is already in output dtype when bias is added. - out_ref_no_bias = [o.to(dtype) for o in out_ref] - if layout == "TN": - bias_last_dim = n - else: # layout == "NT" or "NN" - bias_last_dim = k - bias = ( - [torch.randn(1, bias_last_dim, dtype=dtype, device="cuda") for _ in range(z)] - if case != "discrete_out" - else None - ) - bias_scale = None - if use_bias_scale and bias is not None and layout != "NT": - bias_scale = torch.randn(m, device="cuda", dtype=torch.float32) - # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. - out_ref = _apply_grouped_bias_ref(out_ref_no_bias, bias, bias_scale, m_sizes, dtype) - # Create grouped tensors based on case - device = A[0].device - grouped_A = A - grouped_out = out - grouped_out_bias = [o.clone() for o in out] - grouped_out_no_bias = [o.clone() for o in out] - grouped_bias = None - if layout == "TN": - grouped_A = ( - _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A - ) # weight - grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input - if case != "discrete_out": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output - grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - elif layout == "NN": - grouped_A = ( - _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A - ) # weight - grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output - if case != "discrete_out": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - else: # layout == "NT" - grouped_A = ( - _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - if case != "discrete_in" - else A - ) # input - grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output - if case != "discrete_out": - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad - grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_B, B) - if case != "discrete_out": - _pack_grouped_tensor(grouped_out, out) - _pack_grouped_tensor(grouped_out_bias, out) - _pack_grouped_tensor(grouped_out_no_bias, out) - if case != "discrete_in": - _pack_grouped_tensor(grouped_A, A) - - if bias is not None: - grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) - _pack_grouped_tensor(grouped_bias, bias) - - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out_no_bias, - layout=layout, - accumulate=accumulate, - bias=None, - ) - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out_bias, - layout=layout, - accumulate=accumulate, - bias=grouped_bias, - bias_scale=bias_scale, - ) - out_grouped_no_bias = ( - grouped_out_no_bias - if isinstance(grouped_out_no_bias, list) - else grouped_out_no_bias.split_into_quantized_tensors() - ) - out_grouped_bias = ( - grouped_out_bias - if isinstance(grouped_out_bias, list) - else grouped_out_bias.split_into_quantized_tensors() - ) + static_x = torch.randn(total_tokens, in_features, dtype=dtype, device=device) + static_x.requires_grad_(True) + static_dy = torch.randn(total_tokens, out_features, dtype=dtype, device=device) + static_out_buf = torch.empty(total_tokens, out_features, dtype=dtype, device=device) + + def _zero_grads(): + if static_x.grad is not None: + static_x.grad.zero_() + for param in grouped_linear.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.zero_() - out_grouped_manual_bias = _apply_grouped_bias_ref( - out_grouped_no_bias, bias, bias_scale, m_sizes, dtype - ) - tols = dtype_tols(dtype) - for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias): - torch.testing.assert_close(o, o_ref, **tols) - if bias is not None: - for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): - torch.testing.assert_close(o, o_ref, **tols) + def _clone_param_grads(): + return [param.grad.detach().clone() for param in grouped_linear.parameters()] + def _train_step(x, dy, out_buf, *, use_graphed): + with autocast(enabled=use_fp8, recipe=fp8_recipe): + out = ( + graphed_grouped_linear(x, static_m_splits) + if use_graphed + else grouped_linear(x, static_m_splits) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + graphed_grouped_linear = te.make_graphed_callables( + grouped_linear, + (static_x, static_m_splits), + num_warmup_iters=3, + enabled=use_fp8, + recipe=fp8_recipe, + ) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) -def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -> None: - """Grouped GEMM with all-zero split sizes (zero total work). + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + _zero_grads() + graph_out = ( + _train_step( + static_x, + static_dy, + static_out_buf, + use_graphed=True, + ) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_param_grads = _clone_param_grads() - For wgrad (NT layout) the output should be zero when not accumulating, - or unchanged when accumulating with beta=1. - """ - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") - if not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - if quant_type == "mxfp8" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - - z = 4 - k, n = 256, 256 - dtype = torch.bfloat16 - device = torch.device("cuda") - use_mxfp8 = quant_type == "mxfp8" - - transa = layout[0] == "T" - transb = layout[1] == "T" - zero_first_dims = torch.zeros(z, dtype=torch.int64, device=device) - - def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): - """Create a GroupedTensor with non-zero logical_shape but zero first_dims.""" - buf = torch.randn(0, logical_last_dim, dtype=dtype, device=device) - if use_mxfp8: - if is_a: - rowwise, columnwise = transa, not transa - else: - rowwise, columnwise = not transb, transb - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, + _zero_grads() + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with autocast(enabled=use_fp8, recipe=fp8_recipe): + expected_out = grouped_linear(expected_x, static_m_splits) + expected_out.backward(expected_dy) + + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) + torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) + for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): + assert param.grad is not None + torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + + +def _clone_outputs(outputs): + return [None if out is None else out.detach().clone() for out in outputs] + + +@pytest.fixture(autouse=True) +def _reset_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +class TestGroupedGemm: + """Tests for raw grouped GEMM kernels (cpp_extensions).""" + + @pytest.mark.parametrize( + "shape", + [ + (1, 127, 128, 512), + (8, 15, 128, 512), + (8, 1027, 128, 512), + (16, 10027, 128, 512), + ], + ) + @pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) + @pytest.mark.parametrize("accumulate", [False, True]) + @pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) + def test_grouped_gemm(self, shape, dtype, layout, accumulate, use_cutlass, monkeypatch): + torch.manual_seed(0) + z, m, k, n = shape + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + assert m_splits.sum() == m and len(m_splits) == z + m_splits = m_splits.tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = True + single_output = True + else: # layout == "NT" + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] + grad = True + single_output = False + + if use_cutlass: + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + + for i in range(z): + general_gemm( + A[i], + B[i], + dtype, + grad=grad, + accumulate=accumulate, + layout=layout, + out=out_ref[i], ) - quantizer.optimize_for_gemm = True - return tex.group_quantize(buf, quantizer, z, zero_first_dims) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + + @pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", + ) + @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) + def test_grouped_gemm_cutlass_empty_groups(self, layout, monkeypatch): + dtype = torch.bfloat16 + z, k, n = 1, 2048, 1536 + m_splits = [0] * z + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad + grad = True + single_output = True + else: # layout == "NT" + A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + single_output = False + + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + torch.cuda.synchronize() + + for tensor in out: + torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) + + + @staticmethod + def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + data = grouped_tensor.rowwise_data + if data is None: + data = grouped_tensor.columnwise_data + if data is None: + raise ValueError("GroupedTensor has no data buffers to pack.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + + @staticmethod + def _make_grouped_tensor_from_splits( + m_sizes: List[int], + last_dim: int, + device: torch.device, + dtype: torch.dtype, + ) -> GroupedTensor: + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) return GroupedTensor.make_grouped_tensor( - num_tensors=z, - first_dims=zero_first_dims, + num_tensors=len(m_sizes), + first_dims=first_dims, last_dims=None, - logical_first_dim=k, - logical_last_dim=logical_last_dim, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, quantizer=None, device=device, dtype=dtype, ) - if layout in ("TN", "NN"): - weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] - if use_mxfp8: - grouped_A = _make_grouped_tensor_quantized_mxfp8( - weight_tensors, - rowwise=transa, - columnwise=not transa, - device=device, - ) - else: - grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_A, weight_tensors) - else: # NT - grouped_A = _make_zero_tokens_grouped_tensor(k, is_a=True) - - b_last_dim = k if layout == "TN" else n - grouped_B = _make_zero_tokens_grouped_tensor(b_last_dim, is_a=False) - - if layout == "NT": - out = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_out, out) - else: - out = [torch.zeros(0, dtype=dtype, device=device) for _ in range(z)] - out_last_dim = n if layout == "TN" else k - grouped_out = GroupedTensor.make_grouped_tensor( - num_tensors=z, - first_dims=zero_first_dims, + + @staticmethod + def _make_grouped_tensor_uniform( + num_tensors: int, + first_dim: int, + last_dim: int, + device: torch.device, + dtype: torch.dtype, + ) -> GroupedTensor: + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, last_dims=None, - logical_first_dim=k, - logical_last_dim=out_last_dim, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, quantizer=None, device=device, dtype=dtype, ) - out_before = [o.clone() for o in out] - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out, - layout=layout, - accumulate=accumulate, - ) + @staticmethod + def _apply_grouped_bias_ref( + base_outs: List[torch.Tensor], + bias: Optional[List[torch.Tensor]], + bias_scale: Optional[torch.Tensor], + m_sizes: List[int], + dtype: torch.dtype, + ) -> List[torch.Tensor]: + """Reference: add (optionally per-row scaled) bias to each group's output, cast to ``dtype``.""" + if bias is None: + return list(base_outs) + if bias_scale is None: + return [(o.float() + b.float()).to(dtype) for o, b in zip(base_outs, bias)] + out = [] + offset = 0 + for i, ms in enumerate(m_sizes): + s = bias_scale[offset : offset + ms].unsqueeze(-1) + out.append((base_outs[i].float() + bias[i].float() * s).to(dtype)) + offset += ms + return out + + + @pytest.mark.parametrize( + "z, m, n, k", + [ + (4, 256, 256, 256), + (4, 512, 256, 512), + (4, 512, 512, 256), + (8, 512, 256, 512), + ], + ) + @pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) + @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) + @pytest.mark.parametrize("accumulate", [False, True]) + @pytest.mark.parametrize("use_bias_scale", [False, True]) + def test_grouped_gemm_grouped_tensor(self, z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: + if torch.cuda.get_device_capability() < (9, 0): + pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.") + if torch.cuda.get_device_capability() < (10, 0): + if tex.get_cublasLt_version() < 130400: + pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.") + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + + dtype = torch.bfloat16 + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + if layout == "NT": + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] + else: + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [ + torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> input, NN --> grad_output + out = [ + torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> output, NN --> dgrad + if layout == "NN": + out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] + else: # layout == "TN" + out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] - out_result = ( - grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() - ) - for i in range(z): - if out_result[i].numel() == 0: - continue if accumulate: - torch.testing.assert_close(out_result[i], out_before[i]) + out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] + + # Bias is applied after GEMM (broadcasted along rows) + # Match kernel behavior: GEMM output is already in output dtype when bias is added. + out_ref_no_bias = [o.to(dtype) for o in out_ref] + if layout == "TN": + bias_last_dim = n + else: # layout == "NT" or "NN" + bias_last_dim = k + bias = ( + [torch.randn(1, bias_last_dim, dtype=dtype, device="cuda") for _ in range(z)] + if case != "discrete_out" + else None + ) + bias_scale = None + if use_bias_scale and bias is not None and layout != "NT": + bias_scale = torch.randn(m, device="cuda", dtype=torch.float32) + # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. + out_ref = self._apply_grouped_bias_ref(out_ref_no_bias, bias, bias_scale, m_sizes, dtype) + # Create grouped tensors based on case + device = A[0].device + grouped_A = A + grouped_out = out + grouped_out_bias = [o.clone() for o in out] + grouped_out_no_bias = [o.clone() for o in out] + grouped_bias = None + if layout == "TN": + grouped_A = ( + self._make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + if case != "discrete_out": + grouped_out = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output + grouped_out_bias = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + grouped_out_no_bias = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_A = ( + self._make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_bias = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_no_bias = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_A = ( + self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + if case != "discrete_in" + else A + ) # input + grouped_B = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = self._make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad + grouped_out_bias = self._make_grouped_tensor_uniform(z, n, k, device, dtype) + grouped_out_no_bias = self._make_grouped_tensor_uniform(z, n, k, device, dtype) + self._pack_grouped_tensor(grouped_B, B) + if case != "discrete_out": + self._pack_grouped_tensor(grouped_out, out) + self._pack_grouped_tensor(grouped_out_bias, out) + self._pack_grouped_tensor(grouped_out_no_bias, out) + if case != "discrete_in": + self._pack_grouped_tensor(grouped_A, A) + + if bias is not None: + grouped_bias = self._make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) + self._pack_grouped_tensor(grouped_bias, bias) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_no_bias, + layout=layout, + accumulate=accumulate, + bias=None, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_bias, + layout=layout, + accumulate=accumulate, + bias=grouped_bias, + bias_scale=bias_scale, + ) + out_grouped_no_bias = ( + grouped_out_no_bias + if isinstance(grouped_out_no_bias, list) + else grouped_out_no_bias.split_into_quantized_tensors() + ) + out_grouped_bias = ( + grouped_out_bias + if isinstance(grouped_out_bias, list) + else grouped_out_bias.split_into_quantized_tensors() + ) + + out_grouped_manual_bias = self._apply_grouped_bias_ref( + out_grouped_no_bias, bias, bias_scale, m_sizes, dtype + ) + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias): + torch.testing.assert_close(o, o_ref, **tols) + if bias is not None: + for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): + torch.testing.assert_close(o, o_ref, **tols) + + + @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) + @pytest.mark.parametrize("accumulate", [False, True]) + @pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) + def test_grouped_gemm_grouped_tensor_zero_work(self, layout, accumulate, quant_type) -> None: + """Grouped GEMM with all-zero split sizes (zero total work). + + For wgrad (NT layout) the output should be zero when not accumulating, + or unchanged when accumulating with beta=1. + """ + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + if quant_type == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + z = 4 + k, n = 256, 256 + dtype = torch.bfloat16 + device = torch.device("cuda") + use_mxfp8 = quant_type == "mxfp8" + + transa = layout[0] == "T" + transb = layout[1] == "T" + zero_first_dims = torch.zeros(z, dtype=torch.int64, device=device) + + def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): + """Create a GroupedTensor with non-zero logical_shape but zero first_dims.""" + buf = torch.randn(0, logical_last_dim, dtype=dtype, device=device) + if use_mxfp8: + if is_a: + rowwise, columnwise = transa, not transa + else: + rowwise, columnwise = not transb, transb + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = True + return tex.group_quantize(buf, quantizer, z, zero_first_dims) + return GroupedTensor.make_grouped_tensor( + num_tensors=z, + first_dims=zero_first_dims, + last_dims=None, + logical_first_dim=k, + logical_last_dim=logical_last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + if layout in ("TN", "NN"): + weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + if use_mxfp8: + grouped_A = self._make_grouped_tensor_quantized_mxfp8( + weight_tensors, + rowwise=transa, + columnwise=not transa, + device=device, + ) + else: + grouped_A = self._make_grouped_tensor_uniform(z, n, k, device, dtype) + self._pack_grouped_tensor(grouped_A, weight_tensors) + else: # NT + grouped_A = _make_zero_tokens_grouped_tensor(k, is_a=True) + + b_last_dim = k if layout == "TN" else n + grouped_B = _make_zero_tokens_grouped_tensor(b_last_dim, is_a=False) + + if layout == "NT": + out = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + grouped_out = self._make_grouped_tensor_uniform(z, n, k, device, dtype) + self._pack_grouped_tensor(grouped_out, out) else: - torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) - - -def _make_grouped_tensor_quantized_mxfp8( - tensors: List[torch.Tensor], - *, - rowwise: bool, - columnwise: bool, - device: torch.device, - is_weight: bool = False, -) -> GroupedTensor: - """Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors. - - For weights (uniform per-expert shape), we generally won't keep it swizzled since we - might need for future dequantize operations. Swizzling is done internally within - general_grouped_gemm_for_grouped_tensor call. - - For non-weight tensors (inputs / grad_outputs), we still pass - ``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the - already-swizzled layout up front. - """ - if not tensors: - raise ValueError("Expected non-empty tensor list for grouped quantization.") - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, - ) - quantizer.optimize_for_gemm = not is_weight - grouped_input = torch.cat(tensors, dim=0) - if is_weight: - first_dims = None - else: - first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) - return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) - - -def _per_tensor_quantize_mxfp8( - tensors: List[torch.Tensor], - *, - rowwise: bool, - columnwise: bool, -) -> List: - """Quantize each tensor individually with MXFP8. - Used to build reference discrete inputs for grouped GEMM. - """ - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, - ) - return [quantizer(t) for t in tensors] + out = [torch.zeros(0, dtype=dtype, device=device) for _ in range(z)] + out_last_dim = n if layout == "TN" else k + grouped_out = GroupedTensor.make_grouped_tensor( + num_tensors=z, + first_dims=zero_first_dims, + last_dims=None, + logical_first_dim=k, + logical_last_dim=out_last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + out_before = [o.clone() for o in out] -@pytest.mark.parametrize( - "shape", - [ - (1, 128, 128, 512), - (8, 1024, 128, 512), - (16, 4096, 128, 512), - (2, 256, 2880, 2880), - ], -) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_grouped_gemm_grouped_tensor_mxfp8( - shape, accumulate, layout: str, case: str, dtype: torch.dtype -) -> None: - if tex.get_cublasLt_version() < 130300: - pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") - if dtype == torch.bfloat16 and not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - - torch.manual_seed(0) - z, m, k, n = shape - m_sizes = [m // z] * z - - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output - grad = False - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad - grad = True - else: # layout == "NT" - A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - grad = True - - out_ref = [o.clone() for o in out] - - transa = layout[0] == "T" - transb = layout[1] == "T" - a_is_weight = all(t.shape == A[0].shape for t in A) - a_rowwise, a_columnwise = transa, not transa - b_rowwise, b_columnwise = not transb, transb - grouped_A = _make_grouped_tensor_quantized_mxfp8( - A, - rowwise=a_rowwise, - columnwise=a_columnwise, - device="cuda", - is_weight=a_is_weight, - ) - grouped_B = _make_grouped_tensor_quantized_mxfp8( - B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" - ) - A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) - B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) - - general_grouped_gemm( - A_fp8, - B_fp8, - out_ref, - [None] * z, - dtype, - m_splits=m_sizes, - grad=grad, - accumulate=accumulate, - layout=layout, - single_output=False, - ) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_result = ( + grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() + ) + for i in range(z): + if out_result[i].numel() == 0: + continue + if accumulate: + torch.testing.assert_close(out_result[i], out_before[i]) + else: + torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) + + + @staticmethod + def _make_grouped_tensor_quantized_mxfp8( + tensors: List[torch.Tensor], + *, + rowwise: bool, + columnwise: bool, + device: torch.device, + is_weight: bool = False, + ) -> GroupedTensor: + """Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors. + + For weights (uniform per-expert shape), we generally won't keep it swizzled since we + might need for future dequantize operations. Swizzling is done internally within + general_grouped_gemm_for_grouped_tensor call. + + For non-weight tensors (inputs / grad_outputs), we still pass + ``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the + already-swizzled layout up front. + """ + if not tensors: + raise ValueError("Expected non-empty tensor list for grouped quantization.") + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = not is_weight + grouped_input = torch.cat(tensors, dim=0) + if is_weight: + first_dims = None + else: + first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) + return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) + + + @staticmethod + def _per_tensor_quantize_mxfp8( + tensors: List[torch.Tensor], + *, + rowwise: bool, + columnwise: bool, + ) -> List: + """Quantize each tensor individually with MXFP8. + Used to build reference discrete inputs for grouped GEMM. + """ + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + return [quantizer(t) for t in tensors] - device = A[0].device - grouped_out = None - if case != "discrete_out": + @pytest.mark.parametrize( + "shape", + [ + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), + (2, 256, 2880, 2880), + ], + ) + @pytest.mark.parametrize("accumulate", [False, True]) + @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) + @pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_grouped_gemm_grouped_tensor_mxfp8( + self, + shape, accumulate, layout: str, case: str, dtype: torch.dtype + ) -> None: + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = shape + m_sizes = [m // z] * z + if layout == "TN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False elif layout == "NN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True else: # layout == "NT" - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_out, out) - - grouped_out_input = out if case == "discrete_out" else grouped_out - grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A - general_grouped_gemm_for_grouped_tensor( - grouped_A_input, - grouped_B, - grouped_out_input, - layout=layout, - accumulate=accumulate, - ) + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True - out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() - tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance + out_ref = [o.clone() for o in out] - for o, o_ref in zip(out_grouped, out_ref): - torch.testing.assert_close(o, o_ref, **tols) + transa = layout[0] == "T" + transb = layout[1] == "T" + a_is_weight = all(t.shape == A[0].shape for t in A) + a_rowwise, a_columnwise = transa, not transa + b_rowwise, b_columnwise = not transb, transb + grouped_A = self._make_grouped_tensor_quantized_mxfp8( + A, + rowwise=a_rowwise, + columnwise=a_columnwise, + device="cuda", + is_weight=a_is_weight, + ) + grouped_B = self._make_grouped_tensor_quantized_mxfp8( + B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" + ) + A_fp8 = self._per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) + B_fp8 = self._per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) + + general_grouped_gemm( + A_fp8, + B_fp8, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + device = A[0].device -@pytest.mark.parametrize( - "shape", - [ - (1, 128, 128, 512), - (8, 1024, 128, 512), - (16, 4096, 128, 512), - ], -) -@pytest.mark.parametrize("accumulate", [False, True]) -def test_fp8_grouped_gemm(shape, accumulate): - if not fp8_available: - pytest.skip(reason_for_no_fp8) - - z, m, k, n = shape - m_splits = [m // z] * z - - dtype = torch.bfloat16 - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output - out_ref = [o.clone() for o in out] - - # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() - amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - - a_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - b_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - - A_fp8 = [] - B_fp8 = [] - - for i in range(z): - A_fp8.append(a_quantizers[i](A[i])) - B_fp8.append(b_quantizers[i](B[i])) - - # baseline - for i in range(z): - general_gemm( - A_fp8[i], - B_fp8[i], - dtype, - out=out_ref[i], + grouped_out = None + if case != "discrete_out": + if layout == "TN": + grouped_out = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_out = self._make_grouped_tensor_uniform(z, n, k, device, dtype) + self._pack_grouped_tensor(grouped_out, out) + + grouped_out_input = out if case == "discrete_out" else grouped_out + grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A + general_grouped_gemm_for_grouped_tensor( + grouped_A_input, + grouped_B, + grouped_out_input, + layout=layout, accumulate=accumulate, ) - general_grouped_gemm( - A_fp8, - B_fp8, - out, - [None] * z, - dtype, - m_splits=m_splits, - accumulate=accumulate, - ) - # should be bit-wise match - for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() + tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) -# ============================================================================= -# te.ops.GroupedLinear (ops/fuser API) -# -# Tests the ops-API GroupedLinear and grouped MLP patterns. Reference: PyTorch -# F.linear per group via make_reference_and_test_tensors (float64 CPU). -# Tolerance: dtype_tols() / quantization_tols(). -# ============================================================================= + @pytest.mark.parametrize( + "shape", + [ + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), + ], + ) + @pytest.mark.parametrize("accumulate", [False, True]) + def test_fp8_grouped_gemm(self, shape, accumulate): + if not fp8_available: + pytest.skip(reason_for_no_fp8) -@pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) -def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( - swizzle_type: str, - num_tensors: int = 4, - shape: Sequence[int] = (160, 96), -): - """Helper function for preparing discrete weights for cuDNN group GEMM kernel""" + z, m, k, n = shape + m_splits = [m // z] * z - # Skip unsupported configurations - if not mxfp8_available and swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): - pytest.skip(reason_for_no_mxfp8) - if not nvfp4_available and swizzle_type == "nvfp4": - pytest.skip(reason_for_no_nvfp4) + dtype = torch.bfloat16 + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input + out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + out_ref = [o.clone() for o in out] - # Construct quantizer - quantizer = None - if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=swizzle_type == "mxfp8_rowwise", - columnwise=swizzle_type == "mxfp8_columnwise", - ) - elif swizzle_type == "nvfp4": - quantizer = NVFP4Quantizer( - columnwise=False, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=False, - stochastic_rounding=False, - with_random_sign_mask=False, - ) - - # Per-expert tensors: unquantized, quantized with compact scales, - # quantized with swizzled scales - device = torch.device("cuda") - unquantized_tensors = [ - torch.randn(shape, dtype=torch.bfloat16, device=device) for _ in range(num_tensors) - ] - quantizer.optimize_for_gemm = False - tensors_with_compact_scales = [quantizer(t) for t in unquantized_tensors] - quantizer.optimize_for_gemm = True - tensors_with_swizzled_scales = [quantizer(t) for t in unquantized_tensors] - - # Extract data and scale buffers - if swizzle_type in ("mxfp8_rowwise", "nvfp4"): - data_tensors = [qx._rowwise_data for qx in tensors_with_compact_scales] - scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_compact_scales] - ref_scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_swizzled_scales] - elif swizzle_type == "mxfp8_columnwise": - data_tensors = [qx._columnwise_data for qx in tensors_with_compact_scales] - scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_compact_scales] - ref_scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_swizzled_scales] - else: - raise ValueError("Unrecogized swizzle type") + # fp8 should be robust enough to this fake scale + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - # Call the helper function - data_ptrs, scale_ptrs, swizzled_scales_buffer = ( - tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( - data_tensors, - scale_tensors, - swizzle_type, - device, + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] + + A_fp8 = [] + B_fp8 = [] + + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) + + # baseline + for i in range(z): + general_gemm( + A_fp8[i], + B_fp8[i], + dtype, + out=out_ref[i], + accumulate=accumulate, + ) + general_grouped_gemm( + A_fp8, + B_fp8, + out, + [None] * z, + dtype, + m_splits=m_splits, + accumulate=accumulate, ) - ) - # Check data pointer values - expected_data_ptrs = torch.tensor( - [t.data_ptr() for t in data_tensors], - dtype=torch.int64, - device="cpu", - ) - assert_close(data_ptrs, expected_data_ptrs) - - # Check scale pointer values - scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() - expected_scale_ptrs = torch.tensor( - [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], - dtype=torch.int64, - device="cpu", - ) - assert_close(scale_ptrs, expected_scale_ptrs) + # should be bit-wise match + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - # Check swizzled scale values - swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8) - expected_swizzled_scales_buffer = ( - torch.cat(ref_scale_tensors).view(torch.uint8).view_as(swizzled_scales_buffer) - ) - assert_close( - swizzled_scales_buffer, - expected_swizzled_scales_buffer, - ) - # Poison the padded compact scales - if swizzle_type == "mxfp8_rowwise": - unpadded_scale_shape = (shape[0], shape[1] // 32) - elif swizzle_type == "mxfp8_columnwise": - unpadded_scale_shape = (shape[0] // 32, shape[1]) - elif swizzle_type == "nvfp4": - unpadded_scale_shape = (shape[0], shape[1] // 16) - for scale in scale_tensors: - scale[unpadded_scale_shape[0] :, :].view(torch.uint8).fill_(-1) - scale[:, unpadded_scale_shape[1] :].view(torch.uint8).fill_(-1) - - # Check that swizzling removes poisoned pad scales - _, _, swizzled_scales_buffer = ( - tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( - data_tensors, - scale_tensors, - swizzle_type, - device, - ) - ) - assert_close( - swizzled_scales_buffer, - expected_swizzled_scales_buffer, - ) +class TestGroupedLinearOps: + """Tests for te.ops.GroupedLinear (ops/fuser API).""" + @pytest.fixture(autouse=True) + def _set_single_param_env(self, monkeypatch): + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") -@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) -@pytest.mark.parametrize( - "quantization", - [None] + (["mxfp8"] if mxfp8_available else []), -) -@pytest.mark.parametrize("quantized_weight", (False, True)) -@pytest.mark.parametrize("bias", (False, True)) -@pytest.mark.parametrize("single_grouped_weight", (False, True)) -@pytest.mark.parametrize("single_grouped_bias", (False, True)) -@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) -def test_grouped_linear_cuda_graph_safe( - *, - dtype: torch.dtype, - quantization: Optional[str], - quantized_weight: bool, - bias: bool, - single_grouped_weight: bool, - single_grouped_bias: bool, - accumulate_into_main_grad: bool, - device: torch.device = "cuda", - group_size: int = 4, - in_features: int = 128, - out_features: int = 128, - split_alignment: int = 128, - token_padding: int = 256, -) -> None: - """GroupedLinear forward+backward should be CUDA graph capturable. - - Exercises the grouped-tensor / cublas-grouped-gemm path which uses - GPU-resident split offsets and is the only flow safe to capture. - """ - if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( - single_grouped_weight or single_grouped_bias + @pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) + def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( + self, + swizzle_type: str, + num_tensors: int = 4, + shape: Sequence[int] = (160, 96), ): - pytest.skip( - "single_grouped_weight/single_grouped_bias requires" - " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" - ) - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") - # Skip invalid configurations - if quantization is None and quantized_weight: - pytest.skip("quantized_weight requires a quantization recipe") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - - # Split sizes (statically pinned for graph capture) - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) - # Pad input tokens to validate the sync-free flow - in_shape = (split_sizes.sum().item() + token_padding, in_features) - out_shape = (in_shape[0], out_features) - - recipe = make_recipe(quantization) - with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): - op = te.ops.GroupedLinear( - group_size, - in_features, - out_features, - bias=bias, - device=device, - dtype=dtype, - accumulate_into_main_grad=accumulate_into_main_grad, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - ) + """Helper function for preparing discrete weights for cuDNN group GEMM kernel""" - def _weight_params() -> list[torch.nn.Parameter]: - if single_grouped_weight: - return [op.weight] - return [getattr(op, f"weight{i}") for i in range(group_size)] + # Skip unsupported configurations + if not mxfp8_available and swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + pytest.skip(reason_for_no_mxfp8) + if not nvfp4_available and swizzle_type == "nvfp4": + pytest.skip(reason_for_no_nvfp4) - def _bias_params() -> list[torch.nn.Parameter]: - if not bias: - return [] - if single_grouped_bias: - return [op.bias] - return [getattr(op, f"bias{i}") for i in range(group_size)] + # Construct quantizer + quantizer = None + if swizzle_type in ("mxfp8_rowwise", "mxfp8_columnwise"): + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=swizzle_type == "mxfp8_rowwise", + columnwise=swizzle_type == "mxfp8_columnwise", + ) + elif swizzle_type == "nvfp4": + quantizer = NVFP4Quantizer( + columnwise=False, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) - def _init_main_grads(value: float = 0.0) -> None: - if not accumulate_into_main_grad: - return - with torch.no_grad(): - for w in _weight_params(): - if getattr(w, "main_grad", None) is None: - w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) - w.main_grad.fill_(value) - - def _collect_main_grads() -> list[torch.Tensor]: - return [w.main_grad.detach().clone() for w in _weight_params()] - - def _zero_param_grads() -> None: - for param in op.parameters(): - if param.grad is None: - param.grad = torch.zeros_like(param) - else: - param.grad.zero_() + # Per-expert tensors: unquantized, quantized with compact scales, + # quantized with swizzled scales + device = torch.device("cuda") + unquantized_tensors = [ + torch.randn(shape, dtype=torch.bfloat16, device=device) for _ in range(num_tensors) + ] + quantizer.optimize_for_gemm = False + tensors_with_compact_scales = [quantizer(t) for t in unquantized_tensors] + quantizer.optimize_for_gemm = True + tensors_with_swizzled_scales = [quantizer(t) for t in unquantized_tensors] + + # Extract data and scale buffers + if swizzle_type in ("mxfp8_rowwise", "nvfp4"): + data_tensors = [qx._rowwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._rowwise_scale_inv for qx in tensors_with_swizzled_scales] + elif swizzle_type == "mxfp8_columnwise": + data_tensors = [qx._columnwise_data for qx in tensors_with_compact_scales] + scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_compact_scales] + ref_scale_tensors = [qx._columnwise_scale_inv for qx in tensors_with_swizzled_scales] + else: + raise ValueError("Unrecogized swizzle type") + + # Call the helper function + data_ptrs, scale_ptrs, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, + ) + ) - static_split_sizes = split_sizes.clone() + # Check data pointer values + expected_data_ptrs = torch.tensor( + [t.data_ptr() for t in data_tensors], + dtype=torch.int64, + device="cpu", + ) + assert_close(data_ptrs, expected_data_ptrs) + + # Check scale pointer values + scale_bytes = scale_tensors[0].numel() * scale_tensors[0].element_size() + expected_scale_ptrs = torch.tensor( + [swizzled_scales_buffer.data_ptr() + i * scale_bytes for i in range(num_tensors)], + dtype=torch.int64, + device="cpu", + ) + assert_close(scale_ptrs, expected_scale_ptrs) - def train_step( - x: torch.Tensor, - dy: torch.Tensor, - out_buf: torch.Tensor, - *, - use_graphed: bool, - ) -> torch.Tensor: - with te.autocast(enabled=quantization is not None, recipe=recipe): - out = ( - graphed_module(x, static_split_sizes) - if use_graphed - else op(x, static_split_sizes) + # Check swizzled scale values + swizzled_scales_buffer = swizzled_scales_buffer.view(torch.uint8) + expected_swizzled_scales_buffer = ( + torch.cat(ref_scale_tensors).view(torch.uint8).view_as(swizzled_scales_buffer) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) + + # Poison the padded compact scales + if swizzle_type == "mxfp8_rowwise": + unpadded_scale_shape = (shape[0], shape[1] // 32) + elif swizzle_type == "mxfp8_columnwise": + unpadded_scale_shape = (shape[0] // 32, shape[1]) + elif swizzle_type == "nvfp4": + unpadded_scale_shape = (shape[0], shape[1] // 16) + for scale in scale_tensors: + scale[unpadded_scale_shape[0] :, :].view(torch.uint8).fill_(-1) + scale[:, unpadded_scale_shape[1] :].view(torch.uint8).fill_(-1) + + # Check that swizzling removes poisoned pad scales + _, _, swizzled_scales_buffer = ( + tex.grouped_mlp_experimental.swizzle_scales_and_pack_ptrs_for_discrete_weights( + data_tensors, + scale_tensors, + swizzle_type, + device, ) - out.backward(dy) - out_buf.copy_(out) - return out_buf - - _init_main_grads(0.0) - - static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) - static_dy = torch.randn(out_shape, device=device, dtype=dtype) - static_out_buf = torch.empty(out_shape, device=device, dtype=dtype) - - graphed_module = te.make_graphed_callables( - op, - (static_x, static_split_sizes), - num_warmup_iters=3, - enabled=quantization is not None, - recipe=recipe, - ) + ) + assert_close( + swizzled_scales_buffer, + expected_swizzled_scales_buffer, + ) - # Replace static buffers with fresh data (graph captures must replay - # against new inputs without re-recording). - fresh_x = torch.randn_like(static_x) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_dy.copy_(fresh_dy) - - # Reset grads & main_grads so the captured iteration starts fresh. - _zero_param_grads() - _init_main_grads(0.5) - if static_x.grad is not None: - static_x.grad.zero_() - # Replay the graph - graph_out = ( - train_step(static_x, static_dy, static_out_buf, use_graphed=True).detach().clone() - ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - if accumulate_into_main_grad: - graph_main_grads = _collect_main_grads() - graph_param_grads: list[torch.Tensor] = [] - else: - graph_main_grads = [] - graph_param_grads = [param.grad.detach().clone() for param in op.parameters()] - - # Reference: same op invoked eagerly with the same fresh inputs and - # the same starting grad/main_grad state. - _zero_param_grads() - _init_main_grads(0.5) - static_x.grad.zero_() - - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with te.autocast(enabled=quantization is not None, recipe=recipe): - expected_out = op(expected_x, static_split_sizes) - expected_out.backward(expected_dy) - - tols = dtype_tols(dtype) - if quantization is not None: - tols = quantization_tols(quantization) - - assert_close(graph_out, expected_out, **tols) - assert_close(graph_dx, expected_x.grad, **tols) - if accumulate_into_main_grad: - for g, w in zip(graph_main_grads, _weight_params()): - assert_close(g, w.main_grad, **tols) - else: - for g, param in zip(graph_param_grads, op.parameters()): - assert_close(g, param.grad, **tols) - - -@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) -@pytest.mark.parametrize("single_grouped_weight", (False, True)) -@pytest.mark.parametrize("single_grouped_bias", (False, True)) -@pytest.mark.parametrize("bias", (False, True)) -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("quantization", _ops_quantization_list) -@pytest.mark.parametrize("quantized_compute", (False, True)) -@pytest.mark.parametrize("quantized_weight", (False, True)) -@pytest.mark.parametrize("input_requires_grad", (False, True)) -@pytest.mark.parametrize("weight_requires_grad", (False, True)) -def test_ops_grouped_linear( - *, - group_size: int = 4, - bias: bool, - weight_shape: tuple = (128, 128), - split_alignment: int = 128, - dtype: torch.dtype, - device: torch.device = "cuda", - quantization: Optional[str], - quantized_compute: bool, - quantized_weight: bool, - input_requires_grad: bool, - weight_requires_grad: bool, - delay_wgrad_compute: bool, - single_grouped_weight: bool, - single_grouped_bias: bool, -) -> None: - """te.ops.GroupedLinear forward+backward accuracy""" - - # Split sizes - split_sizes = [split_alignment * i for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) - - # Make input and weight shapes consistent - out_features, in_features = weight_shape - in_shape = (split_sizes.sum().item(), in_features) - out_shape = (in_shape[0], out_features) - - # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) - maybe_skip_quantization(quantization, dims=out_shape) - if quantization is None and (quantized_compute or quantized_weight): - pytest.skip("Quantization scheme is not specified") - if quantization is not None and not (quantized_compute or quantized_weight): - pytest.skip("Quantization scheme is not used") - if quantization is not None and dtype not in (torch.bfloat16, torch.float16): - pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - if ( - single_grouped_weight - and quantized_weight - and quantization in ("fp8_delayed_scaling", "fp8_current_scaling") - ): - pytest.skip( - "single_grouped_weight does not support FP8 delayed/current scaling " - "with quantized_model_init" - ) - - # Random data - x_ref, x_test = make_reference_and_test_tensors( - in_shape, - quantization=quantization, - test_dtype=dtype, - test_device=device, - requires_grad=input_requires_grad, + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) + @pytest.mark.parametrize( + "quantization", + [None] + (["mxfp8"] if mxfp8_available else []), ) - dy_ref, dy_test = make_reference_and_test_tensors( - out_shape, - quantization=quantization, - test_dtype=dtype, - test_device=device, - requires_grad=False, - ) - ws_ref, ws_test = [], [] - bs_ref, bs_test = [], [] - for _ in range(group_size): - w_ref, w_test = make_reference_and_test_tensors( - (out_features, in_features), - quantization=quantization, - test_dtype=dtype, - test_device=device, - quantizer_role=QuantizerRole(tensor_type="weight"), - requires_grad=weight_requires_grad, - ) - b_ref, b_test = None, None - if bias: - b_ref, b_test = make_reference_and_test_tensors( + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + def test_grouped_linear_cuda_graph_safe( + self, + *, + dtype: torch.dtype, + quantization: Optional[str], + quantized_weight: bool, + bias: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + group_size: int = 4, + in_features: int = 128, + out_features: int = 128, + split_alignment: int = 128, + token_padding: int = 256, + ) -> None: + """GroupedLinear forward+backward should be CUDA graph capturable. + + Exercises the grouped-tensor / cublas-grouped-gemm path which uses + GPU-resident split offsets and is the only flow safe to capture. + """ + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") + # Skip invalid configurations + if quantization is None and quantized_weight: + pytest.skip("quantized_weight requires a quantization recipe") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + + # Split sizes (statically pinned for graph capture) + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + # Pad input tokens to validate the sync-free flow + in_shape = (split_sizes.sum().item() + token_padding, in_features) + out_shape = (in_shape[0], out_features) + + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te.ops.GroupedLinear( + group_size, + in_features, out_features, - test_dtype=dtype, - test_device=device, - requires_grad=weight_requires_grad, + bias=bias, + device=device, + dtype=dtype, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, ) - ws_ref.append(w_ref) - ws_test.append(w_test) - bs_ref.append(b_ref) - bs_test.append(b_test) - - # Plain PyTorch reference implementation - xs_ref = torch.split(x_ref, split_sizes.tolist()) - ys_ref = [] - for x, w, b in zip(xs_ref, ws_ref, bs_ref): - ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) - y_ref = torch.cat(ys_ref) - if input_requires_grad or weight_requires_grad: - y_ref.backward(dy_ref) - # Construct te.ops.GroupedLinear - recipe = make_recipe(quantization) - with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): - op = te.ops.GroupedLinear( - group_size, - in_features, - out_features, - bias=bias, - device=device, - dtype=dtype, - delay_wgrad_compute=delay_wgrad_compute, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - ) - with torch.no_grad(): - if single_grouped_weight: - op_weights = op.weight.quantized_tensors - if op_weights is None: - op_weights = op.weight.split_into_quantized_tensors() - if single_grouped_bias: - op_bias_parts = op.bias.split_into_quantized_tensors() - for group_idx in range(group_size): + def _weight_params() -> list[torch.nn.Parameter]: if single_grouped_weight: - op_weights[group_idx].copy_(ws_test[group_idx]) - else: - getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) - if bias: - if single_grouped_bias: - op_bias_parts[group_idx].reshape(-1).copy_(bs_test[group_idx]) + return [op.weight] + return [getattr(op, f"weight{i}") for i in range(group_size)] + + def _bias_params() -> list[torch.nn.Parameter]: + if not bias: + return [] + if single_grouped_bias: + return [op.bias] + return [getattr(op, f"bias{i}") for i in range(group_size)] + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + for w in _weight_params(): + if getattr(w, "main_grad", None) is None: + w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) + w.main_grad.fill_(value) + + def _collect_main_grads() -> list[torch.Tensor]: + return [w.main_grad.detach().clone() for w in _weight_params()] + + def _zero_param_grads() -> None: + for param in op.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) else: - getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) - del ws_test, bs_test - for param in op.parameters(): - param.requires_grad_(requires_grad=weight_requires_grad) - - # Forward and backward pass - with te.autocast(enabled=quantized_compute, recipe=recipe): - y_test = op(x_test, split_sizes) - if input_requires_grad or weight_requires_grad: - y_test.backward(dy_test) - if delay_wgrad_compute and weight_requires_grad: - op.backward_dw() - - # Expected numerical tolerances - tols = dtype_tols(dtype) - if dtype == torch.float32: - tols = dtype_tols(torch.float16) # TF32 GEMM - if quantized_compute: - tols = quantization_tols(quantization) - - # Check results - assert_close(y_test, y_ref, **tols) - assert_close_grads(x_test, x_ref, **tols) - if single_grouped_weight: - if weight_requires_grad: - w_ref_grad = torch.stack([w.grad for w in ws_ref], dim=0) - assert_close(op.weight.grad, w_ref_grad, **tols) - else: - assert op.weight.grad is None - else: - for group_idx in range(group_size): - w_test = getattr(op, f"weight{group_idx}") - assert_close_grads(w_test, ws_ref[group_idx], **tols) - if bias: - if single_grouped_bias: - if weight_requires_grad: - b_ref_grad = torch.stack([b.grad for b in bs_ref], dim=0) - assert_close(op.bias.grad, b_ref_grad, **tols) - else: - assert op.bias.grad is None + param.grad.zero_() + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=quantization is not None, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes) + if use_graphed + else op(x, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(out_shape, device=device, dtype=dtype) + static_out_buf = torch.empty(out_shape, device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + op, + (static_x, static_split_sizes), + num_warmup_iters=3, + enabled=quantization is not None, + recipe=recipe, + ) + + # Replace static buffers with fresh data (graph captures must replay + # against new inputs without re-recording). + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + # Reset grads & main_grads so the captured iteration starts fresh. + _zero_param_grads() + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + + # Replay the graph + graph_out = ( + train_step(static_x, static_dy, static_out_buf, use_graphed=True).detach().clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + if accumulate_into_main_grad: + graph_main_grads = _collect_main_grads() + graph_param_grads: list[torch.Tensor] = [] else: - for group_idx in range(group_size): - b_test = getattr(op, f"bias{group_idx}") - assert_close_grads(b_test, bs_ref[group_idx], **tols) - - -@pytest.mark.parametrize("bias", (False, True)) -@pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) -@pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) -@pytest.mark.parametrize("glu_interleave_size", (None, 32)) -@pytest.mark.parametrize("single_grouped_weight", (False, True)) -@pytest.mark.parametrize("single_grouped_bias", (False, True)) -@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) -@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) -@pytest.mark.parametrize( - "activation", - ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom", "scaled_srelu"), -) -def test_grouped_mlp( - *, - group_size: int = 4, - hidden_size: int = 128, - bias: bool, - dtype: torch.dtype, - quantization: Optional[str], - single_grouped_weight: bool, - single_grouped_bias: bool, - accumulate_into_main_grad: bool, - device: torch.device = "cuda", - split_alignment: int = 256, - glu_interleave_size: Optional[int], - delay_wgrad_compute: bool, - activation: str, -) -> None: - """GroupedLinear + scaled activation + GroupedLinear""" - if dtype == torch.bfloat16 and not is_bf16_available(): - pytest.skip("BF16 not available") - - # Build activation op to determine GLU vs unary - if activation == "scaled_swiglu": - scaled_act_ref = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - elif activation.startswith("scaled_clamped_qgeglu"): - scaled_act_ref = te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - elif activation == "scaled_srelu": - scaled_act_ref = te.ops.ScaledSReLU() - else: - raise ValueError(f"Unexpected activation ({activation})") - activation_is_glu = is_glu_activation(scaled_act_ref) - - # Skip invalid configurations - with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device, dtype=dtype) - if single_grouped_weight and quantization != "mxfp8": - pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - if with_quantization and dtype not in (torch.bfloat16, torch.float16): - pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): - pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") - if not activation_is_glu and glu_interleave_size is not None: - pytest.skip("Unary activations do not use GLU interleaving") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: - pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") - if quantization == "nvfp4_rht": - if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): - pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") - if activation not in ("scaled_swiglu", "scaled_srelu"): - pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") - if ( - with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") - and activation.startswith("scaled_clamped_qgeglu") - and bias - ): - # TODO: ksivaman: Need to debug numerics for this case. - pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + graph_main_grads = [] + graph_param_grads = [param.grad.detach().clone() for param in op.parameters()] - fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size - if activation == "scaled_clamped_qgeglu_custom": - geglu_limit, geglu_alpha, geglu_offset = 5.0, 1.5, 0.5 - else: - geglu_limit, geglu_alpha, geglu_offset = 7.0, 1.702, 1.0 - - # Split sizes (one group intentionally empty to test the zero-token case) - split_sizes = [split_alignment * i for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) - in_shape = (split_sizes.sum().item(), hidden_size) - out_shape = in_shape - - # Reference tensors: float64 CPU; test tensors: target dtype on CUDA - x_ref, x_test = make_reference_and_test_tensors( - in_shape, - min=-0.25, max=0.25, - quantization=quantization, - test_dtype=dtype, - test_device=device, - ) - dy_ref, dy_test = make_reference_and_test_tensors( - out_shape, - min=-0.25, max=0.25, - test_dtype=dtype, - test_device=device, - requires_grad=False, - ) - probs_ref, probs_test = make_reference_and_test_tensors( - (in_shape[0],), - min=0.1, max=1.0, - test_dtype=dtype, - test_device=device, - ) + # Reference: same op invoked eagerly with the same fresh inputs and + # the same starting grad/main_grad state. + _zero_param_grads() + _init_main_grads(0.5) + static_x.grad.zero_() - fc1_ws_ref, fc1_ws_test = [], [] - fc1_bs_ref, fc1_bs_test = [], [] - fc2_ws_ref, fc2_ws_test = [], [] - fc2_bs_ref, fc2_bs_test = [], [] - for _ in range(group_size): - w1_ref, w1_test = make_reference_and_test_tensors( - (fc1_out_features, hidden_size), - min=-0.25, max=0.25, + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=quantization is not None, recipe=recipe): + expected_out = op(expected_x, static_split_sizes) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + if quantization is not None: + tols = quantization_tols(quantization) + + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + if accumulate_into_main_grad: + for g, w in zip(graph_main_grads, _weight_params()): + assert_close(g, w.main_grad, **tols) + else: + for g, param in zip(graph_param_grads, op.parameters()): + assert_close(g, param.grad, **tols) + + + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("quantization", _ops_quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_ops_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, + delay_wgrad_compute: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, + ) -> None: + """te.ops.GroupedLinear forward+backward accuracy""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if ( + single_grouped_weight + and quantized_weight + and quantization in ("fp8_delayed_scaling", "fp8_current_scaling") + ): + pytest.skip( + "single_grouped_weight does not support FP8 delayed/current scaling " + "with quantized_model_init" + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, quantization=quantization, test_dtype=dtype, test_device=device, + requires_grad=input_requires_grad, ) - fc1_ws_ref.append(w1_ref) - fc1_ws_test.append(w1_test) - w2_ref, w2_test = make_reference_and_test_tensors( - (hidden_size, hidden_size), - min=-0.25, max=0.25, + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, quantization=quantization, test_dtype=dtype, test_device=device, + requires_grad=False, ) - fc2_ws_ref.append(w2_ref) - fc2_ws_test.append(w2_test) - if bias: - b1_ref, b1_test = make_reference_and_test_tensors( - (fc1_out_features,), - min=-0.5, max=0.5, - test_dtype=dtype, - test_device=device, - ) - fc1_bs_ref.append(b1_ref) - fc1_bs_test.append(b1_test) - b2_ref, b2_test = make_reference_and_test_tensors( - (hidden_size,), - min=-0.5, max=0.5, + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), + requires_grad=weight_requires_grad, ) - fc2_bs_ref.append(b2_ref) - fc2_bs_test.append(b2_test) - else: - fc1_bs_ref.append(None) - fc1_bs_test.append(None) - fc2_bs_ref.append(None) - fc2_bs_test.append(None) - - def _apply_activation(x: torch.Tensor) -> torch.Tensor: - if activation_is_glu and glu_interleave_size is not None: - x = x.reshape(-1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size) - x = x.transpose(1, 2).reshape(-1, 2 * hidden_size) - if activation == "scaled_swiglu": - x1, x2 = x.chunk(2, dim=-1) - return torch.nn.functional.silu(x1) * x2 - if activation.startswith("scaled_clamped_qgeglu"): - x1, x2 = x.chunk(2, dim=-1) - lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) - x1c = torch.minimum(x1, lim) - x2c = torch.clamp(x2, -lim, lim) - return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) - if activation == "scaled_srelu": - return torch.nn.functional.relu(x).square() - raise ValueError(f"Unexpected activation ({activation})") - - # Reference implementation (float64 CPU PyTorch) - xs = torch.split(x_ref, split_sizes.tolist()) - probs = torch.split(probs_ref, split_sizes.tolist()) - ys = [] - for group_idx in range(group_size): - x = xs[group_idx] - fc1_out = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) - fc2_in = _apply_activation(fc1_out) * probs[group_idx].unsqueeze(-1) - y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) - if bias: - y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) - ys.append(y) - y_ref = torch.cat(ys) - y_ref.backward(dy_ref) - - # Construct TE module - recipe = make_recipe(quantization) - - def _make_scaled_act(): - if activation == "scaled_swiglu": - return te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_clamped_qgeglu_custom": - return te.ops.ScaledClampedQGeGLU( - glu_interleave_size=glu_interleave_size, - limit=geglu_limit, - alpha=geglu_alpha, - glu_linear_offset=geglu_offset, + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch reference implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct te.ops.GroupedLinear + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te.ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + delay_wgrad_compute=delay_wgrad_compute, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, ) - if activation.startswith("scaled_clamped_qgeglu"): - return te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_srelu": - return te.ops.ScaledSReLU() - raise ValueError(f"Unexpected activation ({activation})") - - with te.quantized_model_init(enabled=with_quantization, recipe=recipe): - fc1 = te.ops.GroupedLinear( - group_size, hidden_size, fc1_out_features, - bias=bias, device=device, dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) - fc2 = te.ops.GroupedLinear( - group_size, hidden_size, hidden_size, - bias=bias, device=device, dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - scale_bias=bias, - ) - module = te.ops.Sequential(fc1, _make_scaled_act(), fc2) - - # Copy weights - with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() - for group_idx in range(group_size): + with torch.no_grad(): if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) - if bias: - if single_grouped_bias: - fc1_bparts = fc1.bias.split_into_quantized_tensors() - fc2_bparts = fc2.bias.split_into_quantized_tensors() - fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) - fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + op_weights = op.weight.quantized_tensors + if op_weights is None: + op_weights = op.weight.split_into_quantized_tensors() + if single_grouped_bias: + op_bias_parts = op.bias.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + op_weights[group_idx].copy_(ws_test[group_idx]) else: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) - if accumulate_into_main_grad: - main_grad_sentinel = 0.5 - if single_grouped_weight: - weight_params_for_main_grad = [fc1.weight, fc2.weight] + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + if single_grouped_bias: + op_bias_parts[group_idx].reshape(-1).copy_(bs_test[group_idx]) + else: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + if delay_wgrad_compute and weight_requires_grad: + op.backward_dw() + + # Expected numerical tolerances + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + if single_grouped_weight: + if weight_requires_grad: + w_ref_grad = torch.stack([w.grad for w in ws_ref], dim=0) + assert_close(op.weight.grad, w_ref_grad, **tols) else: - weight_params_for_main_grad = [ - getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) - ] - MegatronTrainingHelper.init_main_grad_buffers( - weight_params_for_main_grad, - fill_value=main_grad_sentinel, - overwrite_main_grad=False, - ) - del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test - - # Forward and backward pass - with te.autocast(enabled=with_quantization, recipe=recipe): - fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) - y_test = module(x_test, split_sizes, probs_test, *fc2_extra) - y_test.backward(dy_test) - if delay_wgrad_compute: - fc1.backward_dw() - fc2.backward_dw() - - # Check for expected fusions - cudnn_frontend_supports_grouped_mlp = ( - _cudnn_frontend_supports_grouped_gemm_srelu() - if activation == "scaled_srelu" - else _cudnn_frontend_version_supported() - ) - expected_grouped_mlp_fusion = cudnn_frontend_supports_grouped_mlp and ( - ( - quantization == "mxfp8" - and dtype in (torch.bfloat16, torch.float16) - and ( - (not activation_is_glu and glu_interleave_size is None) - or (activation_is_glu and glu_interleave_size == 32) - ) - ) - or ( - quantization == "nvfp4_rht" - and dtype == torch.bfloat16 - and activation == "scaled_srelu" - and glu_interleave_size is None - ) - ) - if expected_grouped_mlp_fusion: - if activation_is_glu: - forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU - backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU + assert op.weight.grad is None else: - forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMUnary - backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary - if forward_cls.is_supported(): - forward_ops = module._module_groups[0]._forward_ops - assert len(forward_ops) == 1 - assert isinstance(forward_ops[0][0], forward_cls) - if backward_cls is not None and backward_cls.is_supported(): - backward_ops = module._module_groups[0]._backward_ops - assert len(backward_ops) == 1 - assert isinstance(backward_ops[0][0], backward_cls) - - # Loose tols for sanity checking - tols = {"rtol": 0.125, "atol": 0.25} - - # Check values - assert_close(y_test, y_ref, **tols) - assert_close_grads(x_test, x_ref, **tols) - assert_close_grads(probs_test, probs_ref, **tols) - for group_idx in range(group_size): + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + assert_close_grads(w_test, ws_ref[group_idx], **tols) if bias: if single_grouped_bias: - assert_close(fc2.bias.grad[group_idx], fc2_bs_ref[group_idx].grad, **tols) - assert_close(fc1.bias.grad[group_idx], fc1_bs_ref[group_idx].grad, **tols) + if weight_requires_grad: + b_ref_grad = torch.stack([b.grad for b in bs_ref], dim=0) + assert_close(op.bias.grad, b_ref_grad, **tols) + else: + assert op.bias.grad is None else: - assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) - if not single_grouped_weight and not accumulate_into_main_grad: - assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) - fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) - fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) - if accumulate_into_main_grad: - fc1_expected = ( - [fc1_w_ref_grad + main_grad_sentinel] - if single_grouped_weight - else [g + main_grad_sentinel for g in fc1_w_ref_grad] - ) - fc2_expected = ( - [fc2_w_ref_grad + main_grad_sentinel] - if single_grouped_weight - else [g + main_grad_sentinel for g in fc2_w_ref_grad] + for group_idx in range(group_size): + b_test = getattr(op, f"bias{group_idx}") + assert_close_grads(b_test, bs_ref[group_idx], **tols) + + +class TestGroupedMLP: + """Tests for grouped MLP patterns (te.ops.GroupedLinear + activation).""" + + @pytest.fixture(autouse=True) + def _set_envvars(self, monkeypatch): + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + monkeypatch.setenv("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") + + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) + @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("hidden_size", (128, 256)) + @pytest.mark.parametrize( + "activation", + ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom", "scaled_srelu"), + ) + def test_grouped_mlp( + self, + *, + group_size: int = 4, + hidden_size: int, + bias: bool, + dtype: torch.dtype, + quantization: Optional[str], + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + split_alignment: int = 256, + glu_interleave_size: Optional[int], + delay_wgrad_compute: bool, + activation: str, + ) -> None: + """GroupedLinear + scaled activation + GroupedLinear""" + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("BF16 not available") + + # Build activation op to determine GLU vs unary + if activation == "scaled_swiglu": + scaled_act_ref = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + elif activation.startswith("scaled_clamped_qgeglu"): + scaled_act_ref = te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + elif activation == "scaled_srelu": + scaled_act_ref = te.ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected activation ({activation})") + activation_is_glu = is_glu_activation(scaled_act_ref) + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device, dtype=dtype) + if single_grouped_weight and quantization != "mxfp8": + pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): + pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") + if not activation_is_glu and glu_interleave_size is not None: + pytest.skip("Unary activations do not use GLU interleaving") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: + pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") + if quantization == "nvfp4_rht": + if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): + pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") + if activation not in ("scaled_swiglu", "scaled_srelu"): + pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and activation.startswith("scaled_clamped_qgeglu") + and bias + ): + # TODO: ksivaman: Need to debug numerics for this case. + pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + + fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size + if activation == "scaled_clamped_qgeglu_custom": + geglu_limit, geglu_alpha, geglu_offset = 5.0, 1.5, 0.5 + else: + geglu_limit, geglu_alpha, geglu_offset = 7.0, 1.702, 1.0 + + # Split sizes (one group intentionally empty to test the zero-token case) + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Reference tensors: float64 CPU; test tensors: target dtype on CUDA + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + min=-0.25, max=0.25, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + min=-0.25, max=0.25, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0],), + min=0.1, max=1.0, + test_dtype=dtype, + test_device=device, ) - MegatronTrainingHelper.verify_main_grad_accumulation( - weight_params_for_main_grad, - expected_main_grads=fc1_expected + fc2_expected, - **tols, - ) - elif single_grouped_weight: - assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) - assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) - - -@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) -@pytest.mark.parametrize("bias", (False, True)) -@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) -@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -def test_grouped_mlp_single_weight_numerics( - *, - dtype: torch.dtype, - bias: bool, - activation: str, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, -) -> None: - """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" - - if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - in_shape = (split_sizes.sum().item(), hidden_size) - recipe = make_recipe("mxfp8") - - x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) - dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - fc1_ws_base = [ - torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc2_ws_base = [ - torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc1_bs_base = ( - [ - torch.empty((2 * hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) - for _ in range(group_size) - ] - if bias - else None - ) - fc2_bs_base = ( - [ - torch.empty((hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) - for _ in range(group_size) - ] - if bias - else None - ) - def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: - with te.quantized_model_init(enabled=True, recipe=recipe): - scaled_act = ( - te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] + for _ in range(group_size): + w1_ref, w1_test = make_reference_and_test_tensors( + (fc1_out_features, hidden_size), + min=-0.125, max=0.125, + quantization=quantization, + test_dtype=dtype, + test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), + ) + fc1_ws_ref.append(w1_ref) + fc1_ws_test.append(w1_test) + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + min=-0.125, max=0.125, + quantization=quantization, + test_dtype=dtype, + test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) + fc2_ws_ref.append(w2_ref) + fc2_ws_test.append(w2_test) + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + (fc1_out_features,), + min=-0.5, max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc1_bs_ref.append(b1_ref) + fc1_bs_test.append(b1_test) + b2_ref, b2_test = make_reference_and_test_tensors( + (hidden_size,), + min=-0.5, max=0.5, + test_dtype=dtype, + test_device=device, + ) + fc2_bs_ref.append(b2_ref) + fc2_bs_test.append(b2_test) + else: + fc1_bs_ref.append(None) + fc1_bs_test.append(None) + fc2_bs_ref.append(None) + fc2_bs_test.append(None) + + def _apply_activation(x: torch.Tensor) -> torch.Tensor: + if activation_is_glu and glu_interleave_size is not None: + x = x.reshape(-1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size) + x = x.transpose(1, 2).reshape(-1, 2 * hidden_size) + if activation == "scaled_swiglu": + x1, x2 = x.chunk(2, dim=-1) + return torch.nn.functional.silu(x1) * x2 + if activation.startswith("scaled_clamped_qgeglu"): + x1, x2 = x.chunk(2, dim=-1) + lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) + x1c = torch.minimum(x1, lim) + x2c = torch.clamp(x2, -lim, lim) + return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + if activation == "scaled_srelu": + return torch.nn.functional.relu(x).square() + raise ValueError(f"Unexpected activation ({activation})") + + # Reference implementation (float64 CPU PyTorch) + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + fc1_out = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + fc2_in = _apply_activation(fc1_out) * probs[group_idx].unsqueeze(-1) + y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) + if bias: + y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) + ys.append(y) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct TE module + recipe = make_recipe(quantization) + + def _make_scaled_act(): + if activation == "scaled_swiglu": + return te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_clamped_qgeglu_custom": + return te.ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=geglu_limit, + alpha=geglu_alpha, + glu_linear_offset=geglu_offset, + ) + if activation.startswith("scaled_clamped_qgeglu"): + return te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_srelu": + return te.ops.ScaledSReLU() + raise ValueError(f"Unexpected activation ({activation})") + + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te.ops.GroupedLinear( - group_size, - hidden_size, - 2 * hidden_size, - bias=bias, - device=device, - dtype=dtype, + group_size, hidden_size, fc1_out_features, + bias=bias, device=device, dtype=dtype, single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, ) fc2 = te.ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=bias, - device=device, - dtype=dtype, + group_size, hidden_size, hidden_size, + bias=bias, device=device, dtype=dtype, single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, scale_bias=bias, ) - module = te.ops.Sequential(fc1, scaled_act, fc2) + module = te.ops.Sequential(fc1, _make_scaled_act(), fc2) + # Copy weights with torch.no_grad(): if single_grouped_weight: fc1_weights = fc1.weight.quantized_tensors @@ -2663,174 +2463,512 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: fc2_weights = fc2.weight.split_into_quantized_tensors() for group_idx in range(group_size): if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) if bias: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_base[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_base[group_idx]) + if single_grouped_bias: + fc1_bparts = fc1.bias.split_into_quantized_tensors() + fc2_bparts = fc2.bias.split_into_quantized_tensors() + fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) + fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) + else: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + if accumulate_into_main_grad: + main_grad_sentinel = 0.5 + if single_grouped_weight: + weight_params_for_main_grad = [fc1.weight, fc2.weight] + else: + weight_params_for_main_grad = [ + getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) + ] + MegatronTrainingHelper.init_main_grad_buffers( + weight_params_for_main_grad, + fill_value=main_grad_sentinel, + overwrite_main_grad=False, + ) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test - x = x_base.detach().clone().requires_grad_(True) - probs = probs_base.detach().clone().requires_grad_(True) - dy = dy_base.detach().clone() + # Forward and backward pass + with te.autocast(enabled=with_quantization, recipe=recipe): + fc2_extra = (split_sizes, probs_test) if bias else (split_sizes,) + y_test = module(x_test, split_sizes, probs_test, *fc2_extra) + y_test.backward(dy_test) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() - with te.autocast(enabled=True, recipe=recipe): - fc2_extra = (split_sizes, probs) if bias else (split_sizes,) - y = module(x, split_sizes, probs, *fc2_extra) - y.backward(dy) + # Check for expected fusions + cudnn_frontend_supports_grouped_mlp = ( + _cudnn_frontend_supports_grouped_gemm_srelu() + if activation == "scaled_srelu" + else _cudnn_frontend_version_supported() + ) + expected_grouped_mlp_fusion = cudnn_frontend_supports_grouped_mlp and ( + ( + quantization == "mxfp8" + and dtype in (torch.bfloat16, torch.float16) + and ( + (not activation_is_glu and glu_interleave_size is None) + or (activation_is_glu and glu_interleave_size == 32) + ) + ) + or ( + quantization == "nvfp4_rht" + and dtype == torch.bfloat16 + and activation == "scaled_srelu" + and glu_interleave_size is None + ) + ) + if expected_grouped_mlp_fusion: + if activation_is_glu: + forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU + backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU + else: + forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMUnary + backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary + if forward_cls.is_supported(): + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], forward_cls) + if backward_cls is not None and backward_cls.is_supported(): + backward_ops = module._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], backward_cls) + + # Loose tols for sanity checking + tols = {"rtol": 0.125, "atol": 0.25} + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): + tols = {"rtol": 0.25, "atol": 0.5} + + # Check values + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(probs_test, probs_ref, **tols) + for group_idx in range(group_size): + if bias: + if single_grouped_bias: + assert_close(fc2.bias.grad[group_idx], fc2_bs_ref[group_idx].grad, **tols) + assert_close(fc1.bias.grad[group_idx], fc1_bs_ref[group_idx].grad, **tols) + else: + assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + if not single_grouped_weight and not accumulate_into_main_grad: + assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) + fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) + fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) + if accumulate_into_main_grad: + fc1_expected = ( + [fc1_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc1_w_ref_grad] + ) + fc2_expected = ( + [fc2_w_ref_grad + main_grad_sentinel] + if single_grouped_weight + else [g + main_grad_sentinel for g in fc2_w_ref_grad] + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + weight_params_for_main_grad, + expected_main_grads=fc1_expected + fc2_expected, + **tols, + ) + elif single_grouped_weight: + assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) + assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) - forward_ops = module._module_groups[0]._forward_ops - backward_ops = module._module_groups[0]._backward_ops - assert len(forward_ops) == 1 - assert isinstance( - forward_ops[0][0], - te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, + + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_single_weight_numerics( + self, + *, + dtype: torch.dtype, + bias: bool, + activation: str, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + ) -> None: + """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") + if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + recipe = make_recipe("mxfp8") + + x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) + dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + fc1_ws_base = [ + torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc2_ws_base = [ + torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 + ) + for _ in range(group_size) + ] + fc1_bs_base = ( + [ + torch.empty((2 * hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) + for _ in range(group_size) + ] + if bias + else None ) - assert len(backward_ops) == 1 - assert isinstance( - backward_ops[0][0], - te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, + fc2_bs_base = ( + [ + torch.empty((hidden_size,), device=device, dtype=dtype).uniform_(-0.5, 0.5) + for _ in range(group_size) + ] + if bias + else None ) - if single_grouped_weight: - fc1_dw = fc1.weight.grad.detach().clone() - fc2_dw = fc2.weight.grad.detach().clone() - else: - fc1_dw = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: + with te.quantized_model_init(enabled=True, recipe=recipe): + scaled_act = ( + te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) + fc1 = te.ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + ) + fc2 = te.ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + scale_bias=bias, + ) + module = te.ops.Sequential(fc1, scaled_act, fc2) + + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() + for group_idx in range(group_size): + if single_grouped_weight: + fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_base[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_base[group_idx]) + + x = x_base.detach().clone().requires_grad_(True) + probs = probs_base.detach().clone().requires_grad_(True) + dy = dy_base.detach().clone() + + with te.autocast(enabled=True, recipe=recipe): + fc2_extra = (split_sizes, probs) if bias else (split_sizes,) + y = module(x, split_sizes, probs, *fc2_extra) + y.backward(dy) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) - fc2_dw = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, + ) + + if single_grouped_weight: + fc1_dw = fc1.weight.grad.detach().clone() + fc2_dw = fc2.weight.grad.detach().clone() + else: + fc1_dw = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_dw = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + + fc1_db = None + fc2_db = None + if bias: + fc1_db = torch.stack( + [ + getattr(fc1, f"bias{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_db = torch.stack( + [ + getattr(fc2, f"bias{group_idx}").grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + + return ( + y.detach().clone(), + x.grad.detach().clone(), + probs.grad.detach().clone(), + fc1_dw, + fc2_dw, + fc1_db, + fc2_db, ) - fc1_db = None - fc2_db = None + ( + y_false, + dx_false, + dprobs_false, + fc1_dw_false, + fc2_dw_false, + fc1_db_false, + fc2_db_false, + ) = _run_case(False) + ( + y_true, + dx_true, + dprobs_true, + fc1_dw_true, + fc2_dw_true, + fc1_db_true, + fc2_db_true, + ) = _run_case(True) + + torch.testing.assert_close(y_false, y_true, rtol=0, atol=0) + torch.testing.assert_close(dx_false, dx_true, rtol=0, atol=0) + torch.testing.assert_close(dprobs_false, dprobs_true, rtol=0, atol=0) + torch.testing.assert_close(fc1_dw_false, fc1_dw_true, rtol=0, atol=0) + torch.testing.assert_close(fc2_dw_false, fc2_dw_true, rtol=0, atol=0) if bias: - fc1_db = torch.stack( - [ - getattr(fc1, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + bias_tols = {"rtol": 0.05, "atol": 0.015625} + torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) + torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) + + + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("zero_out_wgrad", (False, True)) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_overwrite_main_grad( + self, + *, + single_grouped_weight: bool, + delay_wgrad_compute: bool, + zero_out_wgrad: bool, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + ) -> None: + """End-to-end check that the fused grouped-MLP backward writes the + wgrad into ``weight.main_grad`` correctly under the MegatronFSDP + ``overwrite_main_grad=True`` convention. + ``test_grouped_mlp`` already covers the standard Megatron-LM + ``fuse_wgrad_accumulation`` (DDP) path where the wgrad GEMM + *accumulates* into ``main_grad``. This test focuses exclusively on + the MegatronFSDP variant where the wgrad GEMM must *overwrite* + ``main_grad`` (because FSDP has already ReduceScattered the previous + accumulation), so ``main_grad`` after backward equals ``wgrad`` + regardless of the prior contents. + + Also exercises the MegatronFSDP ``zero_out_wgrad`` flag, which is + independent of ``main_grad`` and only controls whether the dummy + ``param.grad`` returned to autograd is zeroed (so downstream hooks + that read ``.grad`` don't see stale bytes from the cached dummy). + """ + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") + if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") + + recipe = make_recipe("mxfp8") + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) + dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) + fc1_ws_base = [ + torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 ) - fc2_db = torch.stack( - [ - getattr(fc2, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + for _ in range(group_size) + ] + fc2_ws_base = [ + torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( + -0.25, 0.25 ) + for _ in range(group_size) + ] + + def _build_module(*, accumulate_into_main_grad: bool): + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te.ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + fc2 = te.ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + scaled_act = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + module = te.ops.Sequential(fc1, scaled_act, fc2) + + with torch.no_grad(): + if single_grouped_weight: + fc1_weights = ( + fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() + ) + fc2_weights = ( + fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() + ) + for group_idx in range(group_size): + fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) + else: + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + return module, fc1, fc2 + + def _weight_params(fc): + if single_grouped_weight: + return [fc.weight] + return [getattr(fc, f"weight{i}") for i in range(group_size)] + + def _run_backward(module, fc1, fc2): + x = x_base.detach().clone().requires_grad_(True) + probs = probs_base.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=recipe): + y = module(x, split_sizes, probs, split_sizes) + y.backward(dy_base) + if delay_wgrad_compute: + fc1.backward_dw() + fc2.backward_dw() + + # Reference run: vanilla autograd, no Megatron protocol. + ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) + _run_backward(ref_module, ref_fc1, ref_fc2) + ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] + ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] + + # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). + # NaN sentinel makes a missed write loud (would surface as NaN diff). + test_module, test_fc1, test_fc2 = _build_module(accumulate_into_main_grad=True) + for fc in (test_fc1, test_fc2): + MegatronTrainingHelper.init_main_grad_buffers( + _weight_params(fc), + fill_value=float("nan"), + overwrite_main_grad=True, + zero_out_wgrad=zero_out_wgrad, + ) + _run_backward(test_module, test_fc1, test_fc2) + + # main_grad must be overwritten to exactly the ref wgrad (bitwise: + # the wgrad GEMM is deterministic across the two runs because the + # quantized weights and inputs are identical). + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc1), expected_main_grads=ref_fc1_grads + ) + MegatronTrainingHelper.verify_main_grad_accumulation( + _weight_params(test_fc2), expected_main_grads=ref_fc2_grads + ) - return ( - y.detach().clone(), - x.grad.detach().clone(), - probs.grad.detach().clone(), - fc1_dw, - fc2_dw, - fc1_db, - fc2_db, - ) - - ( - y_false, - dx_false, - dprobs_false, - fc1_dw_false, - fc2_dw_false, - fc1_db_false, - fc2_db_false, - ) = _run_case(False) - ( - y_true, - dx_true, - dprobs_true, - fc1_dw_true, - fc2_dw_true, - fc1_db_true, - fc2_db_true, - ) = _run_case(True) - - torch.testing.assert_close(y_false, y_true, rtol=0, atol=0) - torch.testing.assert_close(dx_false, dx_true, rtol=0, atol=0) - torch.testing.assert_close(dprobs_false, dprobs_true, rtol=0, atol=0) - torch.testing.assert_close(fc1_dw_false, fc1_dw_true, rtol=0, atol=0) - torch.testing.assert_close(fc2_dw_false, fc2_dw_true, rtol=0, atol=0) - if bias: - bias_tols = {"rtol": 0.05, "atol": 0.015625} - torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) - torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) - - -@pytest.mark.parametrize("single_grouped_weight", (False, True)) -@pytest.mark.parametrize("delay_wgrad_compute", (False, True)) -@pytest.mark.parametrize("zero_out_wgrad", (False, True)) -@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -def test_grouped_mlp_overwrite_main_grad( - *, - single_grouped_weight: bool, - delay_wgrad_compute: bool, - zero_out_wgrad: bool, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, -) -> None: - """End-to-end check that the fused grouped-MLP backward writes the - wgrad into ``weight.main_grad`` correctly under the MegatronFSDP - ``overwrite_main_grad=True`` convention. - ``test_grouped_mlp`` already covers the standard Megatron-LM - ``fuse_wgrad_accumulation`` (DDP) path where the wgrad GEMM - *accumulates* into ``main_grad``. This test focuses exclusively on - the MegatronFSDP variant where the wgrad GEMM must *overwrite* - ``main_grad`` (because FSDP has already ReduceScattered the previous - accumulation), so ``main_grad`` after backward equals ``wgrad`` - regardless of the prior contents. - - Also exercises the MegatronFSDP ``zero_out_wgrad`` flag, which is - independent of ``main_grad`` and only controls whether the dummy - ``param.grad`` returned to autograd is zeroed (so downstream hooks - that read ``.grad`` don't see stale bytes from the cached dummy). - """ - if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - - recipe = make_recipe("mxfp8") - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - in_shape = (split_sizes.sum().item(), hidden_size) - x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) - dy_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) - fc1_ws_base = [ - torch.empty((2 * hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - fc2_ws_base = [ - torch.empty((hidden_size, hidden_size), device=device, dtype=dtype).uniform_( - -0.25, 0.25 - ) - for _ in range(group_size) - ] - - def _build_module(*, accumulate_into_main_grad: bool): + @pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_cuda_graph_safe_mxfp8( + self, + *, + dtype: torch.dtype, + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + activation: str, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + token_padding: int = 2048, + ) -> None: + """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" + + if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") + if dtype not in (torch.bfloat16, torch.float16): + pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + # Pad the input tokens to validate the sync-free MOE + in_shape = (split_sizes.sum().item() + token_padding, hidden_size) + recipe = make_recipe("mxfp8") with te.quantized_model_init(enabled=True, recipe=recipe): fc1 = te.ops.GroupedLinear( group_size, @@ -2841,7 +2979,6 @@ def _build_module(*, accumulate_into_main_grad: bool): dtype=dtype, single_grouped_weight=single_grouped_weight, accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, ) fc2 = te.ops.GroupedLinear( group_size, @@ -2852,438 +2989,321 @@ def _build_module(*, accumulate_into_main_grad: bool): dtype=dtype, single_grouped_weight=single_grouped_weight, accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, ) - scaled_act = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - module = te.ops.Sequential(fc1, scaled_act, fc2) + scaled_act = ( + te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_swiglu" + else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + ) + module = te.ops.Sequential( + fc1, + scaled_act, + fc2, + ) - with torch.no_grad(): + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + if single_grouped_weight: + if getattr(fc1.weight, "main_grad", None) is None: + fc1.weight.main_grad = torch.empty( + fc1.weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2.weight, "main_grad", None) is None: + fc2.weight.main_grad = torch.empty( + fc2.weight.size(), + device=device, + dtype=torch.float32, + ) + fc1.weight.main_grad.fill_(value) + fc2.weight.main_grad.fill_(value) + else: + for group_idx in range(group_size): + fc1_weight = getattr(fc1, f"weight{group_idx}") + fc2_weight = getattr(fc2, f"weight{group_idx}") + if getattr(fc1_weight, "main_grad", None) is None: + fc1_weight.main_grad = torch.empty( + fc1_weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2_weight, "main_grad", None) is None: + fc2_weight.main_grad = torch.empty( + fc2_weight.size(), + device=device, + dtype=torch.float32, + ) + fc1_weight.main_grad.fill_(value) + fc2_weight.main_grad.fill_(value) + + def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: if single_grouped_weight: - fc1_weights = ( - fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() + fc1_main_grad = fc1.weight.main_grad.detach().clone() + fc2_main_grad = fc2.weight.main_grad.detach().clone() + else: + fc1_main_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, ) - fc2_weights = ( - fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() + fc2_main_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, ) - for group_idx in range(group_size): - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) - else: - for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) - return module, fc1, fc2 - - def _weight_params(fc): - if single_grouped_weight: - return [fc.weight] - return [getattr(fc, f"weight{i}") for i in range(group_size)] - - def _run_backward(module, fc1, fc2): - x = x_base.detach().clone().requires_grad_(True) - probs = probs_base.detach().clone().requires_grad_(True) - with te.autocast(enabled=True, recipe=recipe): - y = module(x, split_sizes, probs, split_sizes) - y.backward(dy_base) - if delay_wgrad_compute: - fc1.backward_dw() - fc2.backward_dw() - - # Reference run: vanilla autograd, no Megatron protocol. - ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) - _run_backward(ref_module, ref_fc1, ref_fc2) - ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] - ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] - - # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). - # NaN sentinel makes a missed write loud (would surface as NaN diff). - test_module, test_fc1, test_fc2 = _build_module(accumulate_into_main_grad=True) - for fc in (test_fc1, test_fc2): - MegatronTrainingHelper.init_main_grad_buffers( - _weight_params(fc), - fill_value=float("nan"), - overwrite_main_grad=True, - zero_out_wgrad=zero_out_wgrad, - ) - _run_backward(test_module, test_fc1, test_fc2) - - # main_grad must be overwritten to exactly the ref wgrad (bitwise: - # the wgrad GEMM is deterministic across the two runs because the - # quantized weights and inputs are identical). - MegatronTrainingHelper.verify_main_grad_accumulation( - _weight_params(test_fc1), expected_main_grads=ref_fc1_grads - ) - MegatronTrainingHelper.verify_main_grad_accumulation( - _weight_params(test_fc2), expected_main_grads=ref_fc2_grads - ) - - -@pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) -@pytest.mark.parametrize("single_grouped_weight", (False, True)) -@pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) -@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu")) -@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -def test_grouped_mlp_cuda_graph_safe_mxfp8( - *, - dtype: torch.dtype, - single_grouped_weight: bool, - accumulate_into_main_grad: bool, - activation: str, - device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, - split_alignment: int = 256, - glu_interleave_size: int = 32, - token_padding: int = 2048, -) -> None: - """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" - - if not te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): - pytest.skip("MXFP8 fused grouped MLP is not supported on this system") - if dtype not in (torch.bfloat16, torch.float16): - pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") - - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) - # Pad the input tokens to validate the sync-free MOE - in_shape = (split_sizes.sum().item() + token_padding, hidden_size) - recipe = make_recipe("mxfp8") - with te.quantized_model_init(enabled=True, recipe=recipe): - fc1 = te.ops.GroupedLinear( - group_size, - hidden_size, - 2 * hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, - ) - fc2 = te.ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=False, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - accumulate_into_main_grad=accumulate_into_main_grad, + return fc1_main_grad, fc2_main_grad + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + probs: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=True, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes, probs, static_split_sizes) + if use_graphed + else module(x, static_split_sizes, probs, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(in_shape, device=device, dtype=dtype) + static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + module, + (static_x, static_split_sizes, static_probs, static_split_sizes), + num_warmup_iters=3, + enabled=True, + recipe=recipe, ) - scaled_act = ( - te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) - module = te.ops.Sequential( - fc1, - scaled_act, - fc2, + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) - def _init_main_grads(value: float = 0.0) -> None: - if not accumulate_into_main_grad: - return + fresh_x = torch.randn_like(static_x) + fresh_probs = torch.randn_like(static_probs) + fresh_dy = torch.randn_like(static_dy) with torch.no_grad(): - if single_grouped_weight: - if getattr(fc1.weight, "main_grad", None) is None: - fc1.weight.main_grad = torch.empty( - fc1.weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2.weight, "main_grad", None) is None: - fc2.weight.main_grad = torch.empty( - fc2.weight.size(), - device=device, - dtype=torch.float32, - ) - fc1.weight.main_grad.fill_(value) - fc2.weight.main_grad.fill_(value) - else: - for group_idx in range(group_size): - fc1_weight = getattr(fc1, f"weight{group_idx}") - fc2_weight = getattr(fc2, f"weight{group_idx}") - if getattr(fc1_weight, "main_grad", None) is None: - fc1_weight.main_grad = torch.empty( - fc1_weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2_weight, "main_grad", None) is None: - fc2_weight.main_grad = torch.empty( - fc2_weight.size(), - device=device, - dtype=torch.float32, - ) - fc1_weight.main_grad.fill_(value) - fc2_weight.main_grad.fill_(value) + static_x.copy_(fresh_x) + static_probs.copy_(fresh_probs) + static_dy.copy_(fresh_dy) - def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: - if single_grouped_weight: - fc1_main_grad = fc1.weight.main_grad.detach().clone() - fc2_main_grad = fc2.weight.main_grad.detach().clone() + for param in module.parameters(): + param.grad = torch.zeros_like(param) + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + if static_probs.grad is not None: + static_probs.grad.zero_() + + graph_out = ( + train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_dprobs = static_probs.grad.detach().clone() + if accumulate_into_main_grad: + graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() else: - fc1_main_grad = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_main_grad = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - return fc1_main_grad, fc2_main_grad + graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] - static_split_sizes = split_sizes.clone() + for param in module.parameters(): + param.grad.zero_() + _init_main_grads(0.5) + static_x.grad.zero_() + static_probs.grad.zero_() - def train_step( - x: torch.Tensor, - probs: torch.Tensor, - dy: torch.Tensor, - out_buf: torch.Tensor, - *, - use_graphed: bool, - ) -> torch.Tensor: + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_probs = fresh_probs.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() with te.autocast(enabled=True, recipe=recipe): - out = ( - graphed_module(x, static_split_sizes, probs, static_split_sizes) - if use_graphed - else module(x, static_split_sizes, probs, static_split_sizes) + expected_out = module( + expected_x, + static_split_sizes, + expected_probs, + static_split_sizes, ) - out.backward(dy) - out_buf.copy_(out) - return out_buf - - _init_main_grads(0.0) - - static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) - static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) - static_dy = torch.randn(in_shape, device=device, dtype=dtype) - static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) - - graphed_module = te.make_graphed_callables( - module, - (static_x, static_split_sizes, static_probs, static_split_sizes), - num_warmup_iters=3, - enabled=True, - recipe=recipe, - ) - - forward_ops = module._module_groups[0]._forward_ops - backward_ops = module._module_groups[0]._backward_ops - assert len(forward_ops) == 1 - assert isinstance( - forward_ops[0][0], - te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, - ) - assert len(backward_ops) == 1 - assert isinstance( - backward_ops[0][0], - te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, - ) - - fresh_x = torch.randn_like(static_x) - fresh_probs = torch.randn_like(static_probs) - fresh_dy = torch.randn_like(static_dy) - with torch.no_grad(): - static_x.copy_(fresh_x) - static_probs.copy_(fresh_probs) - static_dy.copy_(fresh_dy) - - for param in module.parameters(): - param.grad = torch.zeros_like(param) - _init_main_grads(0.5) - if static_x.grad is not None: - static_x.grad.zero_() - if static_probs.grad is not None: - static_probs.grad.zero_() + expected_out.backward(expected_dy) - graph_out = ( - train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) - .detach() - .clone() - ) - torch.cuda.synchronize() - graph_dx = static_x.grad.detach().clone() - graph_dprobs = static_probs.grad.detach().clone() - if accumulate_into_main_grad: - graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() - else: - graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] - - for param in module.parameters(): - param.grad.zero_() - _init_main_grads(0.5) - static_x.grad.zero_() - static_probs.grad.zero_() - - expected_x = fresh_x.detach().clone().requires_grad_(True) - expected_probs = fresh_probs.detach().clone().requires_grad_(True) - expected_dy = fresh_dy.detach().clone() - with te.autocast(enabled=True, recipe=recipe): - expected_out = module( - expected_x, - static_split_sizes, - expected_probs, - static_split_sizes, - ) - expected_out.backward(expected_dy) - - tols = dtype_tols(dtype) - assert_close(graph_out, expected_out, **tols) - assert_close(graph_dx, expected_x.grad, **tols) - assert_close(graph_dprobs, expected_probs.grad, **tols) - if accumulate_into_main_grad: - expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() - assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) - assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) - else: - for graph_grad, param in zip(graph_param_grads, module.parameters()): - assert_close(graph_grad, param.grad, **tols) - - -def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: - if not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Requires SM100+ for grouped GEMM quant kernel.") - - try: - from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module - except ImportError as exc: - pytest.skip(f"grouped_gemm_quant_wrapper_sm100 unavailable: {exc}") - - device = torch.device("cuda") - dtype = torch.bfloat16 if is_bf16_available() else torch.float16 - num_groups = 4 - m = 256 - n = 512 - k = 512 - total_m = num_groups * m - split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) - - q = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3, rowwise=True, columnwise=False) - q.optimize_for_gemm = False - - torch.manual_seed(0) - a_full = torch.randn(total_m, k, device=device, dtype=dtype) - weights = [torch.randn(n, k, device=device, dtype=dtype) for _ in range(num_groups)] - - grouped_a = tex.group_quantize(a_full, q, num_groups, split_sizes) - a_groups = grouped_a.split_into_quantized_tensors() - b_groups = [q(w) for w in weights] - - # Reference GEMM on dequantized tensors. - ref = torch.empty((total_m, n), device=device, dtype=torch.float32) - start = 0 - for group_idx in range(num_groups): - end = start + m - a_deq = a_groups[group_idx].dequantize(dtype=torch.float32) - b_deq = b_groups[group_idx].dequantize(dtype=torch.float32) - ref[start:end, :] = a_deq @ b_deq.t() - start = end - ref = ref.to(dtype=torch.bfloat16).to(torch.float32) - - # Allocate empty input tensors needed for cuTE DSL kernel - padded_offsets = torch.tensor( - [m * (i + 1) for i in range(num_groups)], - dtype=torch.int32, - device=device, - ) - inputs = { - "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), - "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), - "sfa_tensor": torch.empty( - 1, - total_m // 128, - k // 128, - 32, - 4, - 4, - dtype=torch.float8_e8m0fnu, - device=device, - ).permute(3, 4, 1, 5, 2, 0), - "sfb_tensor": torch.empty( - num_groups, - n // 128, - k // 128, - 32, - 4, - 4, - dtype=torch.float8_e8m0fnu, + tols = dtype_tols(dtype) + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + assert_close(graph_dprobs, expected_probs.grad, **tols) + if accumulate_into_main_grad: + expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() + assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) + assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) + else: + for graph_grad, param in zip(graph_param_grads, module.parameters()): + assert_close(graph_grad, param.grad, **tols) + + + def test_grouped_gemm_quant_cute_matches_mxfp8_quantized(self) -> None: + if not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Requires SM100+ for grouped GEMM quant kernel.") + + try: + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + except ImportError as exc: + pytest.skip(f"grouped_gemm_quant_wrapper_sm100 unavailable: {exc}") + + device = torch.device("cuda") + dtype = torch.bfloat16 if is_bf16_available() else torch.float16 + num_groups = 4 + m = 256 + n = 512 + k = 512 + total_m = num_groups * m + split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) + + q = MXFP8Quantizer(fp8_dtype=te.DType.kFloat8E4M3, rowwise=True, columnwise=False) + q.optimize_for_gemm = False + + torch.manual_seed(0) + a_full = torch.randn(total_m, k, device=device, dtype=dtype) + weights = [torch.randn(n, k, device=device, dtype=dtype) for _ in range(num_groups)] + + grouped_a = tex.group_quantize(a_full, q, num_groups, split_sizes) + a_groups = grouped_a.split_into_quantized_tensors() + b_groups = [q(w) for w in weights] + + # Reference GEMM on dequantized tensors. + ref = torch.empty((total_m, n), device=device, dtype=torch.float32) + start = 0 + for group_idx in range(num_groups): + end = start + m + a_deq = a_groups[group_idx].dequantize(dtype=torch.float32) + b_deq = b_groups[group_idx].dequantize(dtype=torch.float32) + ref[start:end, :] = a_deq @ b_deq.t() + start = end + ref = ref.to(dtype=torch.bfloat16).to(torch.float32) + + # Allocate empty input tensors needed for cuTE DSL kernel + padded_offsets = torch.tensor( + [m * (i + 1) for i in range(num_groups)], + dtype=torch.int32, device=device, - ).permute(3, 4, 1, 5, 2, 0), - "alpha_tensor": torch.empty(num_groups, dtype=torch.float32, device=device), - "prob_tensor": torch.empty(total_m, 1, 1, dtype=torch.float32, device=device), - "padded_offsets_tensor": padded_offsets, - } - # Overwrite inputs with quantized data/scales from MXFP8 quantizer. - a_data = grouped_a.rowwise_data.view(total_m, k).view(dtype=torch.float8_e4m3fn) - a_data = a_data.unsqueeze(0).permute(1, 2, 0).contiguous() - inputs["a_tensor"].copy_(a_data) - - a_scales = grouped_a.scale_inv.view(dtype=torch.float8_e8m0fnu) - a_scales = a_scales.view(1, total_m // 128, 4, 32, k // 128, 4) - a_scales = a_scales.permute(0, 1, 4, 3, 2, 5).contiguous() - a_scales = a_scales.permute(3, 4, 1, 5, 2, 0).contiguous() - inputs["sfa_tensor"].copy_(a_scales) - - b_data = torch.cat([w._rowwise_data.reshape(-1) for w in b_groups]) - b_data = b_data.view(dtype=torch.float8_e4m3fn) - b_data = b_data.view(num_groups, n, k).permute(1, 2, 0).contiguous() - inputs["b_tensor"].copy_(b_data) - - b_scales = torch.cat([w._rowwise_scale_inv for w in b_groups]) - b_scales = b_scales.view(dtype=torch.float8_e8m0fnu) - b_scales = b_scales.view(num_groups, n // 128, 4, 32, k // 128, 4) - b_scales = b_scales.permute(0, 1, 4, 3, 2, 5).contiguous() - b_scales = b_scales.permute(3, 4, 1, 5, 2, 0).contiguous() - inputs["sfb_tensor"].copy_(b_scales) - - inputs["alpha_tensor"].fill_(1.0) - inputs["prob_tensor"].fill_(1.0) - - cute_out = grouped_gemm_quant_wrapper_sm100( - a_tensor=inputs["a_tensor"], - b_tensor=inputs["b_tensor"], - sfa_tensor=inputs["sfa_tensor"], - sfb_tensor=inputs["sfb_tensor"], - padded_offsets=inputs["padded_offsets_tensor"], - alpha_tensor=inputs["alpha_tensor"], - norm_const_tensor=None, - prob_tensor=inputs["prob_tensor"], - acc_dtype=torch.float32, - d_dtype=torch.bfloat16, - cd_major="n", - sf_vec_size=32, - discrete_col_sfd=True, - current_stream=None, - ) - - if isinstance(cute_out, dict): - outputs = cute_out - else: - d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor = cute_out - outputs = { - "d_tensor": d_tensor, - "d_col_tensor": d_col_tensor, - "amax_tensor": amax_tensor, - "sfd_row_tensor": sfd_row_tensor, - "sfd_col_tensor": sfd_col_tensor, + ) + inputs = { + "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( + 1, 2, 0 + ), + "sfa_tensor": torch.empty( + 1, + total_m // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "sfb_tensor": torch.empty( + num_groups, + n // 128, + k // 128, + 32, + 4, + 4, + dtype=torch.float8_e8m0fnu, + device=device, + ).permute(3, 4, 1, 5, 2, 0), + "alpha_tensor": torch.empty(num_groups, dtype=torch.float32, device=device), + "prob_tensor": torch.empty(total_m, 1, 1, dtype=torch.float32, device=device), + "padded_offsets_tensor": padded_offsets, } + # Overwrite inputs with quantized data/scales from MXFP8 quantizer. + a_data = grouped_a.rowwise_data.view(total_m, k).view(dtype=torch.float8_e4m3fn) + a_data = a_data.unsqueeze(0).permute(1, 2, 0).contiguous() + inputs["a_tensor"].copy_(a_data) + + a_scales = grouped_a.scale_inv.view(dtype=torch.float8_e8m0fnu) + a_scales = a_scales.view(1, total_m // 128, 4, 32, k // 128, 4) + a_scales = a_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + a_scales = a_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfa_tensor"].copy_(a_scales) + + b_data = torch.cat([w._rowwise_data.reshape(-1) for w in b_groups]) + b_data = b_data.view(dtype=torch.float8_e4m3fn) + b_data = b_data.view(num_groups, n, k).permute(1, 2, 0).contiguous() + inputs["b_tensor"].copy_(b_data) + + b_scales = torch.cat([w._rowwise_scale_inv for w in b_groups]) + b_scales = b_scales.view(dtype=torch.float8_e8m0fnu) + b_scales = b_scales.view(num_groups, n // 128, 4, 32, k // 128, 4) + b_scales = b_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + b_scales = b_scales.permute(3, 4, 1, 5, 2, 0).contiguous() + inputs["sfb_tensor"].copy_(b_scales) + + inputs["alpha_tensor"].fill_(1.0) + inputs["prob_tensor"].fill_(1.0) + + cute_out = grouped_gemm_quant_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=None, + prob_tensor=inputs["prob_tensor"], + acc_dtype=torch.float32, + d_dtype=torch.bfloat16, + cd_major="n", + sf_vec_size=32, + discrete_col_sfd=True, + current_stream=None, + ) - d_cute = outputs["d_tensor"] - if d_cute.dim() == 3: - d_cute = d_cute.squeeze(-1) - tols = dtype_tols(torch.bfloat16) - assert_close(d_cute[:total_m].float(), ref, **tols) + if isinstance(cute_out, dict): + outputs = cute_out + else: + d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor = cute_out + outputs = { + "d_tensor": d_tensor, + "d_col_tensor": d_col_tensor, + "amax_tensor": amax_tensor, + "sfd_row_tensor": sfd_row_tensor, + "sfd_col_tensor": sfd_col_tensor, + } + + d_cute = outputs["d_tensor"] + if d_cute.dim() == 3: + d_cute = d_cute.squeeze(-1) + tols = dtype_tols(torch.bfloat16) + assert_close(d_cute[:total_m].float(), ref, **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index d76fa6783a..0b0bd55a38 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -19,14 +19,16 @@ import transformer_engine from transformer_engine.common.recipe import Recipe import math -from transformer_engine.pytorch import InferenceParams, QuantizedTensor -from transformer_engine.pytorch import DType from transformer_engine.pytorch import ( + DType, Float8CurrentScalingQuantizer, Float8Quantizer, + InferenceParams, MXFP8Quantizer, NVFP4Quantizer, + QuantizedTensor, QuantizerRole, + is_bf16_available, ) from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( From b40e26e7c4e68eae9ce21c7121d9d2dde8434e5a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 10 Jun 2026 02:35:42 +0000 Subject: [PATCH 07/11] Refactor grouped linear test utilities Pull repeated grouped-linear test setup into shared helpers for environment toggles, split-size construction, grouped parameter copying, and grad collection. This keeps the module, ops, and grouped-MLP coverage aligned around the same single-grouped-parameter conventions instead of duplicating local ad hoc loops. Also make the grouped tensor path comparison look at the actual single grouped bias parameter when that mode is enabled, matching the idiom already used by the fusible ops tests. Signed-off-by: Tim Moon --- tests/pytorch/test_grouped_linear.py | 687 ++++++++++++++++----------- 1 file changed, 399 insertions(+), 288 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 8cece276c5..88f933a1e9 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -381,6 +381,189 @@ def _generate_random_numbers(n, total_sum): _FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" +def _force_legacy_grouped_linear_path(monkeypatch) -> None: + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") + + +def _enable_fused_grouped_linear_path(monkeypatch) -> None: + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + + +def _enable_single_grouped_param(monkeypatch) -> None: + monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + + +def _enable_fused_grouped_mlp(monkeypatch) -> None: + monkeypatch.setenv("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") + + +def _make_grouped_split_sizes( + group_size: int, + split_alignment: int, + *, + start: int = 0, + dtype: torch.dtype = torch.int, + device: torch.device | str = "cuda", +) -> torch.Tensor: + """Construct shuffled per-group token counts.""" + split_sizes = [split_alignment * (i + start) for i in range(group_size)] + random.shuffle(split_sizes) + return torch.tensor(split_sizes, dtype=dtype, device=device) + + +def _grouped_weight_params( + op: torch.nn.Module, + group_size: int, + *, + single_grouped_weight: bool, +) -> list[torch.nn.Parameter]: + """Extract weight parameters from grouped linear module or op.""" + if single_grouped_weight: + return [op.weight] + return [getattr(op, f"weight{i}") for i in range(group_size)] + + +def _grouped_bias_params( + op: torch.nn.Module, + group_size: int, + *, + single_grouped_bias: bool, +) -> list[torch.nn.Parameter]: + """Extract bias parameters from grouped linear module or op.""" + if single_grouped_bias: + return [op.bias] + return [getattr(op, f"bias{i}") for i in range(group_size)] + + +def _copy_grouped_linear_params( + op: torch.nn.Module, + weights: Sequence[torch.Tensor], + biases: Optional[Sequence[Optional[torch.Tensor]]] = None, + *, + single_grouped_weight: bool = False, + single_grouped_bias: bool = False, +) -> None: + """Copy values into grouped linear params""" + + # Copy into weights + if single_grouped_weight: + weight_parts = op.weight.quantized_tensors + if weight_parts is None: + weight_parts = op.weight.split_into_quantized_tensors() + for dst, src in zip(weight_parts, weights): + dst.copy_(src) + else: + for group_idx, weight in enumerate(weights): + getattr(op, f"weight{group_idx}").copy_(weight) + + # Copy into biases + if biases is None: + pass + elif single_grouped_bias: + bias_parts = op.bias.split_into_quantized_tensors() + for dst, src in zip(bias_parts, biases): + dst.reshape(-1).copy_(src) + else: + for group_idx, bias in enumerate(biases): + getattr(op, f"bias{group_idx}").copy_(bias) + + +def _fill_main_grads( + params: Sequence[torch.nn.Parameter], + value: float, + *, + device: torch.device | str, +) -> None: + """Construct param main_grad if needed and fill with value""" + with torch.no_grad(): + for param in params: + if getattr(param, "main_grad", None) is None: + param.main_grad = torch.empty(param.size(), device=device, dtype=torch.float32) + param.main_grad.fill_(value) + + +def _clone_grads(params: Sequence[torch.nn.Parameter]) -> list[torch.Tensor]: + return [param.grad.detach().clone() for param in params] + + +def _clone_main_grads(params: Sequence[torch.nn.Parameter]) -> list[torch.Tensor]: + return [param.main_grad.detach().clone() for param in params] + + +def _stack_cloned_attr(params: Sequence[torch.nn.Parameter], attr: str) -> torch.Tensor: + values = [getattr(param, attr).detach().clone() for param in params] + if len(values) == 1: + return values[0] + return torch.stack(values, dim=0) + + +def _make_scaled_grouped_mlp_activation( + activation: str, + *, + glu_interleave_size: Optional[int], + geglu_limit: float = 7.0, + geglu_alpha: float = 1.702, + geglu_offset: float = 1.0, +) -> torch.nn.Module: + if activation == "scaled_swiglu": + return te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_clamped_qgeglu_custom": + return te.ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=geglu_limit, + alpha=geglu_alpha, + glu_linear_offset=geglu_offset, + ) + if activation.startswith("scaled_clamped_qgeglu"): + return te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_srelu": + return te.ops.ScaledSReLU() + raise ValueError(f"Unexpected activation ({activation})") + + +def _skip_invalid_grouped_mlp_case( + *, + activation: str, + activation_is_glu: bool, + bias: bool, + dtype: torch.dtype, + quantization: Optional[str], + single_grouped_weight: bool, + single_grouped_bias: bool, + glu_interleave_size: Optional[int], + device: torch.device | str, +) -> None: + with_quantization = quantization is not None + maybe_skip_quantization(quantization, device=device, dtype=dtype) + if single_grouped_weight and quantization != "mxfp8": + pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): + pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") + if not activation_is_glu and glu_interleave_size is not None: + pytest.skip("Unary activations do not use GLU interleaving") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: + pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") + if quantization == "nvfp4_rht": + if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): + pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") + if activation not in ("scaled_swiglu", "scaled_srelu"): + pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and activation.startswith("scaled_clamped_qgeglu") + and bias + ): + # TODO: ksivaman: Need to debug numerics for this case. + pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") + + class TestGroupedLinearModule: """Tests for te.GroupedLinear module API. @@ -388,11 +571,15 @@ class TestGroupedLinearModule: """ @pytest.fixture(autouse=True) - def _set_fused_grouped_gemm_env(self, monkeypatch): - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") + def _use_legacy_grouped_linear_path(self, monkeypatch): + _force_legacy_grouped_linear_path(monkeypatch) yield monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) + @pytest.fixture(autouse=True) + def _use_single_grouped_param(self, monkeypatch): + _enable_single_grouped_param(monkeypatch) + @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [1, 3, 6]) @@ -773,10 +960,12 @@ def _run_grouped_linear_path( single_grouped_bias=single_grouped_bias, ) with torch.no_grad(): - for i in range(num_gemms): - getattr(grouped_linear, f"weight{i}").copy_(weights[i]) - if bias: - getattr(grouped_linear, f"bias{i}").copy_(biases[i]) + _copy_grouped_linear_params( + grouped_linear, + weights, + biases if bias else None, + single_grouped_bias=single_grouped_bias, + ) # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. # The legacy path still expects Python split sections in several places. @@ -792,10 +981,12 @@ def _run_grouped_linear_path( grouped_linear.backward_dw() outputs = [y, x.grad] - for i in range(num_gemms): - outputs.append(getattr(grouped_linear, f"weight{i}").grad) - if bias: - outputs.append(getattr(grouped_linear, f"bias{i}").grad) + outputs.extend(getattr(grouped_linear, f"weight{i}").grad for i in range(num_gemms)) + if bias: + if single_grouped_bias: + outputs.append(grouped_linear.bias.grad) + else: + outputs.extend(getattr(grouped_linear, f"bias{i}").grad for i in range(num_gemms)) return _clone_outputs(outputs) @@ -882,7 +1073,7 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): assert grouped_tensor_out is not None assert legacy_out is not None - torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) + assert_close(grouped_tensor_out, legacy_out, **tols) @pytest.mark.parametrize( @@ -902,7 +1093,7 @@ def test_grouped_linear_fused_path_cuda_graph_safe(self, fp8_recipe, bias, monke if torch.cuda.get_device_capability() < (10, 0): pytest.skip("GroupedTensor grouped GEMM path requires SM100+") - monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + _enable_fused_grouped_linear_path(monkeypatch) FP8GlobalStateManager.reset() use_fp8 = fp8_recipe is not None @@ -991,11 +1182,11 @@ def _train_step(x, dy, out_buf, *, use_graphed): tols = dict(rtol=1e-2, atol=5e-3) if use_fp8: tols = dict(rtol=0.05, atol=0.05) - torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) - torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): assert param.grad is not None - torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) + assert_close(graph_grad, param.grad, **tols) def _clone_outputs(outputs): @@ -1728,8 +1919,8 @@ class TestGroupedLinearOps: """Tests for te.ops.GroupedLinear (ops/fuser API).""" @pytest.fixture(autouse=True) - def _set_single_param_env(self, monkeypatch): - monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") + def _use_single_grouped_param(self, monkeypatch): + _enable_single_grouped_param(monkeypatch) @pytest.mark.parametrize("swizzle_type", ["mxfp8_rowwise", "mxfp8_columnwise", "nvfp4"]) def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( @@ -1891,9 +2082,13 @@ def test_grouped_linear_cuda_graph_safe( pytest.skip("single_grouped_bias requires bias=True") # Split sizes (statically pinned for graph capture) - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + start=1, + dtype=torch.int, + device=device, + ) # Pad input tokens to validate the sync-free flow in_shape = (split_sizes.sum().item() + token_padding, in_features) out_shape = (in_shape[0], out_features) @@ -1911,30 +2106,19 @@ def test_grouped_linear_cuda_graph_safe( single_grouped_weight=single_grouped_weight, single_grouped_bias=single_grouped_bias, ) - - def _weight_params() -> list[torch.nn.Parameter]: - if single_grouped_weight: - return [op.weight] - return [getattr(op, f"weight{i}") for i in range(group_size)] - - def _bias_params() -> list[torch.nn.Parameter]: - if not bias: - return [] - if single_grouped_bias: - return [op.bias] - return [getattr(op, f"bias{i}") for i in range(group_size)] + weight_params = _grouped_weight_params( + op, + group_size, + single_grouped_weight=single_grouped_weight, + ) def _init_main_grads(value: float = 0.0) -> None: if not accumulate_into_main_grad: return - with torch.no_grad(): - for w in _weight_params(): - if getattr(w, "main_grad", None) is None: - w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) - w.main_grad.fill_(value) + _fill_main_grads(weight_params, value, device=device) def _collect_main_grads() -> list[torch.Tensor]: - return [w.main_grad.detach().clone() for w in _weight_params()] + return _clone_main_grads(weight_params) def _zero_param_grads() -> None: for param in op.parameters(): @@ -2022,7 +2206,7 @@ def train_step( assert_close(graph_out, expected_out, **tols) assert_close(graph_dx, expected_x.grad, **tols) if accumulate_into_main_grad: - for g, w in zip(graph_main_grads, _weight_params()): + for g, w in zip(graph_main_grads, weight_params): assert_close(g, w.main_grad, **tols) else: for g, param in zip(graph_param_grads, op.parameters()): @@ -2039,7 +2223,7 @@ def train_step( @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) - def test_ops_grouped_linear( + def test_grouped_linear( self, *, group_size: int = 4, @@ -2060,9 +2244,12 @@ def test_ops_grouped_linear( """te.ops.GroupedLinear forward+backward accuracy""" # Split sizes - split_sizes = [split_alignment * i for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + dtype=torch.int, + device=device, + ) # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -2153,22 +2340,13 @@ def test_ops_grouped_linear( single_grouped_bias=single_grouped_bias, ) with torch.no_grad(): - if single_grouped_weight: - op_weights = op.weight.quantized_tensors - if op_weights is None: - op_weights = op.weight.split_into_quantized_tensors() - if single_grouped_bias: - op_bias_parts = op.bias.split_into_quantized_tensors() - for group_idx in range(group_size): - if single_grouped_weight: - op_weights[group_idx].copy_(ws_test[group_idx]) - else: - getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) - if bias: - if single_grouped_bias: - op_bias_parts[group_idx].reshape(-1).copy_(bs_test[group_idx]) - else: - getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + _copy_grouped_linear_params( + op, + ws_test, + bs_test if bias else None, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) del ws_test, bs_test for param in op.parameters(): param.requires_grad_(requires_grad=weight_requires_grad) @@ -2218,9 +2396,9 @@ class TestGroupedMLP: """Tests for grouped MLP patterns (te.ops.GroupedLinear + activation).""" @pytest.fixture(autouse=True) - def _set_envvars(self, monkeypatch): - monkeypatch.setenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "1") - monkeypatch.setenv("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1") + def _enable_grouped_mlp_envvars(self, monkeypatch): + _enable_single_grouped_param(monkeypatch) + _enable_fused_grouped_mlp(monkeypatch) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) @@ -2253,50 +2431,24 @@ def test_grouped_mlp( activation: str, ) -> None: """GroupedLinear + scaled activation + GroupedLinear""" - if dtype == torch.bfloat16 and not is_bf16_available(): - pytest.skip("BF16 not available") - - # Build activation op to determine GLU vs unary - if activation == "scaled_swiglu": - scaled_act_ref = te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - elif activation.startswith("scaled_clamped_qgeglu"): - scaled_act_ref = te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - elif activation == "scaled_srelu": - scaled_act_ref = te.ops.ScaledSReLU() - else: - raise ValueError(f"Unexpected activation ({activation})") + scaled_act_ref = _make_scaled_grouped_mlp_activation( + activation, + glu_interleave_size=glu_interleave_size, + ) activation_is_glu = is_glu_activation(scaled_act_ref) - # Skip invalid configurations + _skip_invalid_grouped_mlp_case( + activation=activation, + activation_is_glu=activation_is_glu, + bias=bias, + dtype=dtype, + quantization=quantization, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + glu_interleave_size=glu_interleave_size, + device=device, + ) with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device, dtype=dtype) - if single_grouped_weight and quantization != "mxfp8": - pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") - if with_quantization and dtype not in (torch.bfloat16, torch.float16): - pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if not activation_is_glu and quantization not in ("mxfp8", "nvfp4", "nvfp4_rht"): - pytest.skip("Scaled unary grouped MLP is only supported with MXFP8 or NVFP4") - if not activation_is_glu and glu_interleave_size is not None: - pytest.skip("Unary activations do not use GLU interleaving") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if activation == "scaled_srelu" and quantization in ("nvfp4", "nvfp4_rht") and bias: - pytest.skip("NVFP4 SReLU grouped MLP coverage is limited to no-bias") - if quantization == "nvfp4_rht": - if activation == "scaled_swiglu" and (bias or glu_interleave_size != 32): - pytest.skip("NVFP4 RHT SwiGLU grouped MLP coverage is limited to no-bias") - if activation not in ("scaled_swiglu", "scaled_srelu"): - pytest.skip("NVFP4 RHT grouped MLP coverage is limited to SwiGLU and SReLU") - if ( - with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") - and activation.startswith("scaled_clamped_qgeglu") - and bias - ): - # TODO: ksivaman: Need to debug numerics for this case. - pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size if activation == "scaled_clamped_qgeglu_custom": @@ -2305,9 +2457,12 @@ def test_grouped_mlp( geglu_limit, geglu_alpha, geglu_offset = 7.0, 1.702, 1.0 # Split sizes (one group intentionally empty to test the zero-token case) - split_sizes = [split_alignment * i for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + dtype=torch.int, + device=device, + ) in_shape = (split_sizes.sum().item(), hidden_size) out_shape = in_shape @@ -2416,22 +2571,6 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: # Construct TE module recipe = make_recipe(quantization) - def _make_scaled_act(): - if activation == "scaled_swiglu": - return te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_clamped_qgeglu_custom": - return te.ops.ScaledClampedQGeGLU( - glu_interleave_size=glu_interleave_size, - limit=geglu_limit, - alpha=geglu_alpha, - glu_linear_offset=geglu_offset, - ) - if activation.startswith("scaled_clamped_qgeglu"): - return te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_srelu": - return te.ops.ScaledSReLU() - raise ValueError(f"Unexpected activation ({activation})") - with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te.ops.GroupedLinear( group_size, hidden_size, fc1_out_features, @@ -2450,41 +2589,48 @@ def _make_scaled_act(): delay_wgrad_compute=delay_wgrad_compute, scale_bias=bias, ) - module = te.ops.Sequential(fc1, _make_scaled_act(), fc2) + module = te.ops.Sequential( + fc1, + _make_scaled_grouped_mlp_activation( + activation, + glu_interleave_size=glu_interleave_size, + geglu_limit=geglu_limit, + geglu_alpha=geglu_alpha, + geglu_offset=geglu_offset, + ), + fc2, + ) # Copy weights with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() - for group_idx in range(group_size): - if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) - if bias: - if single_grouped_bias: - fc1_bparts = fc1.bias.split_into_quantized_tensors() - fc2_bparts = fc2.bias.split_into_quantized_tensors() - fc1_bparts[group_idx].reshape(-1).copy_(fc1_bs_test[group_idx]) - fc2_bparts[group_idx].reshape(-1).copy_(fc2_bs_test[group_idx]) - else: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + _copy_grouped_linear_params( + fc1, + fc1_ws_test, + fc1_bs_test if bias else None, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) + _copy_grouped_linear_params( + fc2, + fc2_ws_test, + fc2_bs_test if bias else None, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) if accumulate_into_main_grad: main_grad_sentinel = 0.5 - if single_grouped_weight: - weight_params_for_main_grad = [fc1.weight, fc2.weight] - else: - weight_params_for_main_grad = [ - getattr(fc, f"weight{i}") for fc in (fc1, fc2) for i in range(group_size) - ] + weight_params_for_main_grad = ( + _grouped_weight_params( + fc1, + group_size, + single_grouped_weight=single_grouped_weight, + ) + + _grouped_weight_params( + fc2, + group_size, + single_grouped_weight=single_grouped_weight, + ) + ) MegatronTrainingHelper.init_main_grad_buffers( weight_params_for_main_grad, fill_value=main_grad_sentinel, @@ -2605,9 +2751,13 @@ def test_grouped_mlp_single_weight_numerics( if not te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + start=1, + dtype=torch.int64, + device=device, + ) in_shape = (split_sizes.sum().item(), hidden_size) recipe = make_recipe("mxfp8") @@ -2645,10 +2795,9 @@ def test_grouped_mlp_single_weight_numerics( def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: with te.quantized_model_init(enabled=True, recipe=recipe): - scaled_act = ( - te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + scaled_act = _make_scaled_grouped_mlp_activation( + activation, + glu_interleave_size=glu_interleave_size, ) fc1 = te.ops.GroupedLinear( group_size, @@ -2672,23 +2821,18 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: module = te.ops.Sequential(fc1, scaled_act, fc2) with torch.no_grad(): - if single_grouped_weight: - fc1_weights = fc1.weight.quantized_tensors - if fc1_weights is None: - fc1_weights = fc1.weight.split_into_quantized_tensors() - fc2_weights = fc2.weight.quantized_tensors - if fc2_weights is None: - fc2_weights = fc2.weight.split_into_quantized_tensors() - for group_idx in range(group_size): - if single_grouped_weight: - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) - else: - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) - if bias: - getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_base[group_idx]) - getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_base[group_idx]) + _copy_grouped_linear_params( + fc1, + fc1_ws_base, + fc1_bs_base if bias else None, + single_grouped_weight=single_grouped_weight, + ) + _copy_grouped_linear_params( + fc2, + fc2_ws_base, + fc2_bs_base if bias else None, + single_grouped_weight=single_grouped_weight, + ) x = x_base.detach().clone().requires_grad_(True) probs = probs_base.detach().clone().requires_grad_(True) @@ -2712,41 +2856,41 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) - if single_grouped_weight: - fc1_dw = fc1.weight.grad.detach().clone() - fc2_dw = fc2.weight.grad.detach().clone() - else: - fc1_dw = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_dw = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) + fc1_dw = _stack_cloned_attr( + _grouped_weight_params( + fc1, + group_size, + single_grouped_weight=single_grouped_weight, + ), + "grad", + ) + fc2_dw = _stack_cloned_attr( + _grouped_weight_params( + fc2, + group_size, + single_grouped_weight=single_grouped_weight, + ), + "grad", + ) fc1_db = None fc2_db = None if bias: - fc1_db = torch.stack( - [ - getattr(fc1, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + fc1_db = _stack_cloned_attr( + _grouped_bias_params( + fc1, + group_size, + single_grouped_bias=False, + ), + "grad", ) - fc2_db = torch.stack( - [ - getattr(fc2, f"bias{group_idx}").grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, + fc2_db = _stack_cloned_attr( + _grouped_bias_params( + fc2, + group_size, + single_grouped_bias=False, + ), + "grad", ) return ( @@ -2829,9 +2973,13 @@ def test_grouped_mlp_overwrite_main_grad( pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") recipe = make_recipe("mxfp8") - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + start=1, + dtype=torch.int64, + device=device, + ) in_shape = (split_sizes.sum().item(), hidden_size) x_base = torch.empty(in_shape, device=device, dtype=dtype).uniform_(-0.25, 0.25) probs_base = torch.empty((in_shape[0],), device=device, dtype=dtype).uniform_(-0.25, 0.25) @@ -2877,26 +3025,24 @@ def _build_module(*, accumulate_into_main_grad: bool): module = te.ops.Sequential(fc1, scaled_act, fc2) with torch.no_grad(): - if single_grouped_weight: - fc1_weights = ( - fc1.weight.quantized_tensors or fc1.weight.split_into_quantized_tensors() - ) - fc2_weights = ( - fc2.weight.quantized_tensors or fc2.weight.split_into_quantized_tensors() - ) - for group_idx in range(group_size): - fc1_weights[group_idx].copy_(fc1_ws_base[group_idx]) - fc2_weights[group_idx].copy_(fc2_ws_base[group_idx]) - else: - for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_base[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_base[group_idx]) + _copy_grouped_linear_params( + fc1, + fc1_ws_base, + single_grouped_weight=single_grouped_weight, + ) + _copy_grouped_linear_params( + fc2, + fc2_ws_base, + single_grouped_weight=single_grouped_weight, + ) return module, fc1, fc2 def _weight_params(fc): - if single_grouped_weight: - return [fc.weight] - return [getattr(fc, f"weight{i}") for i in range(group_size)] + return _grouped_weight_params( + fc, + group_size, + single_grouped_weight=single_grouped_weight, + ) def _run_backward(module, fc1, fc2): x = x_base.detach().clone().requires_grad_(True) @@ -2911,8 +3057,8 @@ def _run_backward(module, fc1, fc2): # Reference run: vanilla autograd, no Megatron protocol. ref_module, ref_fc1, ref_fc2 = _build_module(accumulate_into_main_grad=False) _run_backward(ref_module, ref_fc1, ref_fc2) - ref_fc1_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc1)] - ref_fc2_grads = [wp.grad.detach().clone() for wp in _weight_params(ref_fc2)] + ref_fc1_grads = _clone_grads(_weight_params(ref_fc1)) + ref_fc2_grads = _clone_grads(_weight_params(ref_fc2)) # Test run: main_grad fusion with overwrite_main_grad=True (MegatronFSDP). # NaN sentinel makes a missed write loud (would surface as NaN diff). @@ -2963,9 +3109,13 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") - split_sizes = [split_alignment * (i + 1) for i in range(group_size)] - random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + split_sizes = _make_grouped_split_sizes( + group_size, + split_alignment, + start=1, + dtype=torch.int64, + device=device, + ) # Pad the input tokens to validate the sync-free MOE in_shape = (split_sizes.sum().item() + token_padding, hidden_size) recipe = make_recipe("mxfp8") @@ -2990,75 +3140,36 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( single_grouped_weight=single_grouped_weight, accumulate_into_main_grad=accumulate_into_main_grad, ) - scaled_act = ( - te.ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) - if activation == "scaled_swiglu" - else te.ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + scaled_act = _make_scaled_grouped_mlp_activation( + activation, + glu_interleave_size=glu_interleave_size, ) module = te.ops.Sequential( fc1, scaled_act, fc2, ) + fc1_weight_params = _grouped_weight_params( + fc1, + group_size, + single_grouped_weight=single_grouped_weight, + ) + fc2_weight_params = _grouped_weight_params( + fc2, + group_size, + single_grouped_weight=single_grouped_weight, + ) def _init_main_grads(value: float = 0.0) -> None: if not accumulate_into_main_grad: return - with torch.no_grad(): - if single_grouped_weight: - if getattr(fc1.weight, "main_grad", None) is None: - fc1.weight.main_grad = torch.empty( - fc1.weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2.weight, "main_grad", None) is None: - fc2.weight.main_grad = torch.empty( - fc2.weight.size(), - device=device, - dtype=torch.float32, - ) - fc1.weight.main_grad.fill_(value) - fc2.weight.main_grad.fill_(value) - else: - for group_idx in range(group_size): - fc1_weight = getattr(fc1, f"weight{group_idx}") - fc2_weight = getattr(fc2, f"weight{group_idx}") - if getattr(fc1_weight, "main_grad", None) is None: - fc1_weight.main_grad = torch.empty( - fc1_weight.size(), - device=device, - dtype=torch.float32, - ) - if getattr(fc2_weight, "main_grad", None) is None: - fc2_weight.main_grad = torch.empty( - fc2_weight.size(), - device=device, - dtype=torch.float32, - ) - fc1_weight.main_grad.fill_(value) - fc2_weight.main_grad.fill_(value) + _fill_main_grads(fc1_weight_params + fc2_weight_params, value, device=device) def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: - if single_grouped_weight: - fc1_main_grad = fc1.weight.main_grad.detach().clone() - fc2_main_grad = fc2.weight.main_grad.detach().clone() - else: - fc1_main_grad = torch.stack( - [ - getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - fc2_main_grad = torch.stack( - [ - getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() - for group_idx in range(group_size) - ], - dim=0, - ) - return fc1_main_grad, fc2_main_grad + return ( + _stack_cloned_attr(fc1_weight_params, "main_grad"), + _stack_cloned_attr(fc2_weight_params, "main_grad"), + ) static_split_sizes = split_sizes.clone() From f50d517597dc182919d95a8cc1e7f5f4da488871 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 10 Jun 2026 03:27:03 +0000 Subject: [PATCH 08/11] Remove te.GroupedLinear test case with discrete weight and grouped bias Test was inadvertently disabled, and enabling triggered test failures beyond the scope of this PR. Grouped tensor params are still highly experimental, and it was quite strange to only test grouped bias without grouped weights. Signed-off-by: Tim Moon --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_grouped_linear.py | 15 ++------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 93490d6f81..3e7a76c617 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -29,7 +29,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_P python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 88f933a1e9..82c30fa6f9 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -930,7 +930,6 @@ def _run_grouped_linear_path( bias: bool, fp8_model_params: bool, delay_wgrad_compute: bool, - single_grouped_bias: bool = False, x_base: torch.Tensor, dy: torch.Tensor, weights, @@ -957,14 +956,12 @@ def _run_grouped_linear_path( params_dtype=dtype, device="cuda", delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, ) with torch.no_grad(): _copy_grouped_linear_params( grouped_linear, weights, biases if bias else None, - single_grouped_bias=single_grouped_bias, ) # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. @@ -983,10 +980,7 @@ def _run_grouped_linear_path( outputs = [y, x.grad] outputs.extend(getattr(grouped_linear, f"weight{i}").grad for i in range(num_gemms)) if bias: - if single_grouped_bias: - outputs.append(grouped_linear.bias.grad) - else: - outputs.extend(getattr(grouped_linear, f"bias{i}").grad for i in range(num_gemms)) + outputs.extend(getattr(grouped_linear, f"bias{i}").grad for i in range(num_gemms)) return _clone_outputs(outputs) @@ -1001,13 +995,12 @@ def _run_grouped_linear_path( ], ids=["bf16", "mxfp8"], ) - @pytest.mark.parametrize("single_grouped_bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_grouped_tensor_path_matches_legacy( self, - fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, single_grouped_bias, monkeypatch + fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch ): if torch.cuda.get_device_capability() < (10, 0): pytest.skip("GroupedTensor grouped GEMM path requires SM100+") @@ -1015,8 +1008,6 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( use_fp8 = fp8_recipe is not None if fp8_model_params and not use_fp8: pytest.skip("fp8_model_params requires FP8") - if single_grouped_bias and not bias: - pytest.skip("single_grouped_bias requires bias=True") dtype = torch.bfloat16 num_gemms = 3 @@ -1044,7 +1035,6 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( bias=bias, fp8_model_params=fp8_model_params, delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, x_base=x_base, dy=dy, weights=weights, @@ -1058,7 +1048,6 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( bias=bias, fp8_model_params=fp8_model_params, delay_wgrad_compute=delay_wgrad_compute, - single_grouped_bias=single_grouped_bias, x_base=x_base, dy=dy, weights=weights, From 4368cb97fad56741756d9b3c2c33abfa654c1724 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 10 Jun 2026 04:24:04 +0000 Subject: [PATCH 09/11] Reduce grouped MLP test cases Prioritize configs that trigger op fusion. Separate parametrized cases for advanced Mcore integrations. Signed-off-by: Tim Moon --- tests/pytorch/test_grouped_linear.py | 109 ++++++++++++++++----------- 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 82c30fa6f9..88eb7f1c81 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -166,21 +166,13 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> if torch.cuda.get_device_capability() == (9, 0): use_cutlass_grouped_gemm.append(True) -_grouped_mlp_quantization_list: list = [None] +_quantization_list: list = [None] if fp8_available: - _grouped_mlp_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) + _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: - _grouped_mlp_quantization_list.append("mxfp8") + _quantization_list.append("mxfp8") if nvfp4_available: - _grouped_mlp_quantization_list.extend(("nvfp4", "nvfp4_4over6", "nvfp4_rht")) - -_ops_quantization_list: list = [None] -if fp8_available: - _ops_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) -if mxfp8_available: - _ops_quantization_list.append("mxfp8") -if nvfp4_available: - _ops_quantization_list.extend(("nvfp4", "nvfp4_4over6")) + _quantization_list.extend(("nvfp4", "nvfp4_4over6", "nvfp4_rht")) class TorchGroupedLinearWithPadding(nn.Module): @@ -2207,7 +2199,7 @@ def train_step( @pytest.mark.parametrize("single_grouped_bias", (False, True)) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", param_types, ids=str) - @pytest.mark.parametrize("quantization", _ops_quantization_list) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @@ -2390,14 +2382,8 @@ def _enable_grouped_mlp_envvars(self, monkeypatch): _enable_fused_grouped_mlp(monkeypatch) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) - @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) - @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("single_grouped_weight", (False, True)) - @pytest.mark.parametrize("single_grouped_bias", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) - @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) - @pytest.mark.parametrize("hidden_size", (128, 256)) @pytest.mark.parametrize( "activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom", "scaled_srelu"), @@ -2406,25 +2392,27 @@ def test_grouped_mlp( self, *, group_size: int = 4, - hidden_size: int, + hidden_size: int = 256, bias: bool, - dtype: torch.dtype, + dtype: torch.dtype = torch.bfloat16, quantization: Optional[str], single_grouped_weight: bool, - single_grouped_bias: bool, - accumulate_into_main_grad: bool, + accumulate_into_main_grad: bool = False, device: torch.device = "cuda", split_alignment: int = 256, - glu_interleave_size: Optional[int], - delay_wgrad_compute: bool, + delay_wgrad_compute: bool = False, activation: str, ) -> None: """GroupedLinear + scaled activation + GroupedLinear""" - scaled_act_ref = _make_scaled_grouped_mlp_activation( - activation, - glu_interleave_size=glu_interleave_size, + + # Grouped MLP fused op requires GLU interleaving + activation_is_glu = activation in ( + "scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom" ) - activation_is_glu = is_glu_activation(scaled_act_ref) + glu_interleave_size = 32 if activation_is_glu else None + + # Enable grouped bias if weights are grouped + single_grouped_bias = bias and single_grouped_weight _skip_invalid_grouped_mlp_case( activation=activation, @@ -2526,7 +2514,7 @@ def test_grouped_mlp( fc2_bs_test.append(None) def _apply_activation(x: torch.Tensor) -> torch.Tensor: - if activation_is_glu and glu_interleave_size is not None: + if glu_interleave_size is not None: x = x.reshape(-1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size) x = x.transpose(1, 2).reshape(-1, 2 * hidden_size) if activation == "scaled_swiglu": @@ -2636,29 +2624,31 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: fc1.backward_dw() fc2.backward_dw() - # Check for expected fusions - cudnn_frontend_supports_grouped_mlp = ( - _cudnn_frontend_supports_grouped_gemm_srelu() - if activation == "scaled_srelu" - else _cudnn_frontend_version_supported() - ) - expected_grouped_mlp_fusion = cudnn_frontend_supports_grouped_mlp and ( - ( - quantization == "mxfp8" - and dtype in (torch.bfloat16, torch.float16) + # Determine whether op fusion is expected + is_fusion_expected = False + if quantization == "mxfp8": + is_fusion_expected = ( + dtype in (torch.bfloat16, torch.float16) and ( (not activation_is_glu and glu_interleave_size is None) or (activation_is_glu and glu_interleave_size == 32) ) ) - or ( - quantization == "nvfp4_rht" - and dtype == torch.bfloat16 + if quantization == "nvfp4_rht": + is_fusion_expected = ( + dtype == torch.bfloat16 and activation == "scaled_srelu" and glu_interleave_size is None ) - ) - if expected_grouped_mlp_fusion: + if is_fusion_expected: + is_fusion_expected = ( + _cudnn_frontend_supports_grouped_gemm_srelu() + if activation == "scaled_srelu" + else _cudnn_frontend_version_supported() + ) + + # Check that fusion is applied if expected + if is_fusion_expected: if activation_is_glu: forward_cls = te.ops.fused.ForwardGroupedMLP_CuTeGEMMGLU backward_cls = te.ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU @@ -2716,6 +2706,33 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("quantization", ("mxfp8", "nvfp4_rht")) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_srelu")) + def test_grouped_mlp_mcore_integrations( + self, + *, + bias: bool, + quantization: Optional[str], + single_grouped_weight: bool, + accumulate_into_main_grad: bool, + delay_wgrad_compute: bool, + activation: str, + ) -> None: + """Grouped MLP with advanced Mcore integrations""" + if not (accumulate_into_main_grad or delay_wgrad_compute): + pytest.skip("Repeated test case in test_grouped_mlp") + self.test_grouped_mlp( + bias=bias, + quantization=quantization, + single_grouped_weight=single_grouped_weight, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + activation=activation, + ) @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) @pytest.mark.parametrize("bias", (False, True)) From 168a9ef9d9cf661904182e2a81303b3efd5a4859 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 04:50:32 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 17 +-- tests/pytorch/test_grouped_linear.py | 203 +++++++++++++++------------ tests/pytorch/utils.py | 8 +- 3 files changed, 128 insertions(+), 100 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9134749fb4..cc7eaee711 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -75,6 +75,7 @@ _quantization_list.append("nvfp4") _quantization_list.append("nvfp4_4over6") + @pytest.fixture(autouse=True, scope="function") def _reset_rng_states_per_test(): """Restore torch, CUDA, and Python ``random`` before each test in this module.""" @@ -3493,18 +3494,10 @@ def test_grouped_mlp( assert_close_grads(probs_test, probs_ref, **tols) for group_idx in range(group_size): if bias: - assert_close_grads( - getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols - ) - assert_close_grads( - getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols - ) - assert_close_grads( - getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols - ) - assert_close_grads( - getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols - ) + assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) + assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) class TestCustomOps: diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 88eb7f1c81..433610d27a 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -572,7 +572,6 @@ def _use_legacy_grouped_linear_path(self, monkeypatch): def _use_single_grouped_param(self, monkeypatch): _enable_single_grouped_param(monkeypatch) - @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [1, 3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @@ -612,7 +611,8 @@ def test_grouped_linear_accuracy( if recipe is not None and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {recipe.__class__.__name__}" ) with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -646,9 +646,13 @@ def test_grouped_linear_accuracy( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].weight = Parameter( + getattr(grouped_linear, f"weight{i}").clone() + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].bias = Parameter( + getattr(grouped_linear, f"bias{i}").clone() + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -684,7 +688,6 @@ def test_grouped_linear_accuracy( # cuBLAS implementation should be bit-wise match torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - @pytest.mark.skipif( torch.cuda.get_device_capability() != (9, 0), reason="Only enable CUTLASS grouped gemm on Hopper", @@ -720,7 +723,6 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) - @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3]) @pytest.mark.parametrize("bs", [1]) @@ -761,7 +763,8 @@ def test_grouped_linear_accuracy_save_original_input( if recipe is not None and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {recipe.__class__.__name__}" ) with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -795,9 +798,13 @@ def test_grouped_linear_accuracy_save_original_input( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].weight = Parameter( + getattr(grouped_linear, f"weight{i}").clone() + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].bias = Parameter( + getattr(grouped_linear, f"bias{i}").clone() + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -830,7 +837,6 @@ def test_grouped_linear_accuracy_save_original_input( for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - @pytest.mark.parametrize("save_original_input", [False, True]) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @@ -866,7 +872,8 @@ def test_padding_grouped_linear_accuracy( if recipe is not None and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {recipe.__class__.__name__}" ) with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -913,7 +920,6 @@ def test_padding_grouped_linear_accuracy( for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - @staticmethod def _run_grouped_linear_path( *, @@ -975,7 +981,6 @@ def _run_grouped_linear_path( outputs.extend(getattr(grouped_linear, f"bias{i}").grad for i in range(num_gemms)) return _clone_outputs(outputs) - @pytest.mark.parametrize( "fp8_recipe", [ @@ -991,8 +996,7 @@ def _run_grouped_linear_path( @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_grouped_tensor_path_matches_legacy( - self, - fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch + self, fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch ): if torch.cuda.get_device_capability() < (10, 0): pytest.skip("GroupedTensor grouped GEMM path requires SM100+") @@ -1056,7 +1060,6 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy( assert legacy_out is not None assert_close(grouped_tensor_out, legacy_out, **tols) - @pytest.mark.parametrize( "fp8_recipe", [ @@ -1267,7 +1270,6 @@ def test_grouped_gemm(self, shape, dtype, layout, accumulate, use_cutlass, monke else: torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) - @pytest.mark.skipif( torch.cuda.get_device_capability() != (9, 0), reason="Only enable CUTLASS grouped gemm on Hopper", @@ -1314,7 +1316,6 @@ def test_grouped_gemm_cutlass_empty_groups(self, layout, monkeypatch): for tensor in out: torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) - @staticmethod def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: data = grouped_tensor.rowwise_data @@ -1328,7 +1329,6 @@ def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tens data[offset : offset + numel].copy_(tensor.reshape(-1)) offset += numel - @staticmethod def _make_grouped_tensor_from_splits( m_sizes: List[int], @@ -1348,7 +1348,6 @@ def _make_grouped_tensor_from_splits( dtype=dtype, ) - @staticmethod def _make_grouped_tensor_uniform( num_tensors: int, @@ -1368,7 +1367,6 @@ def _make_grouped_tensor_uniform( dtype=dtype, ) - @staticmethod def _apply_grouped_bias_ref( base_outs: List[torch.Tensor], @@ -1390,7 +1388,6 @@ def _apply_grouped_bias_ref( offset += ms return out - @pytest.mark.parametrize( "z, m, n, k", [ @@ -1404,7 +1401,9 @@ def _apply_grouped_bias_ref( @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("use_bias_scale", [False, True]) - def test_grouped_gemm_grouped_tensor(self, z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: + def test_grouped_gemm_grouped_tensor( + self, z, m, n, k, case, layout, accumulate, use_bias_scale + ) -> None: if torch.cuda.get_device_capability() < (9, 0): pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.") if torch.cuda.get_device_capability() < (10, 0): @@ -1444,7 +1443,9 @@ def test_grouped_gemm_grouped_tensor(self, z, m, n, k, case, layout, accumulate, if layout == "NN": out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] else: # layout == "TN" - out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] + out_ref = [ + torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z) + ] if accumulate: out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] @@ -1475,29 +1476,43 @@ def test_grouped_gemm_grouped_tensor(self, z, m, n, k, case, layout, accumulate, grouped_bias = None if layout == "TN": grouped_A = ( - self._make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + self._make_grouped_tensor_uniform(z, n, k, device, dtype) + if case != "discrete_in" + else A ) # weight grouped_B = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input if case != "discrete_out": - grouped_out = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output + grouped_out = self._make_grouped_tensor_from_splits( + m_sizes, n, device, dtype + ) # output grouped_out_bias = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - grouped_out_no_bias = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + grouped_out_no_bias = self._make_grouped_tensor_from_splits( + m_sizes, n, device, dtype + ) elif layout == "NN": grouped_A = ( - self._make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + self._make_grouped_tensor_uniform(z, n, k, device, dtype) + if case != "discrete_in" + else A ) # weight - grouped_B = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_B = self._make_grouped_tensor_from_splits( + m_sizes, n, device, dtype + ) # grad_output if case != "discrete_out": grouped_out = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) grouped_out_bias = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - grouped_out_no_bias = self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_no_bias = self._make_grouped_tensor_from_splits( + m_sizes, k, device, dtype + ) else: # layout == "NT" grouped_A = ( self._make_grouped_tensor_from_splits(m_sizes, k, device, dtype) if case != "discrete_in" else A ) # input - grouped_B = self._make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_B = self._make_grouped_tensor_from_splits( + m_sizes, n, device, dtype + ) # grad_output if case != "discrete_out": grouped_out = self._make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad grouped_out_bias = self._make_grouped_tensor_uniform(z, n, k, device, dtype) @@ -1552,7 +1567,6 @@ def test_grouped_gemm_grouped_tensor(self, z, m, n, k, case, layout, accumulate, for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): torch.testing.assert_close(o, o_ref, **tols) - @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) @@ -1652,7 +1666,9 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): ) out_result = ( - grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() + grouped_out + if isinstance(grouped_out, list) + else grouped_out.split_into_quantized_tensors() ) for i in range(z): if out_result[i].numel() == 0: @@ -1662,7 +1678,6 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): else: torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) - @staticmethod def _make_grouped_tensor_quantized_mxfp8( tensors: List[torch.Tensor], @@ -1694,10 +1709,11 @@ def _make_grouped_tensor_quantized_mxfp8( if is_weight: first_dims = None else: - first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) + first_dims = torch.tensor( + [t.shape[0] for t in tensors], dtype=torch.int64, device=device + ) return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) - @staticmethod def _per_tensor_quantize_mxfp8( tensors: List[torch.Tensor], @@ -1715,7 +1731,6 @@ def _per_tensor_quantize_mxfp8( ) return [quantizer(t) for t in tensors] - @pytest.mark.parametrize( "shape", [ @@ -1730,8 +1745,7 @@ def _per_tensor_quantize_mxfp8( @pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_grouped_gemm_grouped_tensor_mxfp8( - self, - shape, accumulate, layout: str, case: str, dtype: torch.dtype + self, shape, accumulate, layout: str, case: str, dtype: torch.dtype ) -> None: if tex.get_cublasLt_version() < 130300: pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") @@ -1821,7 +1835,6 @@ def test_grouped_gemm_grouped_tensor_mxfp8( for o, o_ref in zip(out_grouped, out_ref): torch.testing.assert_close(o, o_ref, **tols) - @pytest.mark.parametrize( "shape", [ @@ -2021,7 +2034,6 @@ def test_swizzle_scales_and_pack_ptrs_for_discrete_weights( expected_swizzled_scales_buffer, ) - @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) @pytest.mark.parametrize( "quantization", @@ -2193,7 +2205,6 @@ def train_step( for g, param in zip(graph_param_grads, op.parameters()): assert_close(g, param.grad, **tols) - @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("single_grouped_bias", (False, True)) @@ -2407,7 +2418,9 @@ def test_grouped_mlp( # Grouped MLP fused op requires GLU interleaving activation_is_glu = activation in ( - "scaled_swiglu", "scaled_clamped_qgeglu", "scaled_clamped_qgeglu_custom" + "scaled_swiglu", + "scaled_clamped_qgeglu", + "scaled_clamped_qgeglu_custom", ) glu_interleave_size = 32 if activation_is_glu else None @@ -2446,21 +2459,24 @@ def test_grouped_mlp( # Reference tensors: float64 CPU; test tensors: target dtype on CUDA x_ref, x_test = make_reference_and_test_tensors( in_shape, - min=-0.25, max=0.25, + min=-0.25, + max=0.25, quantization=quantization, test_dtype=dtype, test_device=device, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, - min=-0.25, max=0.25, + min=-0.25, + max=0.25, test_dtype=dtype, test_device=device, requires_grad=False, ) probs_ref, probs_test = make_reference_and_test_tensors( (in_shape[0],), - min=0.1, max=1.0, + min=0.1, + max=1.0, test_dtype=dtype, test_device=device, ) @@ -2472,7 +2488,8 @@ def test_grouped_mlp( for _ in range(group_size): w1_ref, w1_test = make_reference_and_test_tensors( (fc1_out_features, hidden_size), - min=-0.125, max=0.125, + min=-0.125, + max=0.125, quantization=quantization, test_dtype=dtype, test_device=device, @@ -2482,7 +2499,8 @@ def test_grouped_mlp( fc1_ws_test.append(w1_test) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, hidden_size), - min=-0.125, max=0.125, + min=-0.125, + max=0.125, quantization=quantization, test_dtype=dtype, test_device=device, @@ -2493,7 +2511,8 @@ def test_grouped_mlp( if bias: b1_ref, b1_test = make_reference_and_test_tensors( (fc1_out_features,), - min=-0.5, max=0.5, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) @@ -2501,7 +2520,8 @@ def test_grouped_mlp( fc1_bs_test.append(b1_test) b2_ref, b2_test = make_reference_and_test_tensors( (hidden_size,), - min=-0.5, max=0.5, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) @@ -2515,7 +2535,9 @@ def test_grouped_mlp( def _apply_activation(x: torch.Tensor) -> torch.Tensor: if glu_interleave_size is not None: - x = x.reshape(-1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size) + x = x.reshape( + -1, 2 * hidden_size // (2 * glu_interleave_size), 2, glu_interleave_size + ) x = x.transpose(1, 2).reshape(-1, 2 * hidden_size) if activation == "scaled_swiglu": x1, x2 = x.chunk(2, dim=-1) @@ -2536,7 +2558,9 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: ys = [] for group_idx in range(group_size): x = xs[group_idx] - fc1_out = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + fc1_out = torch.nn.functional.linear( + x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] + ) fc2_in = _apply_activation(fc1_out) * probs[group_idx].unsqueeze(-1) y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: @@ -2550,16 +2574,24 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te.ops.GroupedLinear( - group_size, hidden_size, fc1_out_features, - bias=bias, device=device, dtype=dtype, + group_size, + hidden_size, + fc1_out_features, + bias=bias, + device=device, + dtype=dtype, single_grouped_weight=single_grouped_weight, single_grouped_bias=single_grouped_bias, accumulate_into_main_grad=accumulate_into_main_grad, delay_wgrad_compute=delay_wgrad_compute, ) fc2 = te.ops.GroupedLinear( - group_size, hidden_size, hidden_size, - bias=bias, device=device, dtype=dtype, + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, single_grouped_weight=single_grouped_weight, single_grouped_bias=single_grouped_bias, accumulate_into_main_grad=accumulate_into_main_grad, @@ -2596,17 +2628,14 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: ) if accumulate_into_main_grad: main_grad_sentinel = 0.5 - weight_params_for_main_grad = ( - _grouped_weight_params( - fc1, - group_size, - single_grouped_weight=single_grouped_weight, - ) - + _grouped_weight_params( - fc2, - group_size, - single_grouped_weight=single_grouped_weight, - ) + weight_params_for_main_grad = _grouped_weight_params( + fc1, + group_size, + single_grouped_weight=single_grouped_weight, + ) + _grouped_weight_params( + fc2, + group_size, + single_grouped_weight=single_grouped_weight, ) MegatronTrainingHelper.init_main_grad_buffers( weight_params_for_main_grad, @@ -2627,12 +2656,9 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: # Determine whether op fusion is expected is_fusion_expected = False if quantization == "mxfp8": - is_fusion_expected = ( - dtype in (torch.bfloat16, torch.float16) - and ( - (not activation_is_glu and glu_interleave_size is None) - or (activation_is_glu and glu_interleave_size == 32) - ) + is_fusion_expected = dtype in (torch.bfloat16, torch.float16) and ( + (not activation_is_glu and glu_interleave_size is None) + or (activation_is_glu and glu_interleave_size == 32) ) if quantization == "nvfp4_rht": is_fusion_expected = ( @@ -2679,11 +2705,19 @@ def _apply_activation(x: torch.Tensor) -> torch.Tensor: assert_close(fc2.bias.grad[group_idx], fc2_bs_ref[group_idx].grad, **tols) assert_close(fc1.bias.grad[group_idx], fc1_bs_ref[group_idx].grad, **tols) else: - assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + assert_close_grads( + getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols + ) + assert_close_grads( + getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols + ) if not single_grouped_weight and not accumulate_into_main_grad: - assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) + assert_close_grads( + getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols + ) + assert_close_grads( + getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols + ) fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) if accumulate_into_main_grad: @@ -2938,7 +2972,6 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: torch.testing.assert_close(fc1_db_false, fc1_db_true, **bias_tols) torch.testing.assert_close(fc2_db_false, fc2_db_true, **bias_tols) - @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) @pytest.mark.parametrize("zero_out_wgrad", (False, True)) @@ -3088,7 +3121,6 @@ def _run_backward(module, fc1, fc2): _weight_params(test_fc2), expected_main_grads=ref_fc2_grads ) - @pytest.mark.parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @@ -3284,7 +3316,6 @@ def train_step( for graph_grad, param in zip(graph_param_grads, module.parameters()): assert_close(graph_grad, param.grad, **tols) - def test_grouped_gemm_quant_cute_matches_mxfp8_quantized(self) -> None: if not mxfp8_available: pytest.skip(reason_for_no_mxfp8) @@ -3334,12 +3365,12 @@ def test_grouped_gemm_quant_cute_matches_mxfp8_quantized(self) -> None: device=device, ) inputs = { - "a_tensor": torch.empty(1, total_m, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), - "b_tensor": torch.empty(num_groups, n, k, dtype=torch.float8_e4m3fn, device=device).permute( - 1, 2, 0 - ), + "a_tensor": torch.empty( + 1, total_m, k, dtype=torch.float8_e4m3fn, device=device + ).permute(1, 2, 0), + "b_tensor": torch.empty( + num_groups, n, k, dtype=torch.float8_e4m3fn, device=device + ).permute(1, 2, 0), "sfa_tensor": torch.empty( 1, total_m // 128, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 0b0bd55a38..343538e65e 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -42,8 +42,12 @@ fp8_available, reason_for_no_fp8 = transformer_engine.pytorch.is_fp8_available(return_reason=True) -mxfp8_available, reason_for_no_mxfp8 = transformer_engine.pytorch.is_mxfp8_available(return_reason=True) -nvfp4_available, reason_for_no_nvfp4 = transformer_engine.pytorch.is_nvfp4_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = transformer_engine.pytorch.is_mxfp8_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = transformer_engine.pytorch.is_nvfp4_available( + return_reason=True +) def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: From 90586fb753c01ea6429966fae6f39777925e7e78 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 10 Jun 2026 05:11:52 +0000 Subject: [PATCH 11/11] Review suggestion from @greptile-apps Signed-off-by: Tim Moon --- tests/pytorch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 343538e65e..6c6404d071 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +import math import os import random import subprocess @@ -18,7 +19,6 @@ import transformer_engine from transformer_engine.common.recipe import Recipe -import math from transformer_engine.pytorch import ( DType, Float8CurrentScalingQuantizer,