Skip to content

Commit 8482ac0

Browse files
committed
Add layerwise calibration for large models
Introduces layerwise calibration to enable PTQ on models that do not fit in GPU memory, plus supporting infrastructure: - New modelopt/torch/quantization/utils/layerwise_calib.py with layer-by-layer calibration and per-mode opt-out - Disk offloading support in enable_weight_access_and_writeback - Memory-efficient inplace fakequant export with disk offload - Meta device detection in layerwise restore - Fix meta tensor crash when exporting offloaded vLLM fakequant checkpoints - Fix json.dumps sort_keys error with mixed int/str keys in quant_cfg - Rename test_sequential_calibrate -> test_layerwise_calibrate (unit + gpu) - Remove obsolete activation_collector.py Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent dc7ad66 commit 8482ac0

27 files changed

Lines changed: 2345 additions & 581 deletions

examples/llm_ptq/example_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import glob
18+
import hashlib
1819
import inspect
1920
import json
2021
import logging
@@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
854855
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
855856
else:
856857
print("No custom model files found to copy")
858+
859+
860+
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
861+
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
862+
algorithm = quant_cfg.get("algorithm")
863+
if not isinstance(algorithm, dict):
864+
return False
865+
return algorithm.get("layerwise_checkpoint_dir") is not None
866+
867+
868+
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
869+
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to layerwise_checkpoint_dir.
870+
871+
Allows a single recipe to be reused across models without checkpoint collisions.
872+
Must only be called when :func:`needs_checkpoint_path_update` returns True.
873+
"""
874+
algorithm = quant_cfg["algorithm"]
875+
base_dir = algorithm["layerwise_checkpoint_dir"]
876+
877+
name = model_path.rstrip("/")
878+
if "/" in name and not os.path.isabs(name):
879+
name = name.replace("/", "--")
880+
else:
881+
name = Path(name).name
882+
883+
config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8]
884+
885+
quant_cfg = copy.deepcopy(quant_cfg)
886+
quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join(
887+
base_dir, f"{name}_{config_hash}"
888+
)
889+
return quant_cfg

examples/llm_ptq/hf_ptq.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
is_enc_dec,
3535
is_nemotron_vl,
3636
load_mtp_weights,
37+
needs_checkpoint_path_update,
38+
resolve_checkpoint_dir,
3739
run_nemotron_vl_preview,
3840
)
3941
from torch.utils.data import DataLoader
@@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
9193
for i, entry in enumerate(quant_cfg):
9294
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
9395
continue
94-
assert isinstance(entry.get("cfg", {}), dict)
95-
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
96+
cfg = entry.get("cfg") or {}
97+
assert isinstance(cfg, dict)
98+
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
9699
break
97100

98101

@@ -760,7 +763,9 @@ def export_quantized(
760763
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
761764
# Store the MTP layer prefixes on the model for later exclusion from quantization
762765
if args.vllm_fakequant_export:
763-
export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path)
766+
export_hf_vllm_fq_checkpoint(
767+
full_model, export_dir=export_path, inplace_mem_efficient=True
768+
)
764769
else:
765770
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(
766771
full_model, args.pyt_ckpt_path
@@ -1105,6 +1110,12 @@ def quantize_main(
11051110
quant_cfg = copy.deepcopy(quant_cfg)
11061111
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11071112

1113+
if needs_checkpoint_path_update(quant_cfg):
1114+
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
1115+
print(
1116+
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
1117+
)
1118+
11081119
if args.qformat in QUANT_CFG_CHOICES:
11091120
mono_quantize(
11101121
args,

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 135 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from modelopt.torch.quantization.conversion import quantizer_state
2525
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
2626
from modelopt.torch.quantization.utils import get_quantizer_state_dict
27+
from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback
28+
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
2729
from modelopt.torch.utils import get_unwrapped_name
2830

2931
__all__ = ["export_hf_vllm_fq_checkpoint"]
@@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer):
3840
return False
3941

4042

43+
def _fakequant_module_weights(
44+
module: nn.Module,
45+
module_name: str,
46+
model: nn.Module,
47+
state_dict: dict | None,
48+
input_quantizers_folded_pqs: set,
49+
fakequant_weights: set,
50+
inplace: bool,
51+
):
52+
"""Apply fake-quant to a single QuantModule's weights.
53+
54+
When ``inplace=False``, reads/writes weights from/to ``state_dict``.
55+
When ``inplace=True``, modifies the module's weight parameters directly.
56+
"""
57+
if not isinstance(module, QuantModule):
58+
return
59+
for attr_name, quantizer in module.named_children():
60+
if not (
61+
attr_name.endswith("weight_quantizer")
62+
and isinstance(quantizer, TensorQuantizer)
63+
and quantizer.fake_quant
64+
and quantizer.is_enabled
65+
):
66+
continue
67+
weight_name = attr_name.removesuffix("_quantizer")
68+
prefix = f"{module_name}." if module_name else ""
69+
sd_key = f"{prefix}{weight_name}"
70+
assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized"
71+
72+
if inplace:
73+
w = getattr(module, weight_name)
74+
w_quant = quantizer(w.float()).to(w.dtype)
75+
else:
76+
assert state_dict is not None
77+
if sd_key not in state_dict:
78+
continue
79+
w = state_dict[sd_key]
80+
w_quant = quantizer(w.float()).to(w.dtype)
81+
82+
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
83+
# Only valid when input_quantizer does NOT fake-quant activations. If it does
84+
# fake_quant(x*s), the non-linearity prevents folding s into W.
85+
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
86+
if hasattr(module, inp_attr):
87+
inp_q = getattr(module, inp_attr)
88+
if (
89+
hasattr(inp_q, "_pre_quant_scale")
90+
and inp_q._pre_quant_scale is not None
91+
and inp_q._disabled
92+
):
93+
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
94+
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
95+
inp_q_key = get_unwrapped_name(
96+
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
97+
)
98+
input_quantizers_folded_pqs.add(inp_q_key)
99+
100+
if inplace:
101+
w.data.copy_(w_quant)
102+
else:
103+
assert state_dict is not None
104+
state_dict[sd_key] = w_quant.cpu()
105+
fakequant_weights.add(sd_key)
106+
107+
41108
def export_hf_vllm_fq_checkpoint(
42109
model: nn.Module,
43110
export_dir: Path | str,
111+
inplace_mem_efficient: bool = False,
44112
):
45113
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.
46114
@@ -53,62 +121,66 @@ def export_hf_vllm_fq_checkpoint(
53121
Args:
54122
model: In-memory quantized model.
55123
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
124+
inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at
125+
a time using ``enable_weight_access_and_writeback``, avoiding full state
126+
dict materialization. This is destructive — model weights are permanently
127+
modified and weight quantizers are not re-enabled after export.
56128
"""
57129
export_dir = Path(export_dir)
58130
export_dir.mkdir(parents=True, exist_ok=True)
59131

