Skip to content

Commit 8f9b734

Browse files
committed
Add TEGroupedMLP export support for NemotronH models
Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent a076e6c commit 8f9b734

3 files changed

Lines changed: 127 additions & 14 deletions

File tree

modelopt/torch/export/plugins/mcore_custom.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
115115
)
116116

117117

118+
class GroupedMLPSlicing(CustomModuleMapping):
119+
"""A custom module mapping that slices fused grouped MLP weights into per-expert weights."""
120+
121+
def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
122+
"""Create a custom module mapping that slices grouped MLP weights for export."""
123+
super().__init__(
124+
func_name="grouped_mlp_slicing",
125+
target_name_or_prefix=target_name_or_prefix,
126+
func_kwargs=func_kwargs,
127+
)
128+
129+
118130
class GatedMLPMerging(CustomModuleMapping):
119131
"""A custom module mapping that merges gate_proj and up_proj."""
120132

modelopt/torch/export/plugins/mcore_nemotron.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ROW_TP,
2525
CustomModuleMapping,
2626
GroupedMLPMerging,
27+
GroupedMLPSlicing,
2728
NameRemapping,
2829
QKVMerging,
2930
QKVSlicing,
@@ -125,6 +126,7 @@
125126
"conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d."),
126127
"in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj."),
127128
"out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj."),
129+
"fused_norm": NameRemapping("backbone.layers.{}.norm.weight"),
128130
# Attention
129131
"input_layernorm": NameRemapping("backbone.layers.{}.norm."),
130132
"linear_qkv": QKVSlicing("backbone.layers.{}.mixer."),
@@ -147,6 +149,9 @@
147149
# Latent MoE
148150
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
149151
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
152+
# Grouped local experts (TEGroupedMLP: fused per-expert weights)
153+
"experts.linear_fc1": GroupedMLPSlicing("backbone.layers.{}.mixer.experts.{{}}.up_proj"),
154+
"experts.linear_fc2": GroupedMLPSlicing("backbone.layers.{}.mixer.experts.{{}}.down_proj"),
150155
# MTP
151156
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm."),
152157
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."),

modelopt/torch/export/unified_export_megatron.py

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,16 @@ def save_pretrained(
285285
self._hf_config.save_pretrained(save_directory)
286286
try:
287287
generation_config = transformers.GenerationConfig.from_pretrained(
288-
self._hf_pretrained_model_name
288+
self._hf_pretrained_model_name,
289+
trust_remote_code=self.trust_remote_code,
289290
)
290291
generation_config.save_pretrained(save_directory)
291292
except OSError:
292293
pass
293294
try:
294295
tokenizer = transformers.AutoTokenizer.from_pretrained(
295-
self._hf_pretrained_model_name
296+
self._hf_pretrained_model_name,
297+
trust_remote_code=self.trust_remote_code,
296298
)
297299
tokenizer.save_pretrained(save_directory)
298300
except OSError:
@@ -420,6 +422,13 @@ def _get_state_dict(self):
420422
def _get_transformer_layer_state_dict(self, layer, layer_id):
421423
if not isinstance(layer.input_layernorm, IdentityOp):
422424
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
425+
elif (
426+
hasattr(layer.self_attention, "linear_qkv")
427+
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
428+
and layer.self_attention.linear_qkv.layer_norm_weight is not None
429+
and "fused_norm" in self.rules
430+
):
431+
self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id)
423432

424433
if not isinstance(layer.self_attention, IdentityOp):
425434
if "MLASelfAttention" in str(type(layer.self_attention)):
@@ -458,6 +467,15 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
458467

459468
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
460469
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id)
470+
elif (
471+
not isinstance(layer.mlp, IdentityOp)
472+
and "MoE" not in str(type(layer.mlp))
473+
and hasattr(layer.mlp, "linear_fc1")
474+
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
475+
and layer.mlp.linear_fc1.layer_norm_weight is not None
476+
and "fused_norm" in self.rules
477+
):
478+
self.rules["fused_norm"](layer.mlp.linear_fc1.layer_norm_weight, layer_id)
461479

