Skip to content

Commit 1045f38

Browse files
[None][feat] Qwen3-Next MTP (#11370)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent 6efc283 commit 1045f38

File tree

6 files changed

+1135
-619
lines changed

6 files changed

+1135
-619
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def preprocess_weights(self, weights: dict) -> dict:
5656
config = self.config.pretrained_config
5757
tp_size = self.config.mapping.tp_size
5858
tp_rank = self.config.mapping.tp_rank
59+
mtp_layer_offset = config.num_hidden_layers
5960

6061
if self.config.mapping.enable_attention_dp:
6162
tp_size = 1
@@ -66,10 +67,28 @@ def preprocess_weights(self, weights: dict) -> dict:
6667
linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128
6768
linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128
6869

70+
mtp_mapping = {
71+
"mtp.fc": "fc",
72+
"mtp.norm": "shared_head.norm",
73+
"mtp.pre_fc_norm_embedding": "pre_fc_norm_embedding",
74+
"mtp.pre_fc_norm_hidden": "pre_fc_norm_hidden",
75+
}
76+
6977
new_weights = {}
7078
for name, _ in weights.items():
7179
key = name
7280

81+
if key.startswith("mtp.layers."):
82+
_, _, mtp_layer_idx, module_name = key.split(".", 3)
83+
key = (f"model.layers.{mtp_layer_offset + int(mtp_layer_idx)}."
84+
f"{module_name}")
85+
elif key.startswith("mtp."):
86+
for mtp_prefix, trtllm_name in mtp_mapping.items():
87+
if key.startswith(mtp_prefix):
88+
suffix = key[len(mtp_prefix):]
89+
key = f"model.layers.{mtp_layer_offset}.{trtllm_name}{suffix}"
90+
break
91+
7392
if "A_log" in key:
7493
w = split(weights[name], tp_size, tp_rank)
7594
w = w.to(torch.float32)

0 commit comments

Comments
 (0)