Skip to content

Commit 91bb9cf

Browse files
timmoon10codex
andauthored
[PyTorch] Refactor grouped MLP into joint forward-backward fused op (#3117)
* Refactor grouped MLP into joint fused op Consolidate the experimental grouped MLP forward and backward CuTe DSL fusions into a single joint fused operation. Move grouped-MLP-specific helper logic out of ops/_common.py and update tests to assert the joint forward/backward fusion object. Co-authored-by: Codex <codex@openai.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestions from @greptile-apps Also fix linter warnings. Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Codex <codex@openai.com>
1 parent 3976a68 commit 91bb9cf

6 files changed

Lines changed: 2215 additions & 2469 deletions

File tree

tests/pytorch/test_fusible_ops.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import transformer_engine.common.recipe
2121
import transformer_engine.pytorch as te
2222
import transformer_engine.pytorch.ops as te_ops
23-
from transformer_engine.pytorch.ops._common import (
23+
from transformer_engine.pytorch.ops.fused.grouped_mlp import (
2424
_cudnn_frontend_supports_grouped_gemm_srelu,
2525
_cudnn_frontend_version_supported,
2626
is_glu_activation,
@@ -4028,25 +4028,19 @@ def _make_module():
40284028
)
40294029
if expected_grouped_mlp_fusion:
40304030
if activation_is_glu:
4031-
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU
4032-
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU
4031+
fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMGLU
40334032
else:
4034-
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary
4035-
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary
4036-
if forward_cls.is_supported():
4033+
fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMUnary
4034+
if fused_cls.is_supported():
40374035
forward_ops = module._module_groups[0]._forward_ops
4038-
assert len(forward_ops) == 1
4039-
assert isinstance(
4040-
forward_ops[0][0],
4041-
forward_cls,
4042-
)
4043-
if backward_cls is not None and backward_cls.is_supported():
40444036
backward_ops = module._module_groups[0]._backward_ops
4037+
assert len(forward_ops) == 1
40454038
assert len(backward_ops) == 1
40464039
assert isinstance(
4047-
backward_ops[0][0],
4048-
backward_cls,
4040+
forward_ops[0][0],
4041+
fused_cls,
40494042
)
4043+
assert backward_ops[0][0] is forward_ops[0][0]
40504044

40514045
# Loose tols for sanity checking
40524046
tols = {"rtol": 0.125, "atol": 0.25}
@@ -4130,10 +4124,8 @@ def test_grouped_mlp_single_weight_numerics(
41304124
) -> None:
41314125
"""single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP."""
41324126

4133-
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
4134-
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
4135-
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported():
4136-
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")
4127+
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
4128+
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
41374129

41384130
split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
41394131
random.shuffle(split_sizes)
@@ -4234,13 +4226,14 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]:
42344226
assert len(forward_ops) == 1
42354227
assert isinstance(
42364228
forward_ops[0][0],
4237-
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU,
4229+
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
42384230
)
42394231
assert len(backward_ops) == 1
42404232
assert isinstance(
42414233
backward_ops[0][0],
4242-
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU,
4234+
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
42434235
)
4236+
assert backward_ops[0][0] is forward_ops[0][0]
42444237

42454238
if single_grouped_weight:
42464239
fc1_dw = fc1.weight.grad.detach().clone()
@@ -4352,10 +4345,8 @@ def test_grouped_mlp_overwrite_main_grad(
43524345
that read ``.grad`` don't see stale bytes from the cached dummy).
43534346
"""
43544347

4355-
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
4356-
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
4357-
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported():
4358-
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")
4348+
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
4349+
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
43594350

43604351
recipe = make_recipe("mxfp8")
43614352
split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
@@ -4486,7 +4477,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
44864477
) -> None:
44874478
"""Grouped MLP forward+backward should be CUDA graph capturable (MXFP8)."""
44884479

4489-
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported():
4480+
if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported():
44904481
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
44914482
if dtype not in (torch.bfloat16, torch.float16):
44924483
pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16")
@@ -4628,13 +4619,14 @@ def train_step(
46284619
assert len(forward_ops) == 1
46294620
assert isinstance(
46304621
forward_ops[0][0],
4631-
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU,
4622+
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
46324623
)
46334624
assert len(backward_ops) == 1
46344625
assert isinstance(
46354626
backward_ops[0][0],
4636-
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU,
4627+
te_ops.fused.GroupedMLP_CuTeGEMMGLU,
46374628
)
4629+
assert backward_ops[0][0] is forward_ops[0][0]
46384630

46394631
fresh_x = torch.randn_like(static_x)
46404632
fresh_probs = torch.randn_like(static_probs)

0 commit comments

Comments
 (0)