Skip to content

Commit ff947a8

Browse files
committed
[None][feat] Add MegaMoECuteDsl NVFP4 MoE backend
Introduces MegaMoECuteDsl, a fused-communication MoE backend that runs the ported MegaMoE NVFP4 CuteDSL kernel (Sm100MegaMoEKernel) on SM100/SM103. The kernel fuses dispatch + FC1 + SwiGLU + FC2 + combine in a single launch via the in-kernel NVLink dispatch barrier. Single-rank degenerate path and multi-rank EP path are both wired through a unified always-pad-to-max_tokens_per_rank staging contract. Components: - New backend tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py with build-time symmetric-memory provider rendezvous, local staging cache, FUSED_COMM scheduler binding, and quantize_input that pads SF columns to round_up(ceil(hidden/16), 4) to match the kernel TMA contract. - New torch custom op torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell registered in tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py. Hosts Sm100MegaMoENvfp4Runner (TunableRunner) with PARALLEL distributed tuning so every EP rank converges on the same compiled tactic for every chunk (required for the in-kernel dispatch barrier), MegaMoeSymmMemProvider with zero-init buffer, candidate tactic enumeration sweeping {static, atomic_counter} load balance modes, and a stricter IS_MEGAMOE_OP_AVAILABLE probe so half-installed cutlass-dsl wheels do not break the rest of custom_ops. The runner allocates local_workspace via torch.zeros and calls local_workspace.zero_() on every forward in addition to the existing shared_workspace.zero_(), so the cached buffer cannot feed garbage into the kernel's Int32 atomic counters (l1_arrival_count, fc1_done_counter, grid_sync_counter); a negative Int32 in any counter slot would otherwise make the in-kernel spin_wait (v >= positive_threshold) impossible to satisfy and hang the kernel at 100% SM / 0% memory bandwidth. The op returns None and the caller uses the in-place mutated combine_output directly, because torch.library forbids the return value from aliasing any mutated input. - Ported CuteDSL kernel package at tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ (16 files) with package-relative imports, plus blocked_scale.py extracted from upstream runner_fc12.py. - New NVFP4MegaMoECuteDslMethod quant method in tensorrt_llm/_torch/modules/fused_moe/quantization.py: 16-atom gate/up interleave + to_blocked swizzle + flattened-stack staging for mega_fc1_weight / mega_fc1_weight_sf / mega_fc2_weight / mega_fc2_weight_sf. Builds CPU shared-staging buffers for all four derived parameters and registers them with the load balancer through register_all_parameter_slot_and_to_fix_weight_fns so both static AND dynamic EPLB migrate the mega-format bytes atomically with the underlying NVFP4 raw weights + scales. - FusedCommMoEScheduler now calls backend.quantize_input on zero-token chunks too (DG backend updated in parallel) so each fused-comm backend owns its own empty-tensor layout. - create_moe.py factory and ConfigurableMoE allowlist updated; MoEDeveloperGuide adds Backend Capability Matrix entries, FUSED_COMM anti-patterns, and the autotuner tactic representation reference. Tests: - Drops the asymmetric MoeBackendType.MEGAMOE alias; both variants are spelled out as MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with should_skip_megamoe_deepgemm / should_skip_megamoe_cutedsl helpers. - New focused tests in test_moe_backend.py: kernel-package import, to_blocked roundtrip, can_implement positive/negative, quantize_input zero-token, multi-rank symm-provider gate, alpha-gate, sf-byte-width helper, atomic_counter sweep, dynamic-EPLB-now-supported, and a byte-equivalence test for the FC1 gate/up 16-atom interleave that catches any gate-vs-up swap. - test_moe_module.py adds factory + scheduler wiring coverage and updates the shared dist helper to recognise MEGAMOE_CUTEDSL. Hard gates documented in MEGAMOE_CUTEDSL_DESIGN.md: - Kernel ABI hard-codes per-expert alpha / fc2_input_scale to 1.0; v1 alpha gate in NVFP4MegaMoECuteDslMethod rejects non-1.0 checkpoint values until the upstream kernel ABI is extended. - Launch contract requires T == max_tokens_per_rank on every call; backend stages real T rows then pads topk_idx tail to -1 (dispatch_prep skips negative expert ids). - IS_MEGAMOE_OP_AVAILABLE strict probe protects the rest of custom_ops on partial cutlass-dsl installs. Known follow-ups: - per-slot to_blocked perf can be batched (high-volume models). - Real GPU E2E run is blocked on OCI worktree rebuild for ABI parity with the new C++ bindings. Designed and documented under tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md. Signed-off-by: xxi <xxi@nvidia.com>
1 parent c7e7fc5 commit ff947a8

40 files changed

Lines changed: 19612 additions & 120 deletions