60132
# Step 1: Build the folded HF state dict.
61-
# model.state_dict() returns detached copies of all tensors, so model
62-
# parameters are never modified. Apply each weight quantizer's fake-quant
63-
# to the corresponding weight tensor in the copy.
64-
state_dict = model.state_dict()
65133
fakequant_weights = set()
66-
input_quantizers_folded_pqs = (
67-
set()
68-
) # keys for input_quantizers where pre_quant_scale was folded
134+
input_quantizers_folded_pqs = set()
69135
with torch.inference_mode():
70-
for module_name, module in model.named_modules():
71-
if not isinstance(module, QuantModule):
72-
continue
73-
for attr_name, quantizer in module.named_children():
74-
if not (
75-
attr_name.endswith("weight_quantizer")
76-
and isinstance(quantizer, TensorQuantizer)
77-
and quantizer.fake_quant
78-
and quantizer.is_enabled
79-
):
136+
if inplace_mem_efficient:
137+
# Inplace path: iterate decoder layers, one offload<->onload per layer.
138+
decoder_layers = LayerActivationCollector.get_decoder_layers(model)
139+
assert decoder_layers is not None, (
140+
"inplace_mem_efficient=True requires a model with discoverable decoder layers"
141+
)
142+
for name, module in model.named_modules():
143+
if module not in decoder_layers:
80144
continue
81-
weight_name = attr_name.removesuffix("_quantizer")
82-
prefix = f"{module_name}." if module_name else ""
83-
sd_key = f"{prefix}{weight_name}"
84-
assert sd_key not in fakequant_weights, (
85-
f"Weight {sd_key} has already been fakequantized"
86-
)
87-
if sd_key in state_dict:
88-
w = state_dict[sd_key]
89-
w_quant = quantizer(w.float()).to(w.dtype).cpu()
90-
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
91-
# Only valid when input_quantizer does NOT fake-quant activations. If it does
92-
# fake_quant(x*s), the non-linearity prevents folding s into W.
93-
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
94-
if hasattr(module, inp_attr):
95-
inp_q = getattr(module, inp_attr)
96-
if (
97-
hasattr(inp_q, "_pre_quant_scale")
98-
and inp_q._pre_quant_scale is not None
99-
and inp_q._disabled
100-
):
101-
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
102-
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
103-
inp_q_key = get_unwrapped_name(
104-
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
105-
)
106-
input_quantizers_folded_pqs.add(inp_q_key)
107-
state_dict[sd_key] = w_quant
108-
fakequant_weights.add(sd_key)
109-
110-
# Filter quantizer tensors out for a clean HF checkpoint.
111-
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
145+
with enable_weight_access_and_writeback(module, module):
146+
for sub_name, sub_mod in module.named_modules():
147+
full_name = f"{name}.{sub_name}" if sub_name else name
148+
_fakequant_module_weights(
149+
sub_mod,
150+
full_name,
151+
model,
152+
None,
153+
input_quantizers_folded_pqs,
154+
fakequant_weights,
155+
inplace=True,
156+
)
157+
# Meta tensors for offloaded weights (free); offload maps now have
158+
# fakequanted values via writeback.
159+
state_dict = model.state_dict()
160+
else:
161+
# Default path: full state_dict copy, fakequant into the copy.
162+
state_dict = model.state_dict()
163+
for module_name, module in model.named_modules():
164+
with enable_weight_access_and_writeback(module, model):
165+
_fakequant_module_weights(
166+
module,
167+
module_name,
168+
model,
169+
state_dict,
170+
input_quantizers_folded_pqs,
171+
fakequant_weights,
172+
inplace=False,
173+
)
174+
175+
if inplace_mem_efficient:
176+
# Let save_pretrained build its own state_dict so offloaded params go through
177+
# its module_map / get_state_dict_from_offload path (modeling_utils.py:3967+).
178+
# Passing state_dict= bypasses that path and crashes on meta tensors.
179+
quantizer_keys = [k for k in state_dict if "quantizer" in k]
180+
clean_sd = None
181+
else:
182+
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
183+
quantizer_keys = None
112184

