Skip to content

Commit e736f64

Browse files
committed
[None][feat] MegaMoECuteDsl v2: per-expert NVFP4 scales + kernel re-port
Re-port v2 CuteDSL kernel package, remove v1 alpha==1 gate, add SwiGLU clamp, and thread per-expert fc1_norm_const / fc2_alpha derived from each expert raw w2.input_scale so non-uniform NVFP4 checkpoints (e.g. DeepSeek-V4) compute correctly. Signed-off-by: xxi <xxi@nvidia.com>
1 parent 56e5043 commit e736f64

23 files changed

Lines changed: 8762 additions & 3151 deletions

tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -642,13 +642,18 @@ def query_megamoe_shared_workspace_bytes(
642642
expand_intermediate_size_per_partition: int,
643643
max_tokens_per_rank: int,
644644
tactic: Optional[Tuple] = None,
645+
apply_topk_in_fc1: bool = True,
646+
gate_up_clamp: Optional[float] = None,
645647
) -> int:
646648
"""Probe ``Sm100MegaMoEKernel.get_workspace_sizes()`` for the
647649
shared workspace byte count. The shared workspace size is
648-
invariant across all candidate tactics (its regions depend only
649-
on world_size / num_experts_per_rank / num_topk /
650-
max_tokens_per_rank -- see _build_shared_region_specs in
651-
megamoe_kernel.py), so we use the default tactic for the probe.
650+
invariant across all candidate tactics and across the codegen-time
651+
graph/clamp modes (its regions depend only on world_size /
652+
num_experts_per_rank / num_topk / max_tokens_per_rank -- see
653+
_build_shared_region_specs in megamoe_kernel.py), so we use the
654+
default tactic for the probe. ``apply_topk_in_fc1`` / ``gate_up_clamp``
655+
are still threaded so the probe kernel ctor signature is satisfied
656+
and matches the real build.
652657
"""
653658
from ..cute_dsl_kernels.mega_moe_nvfp4 import import_kernel
654659

@@ -681,7 +686,10 @@ def query_megamoe_shared_workspace_bytes(
681686
num_topk=int(num_topk),
682687
max_tokens_per_rank=int(max_tokens_per_rank),
683688
hidden=int(hidden_size),
684-
fc2_in_kernel_topk_reduce=bool(tactic[5]),
689+
fc2_output_dtype=cutlass.BFloat16,
690+
in_kernel_fc2_reduce=bool(tactic[5]),
691+
apply_topk_in_fc1=bool(apply_topk_in_fc1),
692+
gate_up_clamp=(None if gate_up_clamp is None else float(gate_up_clamp)),
685693
**_LOCKED_KERNEL_KWARGS,
686694
)
687695
_, shared_bytes = probe.get_workspace_sizes()
@@ -717,6 +725,8 @@ def __init__(
717725
expand_intermediate_size_per_partition: int,
718726
max_tokens_per_rank: int,
719727
output_dtype: torch.dtype,
728+
apply_topk_in_fc1: bool = True,
729+
gate_up_clamp: Optional[float] = None,
720730
) -> None:
721731
super().__init__()
722732
if (sm_version := get_sm_version()) not in (100, 103):
@@ -745,6 +755,11 @@ def __init__(
745755
)
746756
self.max_tokens_per_rank = int(max_tokens_per_rank)
747757
self.output_dtype = output_dtype
758+
# Codegen-time graph/clamp modes. They change the generated
759+
# kernel, so they are part of ``unique_id`` (and therefore the
760+
# compile-cache key) -- never per-call runtime kwargs.
761+
self.apply_topk_in_fc1 = bool(apply_topk_in_fc1)
762+
self.gate_up_clamp = None if gate_up_clamp is None else float(gate_up_clamp)
748763

749764
def unique_id(self):
750765
return (
@@ -757,6 +772,8 @@ def unique_id(self):
757772
self.expand_intermediate_size_per_partition,
758773
self.max_tokens_per_rank,
759774
str(self.output_dtype),
775+
self.apply_topk_in_fc1,
776+
self.gate_up_clamp,
760777
)
761778

