Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 48 additions & 64 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import inspect
import json
import os
import re
import shutil
import sys
import warnings
Expand Down Expand Up @@ -317,8 +316,10 @@ def get_processor(
return None


def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
Expand All @@ -334,87 +335,76 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"
mtp_layer_prefixes: list[str] = []

if not index_file.exists():
return mtp_layer_prefixes
return [], {}

# Load the index to find all referenced safetensors files
with open(index_file) as f:
index = json.load(f)

# Find all unique safetensors files referenced
all_files = set(index["weight_map"].values())

# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]
index = json.load(open(index_file))
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)

if not mtp_weight_map:
return [], {}

def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break

if not non_standard_files:
return mtp_layer_prefixes

# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)

# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0

for filename in non_standard_files:
not_in_state_dict = {}

for filename, mtp_keys in mtp_weight_map.items():
filepath = model_path / filename
if not filepath.exists():
continue

# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]

# Check which are missing from the model
missing_keys = [k for k in expected_keys if k not in model_state]

if not missing_keys:
# Even if weights are loaded, record the layer prefixes for exclusion
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
for key in expected_keys:
# Extract layer prefix like "model.layers.92" or "layers.92"
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you also tested GLM-4.7 with the changes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break
continue

print(f"Loading {len(missing_keys)} missing weights from {filename}...")

# Extract unique layer prefixes for exclusion from quantization
for key in missing_keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break

# Load the weights to CPU first, load_state_dict will handle device placement
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}

# Load into model
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}

if missing:
print(f" Warning: {len(missing)} keys still missing after loading {filename}")
if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)

if total_loaded > 0:
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")

return mtp_layer_prefixes
return list(mtp_layer_prefixes), not_in_state_dict


def get_dtype(dtype):
Expand Down Expand Up @@ -576,12 +566,6 @@ def get_model(
if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
if mtp_layer_prefixes:
model._mtp_layer_prefixes = mtp_layer_prefixes

return model


Expand Down
16 changes: 9 additions & 7 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_tokenizer,
is_enc_dec,
is_nemotron_vl,
load_mtp_weights_if_needed,
load_mtp_weights,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -349,12 +349,6 @@ def load_model(args: argparse.Namespace):
)
calibration_only = True

# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
if mtp_layer_prefixes:
full_model._mtp_layer_prefixes = mtp_layer_prefixes

model_type = get_model_type(full_model)

device = full_model.device
Expand Down Expand Up @@ -632,9 +626,17 @@ def export_quantized(
"They will be set at deployment time."
)

# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)

if mtp_layer_prefixes:
full_model._mtp_layer_prefixes = mtp_layer_prefixes

export_hf_checkpoint(
full_model,
export_dir=export_path,
extra_state_dict=mtp_state_dict,
)

# Copy custom model files (Python files and JSON configs) if trust_remote_code is used
Expand Down
6 changes: 5 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def export_hf_checkpoint(
export_dir: Path | str = tempfile.gettempdir(),
save_modelopt_state: bool = False,
components: list[str] | None = None,
extra_state_dict: dict[str, torch.Tensor] | None = None,
):
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).

Expand All @@ -976,6 +977,7 @@ def export_hf_checkpoint(
save_modelopt_state: Whether to save the modelopt state_dict.
components: Only used for diffusers pipelines. Optional list of component names
to export. If None, all quantized components are exported.
extra_state_dict: Extra state dictionary to add to the exported model.
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -1012,7 +1014,9 @@ def export_hf_checkpoint(

# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)

original_config = f"{export_dir}/config.json"
Expand Down
Loading