Skip to content

Commit df7ab63

Browse files
fix(prune): support HybridModel in mcore_minitron + Qwen3 fused-TE import
Pruning Megatron-LM HybridModel-based models (Nemotron-H et al.) under the modern Megatron-LM layout (HybridModel as the parent of MambaModel) silently produced an unloadable checkpoint: - `_DynamicMCoreLanguageModel` only registered `GPTModel` and `MambaModel` in `SUPPORTED_MODELS`; `HybridModel` instances fell through the dynamic-space converter. - `MCoreMinitronConfig.default_rules` likewise had no entry for HybridModel, so `convert_to_dynamic` ran `mod.freeze()` on the top-level model. That collapsed `hidden_size` and `num_layers` to a single choice each, so `_prune` skipped them — yielding a saved checkpoint with pruned per-layer dims but unpruned hidden/depth. For GPT-family models (Qwen3) under `--export-default-te-spec`, the fused `TELayerNormColumnParallelLinear.layer_norm_weight` was never loaded from HF: the importer's fused-norm path was keyed on a single `fused_norm` rule (only Nemotron-H provided it, mapping a single HF norm tensor per layer). Standard transformer layers need separate attention vs MLP norm sources. Changes: - `nas/plugins/megatron.py`: register `HybridModel` in `SUPPORTED_MODELS` under a new `HAS_HYBRID` flag; have `_DynamicTEQKVLayerNormColumnParallelLinear` track `in_features` so TE's forward-time `inp_shape[-1] == in_features` assertion holds when hidden_size is pruned. - `prune/plugins/mcore_minitron.py`: add HybridModel entry to `MCoreMinitronConfig.default_rules`, gated on `HAS_HYBRID`. - `export/plugins/megatron_importer.py`: prefer per-context keys `fused_input_layernorm` / `fused_pre_mlp_layernorm`, fall back to legacy `fused_norm` for Nemotron-H back-compat. - `export/plugins/mcore_qwen.py`: add the two fused-norm rules for Qwen3, mapping to `model.layers.{}.input_layernorm.weight` and `model.layers.{}.post_attention_layernorm.weight`. - `utils/plugins/megatron_generate.py`: `.contiguous()` on the logits slice before `broadcast_from_last_pipeline_stage`, which asserts contiguity when SP pads seq_length up to a multiple of TP. - `utils/plugins/megatron_mmlu.py`: accept a `mmlu_dataset` kwarg so callers can point at a local copy. Consumer: Megatron-LM PR NVIDIA/Megatron-LM#4807 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent a451a2b commit df7ab63

6 files changed

Lines changed: 66 additions & 15 deletions

File tree

modelopt/torch/export/plugins/mcore_qwen.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,17 @@
3535
"output_layer": NameRemapping("lm_head.", COL_TP),
3636
# Attention
3737
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
38+
# Fused TE spec (TELayerNormColumnParallelLinear): the LayerNorm weight lives on
39+
# linear_qkv.layer_norm_weight, loaded directly from the HF norm tensor (no `.weight` suffix
40+
# appended since the value is a Parameter, not a sub-module).
41+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
3842
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
3943
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
4044
"q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm.", REPLICATE),
4145
"k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm.", REPLICATE),
4246
# MLP
4347
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
48+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
4449
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
4550
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
4651
# MoE

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -599,16 +599,24 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
599599
)
600600

601601
# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
602-
# Load the fused layer_norm_weight from the HF norm path.
602+
# Load the fused layer_norm_weight from the HF norm path. Prefer the explicit
603+
# per-norm key (needed for standard GPT models where attention and MLP fused norms
604+
# map to different HF tensors); fall back to `fused_norm` for Nemotron-H style
605+
# (one norm per layer, shared across attention/mlp/mamba slots).
603606
if (
604607
isinstance(layer.input_layernorm, IdentityOp)
605608
and hasattr(attention, "linear_qkv")
606609
and hasattr(attention.linear_qkv, "layer_norm_weight")
607-
and "fused_norm" in self.rules
608610
):
609-
self.rules["fused_norm"](
610-
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
611+
fused_key = (
612+
"fused_input_layernorm"
613+
if "fused_input_layernorm" in self.rules
614+
else "fused_norm"
611615
)
616+
if fused_key in self.rules:
617+
self.rules[fused_key](
618+
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
619+
)
612620

613621
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
614622
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp)
@@ -707,16 +715,20 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
707715
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)
708716

