Skip to content

Commit 452c5a0

Browse files
authored
GLM-4.7 MTP support (#792)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Enable GLM-4.7 PTQ workflow, including loading the standalone MTP modules and export as-is. ## Usage <!-- You can potentially add a usage example below. --> ```python python3 hf_ptq.py --pyt_ckpt_path /home/omniml_data_3/models/GLM-4.7 --qformat nvfp4_mlp_only --export_path /home/omniml_data_3/zhiyuc/checkpoints/GLM-4.7-NVFP4-0203 --trust_remote_code ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantization support for GLM-4.7 model with automatic handling of specialized layer architecture. * Added image-text data calibration capabilities for Nemotron VL model quantization. * **Documentation** * Updated support matrix to reflect newly supported models and quantization features. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 944dd1a commit 452c5a0

File tree

5 files changed

+146
-1
lines changed

5 files changed

+146
-1
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ NVIDIA Model Optimizer Changelog (Linux)
1717
- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow.
1818
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.
1919
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.
20+
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
21+
- Add support for image-text data calibration in PTQ for Nemotron VL models.
2022

2123
0.41 (2026-01-19)
2224
^^^^^^^^^^^^^^^^^

examples/llm_ptq/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
109109
| QWen3 MOE, Next <sup>6</sup> || - | - | - ||
110110
| QwQ || - | - | - ||
111111
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - ||
112+
| GLM-4.7<sup>8</sup> || - | - | - ||
112113
| Kimi K2 | - | - | - | - ||
113114
| T5 ||||| - |
114115
| Whisper ||||| - |
@@ -121,7 +122,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
121122
> *<sup>4.</sup>For some models, KV cache quantization may result in a higher accuracy penalty.* \
122123
> *<sup>5.</sup>A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \
123124
> *<sup>6.</sup>Some models currently support export to HF format only.* \
124-
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)*
125+
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
126+
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*
125127
126128
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*
127129

