@@ -64,8 +64,8 @@ def get_mtp_layer_id(config, model_type):
6464 if not hasattr (config , "num_hidden_layers" ):
6565 raise ValueError ("'num_hidden_layers' not found in model config." )
6666
67- # For DeepSeek V3/V3.2/R1 and GLM4 , MTP layer is the last layer
68- if model_type in ["deepseek_v3" , "deepseek_v32" , "glm4_moe" ]:
67+ # For DeepSeek V3/V3.2/R1, GLM4 and GLM5 , MTP layer is the last layer
68+ if model_type in ["deepseek_v3" , "deepseek_v32" , "glm4_moe" , "glm_moe_dsa" ]:
6969 return config .num_hidden_layers
7070
7171 raise ValueError (f"Unsupported model type for MTP export: { model_type } " )
@@ -77,6 +77,7 @@ def get_mtp_model_type(model_type):
7777 "deepseek_v3" : "deepseek_v3_mtp" , # Used for V3 and R1
7878 "deepseek_v32" : "deepseek_v32_mtp" , # Used for V3.2
7979 "glm4_moe" : "glm4_moe_mtp" ,
80+ "glm_moe_dsa" : "glm_moe_dsa_mtp" ,
8081 }
8182 return mapping .get (model_type , f"{ model_type } _mtp" )
8283
@@ -87,6 +88,7 @@ def get_mtp_architecture(model_type):
8788 "deepseek_v3" : "DeepseekMTPForCausalLM" , # Used for V3 and R1
8889 "deepseek_v32" : "DeepseekV32MtpForCausalLM" , # Used for V3.2
8990 "glm4_moe" : "Glm4MoeMtpForCausalLM" ,
91+ "glm_moe_dsa" : "GlmMoeDsaMtpForCausalLM" ,
9092 }
9193 return mapping .get (model_type , "MtpForCausalLM" )
9294
@@ -105,11 +107,8 @@ def update_and_save_config(config, output_dir, model_type):
105107 "quantization_config" : "" ,
106108 }
107109
108- # Model-specific updates
109- if model_type == "deepseek_v3" : # Used for V3 and R1
110- updates ["first_k_dense_replace" ] = 0
111- elif model_type == "deepseek_v32" : # Used for V3.2
112- updates ["first_k_dense_replace" ] = 0
110+ # Keep consistent with MTP exported config requirements.
111+ updates ["first_k_dense_replace" ] = 0
113112
114113 new_config .update (updates )
115114
@@ -120,7 +119,11 @@ def update_and_save_config(config, output_dir, model_type):
120119def copy_non_safetensors_files (input_dir , output_dir ):
121120 for filename in os .listdir (input_dir ):
122121 src_file_path = os .path .join (input_dir , filename )
123- if os .path .isfile (src_file_path ) and not filename .endswith (".safetensors" ):
122+ if (
123+ os .path .isfile (src_file_path )
124+ and not filename .endswith (".safetensors" )
125+ and not filename .endswith (".safetensors.index.json" )
126+ ):
124127 dst_file_path = os .path .join (output_dir , filename )
125128 shutil .copy2 (src_file_path , dst_file_path )
126129 print (f"All non-safetensors files have been copied to { output_dir } " )
@@ -168,15 +171,17 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)
168171
169172 try :
170173 with safe_open (file_path , framework = "pt" ) as f :
171- matching_keys = [k for k in f .keys () if k .startswith (prefix )]
174+ matching_keys = [k for k in f .keys () if ( k .startswith (prefix ) or k == "rot.weight" )]
172175
173176 if not matching_keys :
174177 print (f" No parameters starting with '{ prefix } ' found" )
175178 continue
176179
177180 for key in matching_keys :
178181 # Handle special keys that should be at model level
179- if any (special in key for special in ["embed_tokens" , "shared_head" , "enorm" , "hnorm" , "eh_proj" ]):
182+ if key == "rot.weight" :
183+ new_key = "model.rot.weight"
184+ elif any (special in key for special in ["embed_tokens" , "shared_head" , "enorm" , "hnorm" , "eh_proj" ]):
180185 new_key = key .replace (prefix , "model" )
181186 else :
182187 # Map to layer 0 for MTP model
@@ -238,7 +243,7 @@ def export_mtp_layer_parameters(input_dir, output_dir, mtp_layer_id, model_type)
238243 "--model-type" ,
239244 type = str ,
240245 default = None ,
241- help = "Model type (deepseek_v3, deepseek_v32, glm4_moe). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'." ,
246+ help = "Model type (deepseek_v3, deepseek_v32, glm4_moe, glm_moe_dsa ). If not specified, will auto-detect. Note: DeepSeek V3 and R1 use 'deepseek_v3', V3.2 uses 'deepseek_v32'." ,
242247 )
243248 args = parser .parse_args ()
244249
0 commit comments