@@ -642,13 +642,18 @@ def query_megamoe_shared_workspace_bytes(
642642 expand_intermediate_size_per_partition : int ,
643643 max_tokens_per_rank : int ,
644644 tactic : Optional [Tuple ] = None ,
645+ apply_topk_in_fc1 : bool = True ,
646+ gate_up_clamp : Optional [float ] = None ,
645647 ) -> int :
646648 """Probe ``Sm100MegaMoEKernel.get_workspace_sizes()`` for the
647649 shared workspace byte count. The shared workspace size is
648- invariant across all candidate tactics (its regions depend only
649- on world_size / num_experts_per_rank / num_topk /
650- max_tokens_per_rank -- see _build_shared_region_specs in
651- megamoe_kernel.py), so we use the default tactic for the probe.
650+ invariant across all candidate tactics and across the codegen-time
651+ graph/clamp modes (its regions depend only on world_size /
652+ num_experts_per_rank / num_topk / max_tokens_per_rank -- see
653+ _build_shared_region_specs in megamoe_kernel.py), so we use the
654+ default tactic for the probe. ``apply_topk_in_fc1`` / ``gate_up_clamp``
655+ are still threaded so the probe kernel ctor signature is satisfied
656+ and matches the real build.
652657 """
653658 from ..cute_dsl_kernels .mega_moe_nvfp4 import import_kernel
654659
@@ -681,7 +686,10 @@ def query_megamoe_shared_workspace_bytes(
681686 num_topk = int (num_topk ),
682687 max_tokens_per_rank = int (max_tokens_per_rank ),
683688 hidden = int (hidden_size ),
684- fc2_in_kernel_topk_reduce = bool (tactic [5 ]),
689+ fc2_output_dtype = cutlass .BFloat16 ,
690+ in_kernel_fc2_reduce = bool (tactic [5 ]),
691+ apply_topk_in_fc1 = bool (apply_topk_in_fc1 ),
692+ gate_up_clamp = (None if gate_up_clamp is None else float (gate_up_clamp )),
685693 ** _LOCKED_KERNEL_KWARGS ,
686694 )
687695 _ , shared_bytes = probe .get_workspace_sizes ()
@@ -717,6 +725,8 @@ def __init__(
717725 expand_intermediate_size_per_partition : int ,
718726 max_tokens_per_rank : int ,
719727 output_dtype : torch .dtype ,
728+ apply_topk_in_fc1 : bool = True ,
729+ gate_up_clamp : Optional [float ] = None ,
720730 ) -> None :
721731 super ().__init__ ()
722732 if (sm_version := get_sm_version ()) not in (100 , 103 ):
@@ -745,6 +755,11 @@ def __init__(
745755 )
746756 self .max_tokens_per_rank = int (max_tokens_per_rank )
747757 self .output_dtype = output_dtype
758+ # Codegen-time graph/clamp modes. They change the generated
759+ # kernel, so they are part of ``unique_id`` (and therefore the
760+ # compile-cache key) -- never per-call runtime kwargs.
761+ self .apply_topk_in_fc1 = bool (apply_topk_in_fc1 )
762+ self .gate_up_clamp = None if gate_up_clamp is None else float (gate_up_clamp )
748763
749764 def unique_id (self ):
750765 return (
@@ -757,6 +772,8 @@ def unique_id(self):
757772 self .expand_intermediate_size_per_partition ,
758773 self .max_tokens_per_rank ,
759774 str (self .output_dtype ),
775+ self .apply_topk_in_fc1 ,
776+ self .gate_up_clamp ,
760777 )
761778
762779 def get_valid_tactics (
@@ -810,6 +827,17 @@ def _autotuner_inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.T
810827 if isinstance (topk_weights , torch .Tensor ):
811828 topk_weights .zero_ ()
812829
830+ # New per-expert scale inputs fc1_alpha(8) / fc2_alpha(9) /
831+ # fc1_norm_const(10) are inserted after fc2_weight_sf(7) and
832+ # before combine_output(11). Fill them with 1.0 (NOT zero):
833+ # the FC1/FC2 epilogues divide/scale by these and a zero
834+ # fc1_norm_const would make the fc1-out NVFP4 quant divide by
835+ # zero during fake autotune runs.
836+ for alpha_idx in (8 , 9 , 10 ):
837+ tensor = inputs [alpha_idx ]
838+ if isinstance (tensor , torch .Tensor ):
839+ tensor .fill_ (1.0 )
840+
813841 return inputs
814842
815843 def get_tuning_config (self ) -> TuningConfig :
@@ -838,7 +866,9 @@ def _num_tokens(shapes: List[torch.Size]) -> int:
838866 ConstraintSpec (1 , 0 , _num_tokens ), # activation_sf
839867 ConstraintSpec (2 , 0 , _num_tokens ), # topk_idx
840868 ConstraintSpec (3 , 0 , _num_tokens ), # topk_weights
841- ConstraintSpec (8 , 0 , _num_tokens ), # combine_output
869+ # combine_output moved from idx 8 -> 11 after inserting
870+ # fc1_alpha(8) / fc2_alpha(9) / fc1_norm_const(10).
871+ ConstraintSpec (11 , 0 , _num_tokens ), # combine_output
842872 ),
843873 inputs_pre_hook = self ._autotuner_inputs_pre_hook ,
844874 use_cold_l2_cache = True ,
@@ -887,11 +917,17 @@ def _build_kernel(self, tactic: Tuple):
887917 num_topk = self .num_topk ,
888918 max_tokens_per_rank = self .max_tokens_per_rank ,
889919 hidden = self .hidden_size ,
890- fc2_in_kernel_topk_reduce = bool (use_bf16_redg ),
920+ fc2_output_dtype = cutlass .BFloat16 ,
921+ in_kernel_fc2_reduce = bool (use_bf16_redg ),
922+ apply_topk_in_fc1 = self .apply_topk_in_fc1 ,
923+ gate_up_clamp = self .gate_up_clamp ,
891924 ** _LOCKED_KERNEL_KWARGS ,
892925 )
893926
894927 def _compile_or_get (self , tactic : Tuple , kernel , runtime_kwargs ):
928+ # ``unique_id()`` already carries apply_topk_in_fc1 / gate_up_clamp,
929+ # so the codegen-time graph/clamp modes are part of the cache key
930+ # without listing them again here.
895931 cache_key = (
896932 self .unique_id (),
897933 tuple (tactic [0 ]),
@@ -978,8 +1014,11 @@ def forward(
9781014 fc1_weight_sf ,
9791015 fc2_weight ,
9801016 fc2_weight_sf ,
1017+ fc1_alpha ,
1018+ fc2_alpha ,
1019+ fc1_norm_const ,
9811020 combine_output ,
982- ) = inputs [:9 ]
1021+ ) = inputs [:12 ]
9831022 assert peer_offsets is not None , (
9841023 "Sm100MegaMoENvfp4Runner.forward requires peer_offsets kwarg "
9851024 "(length = world_size); single-rank degenerate mode passes "
@@ -1037,6 +1076,12 @@ def forward(
10371076 fc1_weight_sf_cute = _to_cute (fc1_weight_sf )
10381077 fc2_weight_cute = _to_cute (fc2_weight )
10391078 fc2_weight_sf_cute = _to_cute (fc2_weight_sf )
1079+ # Per-expert fp32 scale tensors are 1-D ``(num_local_slots,)``;
1080+ # 4-byte alignment matches the fp32 element size (the kernel
1081+ # reads them as a plain fp32 vector, no 16-byte TMA tile).
1082+ fc1_alpha_cute = _to_cute (fc1_alpha , assumed_align = 4 )
1083+ fc2_alpha_cute = _to_cute (fc2_alpha , assumed_align = 4 )
1084+ fc1_norm_const_cute = _to_cute (fc1_norm_const , assumed_align = 4 )
10401085 combine_output_cute = _to_cute (combine_output )
10411086 local_workspace_cute = _to_cute (local_workspace )
10421087 shared_workspace_cute = _to_cute (shared_workspace )
@@ -1066,6 +1111,9 @@ def forward(
10661111 fc1_weight_sf = fc1_weight_sf_cute ,
10671112 fc2_weight = fc2_weight_cute ,
10681113 fc2_weight_sf = fc2_weight_sf_cute ,
1114+ fc1_alpha = fc1_alpha_cute ,
1115+ fc2_alpha = fc2_alpha_cute ,
1116+ fc1_norm_const = fc1_norm_const_cute ,
10691117 combine_output = combine_output_cute ,
10701118 local_workspace = local_workspace_cute ,
10711119 shared_workspace = shared_workspace_cute ,
@@ -1110,6 +1158,9 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11101158 fc1_weight_sf : torch .Tensor ,
11111159 fc2_weight : torch .Tensor ,
11121160 fc2_weight_sf : torch .Tensor ,
1161+ fc1_alpha : torch .Tensor ,
1162+ fc2_alpha : torch .Tensor ,
1163+ fc1_norm_const : torch .Tensor ,
11131164 combine_output : torch .Tensor ,
11141165 shared_workspace : torch .Tensor ,
11151166 world_size : int ,
@@ -1121,6 +1172,8 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11211172 expand_intermediate_size_per_partition : int ,
11221173 max_tokens_per_rank : int ,
11231174 peer_offsets : List [int ],
1175+ apply_topk_in_fc1 : bool = True ,
1176+ gate_up_clamp : Optional [float ] = None ,
11241177 ) -> None :
11251178 """Run the fused MegaMoE CuteDSL NVFP4 kernel.
11261179
@@ -1155,6 +1208,8 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11551208 expand_intermediate_size_per_partition = expand_intermediate_size_per_partition ,
11561209 max_tokens_per_rank = max_tokens_per_rank ,
11571210 output_dtype = combine_output .dtype ,
1211+ apply_topk_in_fc1 = apply_topk_in_fc1 ,
1212+ gate_up_clamp = gate_up_clamp ,
11581213 )
11591214 inputs = [
11601215 activation ,
@@ -1165,6 +1220,9 @@ def cute_dsl_megamoe_nvfp4_blackwell(
11651220 fc1_weight_sf ,
11661221 fc2_weight ,
11671222 fc2_weight_sf ,
1223+ fc1_alpha ,
1224+ fc2_alpha ,
1225+ fc1_norm_const ,
11681226 combine_output ,
11691227 ]
11701228 tuner = AutoTuner .get ()
@@ -1193,6 +1251,9 @@ def _(
11931251 fc1_weight_sf : torch .Tensor ,
11941252 fc2_weight : torch .Tensor ,
11951253 fc2_weight_sf : torch .Tensor ,
1254+ fc1_alpha : torch .Tensor ,
1255+ fc2_alpha : torch .Tensor ,
1256+ fc1_norm_const : torch .Tensor ,
11961257 combine_output : torch .Tensor ,
11971258 shared_workspace : torch .Tensor ,
11981259 world_size : int ,
@@ -1204,5 +1265,7 @@ def _(
12041265 expand_intermediate_size_per_partition : int ,
12051266 max_tokens_per_rank : int ,
12061267 peer_offsets : List [int ],
1268+ apply_topk_in_fc1 : bool = True ,
1269+ gate_up_clamp : Optional [float ] = None ,
12071270 ) -> None :
12081271 return None
0 commit comments