Skip to content

Commit 0b8f87b

Browse files
committed
feat: add glm5 mtp export support and optional rot preprocessing for npu draft model.
1 parent 8eb6391 commit 0b8f87b

2 files changed

Lines changed: 39 additions & 12 deletions

File tree

tools/export_mtp.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def get_mtp_layer_id(config, model_type):
6464
if not hasattr(config, "num_hidden_layers"):
6565
raise ValueError("'num_hidden_layers' not found in model config.")
6666

67-
# For DeepSeek V3/V3.2/R1 and GLM4, MTP layer is the last layer
68-
if model_type in ["deepseek_v3", "deepseek_v32", "glm4_moe"]:
67+
# For DeepSeek V3/V3.2/R1, GLM4 and GLM5, MTP layer is the last layer
68+
if model_type in ["deepseek_v3", "deepseek_v32", "glm4_moe", "glm_moe_dsa"]:
6969
return config.num_hidden_layers
7070

7171
raise ValueError(f"Unsupported model type for MTP export: {model_type}")
@@ -77,6 +77,7 @@ def get_mtp_model_type(model_type):
7777
"deepseek_v3": "deepseek_v3_mtp", # Used for V3 and R1
7878
"deepseek_v32": "deepseek_v32_mtp", # Used for V3.2
7979
"glm4_moe": "glm4_moe_mtp",
80+
"glm_moe_dsa": "glm_moe_dsa_mtp",
8081
}
8182
return mapping.get(model_type, f"{model_type}_mtp")
8283

@@ -87,6 +88,7 @@ def get_mtp_architecture(model_type):
8788
"deepseek_v3": "DeepseekMTPForCausalLM", # Used for V3 and R1
8889
"deepseek_v32": "DeepseekV32MtpForCausalLM", # Used for V3.2
8990
"glm4_moe": "Glm4MoeMtpForCausalLM",
91+
"glm_moe_dsa": "GlmMoeDsaMtpForCausalLM",
9092
}
9193
return mapping.get(model_type, "MtpForCausalLM")
9294

@@ -105,11 +107,8 @@ def update_and_save_config(config, output_dir, model_type):
105107
"quantization_config": "",
106108
}
107109

108-
# Model-specific updates
109-
if model_type == "deepseek_v3": # Used for V3 and R1
110-
updates["first_k_dense_replace"] = 0
111-
elif model_type == "deepseek_v32": # Used for V3.2
112-
updates["first_k_dense_replace"] = 0
110+
# Keep consistent with MTP exported config requirements.
111+
updates["first_k_dense_replace"] = 0
113112

114113
new_config.update(updates)
115114

@@ -120,7 +119,11 @@ def update_and_save_config(config, output_dir, model_type):
120119
def copy_non_safetensors_files(input_dir, output_dir):
121120
for filename in os.listdir(input_dir):
122121
src_file_path = os.path.join(input_dir, filename)
123-
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
122+
if (
123+
os.path.isfile(src_file_path)
124+
and not filename.endswith(".safetensors")
125+
and not filename.endswith(".safetensors.index.json")
126+
):
124127
dst_file_path = os.path.join(output_dir, filename)
125128
shutil.copy2(src_file_path, dst_file_path)
126129
print(f"All non-safetensors files have been copied to {output_dir}")
@@ -168,15 +171,17 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)
168171

169172
try:
170173
with safe_open(file_path, framework="pt") as f:
171-
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
174+
matching_keys = [k for k in f.keys() if (k.startswith(prefix) or k == "rot.weight")]
172175

173176
if not matching_keys:
174177
print(f" No parameters starting with '{prefix}' found")
175178
continue
176179

177180
for key in matching_keys:
178181
# Handle special keys that should be at model level
179-
if any(special in key for special in ["embed_tokens", "shared_head", "enorm", "hnorm", "eh_proj"]):
182+
if key == "rot.weight":
183+
new_key = "model.rot.weight"
184+
elif any(special in key for special in ["embed_tokens", "shared_head", "enorm", "hnorm", "eh_proj"]):
180185
new_key = key.replace(prefix, "model")
181186
else:
182187
# Map to layer 0 for MTP model
@@ -238,7 +243,7 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)
238243
"--model-type",
239244
type=str,
240245
default=None,
241-
help="Model type (deepseek_v3, deepseek_v32, glm4_moe). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'.",
246+
help="Model type (deepseek_v3, deepseek_v32, glm4_moe, glm_moe_dsa). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'.",
242247
)
243248
args = parser.parse_args()
244249

