Skip to content

Commit 0cae4cb

Browse files
committed
Fix DeepSeek MoE detection and export mapping
Signed-off-by: Charles.J <jiangchangjian247@gmail.com>
1 parent b1f9f01 commit 0cae4cb

File tree

4 files changed

+150
-9
lines changed

4 files changed

+150
-9
lines changed

modelopt/torch/export/hf_config_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
(["dense_attention_every_n_layers"], "dense_attention_every_n_layers"), # Phi3-small
6565
(["gegelu_limit"], "gegelu_limit"), # Phi3-small
6666
(
67-
["num_local_experts", "moe_num_experts"],
67+
["num_local_experts", "moe_num_experts", "n_routed_experts"],
6868
"moe_num_experts",
69-
), # Mixture of Experts (Mixtral, DBRX)
69+
), # Mixture of Experts (Mixtral, DBRX, DeepSeek)
7070
(["num_experts_per_tok", "moe_top_k"], "moe_top_k"), # Mixture of Experts (Mixtral, DBRX)
7171
(["model_type"], "qwen_type"), # qwen
7272
(["lru_width"], "rnn_hidden_size"), # Recurrent Gemma

modelopt/torch/export/layer_utils.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9898
]
9999
):
100100
linear_names = ["gate_proj", "down_proj", "up_proj"]
101+
elif "deepseek" in model_type:
102+
linear_names = ["gate_proj", "down_proj", "up_proj"]
101103
else:
102104
raise NotImplementedError(f" {model_type} not supported")
103105