.pre-commit-config.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,22 @@ common-files: &common_files |
302302
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
303303
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
304304
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
305+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py |
306+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
307+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
308+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py |
309+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py |
310+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
311+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
312+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
313+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
314+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
315+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
316+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
317+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
318+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
319+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
320+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
305321
tensorrt_llm/_torch/cute_dsl_utils.py |
306322
tensorrt_llm/_torch/debug/__init__.py |
307323
tensorrt_llm/_torch/debug/debug_hook.py |
@@ -1658,6 +1674,22 @@ legacy-files: &legacy_files |
16581674
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
16591675
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
16601676
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
1677+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py |
1678+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
1679+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
1680+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py |
1681+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py |
1682+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
1683+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
1684+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
1685+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
1686+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
1687+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
1688+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
1689+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
1690+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
1691+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
1692+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
16611693
tensorrt_llm/_torch/cute_dsl_utils.py |
16621694
tensorrt_llm/_torch/debug/__init__.py |
16631695
tensorrt_llm/_torch/debug/debug_hook.py |

legacy-files.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,22 @@ tensorrt_llm/_torch/cute_dsl_kernels/blackwell/__init__.py
294294
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py
295295
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py
296296
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
297+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py
298+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py
299+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py
300+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py
301+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py
302+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py
303+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py
304+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py
305+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py
306+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py
307+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py
308+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py
309+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py
310+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py
311+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py
312+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py
297313
tensorrt_llm/_torch/cute_dsl_utils.py
298314
tensorrt_llm/_torch/debug/__init__.py
299315
tensorrt_llm/_torch/debug/debug_hook.py

pyproject.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,22 @@ exclude = [
352352
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
353353
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
354354
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
355+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py",
356+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
357+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
358+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py",
359+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py",
360+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
361+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
362+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
363+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
364+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
365+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
366+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
367+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
368+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
369+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
370+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
355371
"tensorrt_llm/_torch/cute_dsl_utils.py",
356372
"tensorrt_llm/_torch/debug/__init__.py",
357373
"tensorrt_llm/_torch/debug/debug_hook.py",

ruff-legacy.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,22 @@ include = [
311311
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
312312
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
313313
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
314+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py",
315+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
316+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
317+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py",
318+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py",
319+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
320+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
321+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
322+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
323+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
324+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
325+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
326+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
327+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
328+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
329+
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
314330
"tensorrt_llm/_torch/cute_dsl_utils.py",
315331
"tensorrt_llm/_torch/debug/__init__.py",
316332
"tensorrt_llm/_torch/debug/debug_hook.py",

tensorrt_llm/_torch/autotuner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,8 +1468,15 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor,
14681468
if dtype == torch.float4_e2m1fn_x2:
14691469
return (torch.rand(shapes, device=device) * 10 - 5).to(
14701470
torch.uint8).view(dtype)
1471-
else:
1472-
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)
1471+
if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1472+
# PyTorch's direct ``.to(float8_*)`` cast can trip on certain
1473+
# GPU/driver combinations (illegal memory access during the
1474+
# cast kernel). Bridge through ``uint8`` like the FP4 branch
1475+
# above. Backends that need real FP8 numerics during
1476+
# autotuning should set up their own warmup data.
1477+
return (torch.rand(shapes, device=device) * 10 - 5).to(
1478+
torch.uint8).view(dtype)
1479+
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)
14731480

14741481
def _prepare_input_tensors(
14751482
self, profile: OptimizationProfile,

tensorrt_llm/_torch/custom_ops/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@
4949
'cute_dsl_nvfp4_dense_gemm_swiglu_fp4out_blackwell',
5050
]
5151

52+
# MegaMoE NVFP4 op probes a strict superset of IS_CUTLASS_DSL_AVAILABLE
53+
# (cutlass.torch + cutlass._mlir + cute_nvgpu MMA atoms + the ported
54+
# CuteDSL kernel package). The cute_dsl_megamoe_custom_op module
55+
# sets ``IS_MEGAMOE_OP_AVAILABLE`` based on its own try/except probe;
56+
# importing the module is safe regardless of the result -- it just
57+
# logs and leaves ``IS_MEGAMOE_OP_AVAILABLE = False`` on partial
58+
# cutlass-dsl installs so callers can fall back via the factory.
59+
from .cute_dsl_megamoe_custom_op import IS_MEGAMOE_OP_AVAILABLE
60+
if IS_MEGAMOE_OP_AVAILABLE:
61+
from .cute_dsl_megamoe_custom_op import cute_dsl_megamoe_nvfp4_blackwell
62+
__all__ += ['cute_dsl_megamoe_nvfp4_blackwell']
63+
5264
if IS_CUDA_TILE_AVAILABLE:
5365
from .cuda_tile_custom_ops import (cuda_tile_rms_norm,
5466
cuda_tile_rms_norm_fuse_residual_)

0 commit comments

Comments
 (0)