You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix(prune): support HybridModel in mcore_minitron + Qwen3 fused-TE import
Pruning Megatron-LM HybridModel-based models (Nemotron-H et al.) under
the modern Megatron-LM layout (HybridModel as the parent of MambaModel)
silently produced an unloadable checkpoint:
- `_DynamicMCoreLanguageModel` only registered `GPTModel` and `MambaModel`
in `SUPPORTED_MODELS`; `HybridModel` instances fell through the
dynamic-space converter.
- `MCoreMinitronConfig.default_rules` likewise had no entry for
HybridModel, so `convert_to_dynamic` ran `mod.freeze()` on the
top-level model. That collapsed `hidden_size` and `num_layers` to a
single choice each, so `_prune` skipped them — yielding a saved
checkpoint with pruned per-layer dims but unpruned hidden/depth.
For GPT-family models (Qwen3) under `--export-default-te-spec`, the
fused `TELayerNormColumnParallelLinear.layer_norm_weight` was never
loaded from HF: the importer's fused-norm path was keyed on a single
`fused_norm` rule (only Nemotron-H provided it, mapping a single HF
norm tensor per layer). Standard transformer layers need separate
attention vs MLP norm sources.
Changes:
- `nas/plugins/megatron.py`: register `HybridModel` in
`SUPPORTED_MODELS` under a new `HAS_HYBRID` flag; have
`_DynamicTEQKVLayerNormColumnParallelLinear` track `in_features` so
TE's forward-time `inp_shape[-1] == in_features` assertion holds when
hidden_size is pruned.
- `prune/plugins/mcore_minitron.py`: add HybridModel entry to
`MCoreMinitronConfig.default_rules`, gated on `HAS_HYBRID`.
- `export/plugins/megatron_importer.py`: prefer per-context keys
`fused_input_layernorm` / `fused_pre_mlp_layernorm`, fall back to
legacy `fused_norm` for Nemotron-H back-compat.
- `export/plugins/mcore_qwen.py`: add the two fused-norm rules for
Qwen3, mapping to `model.layers.{}.input_layernorm.weight` and
`model.layers.{}.post_attention_layernorm.weight`.
- `utils/plugins/megatron_generate.py`: `.contiguous()` on the logits
slice before `broadcast_from_last_pipeline_stage`, which asserts
contiguity when SP pads seq_length up to a multiple of TP.
- `utils/plugins/megatron_mmlu.py`: accept a `mmlu_dataset` kwarg so
callers can point at a local copy.
Consumer: Megatron-LM PR NVIDIA/Megatron-LM#4807
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
0 commit comments