762779
def get_valid_tactics(
@@ -810,6 +827,17 @@ def _autotuner_inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.T
810827
if isinstance(topk_weights, torch.Tensor):
811828
topk_weights.zero_()
812829

830+
# New per-expert scale inputs fc1_alpha(8) / fc2_alpha(9) /
831+
# fc1_norm_const(10) are inserted after fc2_weight_sf(7) and
832+
# before combine_output(11). Fill them with 1.0 (NOT zero):
833+
# the FC1/FC2 epilogues divide/scale by these and a zero
834+
# fc1_norm_const would make the fc1-out NVFP4 quant divide by
835+
# zero during fake autotune runs.
836+
for alpha_idx in (8, 9, 10):
837+
tensor = inputs[alpha_idx]
838+
if isinstance(tensor, torch.Tensor):
839+
tensor.fill_(1.0)
840+
813841
return inputs
814842

815843
def get_tuning_config(self) -> TuningConfig:
@@ -838,7 +866,9 @@ def _num_tokens(shapes: List[torch.Size]) -> int:
838866
ConstraintSpec(1, 0, _num_tokens), # activation_sf
839867
ConstraintSpec(2, 0, _num_tokens), # topk_idx
840868
ConstraintSpec(3, 0, _num_tokens), # topk_weights
841-
ConstraintSpec(8, 0, _num_tokens), # combine_output
869+
# combine_output moved from idx 8 -> 11 after inserting
870+
# fc1_alpha(8) / fc2_alpha(9) / fc1_norm_const(10).
871+
ConstraintSpec(11, 0, _num_tokens), # combine_output
842872
),
843873
inputs_pre_hook=self._autotuner_inputs_pre_hook,
844874
use_cold_l2_cache=True,
@@ -887,11 +917,17 @@ def _build_kernel(self, tactic: Tuple):
887917
num_topk=self.num_topk,
888918
max_tokens_per_rank=self.max_tokens_per_rank,
889919
hidden=self.hidden_size,
890-
fc2_in_kernel_topk_reduce=bool(use_bf16_redg),
920+
fc2_output_dtype=cutlass.BFloat16,
921+
in_kernel_fc2_reduce=bool(use_bf16_redg),
922+
apply_topk_in_fc1=self.apply_topk_in_fc1,
923+
gate_up_clamp=self.gate_up_clamp,
891924
**_LOCKED_KERNEL_KWARGS,
892925
)
893926