113185
# Step 2: Disable weight quantizers, save modelopt state + quantizer state
114186
# dict, then re-enable. The _disabled=True flag is captured in modelopt_state
@@ -161,9 +233,18 @@ def export_hf_vllm_fq_checkpoint(
161233
modelopt_state["modelopt_state_weights"] = quantizer_state_dict
162234
torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
163235

164-
# Step 3: Save HF weights using the pre-built folded state dict.
165-
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
166-
167-
for wq, orig_rotate in wqs_to_restore:
168-
wq.enable()
169-
wq._rotate = orig_rotate
236+
# Step 3: Save HF weights.
237+
if inplace_mem_efficient:
238+
prev_ignore = getattr(model, "_keys_to_ignore_on_save", None)
239+
model._keys_to_ignore_on_save = quantizer_keys
240+
try:
241+
model.save_pretrained(export_dir, save_modelopt_state=False)
242+
finally:
243+
model._keys_to_ignore_on_save = prev_ignore
244+
else:
245+
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
246+
247+
if not inplace_mem_efficient:
248+
for wq, orig_rotate in wqs_to_restore:
249+
wq.enable()
250+
wq._rotate = orig_rotate

modelopt/torch/quantization/config.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12171217
),
12181218
)
12191219

1220-
use_sequential: bool = ModeloptField(
1220+
layerwise: bool = ModeloptField(
12211221
default=False,
1222-
title="Enable sequential layer-by-layer calibration.",
1222+
title="Enable layerwise (layer-by-layer) calibration.",
12231223
description=(
1224-
"If True, the calibration algorithm is applied sequentially to each decoder block. "
1225-
"Each layer's inputs are captured via a single forward pass that reflects the "
1224+
"If True, the calibration algorithm is applied layer by layer. "
1225+
"Each layer's inputs are captured via a forward pass that reflects the "
12261226
"quantization of all preceding layers, incurring O(N) forward passes for N layers."
12271227
),
12281228
)
12291229

1230+
layerwise_checkpoint_dir: str | None = ModeloptField(
1231+
default=None,
1232+
title="Checkpoint directory for layerwise calibration.",
1233+
description=(
1234+
"If set together with layerwise=True, per-layer checkpoints are saved to this "
1235+
"directory during calibration. On restart, calibration resumes from the last "
1236+
"completed layer."
1237+
),
1238+
)
1239+
1240+
@model_validator(mode="after")
1241+
def validate_layerwise_checkpoint_dir(self):
1242+
"""Raise if layerwise_checkpoint_dir is set but layerwise is False."""
1243+
if self.layerwise_checkpoint_dir is not None and not self.layerwise:
1244+
raise ValueError(
1245+
"layerwise_checkpoint_dir requires layerwise=True. "
1246+
"Set layerwise=True or remove layerwise_checkpoint_dir."
1247+
)
1248+
return self
1249+
12301250

12311251
class MaxCalibConfig(QuantizeAlgorithmConfig):
12321252
"""The config for max calibration algorithm.

0 commit comments

Comments
 (0)