Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e0bdc73
Add layerwise calibration for large models
realAsma Apr 17, 2026
5658381
Drop default layerwise_checkpoint_dir from max PTQ recipe
realAsma Apr 17, 2026
c9b7f82
Add offload test for vllm fakequant export; CHANGELOG entry
realAsma Apr 17, 2026
73980e5
Restructure CHANGELOG layerwise calibration entry
realAsma Apr 17, 2026
29c43ff
using replace_function
kinjalpatel27 Apr 6, 2026
07f1e8b
fixing issues
kinjalpatel27 Apr 8, 2026
bb81659
minor
kinjalpatel27 Apr 8, 2026
bbf3bf6
updated name
kinjalpatel27 Apr 8, 2026
d862bca
addressed comments
kinjalpatel27 Apr 10, 2026
2fa7817
support for prequant scale when input quantizer is enabled
kinjalpatel27 Apr 7, 2026
e0e9565
Updated fakequant export for AWQ smoothing
kinjalpatel27 Apr 7, 2026
280a3f1
fixing issues
kinjalpatel27 Apr 8, 2026
a92cc8a
fixed failures
kinjalpatel27 Apr 13, 2026
ba7610c
minor
kinjalpatel27 Apr 14, 2026
62e6dee
minor
kinjalpatel27 Apr 14, 2026
f11ac58
cleanup
kinjalpatel27 Apr 14, 2026
bd8dbcf
minor
kinjalpatel27 Apr 15, 2026
c6cf4e2
minor
kinjalpatel27 Apr 15, 2026
32f7d8d
added support for GQA
kinjalpatel27 Apr 15, 2026
fe90104
addressed comments
kinjalpatel27 Apr 16, 2026
2c9297b
addressed comments
kinjalpatel27 Apr 16, 2026
86837f6
fix test
kinjalpatel27 Apr 16, 2026
36c3fe4
added MoE test for vllm fakequant
kinjalpatel27 Apr 17, 2026
bbc78dc
minor
kinjalpatel27 Apr 17, 2026
1829ee7
fixes for inplace
kinjalpatel27 Apr 17, 2026
a719ae2
minor
kinjalpatel27 Apr 17, 2026
6be72a8
minor
kinjalpatel27 Apr 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Changelog
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.

**Backward Breaking Changes**

Expand Down
33 changes: 33 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import glob
import hashlib
import inspect
import json
import logging
Expand Down Expand Up @@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
else:
print("No custom model files found to copy")


def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
algorithm = quant_cfg.get("algorithm")
if not isinstance(algorithm, dict):
return False
return algorithm.get("layerwise_checkpoint_dir") is not None


def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to layerwise_checkpoint_dir.