xllm/models/llm/npu/mtp_model_base.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class MtpModelImplBase : public torch::nn::Module {
6868
// MTP extra module
6969
eh_proj_ =
7070
register_module("eh_proj", layer::NpuColumnParallelLinear(context));
71+
rot_ = register_module("rot", layer::NpuColumnParallelLinear(context));
7172
enorm_ = register_module("enorm", layer::NpuRMSNorm(context));
7273
hnorm_ = register_module("hnorm", layer::NpuRMSNorm(context));
7374
final_norm_ = register_module("final_norm", layer::NpuRMSNorm(context));
@@ -104,7 +105,11 @@ class MtpModelImplBase : public torch::nn::Module {
104105
if (input_embedding.defined()) {
105106
h = input_embedding;
106107
}
107-
torch::Tensor hnorm = hnorm_(h, 0);
108+
torch::Tensor hnorm_input = h;
109+
if (enable_rot_) {
110+
hnorm_input = rot_(hnorm_input, /*nodeId=*/0);
111+
}
112+
torch::Tensor hnorm = hnorm_(hnorm_input, 0);
108113
CHECK_EQ(enorm.dim(), hnorm.dim());
109114
CHECK_EQ(enorm.size(0), hnorm.size(0));
110115
h = torch::cat({enorm, hnorm}, /*dim=*/-1);
@@ -196,6 +201,15 @@ class MtpModelImplBase : public torch::nn::Module {
196201

197202
// load the weight from the checkpoint
198203
virtual void load_state_dict(const StateDict& state_dict) {
204+
if (state_dict.get_tensor("rot.weight").defined()) {
205+
if (!enable_rot_) {
206+
LOG(INFO) << "Detected rot.weight in MTP weights, enable optional rot "
207+
"linear before hnorm.";
208+
}
209+
enable_rot_ = true;
210+
rot_->load_state_dict(state_dict.get_dict_with_prefix("rot."));
211+
}
212+
199213
// call each layer's load_state_dict function
200214
for (int i = 0; i < layers_.size(); i++) {
201215
layers_[i]->load_state_dict(
@@ -213,6 +227,9 @@ class MtpModelImplBase : public torch::nn::Module {
213227
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
214228
".");
215229
}
230+
if (enable_rot_) {
231+
rot_->verify_loaded_weights(prefix + "rot.");
232+
}
216233
eh_proj_->verify_loaded_weights(prefix + "eh_proj.");
217234
enorm_->verify_loaded_weights(prefix + "enorm.");
218235
hnorm_->verify_loaded_weights(prefix + "hnorm.");
@@ -223,6 +240,9 @@ class MtpModelImplBase : public torch::nn::Module {
223240
for (int i = 0; i < layers_.size(); i++) {
224241
layers_[i]->merge_loaded_weights();
225242
}
243+
if (enable_rot_) {
244+
rot_->merge_loaded_weights();
245+
}
226246
eh_proj_->merge_loaded_weights();
227247
enorm_->merge_loaded_weights();
228248
hnorm_->merge_loaded_weights();
@@ -250,6 +270,7 @@ class MtpModelImplBase : public torch::nn::Module {
250270
layer::AttentionMask attn_mask_;
251271

252272
// MTP extra modules
273+
layer::NpuColumnParallelLinear rot_{nullptr};
253274
layer::NpuColumnParallelLinear eh_proj_{nullptr};
254275
layer::NpuRMSNorm enorm_{nullptr};
255276
layer::NpuRMSNorm hnorm_{nullptr};
@@ -259,6 +280,7 @@ class MtpModelImplBase : public torch::nn::Module {
259280
std::vector<DecoderLayerType> layers_;
260281

261282
bool layer_forward_interrupted_ = false;
283+
bool enable_rot_ = false;
262284

263285
private:
264286
std::string model_type_;

0 commit comments

Comments
 (0)