|
20 | 20 | import transformer_engine.common.recipe |
21 | 21 | import transformer_engine.pytorch as te |
22 | 22 | import transformer_engine.pytorch.ops as te_ops |
23 | | -from transformer_engine.pytorch.ops._common import ( |
| 23 | +from transformer_engine.pytorch.ops.fused.grouped_mlp import ( |
24 | 24 | _cudnn_frontend_supports_grouped_gemm_srelu, |
25 | 25 | _cudnn_frontend_version_supported, |
26 | 26 | is_glu_activation, |
@@ -4028,25 +4028,19 @@ def _make_module(): |
4028 | 4028 | ) |
4029 | 4029 | if expected_grouped_mlp_fusion: |
4030 | 4030 | if activation_is_glu: |
4031 | | - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU |
4032 | | - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU |
| 4031 | + fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMGLU |
4033 | 4032 | else: |
4034 | | - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary |
4035 | | - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary |
4036 | | - if forward_cls.is_supported(): |
| 4033 | + fused_cls = te_ops.fused.GroupedMLP_CuTeGEMMUnary |
| 4034 | + if fused_cls.is_supported(): |
4037 | 4035 | forward_ops = module._module_groups[0]._forward_ops |
4038 | | - assert len(forward_ops) == 1 |
4039 | | - assert isinstance( |
4040 | | - forward_ops[0][0], |
4041 | | - forward_cls, |
4042 | | - ) |
4043 | | - if backward_cls is not None and backward_cls.is_supported(): |
4044 | 4036 | backward_ops = module._module_groups[0]._backward_ops |
| 4037 | + assert len(forward_ops) == 1 |
4045 | 4038 | assert len(backward_ops) == 1 |
4046 | 4039 | assert isinstance( |
4047 | | - backward_ops[0][0], |
4048 | | - backward_cls, |
| 4040 | + forward_ops[0][0], |
| 4041 | + fused_cls, |
4049 | 4042 | ) |
| 4043 | + assert backward_ops[0][0] is forward_ops[0][0] |
4050 | 4044 |
|
4051 | 4045 | # Loose tols for sanity checking |
4052 | 4046 | tols = {"rtol": 0.125, "atol": 0.25} |
@@ -4130,10 +4124,8 @@ def test_grouped_mlp_single_weight_numerics( |
4130 | 4124 | ) -> None: |
4131 | 4125 | """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" |
4132 | 4126 |
|
4133 | | - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): |
4134 | | - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") |
4135 | | - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): |
4136 | | - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") |
| 4127 | + if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported(): |
| 4128 | + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") |
4137 | 4129 |
|
4138 | 4130 | split_sizes = [split_alignment * (i + 1) for i in range(group_size)] |
4139 | 4131 | random.shuffle(split_sizes) |
@@ -4234,13 +4226,14 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: |
4234 | 4226 | assert len(forward_ops) == 1 |
4235 | 4227 | assert isinstance( |
4236 | 4228 | forward_ops[0][0], |
4237 | | - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, |
| 4229 | + te_ops.fused.GroupedMLP_CuTeGEMMGLU, |
4238 | 4230 | ) |
4239 | 4231 | assert len(backward_ops) == 1 |
4240 | 4232 | assert isinstance( |
4241 | 4233 | backward_ops[0][0], |
4242 | | - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, |
| 4234 | + te_ops.fused.GroupedMLP_CuTeGEMMGLU, |
4243 | 4235 | ) |
| 4236 | + assert backward_ops[0][0] is forward_ops[0][0] |
4244 | 4237 |
|
4245 | 4238 | if single_grouped_weight: |
4246 | 4239 | fc1_dw = fc1.weight.grad.detach().clone() |
@@ -4352,10 +4345,8 @@ def test_grouped_mlp_overwrite_main_grad( |
4352 | 4345 | that read ``.grad`` don't see stale bytes from the cached dummy). |
4353 | 4346 | """ |
4354 | 4347 |
|
4355 | | - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): |
4356 | | - pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") |
4357 | | - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): |
4358 | | - pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") |
| 4348 | + if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported(): |
| 4349 | + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") |
4359 | 4350 |
|
4360 | 4351 | recipe = make_recipe("mxfp8") |
4361 | 4352 | split_sizes = [split_alignment * (i + 1) for i in range(group_size)] |
@@ -4486,7 +4477,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( |
4486 | 4477 | ) -> None: |
4487 | 4478 | """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" |
4488 | 4479 |
|
4489 | | - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): |
| 4480 | + if not te_ops.fused.GroupedMLP_CuTeGEMMGLU.is_supported(): |
4490 | 4481 | pytest.skip("MXFP8 fused grouped MLP is not supported on this system") |
4491 | 4482 | if dtype not in (torch.bfloat16, torch.float16): |
4492 | 4483 | pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") |
@@ -4628,13 +4619,14 @@ def train_step( |
4628 | 4619 | assert len(forward_ops) == 1 |
4629 | 4620 | assert isinstance( |
4630 | 4621 | forward_ops[0][0], |
4631 | | - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, |
| 4622 | + te_ops.fused.GroupedMLP_CuTeGEMMGLU, |
4632 | 4623 | ) |
4633 | 4624 | assert len(backward_ops) == 1 |
4634 | 4625 | assert isinstance( |
4635 | 4626 | backward_ops[0][0], |
4636 | | - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, |
| 4627 | + te_ops.fused.GroupedMLP_CuTeGEMMGLU, |
4637 | 4628 | ) |
| 4629 | + assert backward_ops[0][0] is forward_ops[0][0] |
4638 | 4630 |
|
4639 | 4631 | fresh_x = torch.randn_like(static_x) |
4640 | 4632 | fresh_probs = torch.randn_like(static_probs) |
|
0 commit comments