Skip to content

Commit a5fd5b2

Browse files
yueshen2016danielkorzekwa
authored andcommitted
[OMNIML-3232] Support full TE spec for NemotronH HF-to-Megatron import (#884)
## What does this PR do? **Type of change:** new feature **Overview:** Enable full TE spec support for NemotronH (Mamba hybrid) models during HF-to-Megatron weight import via `import_mcore_gpt_from_hf`. Previously, importing HF weights into a Megatron model built with the full TE spec (`TELayerNormColumnParallelLinear`, `TEGroupedMLP`, etc.) failed for NemotronH models due to two issues: 1. **Grouped expert prefix bug**: The `experts.linear_fc1/fc2` import rules had a hard-coded `mtp.layers.{}` prefix, which was only correct for MTP layers. When regular decoder MoE layers use `TEGroupedMLP` (via the full TE spec), the importer generated incorrect HF keys (e.g., `mtp.layers.27.mixer.experts.0.up_proj.weight` instead of `backbone.layers.27.mixer.experts.0.up_proj.weight`). 2. **Fused layer norm loading**: In the full TE spec, layer norms are fused into `TELayerNormColumnParallelLinear` modules as `layer_norm_weight`. The importer's `_name_remapping` would crash trying to load `layer_norm_weight` from a non-existent HF path (e.g., `backbone.layers.X.mixer.in_proj.layer_norm_weight`), when the actual HF norm weight lives at `backbone.layers.X.norm.weight`. ### Changes **`mcore_nemotron.py`**: - Fixed grouped expert prefix from `mtp.layers.{}` to `backbone.layers.{}`. The `_grouped_mlp_merging` function already handles `backbone` → `mtp` replacement when `is_mtp=True`, so both decoder and MTP layers work correctly. - Added `mapping={"layer_norm_weight": None}` to `in_proj` and `linear_fc1` rules to skip `layer_norm_weight` during `_name_remapping` (loaded separately via `fused_norm`). - Added `fused_norm` rule (`NameRemapping("backbone.layers.{}.norm.weight")`) to load HF norm weights into fused TE modules. **`megatron_importer.py`**: - Added `source_key is None` check in `_name_remapping` to skip keys mapped to `None` in the mapping dict (keeps existing value instead of crashing on missing HF key). - Added fused norm loading in `_import_mamba_layer`: after loading `in_proj`, loads `layer_norm_weight` from HF via `fused_norm` rule when `layer.norm` is `IdentityOp`. - Added fused norm loading in `_import_transformer_layer`: loads `layer_norm_weight` into `linear_qkv` (when `input_layernorm` is `IdentityOp`) and into `linear_fc1` (when `pre_mlp_layernorm` is `IdentityOp`). ## Usage The full TE spec is enabled via the `--full-te-spec` flag on the Megatron-LM side (separate PR). On the ModelOpt side, no user-facing changes are needed -- the import rules automatically handle both local spec and full TE spec models. ```bash # Convert HF checkpoint to Megatron with full TE spec (megatron-lm side) unset MLM_MODEL_CKPT && export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm && export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 export PP=2 export MLM_EXTRA_ARGS="--full-te-spec" bash convert.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 # Quantize the converted checkpoint (megatron-lm side) export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" bash quantize.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 FP8_DEFAULT_CFG # Generate export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && ./generate.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 # MMLU export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && export MLM_EXTRA_ARGS="--fraction 0.05 --disable-tqdm" && ./mmlu.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 ``` ## Testing - Tested end-to-end: HF → Megatron conversion → FP8 quantization → inference (generate) → MMLU evaluation with Nemotron-3-Nano-30B-A3B-BF16. - Verified the resulting model structure matches Megatron-Bridge's TE spec output (TELayerNormColumnParallelLinear, TEGroupedMLP, IdentityOp norms, etc.). - Verified quantized model produces coherent text generation outputs. - Verified backward compatibility: all changes are no-ops for existing local-spec pipelines (guarded by `IdentityOp` checks, `hasattr` checks, and `"fused_norm" in self.rules` checks). ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes -- all changes are guarded by conditions that only activate for full TE spec models. Local spec models follow the exact same code paths as before. - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No ## Additional Information Companion megatron-lm changes (separate PR): - `megatron/core/post_training/modelopt/mamba/model_specs.py`: Added `use_full_te_spec` parameter to return canonical `mamba_stack_spec` from `mamba_layer_specs.py`. - `megatron/post_training/model_builder.py`: Passes `use_full_te_spec=args.full_te_spec` to `get_mamba_stack_modelopt_spec`. - `megatron/post_training/arguments.py`: Added `--full-te-spec` CLI flag. - `examples/post_training/modelopt/convert_model.py`: Skip `moe_grouped_gemm=False` override when `--full-te-spec` is set. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for loading fused normalization weights during model import. * **Bug Fixes** * Improved weight mapping logic to correctly skip redundant layer norm weights in specialized model architectures. * **Refactor** * Reorganized expert model parallel configuration paths for better compatibility with mixed parallel processing settings. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: James Shen <yueshen@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 47ced39 commit a5fd5b2

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

modelopt/torch/export/plugins/mcore_nemotron.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,25 @@
5858
"D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE),
5959
"dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE),
6060
"conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE),
61-
"in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP),
61+
# mapping layer_norm_weight to None tells _name_remapping to skip it;
62+
# the fused layer_norm_weight is loaded separately via the "fused_norm" rule.
63+
"in_proj": NameRemapping(
64+
"backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
65+
),
6266
"out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP),
6367
# Attention
6468
"input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
6569
"linear_qkv": QKVMerging("backbone.layers.{}.mixer.", COL_TP),
6670
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP),
6771
# MLP
6872
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
69-
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP),
73+
"linear_fc1": NameRemapping(
74+
"backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
75+
),
7076
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP),
77+
# Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear
78+
# modules (in_proj, linear_qkv, linear_fc1) when using TE spec.
79+
"fused_norm": NameRemapping("backbone.layers.{}.norm.weight"),
7180
# MoE
7281
"router": NameRemapping(
7382
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
@@ -92,12 +101,14 @@
92101
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
93102
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
94103
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
95-
# Grouped local experts in MTP
104+
# Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers).
105+
# The prefix uses "backbone" for regular decoder layers; when called from MTP
106+
# context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp".
96107
"experts.linear_fc1": GroupedMLPMerging(
97-
"mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}
108+
"backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP
98109
),
99110
"experts.linear_fc2": GroupedMLPMerging(
100-
"mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}
111+
"backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP
101112
),
102113
}
103114

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def _name_remapping(
200200
state_dict[key] = val
201201
else:
202202
source_key = mapping.get(key, key)
203+
# A mapping value of None means "skip this key" (keep existing value).
204+
# This is used for fused TE modules where layer_norm_weight is loaded
205+
# separately from a different HF path.
206+
if source_key is None:
207+
state_dict[key] = val
208+
continue
203209
# For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding
204210
# since bias should always be replicated, not sharded
205211
if (
@@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar):
537543
self.rules["in_proj"](layer.mixer.in_proj, layer_id)
538544
self.rules["out_proj"](layer.mixer.out_proj, layer_id)
539545

546+
# TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear).
547+
# Load the fused layer_norm_weight from the HF norm path.
548+
if (
549+
isinstance(layer.norm, IdentityOp)
550+
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
551+
and "fused_norm" in self.rules
552+
):
553+
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
554+
540555
def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False):
541556
if not isinstance(layer.input_layernorm, IdentityOp):
542557
self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp)
@@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
578593
attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp
579594
)
580595

596+
# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
597+
# Load the fused layer_norm_weight from the HF norm path.
598+
if (
599+
isinstance(layer.input_layernorm, IdentityOp)
600+
and hasattr(attention, "linear_qkv")
601+
and hasattr(attention.linear_qkv, "layer_norm_weight")
602+
and "fused_norm" in self.rules
603+
):
604+
self.rules["fused_norm"](
605+
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
606+
)
607+
581608
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
582609
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp)
583610

@@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
671698
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp)
672699
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)
673700

701+
# TE spec: pre_mlp_layernorm is fused into linear_fc1
702+
# (TELayerNormColumnParallelLinear).
703+
# Load the fused layer_norm_weight from the HF norm path.
704+
if (
705+
isinstance(layer.pre_mlp_layernorm, IdentityOp)
706+
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
707+
and "fused_norm" in self.rules
708+
):
709+
self.rules["fused_norm"](
710+
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
711+
)
712+
674713
def _import_state_dict(self):
675714
model = self.model
676715
layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm)

0 commit comments

Comments
 (0)