You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Introduces MegaMoECuteDsl, a fused-communication MoE backend that runs the ported MegaMoE NVFP4 CuteDSL kernel (Sm100MegaMoEKernel) on SM100/SM103. The kernel fuses dispatch + FC1 + SwiGLU + FC2 + combine in a single launch via the in-kernel NVLink dispatch barrier. Single-rank degenerate path and multi-rank EP path are both wired through a unified always-pad-to-max_tokens_per_rank staging contract.
Components:
- New backend tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py with build-time symmetric-memory provider rendezvous, local staging cache, FUSED_COMM scheduler binding, and quantize_input that pads SF columns to round_up(ceil(hidden/16), 4) to match the kernel TMA contract.
- New torch custom op torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell registered in tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py. Hosts Sm100MegaMoENvfp4Runner (TunableRunner) with PARALLEL distributed tuning so every EP rank converges on the same compiled tactic for every chunk (required for the in-kernel dispatch barrier), MegaMoeSymmMemProvider with zero-init buffer, candidate tactic enumeration sweeping {static, atomic_counter} load balance modes, and a stricter IS_MEGAMOE_OP_AVAILABLE probe so half-installed cutlass-dsl wheels do not break the rest of custom_ops. The runner allocates local_workspace via torch.zeros and calls local_workspace.zero_() on every forward in addition to the existing shared_workspace.zero_(), so the cached buffer cannot feed garbage into the kernel's Int32 atomic counters (l1_arrival_count, fc1_done_counter, grid_sync_counter); a negative Int32 in any counter slot would otherwise make the in-kernel spin_wait (v >= positive_threshold) impossible to satisfy and hang the kernel at 100% SM / 0% memory bandwidth. The op returns None and the caller uses the in-place mutated combine_output directly, because torch.library forbids the return value from aliasing any mutated input.
- Ported CuteDSL kernel package at tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ (16 files) with package-relative imports, plus blocked_scale.py extracted from upstream runner_fc12.py.
- New NVFP4MegaMoECuteDslMethod quant method in tensorrt_llm/_torch/modules/fused_moe/quantization.py: 16-atom gate/up interleave + to_blocked swizzle + flattened-stack staging for mega_fc1_weight / mega_fc1_weight_sf / mega_fc2_weight / mega_fc2_weight_sf. Builds CPU shared-staging buffers for all four derived parameters and registers them with the load balancer through register_all_parameter_slot_and_to_fix_weight_fns so both static AND dynamic EPLB migrate the mega-format bytes atomically with the underlying NVFP4 raw weights + scales.
- FusedCommMoEScheduler now calls backend.quantize_input on zero-token chunks too (DG backend updated in parallel) so each fused-comm backend owns its own empty-tensor layout.
- create_moe.py factory and ConfigurableMoE allowlist updated; MoEDeveloperGuide adds Backend Capability Matrix entries, FUSED_COMM anti-patterns, and the autotuner tactic representation reference.
Tests:
- Drops the asymmetric MoeBackendType.MEGAMOE alias; both variants are spelled out as MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with should_skip_megamoe_deepgemm / should_skip_megamoe_cutedsl helpers.
- New focused tests in test_moe_backend.py: kernel-package import, to_blocked roundtrip, can_implement positive/negative, quantize_input zero-token, multi-rank symm-provider gate, alpha-gate, sf-byte-width helper, atomic_counter sweep, dynamic-EPLB-now-supported, and a byte-equivalence test for the FC1 gate/up 16-atom interleave that catches any gate-vs-up swap.
- test_moe_module.py adds factory + scheduler wiring coverage and updates the shared dist helper to recognise MEGAMOE_CUTEDSL.
Hard gates documented in MEGAMOE_CUTEDSL_DESIGN.md:
- Kernel ABI hard-codes per-expert alpha / fc2_input_scale to 1.0; v1 alpha gate in NVFP4MegaMoECuteDslMethod rejects non-1.0 checkpoint values until the upstream kernel ABI is extended.
- Launch contract requires T == max_tokens_per_rank on every call; backend stages real T rows then pads topk_idx tail to -1 (dispatch_prep skips negative expert ids).
- IS_MEGAMOE_OP_AVAILABLE strict probe protects the rest of custom_ops on partial cutlass-dsl installs.
Known follow-ups:
- per-slot to_blocked perf can be batched (high-volume models).
- Real GPU E2E run is blocked on OCI worktree rebuild for ABI parity with the new C++ bindings.
Designed and documented under tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md.
Signed-off-by: xxi <xxi@nvidia.com>
0 commit comments