[PyTorch] Refactor grouped MLP into joint forward-backward fused op#3117
Conversation
Consolidate the experimental grouped MLP forward and backward CuTe DSL fusions into a single joint fused operation. Move grouped-MLP-specific helper logic out of ops/_common.py and update tests to assert the joint forward/backward fusion object. Co-authored-by: Codex <codex@openai.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryThis PR consolidates the previously separate
Confidence Score: 5/5This is a well-scoped refactoring with no functional changes to kernel dispatch or gradient computation; the only net-new logic is the sliding-window fusion helper and the joint context-saving strategy, both of which are straightforward. The joint forward-backward design is internally consistent: all state is saved in fc1_ctx, fc2_ctx attributes are set correctly, and the test explicitly asserts the same-object invariant. The sliding-window algorithm in fuse_grouped_mlp_ops handles all edge cases correctly. The helper functions migrated from _common.py are byte-for-byte identical to their originals. No correctness issues were found. No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant Fuser
participant GroupedMLP_CuTeGEMMBase as _GroupedMLP_CuTeGEMMBase (joint op)
participant CuDNN as cuDNN FE Kernels
Note over Fuser,CuDNN: Forward pass
Fuser->>GroupedMLP_CuTeGEMMBase: fuser_forward(basic_op_ctxs, input_)
GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: quantize FC1/FC2 inputs and weights
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_activation_kernel(FC1 GEMM + GLU/SReLU)
CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_x (activation output)
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_quant_kernel(FC2 GEMM)
CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_out
GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: fc1_ctx.save_for_backward(split_sizes, grouped_fc1_x, weights, activation_in, scales, grouped_fc2_x, fc2_weights)
GroupedMLP_CuTeGEMMBase-->>Fuser: fc2_out
Note over Fuser,CuDNN: Backward pass (same object)
Fuser->>GroupedMLP_CuTeGEMMBase: fuser_backward(basic_op_ctxs, grad_output)
GroupedMLP_CuTeGEMMBase->>GroupedMLP_CuTeGEMMBase: restore saved tensors from fc1_ctx
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_dactivation_kernel(FC2 dgrad + dActivation)
CuDNN-->>GroupedMLP_CuTeGEMMBase: fc1_dy, grad_scales
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_wgrad_kernel(FC2 wgrad)
CuDNN-->>GroupedMLP_CuTeGEMMBase: fc2_dW
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_quant_kernel(FC1 dgrad)
CuDNN-->>GroupedMLP_CuTeGEMMBase: grad_input
GroupedMLP_CuTeGEMMBase->>CuDNN: grouped_gemm_wgrad_kernel(FC1 wgrad)
CuDNN-->>GroupedMLP_CuTeGEMMBase: fc1_dW
GroupedMLP_CuTeGEMMBase-->>Fuser: grad_input, [fc1_grad_params, (), fc2_grad_params]
Reviews (2): Last reviewed commit: "Review suggestions from @greptile-apps" | Re-trigger Greptile |
| import transformer_engine.pytorch.ops as te_ops | ||
| from transformer_engine.pytorch.ops._common import ( | ||
| from transformer_engine.pytorch.ops.fused.grouped_mlp import ( | ||
| _cudnn_frontend_supports_grouped_gemm_srelu, | ||
| _cudnn_frontend_version_supported, |
There was a problem hiding this comment.
Test imports module-private helpers directly from
grouped_mlp. These names (_cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu) carry a leading _ signalling they are internal implementation details. Importing them from tests tightly couples the test suite to internal symbol names; a later rename or move would break the import with no deprecation period. Consider exposing a small public helper or adding a # noqa: PLC2701 comment with a brief rationale.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
For better or for worse, these tests are tightly coupled with the implementation of the grouped MLP fused op and changing them is beyond the scope of this PR.
Also fix linter warnings. Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Description
We've been pursuing some highly experimental fusions for the MoE grouped MLP block (added in #2769). These have hyperspecific optimizations, depend on bleeding edge cuDNN Frontend versions, and break the op fuser contract that fusions can be applied independently in the forward and backward (see #2981 (comment)). This PR refactors the grouped MLP fused ops as joint forward-backward fusions (see #3080), so that we can fully integrate the forward and backward impls without needing to fuss with compatibility with the basic ops. Also, by packing all grouped-MLP-specific logic in one file, we can treat it as a black box for experimental optimizations.
Type of change
Changes
Checklist: