Skip to content

Commit 2d868d3

Browse files
authored
Performant layerwise calibration for large models (#1251)
## Summary Adds **performant layerwise calibration** for quantizing large models (e.g. DeepSeek-R1 671B) that don't fit entirely on GPU. ([Example commands](#example-commands)) 1. **Performant calibration for large models** — Each decoder layer is moved from CPU/disk to GPU (accelerate) or unsharded (FSDP2) **only once** and kept on GPU for the entire calibration step. Previously, every calibration batch triggered weight transfer for every layer — O(num_batches) weight movements per layer. Now it is O(1) per layer. This also means you can **increase batch size** since only one layer's weights occupy GPU at a time — e.g. DeepSeek-R1 on a single node (8×80GB) with `batch_size=16` and `gpu_max_mem_percentage=0.5`. 2. **Checkpoint save/resume** — Saves progress after each layer, so jobs that exceed cluster time limits (e.g. 4-hour Slurm windows for 100+ layer MoE models) can resume from the last completed layer. 3. **Rename** `sequential_calibrate` → `layerwise_calibrate` for clarity. ### Design details The existing layerwise state machine (skip/run/capture) already processes one layer at a time, but skip-mode layers still kept their parameters in the ModuleList — so frameworks transferred all weights every forward pass. This PR adds: - **`_SkipLayer`**: replaces fully-calibrated layers with a parameter-free dummy in the ModuleList, so framework hooks have nothing to transfer - **`persistent_materialization`**: keeps the active layer on GPU for the entire calibration step, avoiding repeated offload/reload cycles Checkpoint save is per-layer; restore is bulk — quantizer state and weights for layers 0..K-1 are restored once at the end of calibration, keeping the hot path fast. ### Example commands **Qwen3-8B** (NVFP4+GPTQ, single GPU): ```bash python hf_ptq.py \ --pyt_ckpt_path Qwen/Qwen3-8B \ --recipe nvfp4_gptq_sequential.yaml \ --calib_size 64 \ --batch_size 16 \ --dataset cnn_dailymail \ --export_path outputs/qwen3_8b_nvfp4_gptq_seq \ --gpu_max_mem_percentage 0.5 \ --use_seq_device_map \ --vllm_fakequant_export ``` **DeepSeek-R1** (NVFP4 experts-only + FP8 KV, 8×80GB): ```bash python hf_ptq.py \ --model unsloth/DeepSeek-R1-0528-BF16 \ --recipe ../../modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml \ --dataset cnn_dailymail \ --batch_size 16 \ --calib_size 64 \ --calib_seq 512 \ --gpu_max_mem_percentage 0.5 \ --use_seq_device_map \ --trust_remote_code \ --export_path output/DeepSeek-R1-BF16-nvfp4-experts-only-fp8-kv \ --vllm_fakequant_export ``` ### Example: NVFP4+GPTQ layerwise calibration on Qwen3-8B (36 layers, single GPU — 20 GB peak) **Initial run** (killed after layer 11): ``` Layerwise calibration: Found 36 transformer layers Calibrating layer 1/36 | capture: [1] Computing Hessians for 7 linear layers... GPTQ time: 51.39s Calibrating layer 2/36 | run: [1] | capture: [2] Checkpoint: saved layer 0 GPTQ time: 50.06s Calibrating layer 3/36 | skip: 1 | run: [2] | capture: [3] Checkpoint: saved layer 1 ... Calibrating layer 12/36 | skip: 10 | run: [11] | capture: [12] Checkpoint: saved layer 10 <killed> ``` **Resumed run** (picks up from layer 11, finishes all 36): ``` Layerwise calibration: Found 36 transformer layers Checkpoint: resuming layerwise calibration from layer 11/36 Calibrating layer 12 (resumed) GPTQ time: 51.45s Calibrating layer 13/36 | skip: 11 | run: [12] | capture: [13] Checkpoint: saved layer 11 ... Calibrating layer 36/36 | skip: 34 | run: [35] | capture: [36] Checkpoint: saved layer 34 GPTQ time: 50.33s Checkpoint: saved layer 35 (final) Checkpoint: restored 11 previously calibrated layers Layerwise calibration completed Quantized model exported to: outputs/qwen3_8b_nvfp4_gptq_seq GPU 0: Peak memory usage = 20.42 GB ``` ## TODO - [ ] Update CHANGELOG ## Test plan - `tests/unit/torch/quantization/test_layerwise_calibrate.py` — unit tests for skip/swap/restore - `tests/unit/torch/quantization/test_sequential_checkpoint.py` — checkpoint save/resume correctness - `tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py` — CPU-offloaded layerwise + GPTQ + checkpoint resume - `tests/gpu/torch/quantization/test_fsdp2.py` — FSDP2 layerwise calibration ### Verified - [x] Qwen3-8B: layerwise calibration + checkpoint save/restore + fakequantized checkpoint export + vLLM serve - [x] DeepSeek-R1: checkpoint resume tested - [x] DeepSeek-R1: fakequantized checkpoint export verified --------- Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent dc7ad66 commit 2d868d3

29 files changed

+2467
-582
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Changelog
1515
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
1616
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
18+
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
1819

1920
**Backward Breaking Changes**
2021

examples/llm_ptq/example_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import glob
18+
import hashlib
1819
import inspect
1920
import json
2021
import logging
@@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
854855
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
855856
else:
856857
print("No custom model files found to copy")
858+
859+
860+
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
861+
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
862+
algorithm = quant_cfg.get("algorithm")
863+
if not isinstance(algorithm, dict):
864+
return False
865+
return algorithm.get("layerwise_checkpoint_dir") is not None
866+
867+
868+
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
869+
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to layerwise_checkpoint_dir.
870+
871+
Allows a single recipe to be reused across models without checkpoint collisions.
872+
Must only be called when :func:`needs_checkpoint_path_update` returns True.
873+
"""
874+
algorithm = quant_cfg["algorithm"]
875+
base_dir = algorithm["layerwise_checkpoint_dir"]
876+
877+
name = model_path.rstrip("/")
878+
if "/" in name and not os.path.isabs(name):
879+
name = name.replace("/", "--")
880+
else:
881+
name = Path(name).name
882+
883+
config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8]
884+
885+
quant_cfg = copy.deepcopy(quant_cfg)
886+
quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join(
887+
base_dir, f"{name}_{config_hash}"
888+
)
889+
return quant_cfg

examples/llm_ptq/hf_ptq.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
is_enc_dec,
3535
is_nemotron_vl,
3636
load_mtp_weights,
37+
needs_checkpoint_path_update,
38+
resolve_checkpoint_dir,
3739
run_nemotron_vl_preview,
3840
)
3941
from torch.utils.data import DataLoader
@@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
9193
for i, entry in enumerate(quant_cfg):
9294
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
9395
continue
94-
assert isinstance(entry.get("cfg", {}), dict)
95-
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
96+
cfg = entry.get("cfg") or {}
97+
assert isinstance(cfg, dict)
98+
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
9699
break
97100

98101

@@ -760,7 +763,9 @@ def export_quantized(
760763
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
761764
# Store the MTP layer prefixes on the model for later exclusion from quantization
762765
if args.vllm_fakequant_export:
763-
export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path)
766+
export_hf_vllm_fq_checkpoint(
767+
full_model, export_dir=export_path, inplace_mem_efficient=True
768+
)
764769
else:
765770
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(
766771
full_model, args.pyt_ckpt_path
@@ -1105,6 +1110,12 @@ def quantize_main(
11051110
quant_cfg = copy.deepcopy(quant_cfg)
11061111
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11071112

1113+
if needs_checkpoint_path_update(quant_cfg):
1114+
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
1115+
print(
1116+
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
1117+
)
1118+
11081119
if args.qformat in QUANT_CFG_CHOICES:
11091120
mono_quantize(
11101121
args,

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 135 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from modelopt.torch.quantization.conversion import quantizer_state
2525
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
2626
from modelopt.torch.quantization.utils import get_quantizer_state_dict
27+
from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback
28+
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
2729
from modelopt.torch.utils import get_unwrapped_name
2830

2931
__all__ = ["export_hf_vllm_fq_checkpoint"]
@@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer):
3840
return False
3941

4042

43+
def _fakequant_module_weights(
44+
module: nn.Module,
45+
module_name: str,
46+
model: nn.Module,
47+
state_dict: dict | None,
48+
input_quantizers_folded_pqs: set,
49+
fakequant_weights: set,
50+
inplace: bool,
51+
):
52+
"""Apply fake-quant to a single QuantModule's weights.
53+
54+
When ``inplace=False``, reads/writes weights from/to ``state_dict``.
55+
When ``inplace=True``, modifies the module's weight parameters directly.
56+
"""
57+
if not isinstance(module, QuantModule):
58+
return
59+
for attr_name, quantizer in module.named_children():
60+
if not (
61+
attr_name.endswith("weight_quantizer")
62+
and isinstance(quantizer, TensorQuantizer)
63+
and quantizer.fake_quant
64+
and quantizer.is_enabled
65+
):
66+
continue
67+
weight_name = attr_name.removesuffix("_quantizer")
68+
prefix = f"{module_name}." if module_name else ""
69+
sd_key = f"{prefix}{weight_name}"
70+
assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized"
71+
72+
if inplace:
73+
w = getattr(module, weight_name)
74+
w_quant = quantizer(w.float()).to(w.dtype)
75+
else:
76+
assert state_dict is not None
77+
if sd_key not in state_dict:
78+
continue
79+
w = state_dict[sd_key]
80+
w_quant = quantizer(w.float()).to(w.dtype)
81+
82+
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
83+
# Only valid when input_quantizer does NOT fake-quant activations. If it does
84+
# fake_quant(x*s), the non-linearity prevents folding s into W.
85+
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
86+
if hasattr(module, inp_attr):
87+
inp_q = getattr(module, inp_attr)
88+
if (
89+
hasattr(inp_q, "_pre_quant_scale")
90+
and inp_q._pre_quant_scale is not None
91+
and inp_q._disabled
92+
):
93+
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
94+
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
95+
inp_q_key = get_unwrapped_name(
96+
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
97+
)
98+
input_quantizers_folded_pqs.add(inp_q_key)
99+
100+
if inplace:
101+
w.data.copy_(w_quant)
102+
else:
103+
assert state_dict is not None
104+
state_dict[sd_key] = w_quant.cpu()
105+
fakequant_weights.add(sd_key)
106+
107+
41108
def export_hf_vllm_fq_checkpoint(
42109
model: nn.Module,
43110
export_dir: Path | str,
111+
inplace_mem_efficient: bool = False,
44112
):
45113
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.
46114
@@ -53,62 +121,66 @@ def export_hf_vllm_fq_checkpoint(
53121
Args:
54122
model: In-memory quantized model.
55123
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
124+
inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at
125+
a time using ``enable_weight_access_and_writeback``, avoiding full state
126+
dict materialization. This is destructive — model weights are permanently
127+
modified and weight quantizers are not re-enabled after export.
56128
"""
57129
export_dir = Path(export_dir)
58130
export_dir.mkdir(parents=True, exist_ok=True)
59131

60132
# Step 1: Build the folded HF state dict.
61-
# model.state_dict() returns detached copies of all tensors, so model
62-
# parameters are never modified. Apply each weight quantizer's fake-quant
63-
# to the corresponding weight tensor in the copy.
64-
state_dict = model.state_dict()
65133
fakequant_weights = set()
66-
input_quantizers_folded_pqs = (
67-
set()
68-
) # keys for input_quantizers where pre_quant_scale was folded
134+
input_quantizers_folded_pqs = set()
69135
with torch.inference_mode():
70-
for module_name, module in model.named_modules():
71-
if not isinstance(module, QuantModule):
72-
continue
73-
for attr_name, quantizer in module.named_children():
74-
if not (
75-
attr_name.endswith("weight_quantizer")
76-
and isinstance(quantizer, TensorQuantizer)
77-
and quantizer.fake_quant
78-
and quantizer.is_enabled
79-
):
136+
if inplace_mem_efficient:
137+
# Inplace path: iterate decoder layers, one offload<->onload per layer.
138+
decoder_layers = LayerActivationCollector.get_decoder_layers(model)
139+
assert decoder_layers is not None, (
140+
"inplace_mem_efficient=True requires a model with discoverable decoder layers"
141+
)
142+
for name, module in model.named_modules():
143+
if module not in decoder_layers:
80144
continue
81-
weight_name = attr_name.removesuffix("_quantizer")
82-
prefix = f"{module_name}." if module_name else ""
83-
sd_key = f"{prefix}{weight_name}"
84-
assert sd_key not in fakequant_weights, (
85-
f"Weight {sd_key} has already been fakequantized"
86-
)
87-
if sd_key in state_dict:
88-
w = state_dict[sd_key]
89-
w_quant = quantizer(w.float()).to(w.dtype).cpu()
90-
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
91-
# Only valid when input_quantizer does NOT fake-quant activations. If it does
92-
# fake_quant(x*s), the non-linearity prevents folding s into W.
93-
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
94-
if hasattr(module, inp_attr):
95-
inp_q = getattr(module, inp_attr)
96-
if (
97-
hasattr(inp_q, "_pre_quant_scale")
98-
and inp_q._pre_quant_scale is not None
99-
and inp_q._disabled
100-
):
101-
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
102-
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
103-
inp_q_key = get_unwrapped_name(
104-
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
105-
)
106-
input_quantizers_folded_pqs.add(inp_q_key)
107-
state_dict[sd_key] = w_quant
108-
fakequant_weights.add(sd_key)
109-
110-
# Filter quantizer tensors out for a clean HF checkpoint.
111-
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
145+
with enable_weight_access_and_writeback(module, module):
146+
for sub_name, sub_mod in module.named_modules():
147+
full_name = f"{name}.{sub_name}" if sub_name else name
148+
_fakequant_module_weights(
149+
sub_mod,
150+
full_name,
151+
model,
152+
None,
153+
input_quantizers_folded_pqs,
154+
fakequant_weights,
155+
inplace=True,
156+
)
157+
# Meta tensors for offloaded weights (free); offload maps now have
158+
# fakequanted values via writeback.
159+
state_dict = model.state_dict()
160+
else:
161+
# Default path: full state_dict copy, fakequant into the copy.
162+
state_dict = model.state_dict()
163+
for module_name, module in model.named_modules():
164+
with enable_weight_access_and_writeback(module, model):
165+
_fakequant_module_weights(
166+
module,
167+
module_name,
168+
model,
169+
state_dict,
170+
input_quantizers_folded_pqs,
171+
fakequant_weights,
172+
inplace=False,
173+
)
174+
175+
if inplace_mem_efficient:
176+
# Let save_pretrained build its own state_dict so offloaded params go through
177+
# its module_map / get_state_dict_from_offload path (modeling_utils.py:3967+).
178+
# Passing state_dict= bypasses that path and crashes on meta tensors.
179+
quantizer_keys = [k for k in state_dict if "quantizer" in k]
180+
clean_sd = None
181+
else:
182+
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
183+
quantizer_keys = None
112184

113185
# Step 2: Disable weight quantizers, save modelopt state + quantizer state
114186
# dict, then re-enable. The _disabled=True flag is captured in modelopt_state
@@ -161,9 +233,18 @@ def export_hf_vllm_fq_checkpoint(
161233
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
162234
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
163235

164-
# Step 3: Save HF weights using the pre-built folded state dict.
165-
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
166-
167-
for wq, orig_rotate in wqs_to_restore:
168-
wq.enable()
169-
wq._rotate = orig_rotate
236+
# Step 3: Save HF weights.
237+
if inplace_mem_efficient:
238+
prev_ignore = getattr(model, "_keys_to_ignore_on_save", None)
239+
model._keys_to_ignore_on_save = quantizer_keys
240+
try:
241+
model.save_pretrained(export_dir, save_modelopt_state=False)
242+
finally:
243+
model._keys_to_ignore_on_save = prev_ignore
244+
else:
245+
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
246+
247+
if not inplace_mem_efficient:
248+
for wq, orig_rotate in wqs_to_restore:
249+
wq.enable()
250+
wq._rotate = orig_rotate

modelopt/torch/quantization/config.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12171217
),
12181218
)
12191219