@@ -150,6 +152,33 @@ def check_model_compatibility(module_list: list[nn.Module]) -> tuple[bool, bool,
150152

151153
def get_transformer_layers(model: nn.Module) -> list[nn.Module]:
152154
"""Returns the root module of the transformer model."""
155+
if "Megatron" in type(model).__name__:
156+
if hasattr(model, "model") and "GPTModel" in type(model.model).__name__:
157+
# NEMO mcore models can be handled with the following branch.
158+
model = model.model
159+
160+
# NEMO non mcore models, we need to find the language_model module first.
161+
children = [model]
162+
language_model = None
163+
while children and not language_model:
164+
next_children = []
165+
for child in children:
166+
if type(child).__name__ == "TransformerLanguageModel":
167+
language_model = child
168+
break
169+
next_children.extend(list(child.children()))
170+
children = next_children
171+
if language_model:
172+
warn("Warning: this is an old NEMO checkpoint format and will be deprecated soon.")
173+
layers = list(language_model.embedding.children()) + list(
174+
language_model.encoder.children()
175+
)
176+
177+
if hasattr(language_model, "output_layer"):
178+
layers.append(language_model.output_layer)
179+
180+
return layers
181+
153182
if "GPTModel" in type(model).__name__:
154183
# mcore models
155184
layers = []
@@ -298,14 +327,20 @@ def is_mlp(module: nn.Module) -> bool:
298327
return any(key in type(module).__name__.upper() for key in ("MLP", "T5DENSE"))
299328

300329

330+
def _is_deepseek_moe_name(module_name: str) -> bool:
331+
return "deepseek" in module_name and "moe" in module_name
332+
333+
301334
def is_moe(module: nn.Module) -> bool:
302335
"""Returns whether the module is an MOE layer."""
303336
name = type(module).__name__.lower()
304337
# Auto-detect common MoE patterns
305338
if name.endswith("sparsemoeblock") or "moelayer" in name:
306339
return True
340+
if _is_deepseek_moe_name(name) and hasattr(module, "gate") and hasattr(module, "experts"):
341+
return True
307342
# Explicit matches for non-standard naming
308-
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"])
343+
return any(key in name for key in ["arcticmoe", "dbrxffn", "gptossmoe"])
309344

310345

311346
def is_quantlinear(module: nn.Module) -> bool:
@@ -358,7 +393,7 @@ def build_qkv(
358393
num_kv_heads = ext_config.num_kv_heads
359394

360395
if "ColumnParallelLinear" in type(qkv_module).__name__:
361-
# For Megatron-core model, num_kv_heads/num_attention_heads is the first dimension of QKV
396+
# For NEMO model, num_kv_heads/num_attention_heads is the first dimension of QKV
362397
model_metadata_config["head_is_first_dim"] = True
363398

364399
qkv_weight = qkv_module.weight
@@ -965,14 +1000,17 @@ def module_match_name_list(module, name_list):
9651000
"""
9661001
return any(name.lower() in type(module).__name__.lower() for name in name_list)
9671002

968-
if module_match_name_list(
1003+
module_name = type(module).__name__.lower()
1004+
1005+
if _is_deepseek_moe_name(module_name):
1006+
return ["gate_proj", "down_proj", "up_proj"]
1007+
elif module_match_name_list(
9691008
module,
9701009
[
9711010
"Qwen2MoeSparseMoeBlock",
9721011
"Qwen3MoeSparseMoeBlock",
9731012
"Qwen3NextSparseMoeBlock",
9741013
"Qwen3_5MoeSparseMoeBlock",
975-
"DeepseekMoE",
9761014
],
9771015
):
9781016
return ["gate_proj", "down_proj", "up_proj"]
@@ -1455,7 +1493,7 @@ def _set_layer_config_from_metaconfig(layer_config, metaconfig):
14551493
if k in metaconfig:
14561494
setattr(layer_config, name, metaconfig[k])
14571495

1458-
# MCore use "rope" as an alias for "rope_gpt_neox"
1496+
# MCore / NeMo use "rope" as an alias for "rope_gpt_neox"
14591497
if layer_config.position_embedding_type == "rope":
14601498
layer_config.position_embedding_type = "rope_gpt_neox"
14611499

modelopt/torch/export/quant_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,12 +1216,30 @@ def _update_svdquant(modules, new_pre_quant_scale):
12161216
# Mathematical equivalence:
12171217
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
12181218
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
1219-
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
1219+
(
1220+
[
1221+
"LlamaAttention",
1222+
"Qwen3Attention",
1223+
"Qwen3MoeAttention",
1224+
"DeepseekV2Attention",
1225+
"DeepseekV3Attention",
1226+
],
1227+
("v_proj", "o_proj"),
1228+
),
12201229
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
12211230
# Mathematical equivalence:
12221231
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
12231232
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
1224-
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
1233+
(
1234+
[
1235+
"LlamaMLP",
1236+
"Qwen3MLP",
1237+
"Qwen3MoeMLP",
1238+
"DeepseekV2MLP",
1239+
"DeepseekV3MLP",
1240+
],
1241+
("up_proj", "down_proj"),
1242+
),
12251243
]
12261244

12271245

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
import torch.nn as nn
18+
19+
from modelopt.torch.export.hf_config_map import HF_CONFIG_MAP
20+
from modelopt.torch.export.layer_utils import get_expert_linear_names, get_experts_list, is_moe
21+
from modelopt.torch.export.quant_utils import PQS_FUSE_MODULE_MAPPING
22+
23+
24+
class _FakeDeepseekExpert(nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.gate_proj = nn.Linear(8, 16, bias=False)
28+
self.down_proj = nn.Linear(16, 8, bias=False)
29+
self.up_proj = nn.Linear(8, 16, bias=False)
30+
31+
32+
class _FakeDeepseekGate(nn.Module):
33+
def __init__(self, num_experts=2):
34+
super().__init__()
35+
self.top_k = 1
36+
self.n_routed_experts = num_experts
37+
self.gating_dim = 8
38+
self.weight = nn.Parameter(torch.empty(num_experts, 8))
39+
nn.init.normal_(self.weight)
40+
41+
42+
class DeepseekV3MoE(nn.Module):
43+
def __init__(self, num_experts=2):
44+
super().__init__()
45+
self.gate = _FakeDeepseekGate(num_experts)
46+
self.experts = nn.ModuleList([_FakeDeepseekExpert() for _ in range(num_experts)])
47+
self.shared_experts = _FakeDeepseekExpert()
48+
49+
50+
def test_is_moe_detects_deepseek_v3_moe():
51+
assert is_moe(DeepseekV3MoE())
52+
53+
54+
def test_get_expert_linear_names_for_deepseek_v3():
55+
assert get_expert_linear_names(DeepseekV3MoE()) == ["gate_proj", "down_proj", "up_proj"]
56+
57+
58+
def test_get_experts_list_for_deepseek_model_type():
59+
module = DeepseekV3MoE(num_experts=3)
60+
61+
experts_list = get_experts_list(module, "deepseekv3forcausallm")
62+
63+
assert len(experts_list) == 3
64+
assert all(len(expert_group) == 3 for expert_group in experts_list)
65+
assert experts_list[0][0] is module.experts[0].gate_proj
66+
assert experts_list[1][1] is module.experts[1].down_proj
67+
assert experts_list[2][2] is module.experts[2].up_proj
68+
69+
70+
def test_hf_config_map_supports_deepseek_num_experts():
71+
assert any(
72+
output_name == "moe_num_experts" and "n_routed_experts" in input_names
73+
for input_names, output_name in HF_CONFIG_MAP
74+
)
75+
76+
77+
def test_prequant_fuse_mapping_covers_deepseek_v3():
78+
assert any(
79+
"DeepseekV3Attention" in module_names and linear_pair == ("v_proj", "o_proj")
80+
for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING
81+
)
82+
assert any(
83+
"DeepseekV3MLP" in module_names and linear_pair == ("up_proj", "down_proj")
84+
for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING
85+
)

0 commit comments

Comments
 (0)