Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions tools/export_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_mtp_layer_id(config, model_type):
if not hasattr(config, "num_hidden_layers"):
raise ValueError("'num_hidden_layers' not found in model config.")

# For DeepSeek V3/V3.2/R1 and GLM4, MTP layer is the last layer
if model_type in ["deepseek_v3", "deepseek_v32", "glm4_moe"]:
# For DeepSeek V3/V3.2/R1, GLM4 and GLM5, MTP layer is the last layer
if model_type in ["deepseek_v3", "deepseek_v32", "glm4_moe", "glm_moe_dsa"]:
return config.num_hidden_layers

raise ValueError(f"Unsupported model type for MTP export: {model_type}")
Expand All @@ -77,6 +77,7 @@ def get_mtp_model_type(model_type):
"deepseek_v3": "deepseek_v3_mtp", # Used for V3 and R1
"deepseek_v32": "deepseek_v32_mtp", # Used for V3.2
"glm4_moe": "glm4_moe_mtp",
"glm_moe_dsa": "glm_moe_dsa_mtp",
}
return mapping.get(model_type, f"{model_type}_mtp")

Expand All @@ -87,6 +88,7 @@ def get_mtp_architecture(model_type):
"deepseek_v3": "DeepseekMTPForCausalLM", # Used for V3 and R1
"deepseek_v32": "DeepseekV32MtpForCausalLM", # Used for V3.2
"glm4_moe": "Glm4MoeMtpForCausalLM",
"glm_moe_dsa": "GlmMoeDsaMtpForCausalLM",
}
return mapping.get(model_type, "MtpForCausalLM")

Expand All @@ -105,11 +107,8 @@ def update_and_save_config(config, output_dir, model_type):
"quantization_config": "",
}

# Model-specific updates
if model_type == "deepseek_v3": # Used for V3 and R1
updates["first_k_dense_replace"] = 0
elif model_type == "deepseek_v32": # Used for V3.2
updates["first_k_dense_replace"] = 0
# Keep consistent with MTP exported config requirements.
updates["first_k_dense_replace"] = 0

new_config.update(updates)

Expand All @@ -120,7 +119,11 @@ def update_and_save_config(config, output_dir, model_type):
def copy_non_safetensors_files(input_dir, output_dir):
for filename in os.listdir(input_dir):
src_file_path = os.path.join(input_dir, filename)
if os.path.isfile(src_file_path) and not filename.endswith(".safetensors"):
if (
os.path.isfile(src_file_path)
and not filename.endswith(".safetensors")
and not filename.endswith(".safetensors.index.json")
):
dst_file_path = os.path.join(output_dir, filename)
shutil.copy2(src_file_path, dst_file_path)
print(f"All non-safetensors files have been copied to {output_dir}")
Expand Down Expand Up @@ -168,15 +171,17 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)

try:
with safe_open(file_path, framework="pt") as f:
matching_keys = [k for k in f.keys() if k.startswith(prefix)]
matching_keys = [k for k in f.keys() if (k.startswith(prefix) or k == "rot.weight")]

if not matching_keys:
print(f" No parameters starting with '{prefix}' found")
continue

for key in matching_keys:
# Handle special keys that should be at model level
if any(special in key for special in ["embed_tokens", "shared_head", "enorm", "hnorm", "eh_proj"]):
if key == "rot.weight":
new_key = "model.rot.weight"
elif any(special in key for special in ["embed_tokens", "shared_head", "enorm", "hnorm", "eh_proj"]):
new_key = key.replace(prefix, "model")
Comment thread
sanlio36 marked this conversation as resolved.
else:
# Map to layer 0 for MTP model
Expand Down Expand Up @@ -238,7 +243,7 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)
"--model-type",
type=str,
default=None,
help="Model type (deepseek_v3, deepseek_v32, glm4_moe). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'.",
help="Model type (deepseek_v3, deepseek_v32, glm4_moe, glm_moe_dsa). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'.",
)
args = parser.parse_args()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ void NpuDeepseekV32DecoderLayerImpl::initialize_attention_parameters(
param.kvLoraRank = args.kv_lora_rank();
param.softmaxScale = sm_scale_;
// not support in glm_moe_dsa
if (quantize_type_ == "w8a8_dynamic" && args.model_type() != "glm_moe_dsa") {
bool is_glm_moe_dsa =
args.model_type().find("glm_moe_dsa") != std::string::npos;
if (quantize_type_ == "w8a8_dynamic" && !is_glm_moe_dsa) {
param.enableMlaPreprocess = true;
} else {
param.enableMlaPreprocess = false;
Expand Down
4 changes: 1 addition & 3 deletions xllm/models/llm/npu/glm5_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ REGISTER_MODEL_ARGS(glm_moe_dsa, [&] {
LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f);
LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim();
});
SET_ARG(head_dim, args->qk_nope_head_dim() + args->qk_rope_head_dim());
LOAD_ARG_OR_FUNC(
rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); });