894927
def _compile_or_get(self, tactic: Tuple, kernel, runtime_kwargs):
928+
# ``unique_id()`` already carries apply_topk_in_fc1 / gate_up_clamp,
929+
# so the codegen-time graph/clamp modes are part of the cache key
930+
# without listing them again here.
895931
cache_key = (
896932
self.unique_id(),
897933
tuple(tactic[0]),
@@ -978,8 +1014,11 @@ def forward(
9781014
fc1_weight_sf,
9791015
fc2_weight,
9801016
fc2_weight_sf,
1017+
fc1_alpha,
1018+
fc2_alpha,
1019+
fc1_norm_const,
9811020
combine_output,
982-
) = inputs[:9]
1021+
) = inputs[:12]
9831022
assert peer_offsets is not None, (
9841023
"Sm100MegaMoENvfp4Runner.forward requires peer_offsets kwarg "
9851024
"(length = world_size); single-rank degenerate mode passes "
@@ -1037,6 +1076,12 @@ def forward(
10371076
fc1_weight_sf_cute = _to_cute(fc1_weight_sf)
10381077
fc2_weight_cute = _to_cute(fc2_weight)
10391078
fc2_weight_sf_cute = _to_cute(fc2_weight_sf)
1079+
# Per-expert fp32 scale tensors are 1-D ``(num_local_slots,)``;
1080+
# 4-byte alignment matches the fp32 element size (the kernel
1081+
# reads them as a plain fp32 vector, no 16-byte TMA tile).
1082+
fc1_alpha_cute = _to_cute(fc1_alpha, assumed_align=4)
1083+
fc2_alpha_cute = _to_cute(fc2_alpha, assumed_align=4)
1084+
fc1_norm_const_cute = _to_cute(fc1_norm_const, assumed_align=4)
10401085
combine_output_cute = _to_cute(combine_output)
10411086
local_workspace_cute = _to_cute(local_workspace)
10421087
shared_workspace_cute = _to_cute(shared_workspace)
@@ -1066,6 +1111,9 @@ def forward(
10661111
fc1_weight_sf=fc1_weight_sf_cute,
10671112
fc2_weight=fc2_weight_cute,
10681113
fc2_weight_sf=fc2_weight_sf_cute,
1114+
fc1_alpha=fc1_alpha_cute,
1115+
fc2_alpha=fc2_alpha_cute,
1116+
fc1_norm_const=fc1_norm_const_cute,
10691117
combine_output=combine_output_cute,
10701118
local_workspace=local_workspace_cute,
10711119
shared_workspace=shared_workspace_cute,
@@ -1110,6 +1158,9 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11101158
fc1_weight_sf: torch.Tensor,
11111159
fc2_weight: torch.Tensor,
11121160
fc2_weight_sf: torch.Tensor,
1161+
fc1_alpha: torch.Tensor,
1162+
fc2_alpha: torch.Tensor,
1163+
fc1_norm_const: torch.Tensor,
11131164
combine_output: torch.Tensor,
11141165
shared_workspace: torch.Tensor,
11151166
world_size: int,
@@ -1121,6 +1172,8 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11211172
expand_intermediate_size_per_partition: int,
11221173
max_tokens_per_rank: int,
11231174
peer_offsets: List[int],
1175+
apply_topk_in_fc1: bool = True,
1176+
gate_up_clamp: Optional[float] = None,
11241177
) -> None:
11251178
"""Run the fused MegaMoE CuteDSL NVFP4 kernel.
11261179
@@ -1155,6 +1208,8 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11551208
expand_intermediate_size_per_partition=expand_intermediate_size_per_partition,
11561209
max_tokens_per_rank=max_tokens_per_rank,
11571210
output_dtype=combine_output.dtype,
1211+
apply_topk_in_fc1=apply_topk_in_fc1,
1212+
gate_up_clamp=gate_up_clamp,
11581213
)
11591214
inputs = [
11601215
activation,
@@ -1165,6 +1220,9 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11651220
fc1_weight_sf,
11661221
fc2_weight,
11671222
fc2_weight_sf,
1223+
fc1_alpha,
1224+
fc2_alpha,
1225+
fc1_norm_const,
11681226
combine_output,
11691227
]
11701228
tuner = AutoTuner.get()
@@ -1193,6 +1251,9 @@ def _(
11931251
fc1_weight_sf: torch.Tensor,
11941252
fc2_weight: torch.Tensor,
11951253
fc2_weight_sf: torch.Tensor,
1254+
fc1_alpha: torch.Tensor,
1255+
fc2_alpha: torch.Tensor,
1256+
fc1_norm_const: torch.Tensor,
11961257
combine_output: torch.Tensor,
11971258
shared_workspace: torch.Tensor,
11981259
world_size: int,
@@ -1204,5 +1265,7 @@ def _(
12041265
expand_intermediate_size_per_partition: int,
12051266
max_tokens_per_rank: int,
12061267
peer_offsets: List[int],
1268+
apply_topk_in_fc1: bool = True,
1269+
gate_up_clamp: Optional[float] = None,
12071270
) -> None:
12081271
return None

tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"from_blocked",
5050
"import_kernel",
5151
"import_sym_buffer_host",
52+
"import_topk_reduce",
5253
"stack_byte_reinterpretable_tensors",
5354
"to_blocked",
5455
]
@@ -81,3 +82,20 @@ def import_sym_buffer_host():
8182
# SymBufferHost lives at module scope as a factory; the upstream API
8283
# constructs the per-world-size variant inside sym_buffer.py.
8384
return sym_buffer
85+
86+
87+
def import_topk_reduce():
88+
"""Lazily import the standalone CuteDSL top-k reduce kernel API.
89+
90+
Returns ``(compile_topk_reduce, launch_compiled_topk_reduce)`` from
91+
:mod:`.topk_reduce` (mirrors :func:`import_kernel`). The reduce kernel
92+
is only needed by the opt-in transformers graph
93+
(``apply_topk_in_fc1=False``); the deepgemm-default route reduces on
94+
the host via ``combine_output.sum(dim=1)`` and never imports it. Like
95+
``import_kernel`` this stays lazy so non-SM100 / no-cutlass-dsl
96+
environments can import the backend for capability probing without
97+
pulling the heavyweight CuteDSL symbols.
98+
"""
99+
from .topk_reduce import compile_topk_reduce, launch_compiled_topk_reduce
100+
101+
return compile_topk_reduce, launch_compiled_topk_reduce

tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,7 @@ def enrich_work_tile_info(
238238
- fc1 tiles peek the dispatch->fc1 ``fc1_ready_counter`` at the
239239
same slot index but with ``valid_tokens_in_tile`` as threshold
240240
(per-tile dynamic). This branch only emits when
241-
``self.fc1_ready_counter_ptr is not None`` (MegaMoE mode). See
242-
fc12_integrate_comm.md §4.
241+
``self.fc1_ready_counter_ptr is not None`` (MegaMoE mode).
243242
"""
244243
# Invalid tiles keep (None_ | 0); do not index an arbitrary counter slot.
245244
is_valid = base_work.is_valid_tile
@@ -250,7 +249,8 @@ def enrich_work_tile_info(
250249
# pull) and fc2 release-add (fc1 epi) target the per-task-tile
251250
# counter slot indexed by ``cumulative_token_block_count +
252251
# tile_n_idx``.
253-
counter_slot = base_work.cumulative_token_block_count + base_work.tile_n_idx
252+
counter_slot = (base_work.cumulative_token_block_count +
253+
base_work.tile_n_idx)
254254
is_fc1 = base_work.phase == Int32(int(BlockPhase.Linear1))
255255
is_fc2 = base_work.phase == Int32(int(BlockPhase.Linear2))
256256

0 commit comments

Comments
 (0)