Skip to content

Commit 5ddcbe7

Browse files
committed
Support full TE spec for NemotronH HF-to-Megatron import
Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent ae69d5d commit 5ddcbe7

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

modelopt/torch/export/plugins/mcore_nemotron.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,25 @@
5858
"D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE),
5959
"dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE),
6060
"conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE),
61-
"in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP),
61+
# mapping layer_norm_weight to None tells _name_remapping to skip it;
62+
# the fused layer_norm_weight is loaded separately via the "fused_norm" rule.
63+
"in_proj": NameRemapping(
64+
"backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
65+
),
6266
"out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP),
6367
# Attention
6468
"input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
6569
"linear_qkv": QKVMerging("backbone.layers.{}.mixer.", COL_TP),
6670
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP),
6771
# MLP
6872
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
69-
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP),
73+
"linear_fc1": NameRemapping(
74+
"backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
75+
),
7076
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP),
77+
# Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear
78+
# modules (in_proj, linear_qkv, linear_fc1) when using TE spec.
79+
"fused_norm": NameRemapping("backbone.layers.{}.norm.weight"),
7180
# MoE
7281
"router": NameRemapping(
7382
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
@@ -92,12 +101,14 @@
92101
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
93102
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
94103
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
95-
# Grouped local experts in MTP
104+
# Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers).
105+
# The prefix uses "backbone" for regular decoder layers; when called from MTP
106+
# context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp".
96107
"experts.linear_fc1": GroupedMLPMerging(
97-
"mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}
108+
"backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP
98109
),
99110
"experts.linear_fc2": GroupedMLPMerging(
100-
"mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}
111+
"backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP
101112
),
102113
}
103114

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def _name_remapping(
200200
state_dict[key] = val
201201
else:
202202
source_key = mapping.get(key, key)
203+
# A mapping value of None means "skip this key" (keep existing value).
204+
# This is used for fused TE modules where layer_norm_weight is loaded
205+
# separately from a different HF path.
206+
if source_key is None:
207+
state_dict[key] = val
208+
continue
203209
# For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding
204210
# since bias should always be replicated, not sharded
205211
if (
@@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar):
537543
self.rules["in_proj"](layer.mixer.in_proj, layer_id)
538544
self.rules["out_proj"](layer.mixer.out_proj, layer_id)
539545

546+
# TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear).
547+
# Load the fused layer_norm_weight from the HF norm path.
548+
if (
549+
isinstance(layer.norm, IdentityOp)
550+
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
551+
and "fused_norm" in self.rules
552+
):
553+
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
554+
540555
def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False):
541556
if not isinstance(layer.input_layernorm, IdentityOp):
542557
self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp)
@@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
578593
attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp
579594
)
580595

596+
# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
597+
# Load the fused layer_norm_weight from the HF norm path.
598+
if (
599+
isinstance(layer.input_layernorm, IdentityOp)
600+
and hasattr(attention, "linear_qkv")
601+
and hasattr(attention.linear_qkv, "layer_norm_weight")
602+
and "fused_norm" in self.rules
603+
):
604+
self.rules["fused_norm"](
605+
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
606+
)
607+
581608
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
582609
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp)
583610

@@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
671698
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp)
672699
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)
673700

701+
# TE spec: pre_mlp_layernorm is fused into linear_fc1
702+
# (TELayerNormColumnParallelLinear).
703+
# Load the fused layer_norm_weight from the HF norm path.
704+
if (
705+
isinstance(layer.pre_mlp_layernorm, IdentityOp)
706+
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
707+
and "fused_norm" in self.rules
708+
):
709+
self.rules["fused_norm"](
710+
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
711+
)
712+
674713
def _import_state_dict(self):
675714
model = self.model
676715
layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm)

0 commit comments

Comments
 (0)