2828import transformers
2929from accelerate import infer_auto_device_map , init_empty_weights
3030from accelerate .utils import get_max_memory
31- from safetensors .torch import load_file
3231from transformers import (
3332 AutoConfig ,
3433 AutoModelForCausalLM ,
@@ -316,32 +315,36 @@ def get_processor(
316315 return None
317316
318317
319- def load_mtp_weights (
320- model : torch .nn .Module , model_path : str
321- ) -> tuple [list [str ], dict [str , torch .Tensor ]]:
322- """Load MTP weights from the model checkpoint.
318+ def load_mtp_weights_if_needed (model : torch .nn .Module , model_path : str ) -> list [str ]:
319+ """Detect MTP weights in separate safetensors files (e.g., GLM-4.7).
323320
324- Some models store additional layers in separate safetensors files with non-standard
325- names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
326- files even though they're referenced in model.safetensors.index.json.
321+ Some models store MTP (Multi-Token Prediction) layers in separate safetensors files
322+ (e.g., mtp.safetensors) that are referenced in model.safetensors.index.json but
323+ not loaded by HuggingFace transformers (because the model architecture doesn't
324+ include these layers).
327325
328- This function detects such cases and explicitly loads the missing weights.
326+ This function:
327+ 1. Detects non-standard safetensors files with weights not in the model
328+ 2. Stores info about these files on the model for later export (model._mtp_files_info)
329+ 3. Returns the layer prefixes (e.g., ["model.layers.92"]) for quantization exclusion
330+
331+ Note: The weights are NOT loaded into the model (since the model architecture doesn't
332+ support them), but we track them so they can be copied during export.
329333
330334 Args:
331- model: The loaded model that may be missing weights
335+ model: The loaded model
332336 model_path: Path to the model directory
333337
334338 Returns:
335- List of layer prefixes that were loaded from non-standard safetensors files .
339+ List of layer prefixes that contain MTP weights (e.g., ["model.layers.92"]) .
336340 These layers should typically be excluded from quantization.
337- Empty list if no additional weights were loaded.
338- Dictionary of MTP weights that were not loaded into the model state dict.
341+ Empty list if no MTP weights were found.
339342 """
340343 model_path = Path (model_path )
341344 index_file = model_path / "model.safetensors.index.json"
342345
343346 if not index_file .exists ():
344- return [], {}
347+ return []
345348
346349 # Load the index to find all referenced safetensors files
347350 index = json .load (open (index_file ))
@@ -353,58 +356,54 @@ def load_mtp_weights(
353356 mtp_weight_map .setdefault (v , []).append (k )
354357
355358 if not mtp_weight_map :
356- return [], {}
359+ return []
357360
358- def _extract_layer_prefixes (keys ):
359- mtp_layer_prefixes = set ()
360- for key in keys :
361- parts = key .split ("." )
362- for i , part in enumerate (parts ):
363- if part == "layers" and i + 1 < len (parts ) and parts [i + 1 ].isdigit ():
364- prefix = "." .join (parts [: i + 2 ])
365- mtp_layer_prefixes .add (prefix )
366- break
367-
368- return mtp_layer_prefixes
369-
370- # Flatten mtp_weight_map.values() (list of list of str) to a single list of str
371- mtp_keys = [k for keys in mtp_weight_map .values () for k in keys ]
372- mtp_layer_prefixes = _extract_layer_prefixes (mtp_keys )
373-
374- # Check which non-standard files exist and have missing weights
361+ # Check which non-standard files exist and have weights not in the model
375362 model_state = model .state_dict ()
376- total_loaded = 0
377-
378- not_in_state_dict = {}
363+ mtp_files_info = [] # Store info for export: [{source_path, filename, weight_map}]
364+ mtp_layer_prefixes = []
379365
380- for filename , mtp_keys in mtp_weight_map . items () :
366+ for filename in mtp_weight_map :
381367 filepath = model_path / filename
382368 if not filepath .exists ():
383369 continue
384370
385- print (f"Loading { len (mtp_keys )} mtp weights from { filename } ..." )
386- weights = load_file (str (filepath ), device = "cpu" )
387- weights = {k : v for k , v in weights .items () if k in mtp_keys }
388- # Load the MTP weights to the model state dict
389- in_state_dict = {k : weights [k ] for k in weights if k in model_state }
390- not_in_state_dict = not_in_state_dict | {
391- k : weights [k ] for k in weights if k not in model_state
392- }
393-
394- if in_state_dict :
395- model .load_state_dict (in_state_dict , strict = False )
396- total_loaded += len (in_state_dict )
397-
398- if total_loaded > 0 :
399- print (
400- f"✓ Successfully loaded { total_loaded } MTP weights, "
401- f"{ len (not_in_state_dict )} MTP weights not in model.state_dict"
402- )
371+ # Find keys that should be in this file
372+ expected_keys = [k for k , v in index ["weight_map" ].items () if v == filename ]
373+
374+ # Check which are missing from the model (i.e., model doesn't have these modules)
375+ missing_keys = [k for k in expected_keys if k not in model_state ]
376+
377+ # Extract layer prefixes from all expected keys
378+ for key in expected_keys :
379+ parts = key .split ("." )
380+ for i , part in enumerate (parts ):
381+ if part == "layers" and i + 1 < len (parts ) and parts [i + 1 ].isdigit ():
382+ prefix = "." .join (parts [: i + 2 ]) # e.g., "model.layers.92"
383+ if prefix not in mtp_layer_prefixes :
384+ mtp_layer_prefixes .append (prefix )
385+ break
386+
387+ # If there are missing keys, the model architecture doesn't support these weights
388+ # Store info for copying during export
389+ if missing_keys :
390+ file_weight_map = dict .fromkeys (expected_keys , filename )
391+ mtp_files_info .append ({
392+ "source_path" : str (filepath ),
393+ "filename" : filename ,
394+ "weight_map" : file_weight_map ,
395+ })
396+ print (f"Found { len (expected_keys )} MTP weights in { filename } (will copy during export)" )
397+
398+ # Store MTP file info on the model for use during export
399+ if mtp_files_info :
400+ model ._mtp_files_info = mtp_files_info
401+ print (f"✓ Stored { len (mtp_files_info )} MTP file(s) info for export" )
403402
404403 if mtp_layer_prefixes :
405404 print (f"✓ Detected MTP layers to exclude from quantization: { mtp_layer_prefixes } " )
406405
407- return list ( mtp_layer_prefixes ), not_in_state_dict
406+ return mtp_layer_prefixes
408407
409408
410409def get_dtype (dtype ):
0 commit comments