1220-
use_sequential: bool = ModeloptField(
1220+
layerwise: bool = ModeloptField(
12211221
default=False,
1222-
title="Enable sequential layer-by-layer calibration.",
1222+
title="Enable layerwise (layer-by-layer) calibration.",
12231223
description=(
1224-
"If True, the calibration algorithm is applied sequentially to each decoder block. "
1225-
"Each layer's inputs are captured via a single forward pass that reflects the "
1224+
"If True, the calibration algorithm is applied layer by layer. "
1225+
"Each layer's inputs are captured via a forward pass that reflects the "
12261226
"quantization of all preceding layers, incurring O(N) forward passes for N layers."
12271227
),
12281228
)
12291229

1230+
layerwise_checkpoint_dir: str | None = ModeloptField(
1231+
default=None,
1232+
title="Checkpoint directory for layerwise calibration.",
1233+
description=(
1234+
"If set together with layerwise=True, per-layer checkpoints are saved to this "
1235+
"directory during calibration. On restart, calibration resumes from the last "
1236+
"completed layer."
1237+
),
1238+
)
1239+
1240+
@model_validator(mode="after")
1241+
def validate_layerwise_checkpoint_dir(self):
1242+
"""Raise if layerwise_checkpoint_dir is set but layerwise is False."""
1243+
if self.layerwise_checkpoint_dir is not None and not self.layerwise:
1244+
raise ValueError(
1245+
"layerwise_checkpoint_dir requires layerwise=True. "
1246+
"Set layerwise=True or remove layerwise_checkpoint_dir."
1247+
)
1248+
return self
1249+
12301250

12311251
class MaxCalibConfig(QuantizeAlgorithmConfig):
12321252
"""The config for max calibration algorithm.

0 commit comments

Comments
 (0)