Skip to content

[PyTorch] Refactor grouped MLP into joint forward-backward fused op#3117

Merged
timmoon10 merged 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/refactor-grouped-mlp-op
Jun 11, 2026
Merged

[PyTorch] Refactor grouped MLP into joint forward-backward fused op#3117
timmoon10 merged 2 commits into
NVIDIA:mainfrom
timmoon10:tmoon/refactor-grouped-mlp-op

Conversation

@timmoon10

Copy link
Copy Markdown
Member

Description

We've been pursuing some highly experimental fusions for the MoE grouped MLP block (added in #2769). These have hyperspecific optimizations, depend on bleeding edge cuDNN Frontend versions, and break the op fuser contract that fusions can be applied independently in the forward and backward (see #2981 (comment)). This PR refactors the grouped MLP fused ops as joint forward-backward fusions (see #3080), so that we can fully integrate the forward and backward impls without needing to fuss with compatibility with the basic ops. Also, by packing all grouped-MLP-specific logic in one file, we can treat it as a black box for experimental optimizations.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Refactor grouped MLP into a joint forward-backward fused op
  • Consolidate grouped-MLP-specific logic into same source file as grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Consolidate the experimental grouped MLP forward and backward CuTe DSL fusions into a single joint fused operation. Move grouped-MLP-specific helper logic out of ops/_common.py and update tests to assert the joint forward/backward fusion object.

Co-authored-by: Codex <codex@openai.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR consolidates the previously separate ForwardGroupedMLP_CuTeGEMM* and BackwardGroupedMLP_CuTeGEMM* fused ops into a single joint _GroupedMLP_CuTeGEMMBase class (with GroupedMLP_CuTeGEMMGLU and GroupedMLP_CuTeGEMMUnary subclasses), and co-locates all grouped-MLP-specific helpers that were scattered in _common.py and the two deleted files into the new grouped_mlp.py.

  • The joint op stores all forward state in fc1_ctx.save_for_backward and dispatches cuDNN FE kernels in both fuser_forward and fuser_backward via the same object, satisfying the backward_ops[0][0] is forward_ops[0][0] invariant now asserted in the tests.
  • All grouped-MLP-specific helpers (_cudnn_frontend_version_*, NVFP4 grouping utils, _compute_grad_params, sliding-window fuse_grouped_mlp_ops, etc.) are self-contained in grouped_mlp.py, enabling future experimental optimizations without affecting the rest of the op fuser framework.

Confidence Score: 5/5

This is a well-scoped refactoring with no functional changes to kernel dispatch or gradient computation; the only net-new logic is the sliding-window fusion helper and the joint context-saving strategy, both of which are straightforward.

The joint forward-backward design is internally consistent: all state is saved in fc1_ctx, fc2_ctx attributes are set correctly, and the test explicitly asserts the same-object invariant. The sliding-window algorithm in fuse_grouped_mlp_ops handles all edge cases correctly. The helper functions migrated from _common.py are byte-for-byte identical to their originals. No correctness issues were found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/grouped_mlp.py New 2193-line file that merges forward and backward grouped MLP logic into a single joint fused op; contains correct sliding-window fusion, quantizer handling, cuDNN kernel dispatch, and context management.
transformer_engine/pytorch/ops/_common.py Removes ~360 lines of grouped-MLP-specific helpers (version checks, NVFP4 grouping helpers, etc.) that were migrated to grouped_mlp.py; remaining helpers are generic and unchanged.
transformer_engine/pytorch/ops/fused/init.py Replaces imports of separate ForwardGroupedMLP_CuTeGEMM* and BackwardGroupedMLP_CuTeGEMM* classes with the unified GroupedMLP_CuTeGEMMGLU and GroupedMLP_CuTeGEMMUnary; registration is now handled inside grouped_mlp.py.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Deleted entirely; logic absorbed into grouped_mlp.py as the fuser_backward method of _GroupedMLP_CuTeGEMMBase.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Deleted entirely; logic absorbed into grouped_mlp.py as the fuser_forward method of _GroupedMLP_CuTeGEMMBase.
tests/pytorch/test_fusible_ops.py Updates test imports and structural assertions to reflect the new joint fused-op model; adds identity check (backward_ops[0][0] is forward_ops[0][0]) to verify the joint fusion invariant.

Sequence Diagram

sequenceDiagram
    participant Fuser
    participant GroupedMLP_CuTeGEMMBase as _GroupedMLP_CuTeGEMMBase (joint op)
    participant CuDNN as cuDNN FE Kernels

    Note over Fuser,CuDNN: Forward pass
    Fuser->>GroupedMLP_CuTeGEMMBase: fuser_forward(basic_op_ctxs, input_)
    GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: quantize FC1/FC2 inputs and weights
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_activation_kernel(FC1 GEMM + GLU/SReLU)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_x (activation output)
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_quant_kernel(FC2 GEMM)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_out
    GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: fc1_ctx.save_for_backward(split_sizes, grouped_fc1_x, weights, activation_in, scales, grouped_fc2_x, fc2_weights)
    GroupedMLP_CuTeGEMMBase-->>Fuser: fc2_out

    Note over Fuser,CuDNN: Backward pass (same object)
    Fuser->>GroupedMLP_CuTeGEMMBase: fuser_backward(basic_op_ctxs, grad_output)
    GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: restore saved tensors from fc1_ctx
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_dactivation_kernel(FC2 dgrad + dActivation)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: fc1_dy, grad_scales
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_wgrad_kernel(FC2 wgrad)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_dW
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_quant_kernel(FC1 dgrad)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: grad_input
    GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_wgrad_kernel(FC1 wgrad)
    CuDNN-->>GroupedMLP_CuTeGEMMBase: fc1_dW
    GroupedMLP_CuTeGEMMBase-->>Fuser: grad_input, [fc1_grad_params, (), fc2_grad_params]
Loading

Reviews (2): Last reviewed commit: "Review suggestions from @greptile-apps" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/grouped_mlp.py
Comment on lines 22 to 25
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import (
from transformer_engine.pytorch.ops.fused.grouped_mlp import (
_cudnn_frontend_supports_grouped_gemm_srelu,
_cudnn_frontend_version_supported,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test imports module-private helpers directly from grouped_mlp. These names (_cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu) carry a leading _ signalling they are internal implementation details. Importing them from tests tightly couples the test suite to internal symbol names; a later rename or move would break the import with no deprecation period. Consider exposing a small public helper or adding a # noqa: PLC2701 comment with a brief rationale.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better or for worse, these tests are tightly coupled with the implementation of the grouped MLP fused op and changing them is beyond the scope of this PR.

Also fix linter warnings.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

Copy link
Copy Markdown
Member Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10 timmoon10 merged commit 91bb9cf into NVIDIA:main Jun 11, 2026
21 of 25 checks passed
@timmoon10 timmoon10 deleted the tmoon/refactor-grouped-mlp-op branch June 11, 2026 20:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants