Skip to content

Commit d773356

Browse files
fix(prune): Minitron HybridModel + GPT-family fused-TE-spec import/export (#1518)
## Summary Split out of #1501 so the pack=True calibration packing change can land independently. This PR carries the pruning + export-side fixes. **Pruning bug fixes** - Register `HybridModel` (parent of `MambaModel` in modern Megatron-LM) under a new `HAS_HYBRID` flag so `mcore_minitron` actually prunes Nemotron-H et al. Previously `HybridModel` instances fell through `convert_to_dynamic`, got `freeze()`-ed (collapsing `hidden_size` / `num_layers` to a single choice), and produced unloadable saved checkpoints with mixed pruned/unpruned dims. - Replace the `isinstance(MambaModel)` gate in `_get_hybrid_pattern_key` with attribute-presence detection so both `MambaModel` (still using `hybrid_override_pattern`) and plain `HybridModel` (`hybrid_layer_pattern`) are handled uniformly. - Track `in_features` as a dynamic attribute on `_DynamicTEQKVLayerNormColumnParallelLinear` so TE's forward-time `inp_shape[-1] == in_features` assertion holds when `hidden_size` is pruned. - Dedupe MambaModel / HybridModel divisor dict into `_HYBRID_DIVISORS`. **Fused-TE-spec import/export for GPT-family** - Importer: prefer per-context keys (`fused_input_layernorm`, `fused_pre_mlp_layernorm`); fall back to legacy `fused_norm` for Nemotron-H back-compat. **Raise `KeyError`** when a fused-TE model has neither rule registered — the branch only fires when the model uses fused `TELayerNormColumnParallelLinear`, so a missing rule is unambiguously a plugin misconfig that would otherwise ship a chance-accuracy checkpoint. - Exporter: mirror the same fallback chain in `_get_fused_norm_weight` so GPT-family models round-trip cleanly back to HF. - Add the new rules to Qwen3, Qwen2.5, Llama, Llama4 (MoE-only, only `fused_input_layernorm`), DeepSeek, GptOss (MoE-only, only `fused_input_layernorm`) import and export mappings. - Preserve TE `_extra_state` from the existing module state dict (don't blank to `None`) at both call sites in the importer. **Misc** - `megatron_prefill`: `.contiguous()` on the logits slice before `broadcast_from_last_pipeline_stage` — broadcast asserts contiguity which fails when SP pads `seq_length` to a multiple of TP. - `megatron_mmlu`: accept `mmlu_dataset` kwarg so callers can point at a local copy of `cais/mmlu`. - `warn_rank_0`: auto-bump `stacklevel` by 1 inside the wrapper so callers' warnings point at user code, not at the wrapper frame. - `tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml`: bump `mmlu_lower_bound` 0.68 → 0.75 (validated end-to-end with the fused-norm import fix). - CHANGELOG: bug-fix entry for the importer; date correction on the 0.44 entry. ## Consumer Megatron-LM PR NVIDIA/Megatron-LM#4807 — `prune.py` / `mmlu.py` consume these APIs and currently ship inline WARs against released 0.44. Once 0.45 ships and the modelopt pin is bumped, those WARs collapse to one-liners. Related: #1501 (calibration packing). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Importer/exporter now correctly load fused LayerNorm weights for GPT-family models, preferring context-specific fused keys with a legacy fallback. * **New Features** * Hybrid Mamba/HybridModel support added for pruning/NAS workflows. * MMLU evaluation accepts a customizable dataset path (default: "cais/mmlu"). * **Improvements** * Extended export/import mappings and state handling across DeepSeek, GPT, Llama, Qwen; ensured last-stage logits are contiguous. * **Documentation** * Updated changelog entry and release date adjustment. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1518?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 8f1529a commit d773356

13 files changed

Lines changed: 199 additions & 70 deletions

File tree

CHANGELOG.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ Changelog
2626
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2727
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
2828

29-
0.44 (2026-05-18)
29+
**Bug Fixes**
30+
31+
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.
32+
33+
0.44 (2026-05-14)
3034
^^^^^^^^^^^^^^^^^
3135

3236
**New Features**

modelopt/torch/export/plugins/mcore_deepseek.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
"linear_kv_up_proj": NameRemapping("model.layers.{}.self_attn.kv_b_proj."),
4444
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
4545
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
46+
# Fused TE spec (mirrors the import side). MLA has no linear_qkv so
47+
# fused_input_layernorm is inert today; fused_pre_mlp_layernorm reaches dense layers.
48+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
49+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
4650
# MLP for dense layers
4751
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
4852
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
@@ -88,6 +92,11 @@
8892
"output_layer": NameRemapping("lm_head.", COL_TP),
8993
# Per-layer
9094
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
95+
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
96+
# MLA has no linear_qkv so fused_input_layernorm is inert for DeepSeek today; included
97+
# for parity in case a future spec fuses the layernorm into a Q/KV projection.
98+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
99+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
91100
"linear_q_proj": NameRemapping("model.layers.{}.self_attn.q_proj.", COL_TP),
92101
"linear_q_down_proj": NameRemapping("model.layers.{}.self_attn.q_a_proj.", REPLICATE),
93102
"linear_q_layernorm": NameRemapping("model.layers.{}.self_attn.q_a_layernorm.", REPLICATE),

modelopt/torch/export/plugins/mcore_gptoss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
gptoss_causal_lm_export: dict[str, CustomModuleMapping | bool] = {
3232
"word_embeddings": NameRemapping("model.embed_tokens."),
3333
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
34+
# MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable.
35+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
3436
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
3537
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
3638
"softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks"),
@@ -52,6 +54,10 @@
5254
gptoss_causal_lm_import: dict[str, CustomModuleMapping | bool] = {
5355
"word_embeddings": NameRemapping("model.embed_tokens.", COL_TP),
5456
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
57+
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
58+
# gpt-oss is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's
59+
# fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired.
60+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
5561
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
5662
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
5763
"softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks", COL_TP),

modelopt/torch/export/plugins/mcore_llama.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@
3737
llama_causal_lm_export: dict[str, CustomModuleMapping] = {
3838
"word_embeddings": NameRemapping("model.embed_tokens."),
3939
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
40+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
4041
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
4142
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
4243
# KV cache quant export
4344
"core_attention": SelfAttentionScaling("model.layers.{}.self_attn."),
4445
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
46+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
4547
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
4648
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
4749
"final_layernorm": NameRemapping("model.norm."),
@@ -51,6 +53,8 @@
5153
llama4_causal_lm_export: dict[str, CustomModuleMapping | bool] = {
5254
"word_embeddings": NameRemapping("language_model.model.embed_tokens."),
5355
"input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm."),
56+
# MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable.
57+
"fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"),
5458
# self_attn
5559
"linear_qkv": QKVSlicing("language_model.model.layers.{}.self_attn."),
5660
"linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj."),
@@ -150,9 +154,12 @@
150154
llama_causal_lm_import: dict[str, CustomModuleMapping] = {
151155
"word_embeddings": NameRemapping("model.embed_tokens.", COL_TP),
152156
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
157+
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
158+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
153159
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
154160
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
155161
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
162+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
156163
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
157164
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
158165
"final_layernorm": NameRemapping("model.norm.", REPLICATE),
@@ -162,6 +169,10 @@
162169
llama4_causal_lm_import: dict[str, CustomModuleMapping | bool] = {
163170
"word_embeddings": NameRemapping("language_model.model.embed_tokens.", COL_TP),
164171
"input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.", REPLICATE),
172+
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
173+
# Llama4 is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's
174+
# fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired.
175+
"fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"),
165176
"linear_qkv": QKVMerging("language_model.model.layers.{}.self_attn.", COL_TP),
166177
"linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj.", ROW_TP),
167178
"pre_mlp_layernorm": NameRemapping(

modelopt/torch/export/plugins/mcore_qwen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,17 @@
3535
"output_layer": NameRemapping("lm_head.", COL_TP),
3636
# Attention
3737
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
38+
# Fused TE spec (TELayerNormColumnParallelLinear): the LayerNorm weight lives on
39+
# linear_qkv.layer_norm_weight, loaded directly from the HF norm tensor (no `.weight` suffix
40+
# appended since the value is a Parameter, not a sub-module).
41+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
3842
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
3943
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
4044
"q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm.", REPLICATE),
4145
"k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm.", REPLICATE),
4246
# MLP
4347
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
48+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
4449
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
4550
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
4651
# MoE
@@ -56,12 +61,14 @@
5661
"output_layer": NameRemapping("lm_head."),
5762
# Attention
5863
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
64+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
5965
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
6066
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
6167
"q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm."),
6268
"k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm."),
6369
# MLP
6470
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
71+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
6572
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
6673
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
6774
# MoE
@@ -76,10 +83,12 @@
7683
"output_layer": NameRemapping("lm_head.", COL_TP),
7784
# Attention
7885
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
86+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
7987
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
8088
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
8189
# MLP
8290
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
91+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
8392
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
8493
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
8594
}
@@ -90,10 +99,12 @@
9099
"output_layer": NameRemapping("lm_head."),
91100
# Attention
92101
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
102+
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
93103
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
94104
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
95105
# MLP
96106
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
107+
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
97108
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
98109
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
99110
}

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,9 @@ def _gated_mlp_merging(
238238
else:
239239
prefix = prefix.replace("model", "mtp")
240240

241-
weight = module.state_dict().get("weight", None)
242-
weight_scale = module.state_dict().get("weight_quantizer._scale", None)
241+
module_state_dict = module.state_dict()
242+
weight = module_state_dict.get("weight", None)
243+
weight_scale = module_state_dict.get("weight_quantizer._scale", None)
243244

244245
state_dict = {}
245246

@@ -273,6 +274,16 @@ def _gated_mlp_merging(
273274
else:
274275
state_dict["weight"] = tensor.to(self.dtype).to(device=weight.device)
275276

277+
# Preserve the fused LayerNorm weight + TE _extra_state already on the module so
278+
# the strict load_state_dict below doesn't fail for TELayerNormColumnParallelLinear
279+
# (fused under --export-default-te-spec). The actual HF norm tensor is loaded
280+
# separately via the `fused_pre_mlp_layernorm` rule.
281+
layer_norm_weight = module_state_dict.get("layer_norm_weight", None)
282+
if layer_norm_weight is not None:
283+
state_dict["layer_norm_weight"] = layer_norm_weight
284+
if "_extra_state" in module_state_dict:
285+
state_dict["_extra_state"] = module_state_dict["_extra_state"]
286+
276287
module.load_state_dict(state_dict)
277288

278289
def _grouped_mlp_merging(
@@ -433,7 +444,13 @@ def _qkv_merging(
433444
layer_norm_weight = module_state_dict.get("layer_norm_weight", None)
434445
if layer_norm_weight is not None:
435446
state_dict["layer_norm_weight"] = layer_norm_weight
436-
state_dict["_extra_state"] = None # for TE modules require _extra_state key
447+
# Preserve the TE metadata struct (FP8 amax history, recipe version, etc.) —
448+
# `load_state_dict(..., strict=True)` requires the key, but blanking it could
449+
# zero out per-module FP8 bookkeeping on TE versions that populate it. Only
450+
# forward through when the source actually has it, to avoid adding an
451+
# unexpected `_extra_state=None` to TE variants that don't.
452+
if "_extra_state" in module_state_dict:
453+
state_dict["_extra_state"] = module_state_dict["_extra_state"]
437454

438455
module.load_state_dict(state_dict)
439456

@@ -599,14 +616,32 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
599616
)
600617

601618
# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
602-
# Load the fused layer_norm_weight from the HF norm path.
619+
# Prefer the per-context key (`fused_input_layernorm`); fall back to the legacy
620+
# single-key `fused_norm` for Nemotron-H style (one norm shared across slots).
621+
# Missing both is a plugin misconfig — raise rather than silently random-init.
603622
if (
604623
isinstance(layer.input_layernorm, IdentityOp)
605624
and hasattr(attention, "linear_qkv")
606625
and hasattr(attention.linear_qkv, "layer_norm_weight")
607-
and "fused_norm" in self.rules
608626
):
609-
self.rules["fused_norm"](
627+
fused_key = (
628+
"fused_input_layernorm"
629+
if "fused_input_layernorm" in self.rules
630+
else "fused_norm"
631+
)
632+
if fused_key not in self.rules:
633+
# Branch only fires when model uses fused TELayerNormColumnParallelLinear,
634+
# so missing rule is unambiguously a plugin misconfiguration; raise so it
635+
# doesn't silently ship a chance-accuracy checkpoint.
636+
raise KeyError(
637+
f"{self.arch} uses fused TELayerNormColumnParallelLinear for "
638+
"attention but neither `fused_input_layernorm` nor legacy "
639+
"`fused_norm` is in its import mapping; `linear_qkv.layer_norm_weight` "
640+
"would be left at random init. Add "
641+
'`fused_input_layernorm: NameRemapping("...input_layernorm.weight")` '
642+
f"to the {self.arch} import mapping."
643+
)
644+
self.rules[fused_key](
610645
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
611646
)
612647

@@ -707,14 +742,27 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
707742
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)
708743

709744
# TE spec: pre_mlp_layernorm is fused into linear_fc1
710-
# (TELayerNormColumnParallelLinear).
711-
# Load the fused layer_norm_weight from the HF norm path.
712-
if (
713-
isinstance(layer.pre_mlp_layernorm, IdentityOp)
714-
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
715-
and "fused_norm" in self.rules
745+
# (TELayerNormColumnParallelLinear). See input_layernorm path above for the
746+
# rule-key fallback rationale.
747+
if isinstance(layer.pre_mlp_layernorm, IdentityOp) and hasattr(
748+
layer.mlp.linear_fc1, "layer_norm_weight"
716749
):
717-
self.rules["fused_norm"](
750+
fused_key = (
751+
"fused_pre_mlp_layernorm"
752+
if "fused_pre_mlp_layernorm" in self.rules
753+
else "fused_norm"
754+
)
755+
if fused_key not in self.rules:
756+
raise KeyError(
757+
f"{self.arch} uses fused TELayerNormColumnParallelLinear for "
758+
"MLP but neither `fused_pre_mlp_layernorm` nor legacy "
759+
"`fused_norm` is in its import mapping; "
760+
"`linear_fc1.layer_norm_weight` would be left at random init. "
761+
"Add `fused_pre_mlp_layernorm: NameRemapping("
762+
'"...post_attention_layernorm.weight")` '
763+
f"to the {self.arch} import mapping."
764+
)
765+
self.rules[fused_key](
718766
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
719767
)
720768

0 commit comments

Comments
 (0)