Skip to content

Commit 33d4d27

Browse files
jenchen13danielkorzekwa
authored andcommitted
Latent MOE & Repeated MTP support for NemotronH; fix KV cache quant export (#830)
## What does this PR do? **Type of change:** New feature and bug fix **Overview:** Support Latent MOE and Repeated MTP for NemotronH models - Enable latent MOE modules during megatron import/export - Fix KV cache quantization export: remove old `qkv_layer.output_quantizer` export & replace with proper `k/v_bmm_quantizer` logic - Improvements to EP amax sync - Support repeated MTP import/export for NemotronH models (only BF16 export for MTP for now) ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for grouped MLP and self-attention scaling operations in model export workflows * Enhanced model parallel training capabilities with improved component mapping * Expanded quantization configuration handling with dynamic module exclusion across distributed ranks * Improved support for additional transformer engine components * **Refactor** * Reorganized internal export and import logic for improved maintainability and specialist model architecture support <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jenchen13 <jennifchen@nvidia.com> Signed-off-by: Jennifer Chen <jennifchen@nvidia.com> Signed-off-by: Jenny Chen <jennifchen@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 9b84e13 commit 33d4d27

11 files changed

Lines changed: 849 additions & 521 deletions

File tree

modelopt/torch/export/plugins/mcore_custom.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
103103
)
104104

105105

106+
class GroupedMLPMerging(CustomModuleMapping):
107+
"""A custom module mapping that merges up_proj and down_proj for Grouped MLP."""
108+
109+
def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
110+
"""Create a custom module mapping that merges up_proj and down_proj for Grouped MLP."""
111+
super().__init__(
112+
func_name="grouped_mlp_merging",
113+
target_name_or_prefix=target_name_or_prefix,
114+
func_kwargs=func_kwargs,
115+
)
116+
117+
106118
class GatedMLPMerging(CustomModuleMapping):
107119
"""A custom module mapping that merges gate_proj and up_proj."""
108120

@@ -127,6 +139,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any]
127139
)
128140

129141

142+
class SelfAttentionScaling(CustomModuleMapping):
143+
"""A custom module mapping that scales self attention."""
144+
145+
def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}):
146+
"""Create a custom module mapping that scales self attention."""
147+
super().__init__(
148+
func_name="self_attention_scaling",
149+
target_name_or_prefix=target_name_or_prefix,
150+
func_kwargs=func_kwargs,
151+
)
152+
153+
130154
class GatedMLPSlicing(CustomModuleMapping):
131155
"""A custom module mapping that slices gate_proj and up_proj."""
132156

modelopt/torch/export/plugins/mcore_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
PackNameRemapping,
3131
QKVMerging,
3232
QKVSlicing,
33+
SelfAttentionScaling,
3334
UnpackNameRemapping,
3435
)
3536

@@ -38,6 +39,8 @@
3839
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
3940
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
4041
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
42+
# KV cache quant export
43+
"core_attention": SelfAttentionScaling("model.layers.{}.self_attn."),
4144
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
4245
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
4346
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),

modelopt/torch/export/plugins/mcore_nemotron.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
ROW_ETP,
2424
ROW_TP,
2525
CustomModuleMapping,
26+
GroupedMLPMerging,
2627
NameRemapping,
2728
QKVMerging,
2829
QKVSlicing,
30+
SelfAttentionScaling,
2931
)
3032

3133
# Example on adding a new CausalLM.
@@ -35,6 +37,7 @@
3537
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
3638
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
3739
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
40+
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
3841
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
3942
# NemotronForCausalLM is using square-relu where no gated handle is needed.
4043
"linear_fc1": NameRemapping("model.layers.{}.mlp.up_proj."),
@@ -81,9 +84,23 @@
8184
"shared_experts.linear_fc2": NameRemapping(
8285
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
8386
),
87+
# Latent MoE
88+
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE),
89+
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE),
90+
# Repeated MTP module
91+
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}),
92+
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
93+
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
94+
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
95+
# Grouped local experts in MTP
96+
"experts.linear_fc1": GroupedMLPMerging(
97+
"mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}
98+
),
99+
"experts.linear_fc2": GroupedMLPMerging(
100+
"mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}
101+
),
84102
}
85103

86-
87104
nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = {
88105
"word_embeddings": NameRemapping("backbone.embeddings."),
89106
"final_norm": NameRemapping("backbone.norm_f."),
@@ -101,6 +118,7 @@
101118
"input_layernorm": NameRemapping("backbone.layers.{}.norm."),
102119
"linear_qkv": QKVSlicing("backbone.layers.{}.mixer."),
103120
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."),
121+
"core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."),
104122
# MLP
105123
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
106124
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
@@ -115,4 +133,12 @@
115133
"shared_experts.linear_fc2": NameRemapping(
116134
"backbone.layers.{}.mixer.shared_experts.down_proj."
117135
),
136+
# Latent MoE
137+
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
138+
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
139+
# MTP
140+
"mtp.enorm": NameRemapping("mtp.layers.{}.enorm."),
141+
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."),
142+
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."),
143+
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."),
118144
}

0 commit comments

Comments
 (0)