Skip to content

Commit 20f5340

Browse files
committed
detect MTP, copy the original mtp.safetensors, update the index file
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 5e43b2a commit 20f5340

2 files changed

Lines changed: 116 additions & 55 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import transformers
2929
from accelerate import infer_auto_device_map, init_empty_weights
3030
from accelerate.utils import get_max_memory
31-
from safetensors.torch import load_file
3231
from transformers import (
3332
AutoConfig,
3433
AutoModelForCausalLM,
@@ -316,32 +315,36 @@ def get_processor(
316315
return None
317316

318317

319-
def load_mtp_weights(
320-
model: torch.nn.Module, model_path: str
321-
) -> tuple[list[str], dict[str, torch.Tensor]]:
322-
"""Load MTP weights from the model checkpoint.
318+
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
319+
"""Detect MTP weights in separate safetensors files (e.g., GLM-4.7).
323320
324-
Some models store additional layers in separate safetensors files with non-standard
325-
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
326-
files even though they're referenced in model.safetensors.index.json.
321+
Some models store MTP (Multi-Token Prediction) layers in separate safetensors files
322+
(e.g., mtp.safetensors) that are referenced in model.safetensors.index.json but
323+
not loaded by HuggingFace transformers (because the model architecture doesn't
324+
include these layers).
327325
328-
This function detects such cases and explicitly loads the missing weights.
326+
This function:
327+
1. Detects non-standard safetensors files with weights not in the model
328+
2. Stores info about these files on the model for later export (model._mtp_files_info)
329+
3. Returns the layer prefixes (e.g., ["model.layers.92"]) for quantization exclusion
330+
331+
Note: The weights are NOT loaded into the model (since the model architecture doesn't
332+
support them), but we track them so they can be copied during export.
329333
330334
Args:
331-
model: The loaded model that may be missing weights
335+
model: The loaded model
332336
model_path: Path to the model directory
333337
334338
Returns:
335-
List of layer prefixes that were loaded from non-standard safetensors files.
339+
List of layer prefixes that contain MTP weights (e.g., ["model.layers.92"]).
336340
These layers should typically be excluded from quantization.
337-
Empty list if no additional weights were loaded.
338-
Dictionary of MTP weights that were not loaded into the model state dict.
341+
Empty list if no MTP weights were found.
339342
"""
340343
model_path = Path(model_path)
341344
index_file = model_path / "model.safetensors.index.json"
342345

343346
if not index_file.exists():
344-
return [], {}
347+
return []
345348

346349
# Load the index to find all referenced safetensors files
347350
index = json.load(open(index_file))
@@ -353,58 +356,54 @@ def load_mtp_weights(
353356
mtp_weight_map.setdefault(v, []).append(k)
354357

355358
if not mtp_weight_map:
356-
return [], {}
359+
return []
357360

358-
def _extract_layer_prefixes(keys):
359-
mtp_layer_prefixes = set()
360-
for key in keys:
361-
parts = key.split(".")
362-
for i, part in enumerate(parts):
363-
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
364-
prefix = ".".join(parts[: i + 2])
365-
mtp_layer_prefixes.add(prefix)
366-
break
367-
368-
return mtp_layer_prefixes
369-
370-
# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
371-
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
372-
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)
373-
374-
# Check which non-standard files exist and have missing weights
361+
# Check which non-standard files exist and have weights not in the model
375362
model_state = model.state_dict()
376-
total_loaded = 0
377-
378-
not_in_state_dict = {}
363+
mtp_files_info = [] # Store info for export: [{source_path, filename, weight_map}]
364+
mtp_layer_prefixes = []
379365

380-
for filename, mtp_keys in mtp_weight_map.items():
366+
for filename in mtp_weight_map:
381367
filepath = model_path / filename
382368
if not filepath.exists():
383369
continue
384370

385-
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
386-
weights = load_file(str(filepath), device="cpu")
387-
weights = {k: v for k, v in weights.items() if k in mtp_keys}
388-
# Load the MTP weights to the model state dict
389-
in_state_dict = {k: weights[k] for k in weights if k in model_state}
390-
not_in_state_dict = not_in_state_dict | {
391-
k: weights[k] for k in weights if k not in model_state
392-
}
393-
394-
if in_state_dict:
395-
model.load_state_dict(in_state_dict, strict=False)
396-
total_loaded += len(in_state_dict)
397-
398-
if total_loaded > 0:
399-
print(
400-
f"✓ Successfully loaded {total_loaded} MTP weights, "
401-
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
402-
)
371+
# Find keys that should be in this file
372+
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
373+
374+
# Check which are missing from the model (i.e., model doesn't have these modules)
375+
missing_keys = [k for k in expected_keys if k not in model_state]
376+
377+
# Extract layer prefixes from all expected keys
378+
for key in expected_keys:
379+
parts = key.split(".")
380+
for i, part in enumerate(parts):
381+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
382+
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
383+
if prefix not in mtp_layer_prefixes:
384+
mtp_layer_prefixes.append(prefix)
385+
break
386+
387+
# If there are missing keys, the model architecture doesn't support these weights
388+
# Store info for copying during export
389+
if missing_keys:
390+
file_weight_map = dict.fromkeys(expected_keys, filename)
391+
mtp_files_info.append({
392+
"source_path": str(filepath),
393+
"filename": filename,
394+
"weight_map": file_weight_map,
395+
})
396+
print(f"Found {len(expected_keys)} MTP weights in {filename} (will copy during export)")
397+
398+
# Store MTP file info on the model for use during export
399+
if mtp_files_info:
400+
model._mtp_files_info = mtp_files_info
401+
print(f"✓ Stored {len(mtp_files_info)} MTP file(s) info for export")
403402

404403
if mtp_layer_prefixes:
405404
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
406405

407-
return list(mtp_layer_prefixes), not_in_state_dict
406+
return mtp_layer_prefixes
408407

409408

410409
def get_dtype(dtype):

modelopt/torch/export/unified_export_hf.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections.abc
1919
import json
2020
import re
21+
import shutil
2122
import tempfile
2223
import warnings
2324
from builtins import ValueError
@@ -954,6 +955,64 @@ def _export_diffusers_checkpoint(
954955
print(f"Export complete. Saved to: {export_dir}")
955956

956957

958+
def _copy_mtp_files_if_needed(model: nn.Module, export_dir: Path) -> None:
959+
"""Copy MTP (Multi-Token Prediction) safetensors files if they exist.
960+
961+
Some models like GLM-4.7 have MTP layers stored in separate safetensors files
962+
(e.g., mtp.safetensors) that aren't part of the model's state_dict because
963+
HuggingFace Transformers doesn't create the corresponding modules.
964+
965+
This function copies those files to the export directory and updates the
966+
model.safetensors.index.json to include the MTP weights.
967+
968+
Args:
969+
model: The model being exported (may have _mtp_files_info attribute)
970+
export_dir: The export directory path
971+
"""
972+
mtp_files_info = getattr(model, "_mtp_files_info", None)
973+
if not mtp_files_info:
974+
return
975+
976+
export_dir = Path(export_dir)
977+
index_file = export_dir / "model.safetensors.index.json"
978+
979+
# Load existing index if present
980+
if index_file.exists():
981+
with open(index_file) as f:
982+
index_data = json.load(f)
983+
else:
984+
# Create a basic index structure if it doesn't exist
985+
index_data = {"metadata": {}, "weight_map": {}}
986+
987+
# Copy each MTP file and update the index
988+
for mtp_info in mtp_files_info:
989+
source_path = Path(mtp_info["source_path"])
990+
filename = mtp_info["filename"]
991+
weight_map = mtp_info["weight_map"]
992+
993+
if not source_path.exists():
994+
print(f"Warning: MTP source file not found: {source_path}")
995+
continue
996+
997+
dest_path = export_dir / filename
998+
999+
# Copy the file
1000+
print(f"Copying MTP file: {filename}")
1001+
shutil.copy2(source_path, dest_path)
1002+
1003+
# Update the weight map in the index
1004+
for weight_name, file_name in weight_map.items():
1005+
index_data["weight_map"][weight_name] = file_name
1006+
1007+
print(f"✓ Copied {filename} with {len(weight_map)} weights")
1008+
1009+
# Write updated index
1010+
with open(index_file, "w") as f:
1011+
json.dump(index_data, f, indent=2)
1012+
1013+
print("✓ Updated model.safetensors.index.json with MTP weights")
1014+
1015+
9571016
def export_hf_checkpoint(
9581017
model: Any,
9591018
dtype: torch.dtype | None = None,
@@ -1019,6 +1078,9 @@ def export_hf_checkpoint(
10191078
save_modelopt_state=save_modelopt_state,
10201079
)
10211080

1081+
# Copy MTP files if present (e.g., GLM-4.7 mtp.safetensors)
1082+
_copy_mtp_files_if_needed(model, export_dir)
1083+
10221084
original_config = f"{export_dir}/config.json"
10231085
config_data = {}
10241086

0 commit comments

Comments
 (0)