Expand Down
4 changes: 1 addition & 3 deletions xllm/models/llm/npu/glm5_moe_mtp.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ REGISTER_MODEL_ARGS(glm_moe_dsa_mtp, [&] {
LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f);
LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim();
});
SET_ARG(head_dim, args->qk_nope_head_dim() + args->qk_rope_head_dim());
LOAD_ARG_OR_FUNC(
rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); });

Expand Down
24 changes: 23 additions & 1 deletion xllm/models/llm/npu/mtp_model_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class MtpModelImplBase : public torch::nn::Module {
// MTP extra module
eh_proj_ =
register_module("eh_proj", layer::NpuColumnParallelLinear(context));
rot_ = register_module("rot", layer::NpuColumnParallelLinear(context));
enorm_ = register_module("enorm", layer::NpuRMSNorm(context));
hnorm_ = register_module("hnorm", layer::NpuRMSNorm(context));
final_norm_ = register_module("final_norm", layer::NpuRMSNorm(context));
Expand Down Expand Up @@ -104,7 +105,11 @@ class MtpModelImplBase : public torch::nn::Module {
if (input_embedding.defined()) {
h = input_embedding;
}
torch::Tensor hnorm = hnorm_(h, 0);
torch::Tensor hnorm_input = h;
if (enable_rot_) {
hnorm_input = rot_(hnorm_input, /*nodeId=*/0);
}
torch::Tensor hnorm = hnorm_(hnorm_input, 0);
CHECK_EQ(enorm.dim(), hnorm.dim());
CHECK_EQ(enorm.size(0), hnorm.size(0));
h = torch::cat({enorm, hnorm}, /*dim=*/-1);
Expand Down Expand Up @@ -196,6 +201,15 @@ class MtpModelImplBase : public torch::nn::Module {

// load the weight from the checkpoint
virtual void load_state_dict(const StateDict& state_dict) {
if (state_dict.get_tensor("rot.weight").defined()) {
if (!enable_rot_) {
LOG(INFO) << "Detected rot.weight in MTP weights, enable optional rot "
"linear before hnorm.";
}
enable_rot_ = true;
rot_->load_state_dict(state_dict.get_dict_with_prefix("rot."));
}

// call each layer's load_state_dict function
for (int i = 0; i < layers_.size(); i++) {
layers_[i]->load_state_dict(
Expand All @@ -213,6 +227,9 @@ class MtpModelImplBase : public torch::nn::Module {
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
".");
}
if (enable_rot_) {
rot_->verify_loaded_weights(prefix + "rot.");
}
eh_proj_->verify_loaded_weights(prefix + "eh_proj.");
enorm_->verify_loaded_weights(prefix + "enorm.");
hnorm_->verify_loaded_weights(prefix + "hnorm.");
Expand All @@ -223,6 +240,9 @@ class MtpModelImplBase : public torch::nn::Module {
for (int i = 0; i < layers_.size(); i++) {
layers_[i]->merge_loaded_weights();
}
if (enable_rot_) {
rot_->merge_loaded_weights();
}
eh_proj_->merge_loaded_weights();
enorm_->merge_loaded_weights();
hnorm_->merge_loaded_weights();
Expand Down Expand Up @@ -250,6 +270,7 @@ class MtpModelImplBase : public torch::nn::Module {
layer::AttentionMask attn_mask_;

// MTP extra modules
layer::NpuColumnParallelLinear rot_{nullptr};
layer::NpuColumnParallelLinear eh_proj_{nullptr};
layer::NpuRMSNorm enorm_{nullptr};
layer::NpuRMSNorm hnorm_{nullptr};
Expand All @@ -259,6 +280,7 @@ class MtpModelImplBase : public torch::nn::Module {
std::vector<DecoderLayerType> layers_;

bool layer_forward_interrupted_ = false;
bool enable_rot_ = false;

private:
std::string model_type_;
Expand Down
Loading