462480
if not isinstance(layer.mlp, IdentityOp):
463481
if "MoE" in str(type(layer.mlp)):
@@ -473,22 +491,30 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
473491
self.rules["shared_experts.linear_fc2"](
474492
layer.mlp.shared_experts.linear_fc2, layer_id
475493
)
476-
if not self.rules.get("use_packed_local_experts", False):
477-
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
494+
if hasattr(layer.mlp.experts, "local_experts"):
495+
if not self.rules.get("use_packed_local_experts", False):
496+
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
497+
self.rules["local_experts.linear_fc1"](
498+
expert.linear_fc1, layer_id, expert_id
499+
)
500+
self.rules["local_experts.linear_fc2"](
501+
expert.linear_fc2, layer_id, expert_id
502+
)
503+
else:
504+
# For llama 4, in hf unified checkpoint, all local experts share one scale
478505
self.rules["local_experts.linear_fc1"](
479-
expert.linear_fc1, layer_id, expert_id
506+
layer.mlp.experts.local_experts, layer_id
480507
)
481508
self.rules["local_experts.linear_fc2"](
482-
expert.linear_fc2, layer_id, expert_id
509+
layer.mlp.experts.local_experts, layer_id
483510
)
484-
else:
485-
# For llama 4, in hf unified checkpoint, all local experts share one scale
486-
self.rules["local_experts.linear_fc1"](
487-
layer.mlp.experts.local_experts, layer_id
488-
)
489-
self.rules["local_experts.linear_fc2"](
490-
layer.mlp.experts.local_experts, layer_id
491-
)
511+
elif "experts.linear_fc1" in self.rules:
512+
# TEGroupedMLP: experts use fused grouped GEMM with a single
513+
# linear_fc1/linear_fc2 for all experts (no local_experts attribute).
514+
# Uses "experts.linear_fc1" rule (GroupedMLPMerging) instead of
515+
# "local_experts.linear_fc1" which expects per-expert iteration.
516+
self.rules["experts.linear_fc1"](layer.mlp.experts.linear_fc1, layer_id)
517+
self.rules["experts.linear_fc2"](layer.mlp.experts.linear_fc2, layer_id)
492518
else:
493519
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id)
494520
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)
@@ -529,6 +555,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
529555
def _get_mamba_layer_state_dict(self, layer, layer_id):
530556
if not isinstance(layer.norm, IdentityOp):
531557
self.rules["norm"](layer.norm, layer_id)
558+
elif (
559+
isinstance(layer.norm, IdentityOp)
560+
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
561+
and layer.mixer.in_proj.layer_norm_weight is not None
562+
and "fused_norm" in self.rules
563+
):
564+
# TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
565+
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
532566

533567
self.rules["mixer_norm"](layer.mixer.norm, layer_id)
534568
self.rules["A_log"](layer.mixer.A_log, layer_id)
@@ -655,6 +689,7 @@ def _custom_mapping_to_lambda(mapping):
655689
"qkv_slicing": self._qkv_slicing,
656690
"self_attention_scaling": self._self_attention_scaling,
657691
"gated_mlp_slicing": self._gated_mlp_slicing,
692+
"grouped_mlp_slicing": self._grouped_mlp_slicing,
658693
"pack_name_remapping": self._pack_name_remapping,
659694
"pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss,
660695
}
@@ -855,6 +890,67 @@ def _gated_mlp_slicing(
855890
self._state_dict[gate_proj_key] = val.detach().clone()
856891
self._state_dict[up_proj_key] = val.detach().clone()
857892

893+
def _grouped_mlp_slicing(self, module, prefix, parallel_config=None):
894+
"""Export TEGroupedMLP weights by splitting per-expert weights into individual HF weights.
895+
896+
TEGroupedMLP (via TEGroupedLinear) stores weights as weight0, weight1, ..., weight{N-1}
897+
in its state_dict, where each weight{i} corresponds to one expert. This method extracts
898+
quantization state from the module, then iterates over experts and saves each expert's
899+
weight (and scales if quantized) under the HF-style per-expert prefix.
900+
901+
This is the reverse of _grouped_mlp_merging in the importer.
902+
"""
903+
num_experts = module.num_gemms
904+
905+
# TEGroupedLinear doesn't have module.weight (it has weight0, weight1, ...).
906+
# Temporarily assign weight = weight0 so _get_quantized_state can extract
907+
# qformat, scales, and input_scale from the module's quantizers.
908+
has_weight = hasattr(module, "weight")
909+
if not has_weight:
910+
module.weight = module.weight0
911+
try:
912+
name_to_value, qformat, block_size = self._get_quantized_state(
913+
module, self.dtype, prefix=prefix
914+
)
915+
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
916+
name_to_value.pop("weight", None)
917+
finally:
918+
if not has_weight and hasattr(module, "weight"):
919+
delattr(module, "weight")
920+
921+
state_dict = module.state_dict()
922+
923+
for expert_id in range(num_experts):
924+
expert_prefix = prefix.format(expert_id) + "."
925+
weight_key = f"weight{expert_id}"
926+
927+
if weight_key not in state_dict:
928+
raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}")
929+
930+
weight = state_dict[weight_key].to(self.dtype).cpu()
931+
932+
if weight_scale is None:
933+
self._state_dict[expert_prefix + "weight"] = weight
934+
else:
935+
self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
936+
weight,
937+
weight_scale,
938+
qformat,
939+
weight_scale_2,
940+
block_size,
941+
)
942+
self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()
943+
944+
if weight_scale_2 is not None:
945+
self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
946+
947+
for key, val in name_to_value.items():
948+
if key == "output_scale":
949+
continue
950+
for expert_id in range(num_experts):
951+
expert_prefix = prefix.format(expert_id) + "."
952+
self._state_dict[expert_prefix + key] = val.detach().clone()
953+
858954
def _qkv_slicing(
859955
self,
860956
module,

0 commit comments

Comments
 (0)