examples/llm_ptq/example_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import copy
1717
import glob
1818
import inspect
19+
import json
1920
import os
21+
import re
2022
import shutil
2123
import sys
2224
import warnings
@@ -27,6 +29,7 @@
2729
import transformers
2830
from accelerate import infer_auto_device_map, init_empty_weights
2931
from accelerate.utils import get_max_memory
32+
from safetensors.torch import load_file
3033
from transformers import (
3134
AutoConfig,
3235
AutoModelForCausalLM,
@@ -314,6 +317,106 @@ def get_processor(
314317
return None
315318

316319

320+
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
321+
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
322+
323+
Some models store additional layers in separate safetensors files with non-standard
324+
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
325+
files even though they're referenced in model.safetensors.index.json.
326+
327+
This function detects such cases and explicitly loads the missing weights.
328+
329+
Args:
330+
model: The loaded model that may be missing weights
331+
model_path: Path to the model directory
332+
333+
Returns:
334+
List of layer prefixes that were loaded from non-standard safetensors files.
335+
These layers should typically be excluded from quantization.
336+
Empty list if no additional weights were loaded.
337+
"""
338+
model_path = Path(model_path)
339+
index_file = model_path / "model.safetensors.index.json"
340+
mtp_layer_prefixes: list[str] = []
341+
342+
if not index_file.exists():
343+
return mtp_layer_prefixes
344+
345+
# Load the index to find all referenced safetensors files
346+
with open(index_file) as f:
347+
index = json.load(f)
348+
349+
# Find all unique safetensors files referenced
350+
all_files = set(index["weight_map"].values())
351+
352+
# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
353+
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
354+
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]
355+
356+
if not non_standard_files:
357+
return mtp_layer_prefixes
358+
359+
# Check which non-standard files exist and have missing weights
360+
model_state = model.state_dict()
361+
total_loaded = 0
362+
363+
for filename in non_standard_files:
364+
filepath = model_path / filename
365+
if not filepath.exists():
366+
continue
367+
368+
# Find keys that should be in this file
369+
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
370+
371+
# Check which are missing from the model
372+
missing_keys = [k for k in expected_keys if k not in model_state]
373+
374+
if not missing_keys:
375+
# Even if weights are loaded, record the layer prefixes for exclusion
376+
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
377+
for key in expected_keys:
378+
# Extract layer prefix like "model.layers.92" or "layers.92"
379+
parts = key.split(".")
380+
for i, part in enumerate(parts):
381+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
382+
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
383+
if prefix not in mtp_layer_prefixes:
384+
mtp_layer_prefixes.append(prefix)
385+
break
386+
continue
387+
388+
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
389+
390+
# Extract unique layer prefixes for exclusion from quantization
391+
for key in missing_keys:
392+
parts = key.split(".")
393+
for i, part in enumerate(parts):
394+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
395+
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
396+
if prefix not in mtp_layer_prefixes:
397+
mtp_layer_prefixes.append(prefix)
398+
break
399+
400+
# Load the weights to CPU first, load_state_dict will handle device placement
401+
weights = load_file(str(filepath), device="cpu")
402+
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
403+
404+
# Load into model
405+
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
406+
total_loaded += len(weights_to_load)
407+
408+
if missing:
409+
print(f" Warning: {len(missing)} keys still missing after loading {filename}")
410+
411+
if total_loaded > 0:
412+
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
413+
414+
if mtp_layer_prefixes:
415+
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
416+
417+
return mtp_layer_prefixes
418+
419+
317420
def get_dtype(dtype):
318421
if dtype == "bf16":
319422
dtype = torch.bfloat16
@@ -473,6 +576,12 @@ def get_model(
473576
if device == "cuda" and not is_model_on_gpu(model):
474577
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
475578

579+
# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
580+
# Store the MTP layer prefixes on the model for later exclusion from quantization
581+
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
582+
if mtp_layer_prefixes:
583+
model._mtp_layer_prefixes = mtp_layer_prefixes
584+
476585
return model
477586

478587

examples/llm_ptq/hf_ptq.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
get_tokenizer,
3232
is_enc_dec,
3333
is_nemotron_vl,
34+
load_mtp_weights_if_needed,
3435
run_nemotron_vl_preview,
3536
)
3637
from torch.utils.data import DataLoader
@@ -348,6 +349,12 @@ def load_model(args: argparse.Namespace):
348349
)
349350
calibration_only = True
350351

352+
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
353+
# Store the MTP layer prefixes on the model for later exclusion from quantization
354+
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
355+
if mtp_layer_prefixes:
356+
full_model._mtp_layer_prefixes = mtp_layer_prefixes
357+
351358
model_type = get_model_type(full_model)
352359

353360
device = full_model.device
@@ -878,6 +885,19 @@ def quantize_main(
878885
KV_QUANT_CFG_CHOICES,
879886
)
880887

888+
# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
889+
# These layers are typically speculative decoding layers that should be exported as-is
890+
mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None)
891+
if mtp_layer_prefixes:
892+
import copy
893+
894+
quant_cfg = copy.deepcopy(quant_cfg)
895+
for prefix in mtp_layer_prefixes:
896+
# Add exclusion pattern for this MTP layer (e.g., "*layers.92*")
897+
pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*"
898+
quant_cfg["quant_cfg"][pattern] = {"enable": False}
899+
print(f"Excluding MTP layer from quantization: {pattern}")
900+
881901
if args.qformat in QUANT_CFG_CHOICES:
882902
mono_quantize(
883903
args,

modelopt/torch/export/unified_export_hf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,18 @@ def _export_transformers_checkpoint(
710710

711711
quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)
712712

713+
# Add MTP layer prefixes to exclude_modules if they were excluded from quantization
714+
# This ensures they appear in quantization_config["ignore"] in config.json
715+
mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None)
716+
if mtp_layer_prefixes:
717+
exclude_modules = quant_config["quantization"].setdefault("exclude_modules", [])
718+
for prefix in mtp_layer_prefixes:
719+
# Add wildcard pattern to exclude all submodules under this MTP layer
720+
pattern = f"{prefix}*"
721+
if pattern not in exclude_modules:
722+
exclude_modules.append(pattern)
723+
print(f"Adding MTP layer to quantization_config ignore: {pattern}")
724+
713725
# Process all quantized modules and export weights
714726
_process_quantized_modules(model, dtype, is_modelopt_qlora)
715727

0 commit comments

Comments
 (0)