Allows a single recipe to be reused across models without checkpoint collisions.
Must only be called when :func:`needs_checkpoint_path_update` returns True.
"""
algorithm = quant_cfg["algorithm"]
base_dir = algorithm["layerwise_checkpoint_dir"]

name = model_path.rstrip("/")
if "/" in name and not os.path.isabs(name):
name = name.replace("/", "--")
else:
name = Path(name).name

config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8]

quant_cfg = copy.deepcopy(quant_cfg)
quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join(
base_dir, f"{name}_{config_hash}"
)
return quant_cfg
17 changes: 14 additions & 3 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
is_enc_dec,
is_nemotron_vl,
load_mtp_weights,
needs_checkpoint_path_update,
resolve_checkpoint_dir,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
for i, entry in enumerate(quant_cfg):
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
assert isinstance(entry.get("cfg", {}), dict)
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
break


Expand Down Expand Up @@ -759,7 +762,9 @@ def export_quantized(
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
if args.vllm_fakequant_export:
export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path)
export_hf_vllm_fq_checkpoint(
full_model, export_dir=export_path, inplace_mem_efficient=True
)
else:
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(
full_model, args.pyt_ckpt_path
Expand Down Expand Up @@ -1104,6 +1109,12 @@ def quantize_main(
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if needs_checkpoint_path_update(quant_cfg):
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
print(
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
)

if args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
Expand Down
5 changes: 2 additions & 3 deletions examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,5 @@ QUANT_CFG=<quant_cfg> QUANT_FILE_PATH=<quantizer_state.pth> python vllm_serve_fa
## Known Problems

1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align).
2. AWQ reload is not supported yet
Comment thread
kinjalpatel27 marked this conversation as resolved.
3. KV cache quantization export and reload is not supported in MCore yet.
4. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs.
2. KV cache quantization export and reload is not supported in MCore yet.
3. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs.
53 changes: 39 additions & 14 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import os
import warnings
from typing import Any

import torch
Expand All @@ -26,13 +27,16 @@
convert_modelopt_state_to_vllm,
load_state_dict_from_path,
restore_from_modelopt_state_vllm,
shard_pre_quant_scale_for_tp,
)

import modelopt.torch.quantization as mtq
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
from modelopt.torch.quantization.plugins.vllm import (
disable_compilation,
post_restore_vllm_parallel_linears,
)
from modelopt.torch.utils import safe_load
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader

quant_config: dict[str, Any] = {
Expand Down Expand Up @@ -61,28 +65,48 @@ def _fakequant_run_prolog_worker(self) -> None:
model = model.unwrap()
if quant_config["modelopt_state_path"]:
print(f"Loading modelopt state from {quant_config['modelopt_state_path']}")
# Load on CPU to avoid failures when the checkpoint was saved from a different
# GPU mapping
modelopt_state = torch.load(
quant_config["modelopt_state_path"], weights_only=True, map_location="cpu"
)
# Load on CPU to avoid failures when the checkpoint was saved from a different GPU mapping.
modelopt_state = safe_load(quant_config["modelopt_state_path"], map_location="cpu")
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
map_fun = (
self.model_runner.model.hf_to_vllm_mapper.apply_dict
if hasattr(self.model_runner.model, "hf_to_vllm_mapper")
else None
)
# convert modelopt state to vllm format
modelopt_state = convert_modelopt_state_to_vllm(modelopt_state, map_fun=map_fun)
# restore model from modelopt state
restore_from_modelopt_state_vllm(model, modelopt_state)

if modelopt_weights is not None:
# convert quantizer state values to vllm format
modelopt_weights = convert_dict_to_vllm(modelopt_weights, map_fun=map_fun)
mtq.utils.set_quantizer_state_dict(model, modelopt_weights)
# set_quantizer_state_dict does not invoke modelopt_post_restore (unlike restore_quantizer_state).
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
from modelopt.torch.quantization.nn import TensorQuantizer
from modelopt.torch.utils import get_unwrapped_name

loaded_keys = {
get_unwrapped_name(n, model)
for n, m in model.named_modules()
if isinstance(m, TensorQuantizer)
}
# Same namespace as ``loaded_keys``: checkpoint keys may include DDP/FSDP
# prefixes that ``convert_dict_to_vllm`` does not strip.
pqs_in_weights = {
get_unwrapped_name(k, model)
for k, v in modelopt_weights.items()
if isinstance(v, dict) and "_pre_quant_scale" in v
}
unmatched_pqs = pqs_in_weights - loaded_keys
if unmatched_pqs:
sample = sorted(unmatched_pqs)[:20]
warnings.warn(
f"{len(unmatched_pqs)} checkpoint pre_quant_scale key(s) have no "
f"matching TensorQuantizer in the model (showing up to 20): {sample}",
stacklevel=2,
)
# set_quantizer_state_dict does not run modelopt_post_restore (unlike restore_quantizer_state).
post_restore_vllm_parallel_linears(model)
# Must follow post_restore: shard_pre_quant_scale_for_tp uses weight H_in vs pqs length.
shard_pre_quant_scale_for_tp(model)

else:
if quant_config["quant_file_path"]:
Expand All @@ -101,15 +125,13 @@ def _fakequant_run_prolog_worker(self) -> None:

quant_cfg = get_quant_config(quant_config, model)

# quantize model
with disable_compilation(model):
print("Quantizing model...")
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

quantizer_file_path = quant_config["quant_file_path"]
if quantizer_file_path:
# Get amax and other quantizer state from the quantizer file
# this can be used with Megatron-LM exported model using export_mcore_gpt_to_hf_vllm_fq
self.model_runner._dummy_run(1)
current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model)
model.load_state_dict(current_state_dict)

Expand All @@ -122,8 +144,11 @@ def _fakequant_run_prolog_worker(self) -> None:

Comment thread
kinjalpatel27 marked this conversation as resolved.
mtq.fold_weight(model)
for name, module in model.named_modules():
if name.endswith("weight_quantizer"):
assert not module.is_enabled, f"quantizer {name} is still enabled"
if is_weight_quantizer_state_key(name) and module.is_enabled:
raise RuntimeError(
f"Weight quantizer {name!r} is still enabled after fold_weight — "
"double-quantization would corrupt activations."
)


class FakeQuantWorker(BaseWorker):
Expand Down
Loading
Loading