Skip to content

Commit 6685adf

Browse files
committed
Merge main
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
2 parents 048296b + 452c5a0 commit 6685adf

9 files changed

Lines changed: 219 additions & 54 deletions

File tree

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ NVIDIA Model Optimizer Changelog (Linux)
1919
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.
2020
- Add unified Hugging Face export support for diffusers pipelines/components.
2121
- Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow.
22+
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
23+
- Add support for image-text data calibration in PTQ for Nemotron VL models.
2224

2325
0.41 (2026-01-19)
2426
^^^^^^^^^^^^^^^^^
@@ -228,7 +230,7 @@ NVIDIA Model Optimizer Changelog (Linux)
228230
- Add support for UNet ONNX quantization.
229231
- Enable ``concat_elimination`` pass by default to improve the performance of quantized ONNX models.
230232
- Enable Redundant Cast elimination pass by default in :meth:`moq.quantize <modelopt.onnx.quantization.quantize>`.
231-
- Add new attribute ``parallel_state`` to :class:`DynamicModule <modelopt.torch.opt.dynamic.DynamicModule>` to support distributed parallelism such as data parallel and tensor parallel.
233+
- Add new attribute ``parallel_state`` to :class:`QuantModule <modelopt.torch.quantization.nn.modules.quant_module.QuantModule>` to support distributed parallelism such as data parallel and tensor parallel.
232234
- Add MXFP8, NVFP4 quantized ONNX export support.
233235
- Add new example for torch quantization to ONNX for MXFP8, NVFP4 precision.
234236

examples/llm_ptq/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
109109
| QWen3 MOE, Next <sup>6</sup> || - | - | - ||
110110
| QwQ || - | - | - ||
111111
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - ||
112+
| GLM-4.7<sup>8</sup> || - | - | - ||
112113
| Kimi K2 | - | - | - | - ||
113114
| T5 ||||| - |
114115
| Whisper ||||| - |
@@ -121,7 +122,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
121122
> *<sup>4.</sup>For some models, KV cache quantization may result in a higher accuracy penalty.* \
122123
> *<sup>5.</sup>A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \
123124
> *<sup>6.</sup>Some models currently support export to HF format only.* \
124-
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)*
125+
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
126+
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*
125127
126128
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*
127129

