Skip to content

Commit 2772b99

Browse files
authored
[TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend (#14608)
Signed-off-by: xxi <xxi@nvidia.com>
1 parent 0593968 commit 2772b99

46 files changed

Lines changed: 19831 additions & 162 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/skills/trtllm-moe-develop/SKILL.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,26 @@ Checklist:
268268
- Existing legacy `forward` methods can be read for compatibility context, but
269269
they are not the default pattern for new backend work.
270270

271+
### Imported Kernel ABI Checklist
272+
273+
When importing or wrapping an upstream kernel, derive the TRT-LLM adapter
274+
contract from the lowest-level kernel consumer. Comments, docs, design notes,
275+
and parameter names are useful hints, but they are not proof of the runtime ABI.
276+
277+
- Derive weight shape and layout from the kernel entrypoint, `make_layout`, TMA,
278+
MMA/GEMM transforms, and stride usage. Record required tensor shape, stride,
279+
physical storage layout, and boundary view layout.
280+
- Derive alpha and scale semantics from kernel consumption points. Trace where
281+
alpha, norm constants, block scales, activation scales, and weight scales are
282+
loaded and multiplied before deciding how upper layers compute or pack them.
283+
Treat weight bytes, block scales/SF, and global alpha/norm constants as
284+
separate contracts.
285+
- Design the upper-layer adapter from the kernel ABI upward. Map each kernel
286+
input/output to an adapter responsibility: storage tensor, view/transposition,
287+
dtype reinterpretation, padding, scale packing, workspace ownership,
288+
synchronization, and output reduction. Validate parity with upstream
289+
invocation dumps, not just final output.
290+
271291
### Quantization And Weights
272292

273293
Role:

.pre-commit-config.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,25 @@ common-files: &common_files |
302302
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
303303
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
304304
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
305+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
306+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
307+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
308+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
309+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
310+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
311+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
312+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
313+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
314+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
315+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
316+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
317+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
318+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
319+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
320+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
321+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
322+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
323+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
305324
tensorrt_llm/_torch/cute_dsl_utils.py |
306325
tensorrt_llm/_torch/debug/__init__.py |
307326
tensorrt_llm/_torch/debug/debug_hook.py |
@@ -1658,6 +1677,25 @@ legacy-files: &legacy_files |
16581677
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
16591678
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
16601679
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
1680+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
1681+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
1682+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
1683+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
1684+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
1685+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
1686+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
1687+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
1688+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
1689+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
1690+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
1691+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
1692+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
1693+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
1694+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
1695+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
1696+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
1697+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
1698+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
16611699
tensorrt_llm/_torch/cute_dsl_utils.py |
16621700
tensorrt_llm/_torch/debug/__init__.py |
16631701
tensorrt_llm/_torch/debug/debug_hook.py |

legacy-files.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,25 @@ tensorrt_llm/_torch/cute_dsl_kernels/blackwell/__init__.py
294294
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py
295295
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py
296296
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
297+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py
298+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py
299+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py
300+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py
301+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py
302+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py
303+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py
304+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py
305+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py
306+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py
307+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py
308+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py
309+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py
310+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py
311+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py
312+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py
313+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py
314+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py
315+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py
297316
tensorrt_llm/_torch/cute_dsl_utils.py
298317
tensorrt_llm/_torch/debug/__init__.py
299318
tensorrt_llm/_torch/debug/debug_hook.py

pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,25 @@ exclude = [
352352
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
353353
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
354354
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
355+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py",
356+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py",
357+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
358+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
359+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py",
360+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py",
361+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
362+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
363+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
364+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
365+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
366+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
367+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
368+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
369+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
370+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
371+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
372+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py",
373+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py",
355374
"tensorrt_llm/_torch/cute_dsl_utils.py",
356375
"tensorrt_llm/_torch/debug/__init__.py",
357376
"tensorrt_llm/_torch/debug/debug_hook.py",

ruff-legacy.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,25 @@ include = [
311311
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
312312
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
313313
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
314+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py",
315+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py",
316+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
317+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
318+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py",
319+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py",
320+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
321+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
322+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
323+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
324+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
325+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
326+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
327+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
328+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
329+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
330+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
331+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py",
332+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py",
314333
"tensorrt_llm/_torch/cute_dsl_utils.py",
315334
"tensorrt_llm/_torch/debug/__init__.py",
316335
"tensorrt_llm/_torch/debug/debug_hook.py",

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,8 +1511,7 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor,
15111511
if dtype == torch.float4_e2m1fn_x2:
15121512
return (torch.rand(shapes, device=device) * 10 - 5).to(
15131513
torch.uint8).view(dtype)
1514-
else:
1515-
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)
1514+
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)
15161515

15171516
def _prepare_input_tensors(
15181517
self, profile: OptimizationProfile,

tensorrt_llm/_torch/custom_ops/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def inplace_slice_copy(dest: torch.Tensor, src: torch.Tensor, dim1_start: int,
5858
'cute_dsl_nvfp4_dense_gemm_swiglu_fp4out_blackwell',
5959
]
6060

61+
# MegaMoE NVFP4 op probes a strict superset of IS_CUTLASS_DSL_AVAILABLE
62+
# (cutlass.torch + cutlass._mlir + cute_nvgpu MMA atoms + the ported
63+
# CuteDSL kernel package). The cute_dsl_megamoe_custom_op module
64+
# sets ``IS_MEGAMOE_OP_AVAILABLE`` based on its own try/except probe;
65+
# importing the module is safe regardless of the result -- it just
66+
# logs and leaves ``IS_MEGAMOE_OP_AVAILABLE = False`` on partial
67+
# cutlass-dsl installs so callers can fall back via the factory.
68+
from .cute_dsl_megamoe_custom_op import IS_MEGAMOE_OP_AVAILABLE
69+
if IS_MEGAMOE_OP_AVAILABLE:
70+
from .cute_dsl_megamoe_custom_op import cute_dsl_megamoe_nvfp4_blackwell
71+
__all__ += ['cute_dsl_megamoe_nvfp4_blackwell']
72+
6173
if IS_CUDA_TILE_AVAILABLE:
6274
from .cuda_tile_custom_ops import (cuda_tile_rms_norm,
6375
cuda_tile_rms_norm_fuse_residual_)

0 commit comments

Comments
 (0)