709717
# TE spec: pre_mlp_layernorm is fused into linear_fc1
710-
# (TELayerNormColumnParallelLinear).
711-
# Load the fused layer_norm_weight from the HF norm path.
712-
if (
713-
isinstance(layer.pre_mlp_layernorm, IdentityOp)
714-
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
715-
and "fused_norm" in self.rules
718+
# (TELayerNormColumnParallelLinear). See input_layernorm path above for the
719+
# rule-key fallback rationale.
720+
if isinstance(layer.pre_mlp_layernorm, IdentityOp) and hasattr(
721+
layer.mlp.linear_fc1, "layer_norm_weight"
716722
):
717-
self.rules["fused_norm"](
718-
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
723+
fused_key = (
724+
"fused_pre_mlp_layernorm"
725+
if "fused_pre_mlp_layernorm" in self.rules
726+
else "fused_norm"
719727
)
728+
if fused_key in self.rules:
729+
self.rules[fused_key](
730+
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
731+
)
720732

721733
def _import_state_dict(self):
722734
model = self.model

modelopt/torch/nas/plugins/megatron.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@
7979
except ImportError:
8080
HAS_MAMBA = False
8181

82+
# Newer Megatron-LM splits MambaModel out of HybridModel; instantiates Nemotron-H et al.
83+
# as plain HybridModel. Register that as well so the dynamic-space converter can build
84+
# a configurable search space on hybrid models.
85+
try:
86+
from megatron.core.models.hybrid.hybrid_model import HybridModel
87+
88+
SUPPORTED_MODELS[HybridModel] = "megatron.core.models.hybrid.HybridModel"
89+
90+
HAS_HYBRID = True
91+
except ImportError:
92+
HAS_HYBRID = False
93+
8294
__all__ = ["get_te_mamba_stack_spec"]
8395

8496

@@ -394,6 +406,9 @@ def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: Trace
394406
lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups)
395407
* mod.config.kv_channels,
396408
)
409+
# in_features must track input_size so TE's forward-time inp_shape[-1] == in_features
410+
# assertion holds when hidden_size is pruned.
411+
self._register_dynamic_attribute("in_features", lambda mod, val: mod.input_size)
397412
self._register_dynamic_attribute("weight", self._get_weight)
398413
# TE stores a zero-length tensor (not None) when bias=False; only register if non-empty
399414
if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0:

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
from modelopt.torch.nas.conversion import NASModeRegistry
5858
from modelopt.torch.nas.plugins.megatron import (
59+
HAS_HYBRID,
5960
HAS_MAMBA,
6061
SUPPORTED_MODELS,
6162
_DynamicMambaLayer,
@@ -756,6 +757,19 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di
756757
if HAS_MAMBA
757758
else {}
758759
),
760+
**(
761+
{
762+
"megatron.core.models.hybrid.HybridModel": {
763+
"hidden_size_divisor": 256,
764+
"ffn_hidden_size_divisor": 512,
765+
"mamba_head_dim_divisor": 8,
766+
"num_moe_experts_divisor": 8,
767+
"num_layers_divisor": 2,
768+
},
769+
}
770+
if HAS_HYBRID
771+
else {}
772+
),
759773
},
760774
doc='Configuration for the ``"mcore_minitron"`` mode.',
761775
),

modelopt/torch/utils/plugins/megatron_generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def megatron_prefill(
150150
)
151151
send_to_next_pipeline_rank(output.to(dtype=pp_dtype))
152152

153-
logits = output[:, :seq_length, :].detach() if pp_last else None
153+
# .contiguous() is required because the slice is a view with the padded stride; the broadcast
154+
# below asserts contiguity when SP pads seq_length up to a multiple of TP.
155+
logits = output[:, :seq_length, :].detach().contiguous() if pp_last else None
154156

155157
if model.config.bf16:
156158
logits_dtype = torch.bfloat16

modelopt/torch/utils/plugins/megatron_mmlu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def megatron_mmlu(
6060
few_shots: int = 0,
6161
fraction: float = 0.05,
6262
batch_size: int = 1,
63+
mmlu_dataset: str = "cais/mmlu",
6364
) -> float:
6465
"""Evaluate the model on MMLU using log-likelihood scoring over batched prefill passes.
6566
@@ -73,6 +74,8 @@ def megatron_mmlu(
7374
few_shots: The number of few-shot examples to use.
7475
fraction: The fraction of the test set to evaluate on.
7576
batch_size: Number of examples to process in one forward pass.
77+
mmlu_dataset: HF dataset name or local MMLU dataset path passed to `datasets.load_dataset`.
78+
Defaults to ``cais/mmlu``.
7679
"""
7780
print_rank_0(
7881
f"\nMMLU ({fraction * 100}%, {few_shots}-shot, Batch Size: {batch_size}) evaluation started...\n"
@@ -104,8 +107,8 @@ def _generate_prompt(test_example, dev_examples, few_shots=0):
104107

105108
# Load all subjects in two dataset calls instead of 2x num_subjects calls.
106109
# The "all" config includes a "subject" field for per-subject reporting.
107-
test_dataset = load_dataset("cais/mmlu", "all", split="test")
108-
dev_dataset = load_dataset("cais/mmlu", "all", split="dev") if few_shots > 0 else None
110+
test_dataset = load_dataset(mmlu_dataset, "all", split="test")
111+
dev_dataset = load_dataset(mmlu_dataset, "all", split="dev") if few_shots > 0 else None
109112

110113
# Group dev examples by subject for few-shot prompt construction.
111114
dev_by_subject: dict = {}

0 commit comments

Comments
 (0)