examples/llm_ptq/example_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import copy
1717
import glob
1818
import inspect
19+
import json
1920
import os
21+
import re
2022
import shutil
2123
import sys
2224
import warnings
@@ -27,6 +29,7 @@
2729
import transformers
2830
from accelerate import infer_auto_device_map, init_empty_weights
2931
from accelerate.utils import get_max_memory
32+
from safetensors.torch import load_file
3033
from transformers import (
3134
AutoConfig,
3235
AutoModelForCausalLM,
@@ -314,6 +317,106 @@ def get_processor(
314317
return None
315318

316319

320+
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
321+
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
322+
323+
Some models store additional layers in separate safetensors files with non-standard
324+
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
325+
files even though they're referenced in model.safetensors.index.json.
326+
327+
This function detects such cases and explicitly loads the missing weights.
328+
329+
Args:
330+
model: The loaded model that may be missing weights
331+
model_path: Path to the model directory
332+
333+
Returns:
334+
List of layer prefixes that were loaded from non-standard safetensors files.
335+
These layers should typically be excluded from quantization.
336+
Empty list if no additional weights were loaded.
337+
"""
338+
model_path = Path(model_path)
339+
index_file = model_path / "model.safetensors.index.json"
340+
mtp_layer_prefixes: list[str] = []
341+
342+
if not index_file.exists():
343+
return mtp_layer_prefixes
344+
345+
# Load the index to find all referenced safetensors files
346+
with open(index_file) as f:
347+
index = json.load(f)
348+
349+
# Find all unique safetensors files referenced
350+
all_files = set(index["weight_map"].values())
351+
352+
# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
353+
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
354+
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]
355+
356+
if not non_standard_files:
357+
return mtp_layer_prefixes
358+
359+
# Check which non-standard files exist and have missing weights
360+
model_state = model.state_dict()
361+
total_loaded = 0
362+
363+
for filename in non_standard_files:
364+
filepath = model_path / filename
365+
if not filepath.exists():
366+
continue
367+
368+
# Find keys that should be in this file
369+
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
370+
371+
# Check which are missing from the model
372+
missing_keys = [k for k in expected_keys if k not in model_state]
373+
374+
if not missing_keys:
375+
# Even if weights are loaded, record the layer prefixes for exclusion
376+
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
377+
for key in expected_keys:
378+
# Extract layer prefix like "model.layers.92" or "layers.92"
379+
parts = key.split(".")
380+
for i, part in enumerate(parts):
381+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
382+
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
383+
if prefix not in mtp_layer_prefixes:
384+
mtp_layer_prefixes.append(prefix)
385+
break
386+
continue
387+
388+
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
389+
390+
# Extract unique layer prefixes for exclusion from quantization
391+
for key in missing_keys:
392+
parts = key.split(".")
393+
for i, part in enumerate(parts):
394+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
395+
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
396+
if prefix not in mtp_layer_prefixes:
397+
mtp_layer_prefixes.append(prefix)
398+
break
399+
400+
# Load the weights to CPU first, load_state_dict will handle device placement
401+
weights = load_file(str(filepath), device="cpu")
402+
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
403+
404+
# Load into model
405+
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
406+
total_loaded += len(weights_to_load)
407+
408+
if missing:
409+
print(f" Warning: {len(missing)} keys still missing after loading {filename}")
410+
411+
if total_loaded > 0:
412+
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
413+
414+
if mtp_layer_prefixes:
415+
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
416+
417+
return mtp_layer_prefixes
418+
419+
317420
def get_dtype(dtype):
318421
if dtype == "bf16":
319422
dtype = torch.bfloat16
@@ -473,6 +576,12 @@ def get_model(
473576
if device == "cuda" and not is_model_on_gpu(model):
474577
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
475578

579+
# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
580+
# Store the MTP layer prefixes on the model for later exclusion from quantization
581+
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
582+
if mtp_layer_prefixes:
583+
model._mtp_layer_prefixes = mtp_layer_prefixes
584+
476585
return model
477586

478587

examples/llm_ptq/hf_ptq.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
get_tokenizer,
3232
is_enc_dec,
3333
is_nemotron_vl,
34+
load_mtp_weights_if_needed,
3435
run_nemotron_vl_preview,
3536
)
3637
from torch.utils.data import DataLoader
@@ -348,6 +349,12 @@ def load_model(args: argparse.Namespace):
348349
)
349350
calibration_only = True
350351

352+
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
353+
# Store the MTP layer prefixes on the model for later exclusion from quantization
354+
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
355+
if mtp_layer_prefixes:
356+
full_model._mtp_layer_prefixes = mtp_layer_prefixes
357+
351358
model_type = get_model_type(full_model)
352359

353360
device = full_model.device
@@ -878,6 +885,19 @@ def quantize_main(
878885
KV_QUANT_CFG_CHOICES,
879886
)
880887

888+
# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
889+
# These layers are typically speculative decoding layers that should be exported as-is
890+
mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None)
891+
if mtp_layer_prefixes:
892+
import copy
893+
894+
quant_cfg = copy.deepcopy(quant_cfg)
895+
for prefix in mtp_layer_prefixes:
896+
# Add exclusion pattern for this MTP layer (e.g., "*layers.92*")
897+
pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*"
898+
quant_cfg["quant_cfg"][pattern] = {"enable": False}
899+
print(f"Excluding MTP layer from quantization: {pattern}")
900+
881901
if args.qformat in QUANT_CFG_CHOICES:
882902
mono_quantize(
883903
args,

examples/megatron_bridge/README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br
1818

1919
Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed.
2020

21-
To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Model-Optimizer && git checkout main && git pull`).
21+
To get the latest ModelOpt features and examples, you can mount your latest ModelOpt cloned repository to the container at `/opt/Megatron-Bridge/3rdparty/Model-Optimizer` or pull the latest changes once inside the docker container (`cd /opt/Megatron-Bridge/3rdparty/Model-Optimizer && git checkout main && git pull`).
2222

2323
## Pruning
2424

@@ -30,17 +30,27 @@ Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while
3030
top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model.
3131

3232
```bash
33-
torchrun --nproc_per_node 2 /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
33+
torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
3434
--hf_model_name_or_path Qwen/Qwen3-8B \
3535
--prune_target_params 6e9 \
3636
--hparams_to_skip num_attention_heads \
3737
--output_hf_path /tmp/Qwen3-8B-Pruned-6B
3838
```
3939

40+
Example usage for manually pruning to a specific architecture using following defaults:
41+
1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration.
42+
43+
```bash
44+
torchrun --nproc_per_node 2 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py \
45+
--hf_model_name_or_path Qwen/Qwen3-8B \
46+
--prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \
47+
--output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual
48+
```
49+
4050
To see the full usage for advanced configurations, run:
4151

4252
```bash
43-
python /opt/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
53+
python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help
4454
```
4555

4656
> [!TIP]

examples/megatron_bridge/prune_minitron.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_args() -> argparse.Namespace:
9191

9292
# Pruning parameters
9393
parser.add_argument(
94-
"--prune_intermediate_checkpoint",
94+
"--prune_intermediate_ckpt",
9595
type=str,
9696
default=None,
9797
help=(
@@ -169,18 +169,16 @@ def get_args() -> argparse.Namespace:
169169
args = parser.parse_args()
170170

171171
# Post-process arguments
172-
if args.prune_intermediate_checkpoint is None:
172+
if args.prune_intermediate_ckpt is None:
173173
if args.output_megatron_path:
174-
args.prune_intermediate_checkpoint = (
174+
args.prune_intermediate_ckpt = (
175175
f"{args.output_megatron_path}/modelopt_pruning_scores.pth"
176176
)
177177
elif args.output_hf_path:
178-
args.prune_intermediate_checkpoint = (
179-
f"{args.output_hf_path}/modelopt_pruning_scores.pth"
180-
)
178+
args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth"
181179
print_rank_0(
182180
"No checkpoint provided to cache intermediate pruning scores. "
183-
f"Setting to: {args.prune_intermediate_checkpoint}"
181+
f"Setting to: {args.prune_intermediate_ckpt}"
184182
)
185183

186184
if args.prune_export_config:
@@ -247,7 +245,7 @@ def main(args: argparse.Namespace):
247245

248246
pruning_config = {
249247
"forward_loop": forward_loop,
250-
"checkpoint": args.prune_intermediate_checkpoint,
248+
"checkpoint": args.prune_intermediate_ckpt,
251249
}
252250
if args.prune_target_params is not None:
253251
# Restrict search space to a smaller set of candidates
@@ -377,8 +375,8 @@ def score_func_mmlu(m):
377375

378376

379377
if __name__ == "__main__":
380-
args = get_args()
381378
dist.setup()
379+
args = get_args()
382380
try:
383381
main(args)
384382
finally:

modelopt/torch/export/unified_export_hf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,18 @@ def _export_transformers_checkpoint(
710710

711711
quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)
712712

713+
# Add MTP layer prefixes to exclude_modules if they were excluded from quantization
714+
# This ensures they appear in quantization_config["ignore"] in config.json
715+
mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None)
716+
if mtp_layer_prefixes:
717+
exclude_modules = quant_config["quantization"].setdefault("exclude_modules", [])
718+
for prefix in mtp_layer_prefixes:
719+
# Add wildcard pattern to exclude all submodules under this MTP layer
720+
pattern = f"{prefix}*"
721+
if pattern not in exclude_modules:
722+
exclude_modules.append(pattern)
723+
print(f"Adding MTP layer to quantization_config ignore: {pattern}")
724+
713725
# Process all quantized modules and export weights
714726
_process_quantized_modules(model, dtype, is_modelopt_qlora)
715727

modelopt/torch/opt/dynamic.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torch.nn.parameter import Parameter
3232

3333
from modelopt.torch.utils import get_unwrapped_name, is_channels_last, unwrap_model
34-
from modelopt.torch.utils.distributed import ParallelState
3534
from modelopt.torch.utils.network import bind_forward_method
3635

3736
from .config import ModeloptBaseRule, RulesDict
@@ -359,14 +358,10 @@ class DynamicModule(nn.Module):
359358
should ensure only to expose ``hparams`` in the outermost class and handle other ``hparams``
360359
internally including ``hparams`` of child modules that are exposed on their own usually
361360
(e.g. block module implementations containing DynamicLinear).
362-
363-
In addition, the class also provides ``parallel_state`` attribute that can be used to access
364-
the parallel state of the module.
365361
"""
366362

367363
# this is needed to store the special attributes for dynamic modules
368364
_dm_attribute_manager: _DMAttributeManager
369-
_parallel_state: ParallelState
370365

371366
def __init__(self, *args, **kwargs):
372367
"""Initializing a dynamic module is not allowed!"""
@@ -657,10 +652,6 @@ def bind_forward_method_if_needed(self):
657652
# setup new hparams and dynamic attributes
658653
module._setup(**setup_kwargs)
659654

660-
# setup parallel state now that the module is converted
661-
if module.parallel_state is None:
662-
module._initialize_parallel_state()
663-
664655
return module
665656

666657
def _setup(self, **setup_kwargs: Any):
@@ -867,36 +858,6 @@ def original_cls(self) -> type[nn.Module]:
867858
"""
868859
return self._get_dm_attribute_manager().og_cls
869860

870-
@property
871-
def parallel_state(self) -> ParallelState | None:
872-
"""Return the parallel state of the dynamic module."""
873-
return getattr(self, "_parallel_state", None)
874-
875-
@parallel_state.setter
876-
def parallel_state(self, parallel_state: ParallelState):
877-
"""Set the parallel state of the dynamic module."""
878-
assert isinstance(parallel_state, ParallelState), (
879-
"parallel_state must be a ParallelState object!"
880-
)
881-
self._parallel_state = parallel_state
882-
883-
def _initialize_parallel_state(self):
884-
"""Initialize the parallel state of the dynamic module.
885-
886-
This method is called only if the `DynamicModule` does not have a `parallel_state` attribute
887-
after `_setup` is called.
888-
"""
889-
if torch.distributed.is_initialized():
890-
warnings.warn(
891-
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
892-
"Using default parallel_state which has data_parallel_group set to the default process group and "
893-
"tensor_parallel_group is unspecified. "
894-
"If you are using tensor parallelism for this module, you should set the parallel_state "
895-
"in its `_setup` method."
896-
)
897-
898-
self.parallel_state = ParallelState(data_parallel_group=None)
899-
900861
def get_original_cls_by_level(self, level: int = -1) -> type[nn.Module]:
901862
"""Return the original class of the dynamic module.
902863

0 commit comments

Comments
 (0)