@@ -316,10 +316,10 @@ def get_processor(
316316 return None
317317
318318
319- def load_mtp_weights_if_needed (
319+ def load_mtp_weights (
320320 model : torch .nn .Module , model_path : str
321321) -> tuple [list [str ], dict [str , torch .Tensor ]]:
322- """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7) .
322+ """Load MTP weights from the model checkpoint .
323323
324324 Some models store additional layers in separate safetensors files with non-standard
325325 names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
@@ -335,6 +335,7 @@ def load_mtp_weights_if_needed(
335335 List of layer prefixes that were loaded from non-standard safetensors files.
336336 These layers should typically be excluded from quantization.
337337 Empty list if no additional weights were loaded.
338+ Dictionary of MTP weights that were not loaded into the model state dict.
338339 """
339340 model_path = Path (model_path )
340341 index_file = model_path / "model.safetensors.index.json"
@@ -565,14 +566,6 @@ def get_model(
565566 if device == "cuda" and not is_model_on_gpu (model ):
566567 print ("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM" )
567568
568- # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
569- # Store the MTP layer prefixes on the model for later exclusion from quantization
570- mtp_layer_prefixes , mtp_state_dict = load_mtp_weights_if_needed (model , ckpt_path )
571- if mtp_layer_prefixes :
572- model ._mtp_layer_prefixes = mtp_layer_prefixes
573- if mtp_state_dict :
574- model ._mtp_state_dict = mtp_state_dict
575-
576569 return model
577570
578571
0 commit comments