Skip to content

Commit e0cda1b

Browse files
realAsmaclaude
andcommitted
Add layerwise calibration for large models
This PR does three things: 1. Rename sequential_calibrate to layerwise_calibrate to better describe the layer-by-layer algorithm (use_sequential -> use_layerwise, _seq_calib -> _layerwise_calib). 2. Make layerwise calibration performant: persistent_materialization keeps the active layer on GPU for the entire calibration step, and _SkipLayer replaces fully-calibrated layers with parameter-free dummies so framework hooks (accelerate, FSDP2) skip materialization. 3. Add checkpoint save/resume so calibration of large models can be interrupted and restarted from the last completed layer. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com> Add layerwise calibration for large models Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com> Move checkpoint_dir helpers from library to examples/llm_ptq Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com> Rename layerwise config fields and enable layerwise on experts-only recipe - use_layerwise -> layerwise, checkpoint_dir -> layerwise_checkpoint_dir - Enable layerwise calibration + checkpointing on nvfp4_experts_only-fp8_kv recipe - Add layerwise_checkpoint_dir to nvfp4_default-none_kv_gptq recipe Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com> Address PR review feedback for layerwise calibration - Add inline security comments for all torch.load(weights_only=False) calls - Replace bare assert with RuntimeError for unsupported offload hook layout - Write back buffers (not just parameters) in _writeback_params_to_weights_map - Add cross-field validator rejecting layerwise_checkpoint_dir without layerwise=True - Validate num_layers mismatch on checkpoint resume - Handle integer device ordinals in _get_execution_device_from_hook - Clean up stale layer artifacts in partial-checkpoint tests - Guard non-dict algorithm values in needs_checkpoint_path_update - Add comment explaining dummy output_meta for last layer Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 361f7e3 commit e0cda1b

26 files changed

+1915
-503
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import glob
18+
import hashlib
1819
import inspect
1920
import json
2021
import logging
@@ -854,3 +855,37 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
854855
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
855856
else:
856857
print("No custom model files found to copy")
858+
859+
860+
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
861+
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
862+
algorithm = quant_cfg.get("algorithm")
863+
if not isinstance(algorithm, dict):
864+
return False
865+
return algorithm.get("layerwise_checkpoint_dir") is not None
866+
867+
868+
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
869+
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to layerwise_checkpoint_dir.
870+
871+
Allows a single recipe to be reused across models without checkpoint collisions.
872+
Must only be called when :func:`needs_checkpoint_path_update` returns True.
873+
"""
874+
algorithm = quant_cfg["algorithm"]
875+
base_dir = algorithm["layerwise_checkpoint_dir"]
876+
877+
name = model_path.rstrip("/")
878+
if "/" in name and not os.path.isabs(name):
879+
name = name.replace("/", "--")
880+
else:
881+
name = Path(name).name
882+
883+
config_hash = hashlib.sha256(
884+
json.dumps(quant_cfg, sort_keys=True, default=str).encode()
885+
).hexdigest()[:8]
886+
887+
quant_cfg = copy.deepcopy(quant_cfg)
888+
quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join(
889+
base_dir, f"{name}_{config_hash}"
890+
)
891+
return quant_cfg

examples/llm_ptq/hf_ptq.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
is_enc_dec,
3535
is_nemotron_vl,
3636
load_mtp_weights,
37+
needs_checkpoint_path_update,
38+
resolve_checkpoint_dir,
3739
run_nemotron_vl_preview,
3840
)
3941
from torch.utils.data import DataLoader
@@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
9193
for i, entry in enumerate(quant_cfg):
9294
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
9395
continue
94-
assert isinstance(entry.get("cfg", {}), dict)
95-
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
96+
cfg = entry.get("cfg") or {}
97+
assert isinstance(cfg, dict)
98+
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
9699
break
97100

98101

@@ -1104,6 +1107,12 @@ def quantize_main(
11041107
quant_cfg = copy.deepcopy(quant_cfg)
11051108
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11061109

1110+
if needs_checkpoint_path_update(quant_cfg):
1111+
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
1112+
print(
1113+
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
1114+
)
1115+
11071116
if args.qformat in QUANT_CFG_CHOICES:
11081117
mono_quantize(
11091118
args,

modelopt/torch/quantization/config.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12171217
),
12181218
)
12191219

1220-
use_sequential: bool = ModeloptField(
1220+
layerwise: bool = ModeloptField(
12211221
default=False,
1222-
title="Enable sequential layer-by-layer calibration.",
1222+
title="Enable layerwise (layer-by-layer) calibration.",
12231223
description=(
1224-
"If True, the calibration algorithm is applied sequentially to each decoder block. "
1224+
"If True, the calibration algorithm is applied to each decoder layer independently. "
12251225
"Each layer's inputs are captured via a single forward pass that reflects the "
12261226
"quantization of all preceding layers, incurring O(N) forward passes for N layers."
12271227
),
12281228
)
12291229

1230+
layerwise_checkpoint_dir: str | None = ModeloptField(
1231+
default=None,
1232+
title="Checkpoint directory for layerwise calibration.",
1233+
description=(
1234+
"If set together with layerwise=True, per-layer checkpoints are saved to this "
1235+
"directory during calibration. On restart, calibration resumes from the last "
1236+
"completed layer."
1237+
),
1238+
)
1239+
1240+
@model_validator(mode="after")
1241+
def validate_layerwise_checkpoint_dir(self):
1242+
"""Raise if layerwise_checkpoint_dir is set but layerwise is False."""
1243+
if self.layerwise_checkpoint_dir is not None and not self.layerwise:
1244+
raise ValueError(
1245+
"layerwise_checkpoint_dir requires layerwise=True. "
1246+
"Set layerwise=True or remove layerwise_checkpoint_dir."
1247+
)
1248+
return self
1249+
12301250

12311251
class MaxCalibConfig(QuantizeAlgorithmConfig):
12321252
"""The config for max calibration algorithm.

modelopt/torch/quantization/mode.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@
6060
from .model_calib import (
6161
awq,
6262
gptq,
63+
layerwise_calibrate,
6364
local_hessian_calibrate,
6465
max_calibrate,
6566
mse_calibrate,
66-
sequential_calibrate,
6767
smoothquant,
6868
svdquant,
6969
)
@@ -222,7 +222,8 @@ def wrapped_calib_func(
222222
"""
223223
kwargs = config.model_dump()
224224
method = kwargs.pop("method")
225-
sequential = kwargs.pop("use_sequential", False)
225+
layerwise = kwargs.pop("layerwise", False)
226+
checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None)
226227
if method is not None and "awq" in method:
227228
# For backward compatibility
228229
kwargs["algorithm"] = method
@@ -237,17 +238,16 @@ def wrapped_calib_func(
237238
module._moe_calib_experts_ratio = moe_calib_experts_ratio
238239

239240
if func is not None:
240-
if sequential:
241+
if layerwise:
242+
# TODO: add a method guard here — not all calib methods support per-layer invocation
241243
if forward_loop is None:
242244
raise ValueError("forward_loop is required for calibration but got None.")
243-
assert method in ["max", "gptq"], (
244-
f"Sequential calibration currently only supports max and gptq calibration, got {method}"
245-
)
246-
# Wrap with sequential processing
247-
sequential_calibrate(
245+
# Wrap with layerwise processing
246+
layerwise_calibrate(
248247
model,
249248
forward_loop=forward_loop,
250249
calib_func=func,
250+
checkpoint_dir=checkpoint_dir,
251251
**kwargs,
252252
)
253253
else:

modelopt/torch/quantization/model_calib.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
31+
from modelopt.torch.quantization.utils.layerwise_calib import (
32+
LayerActivationCollector,
33+
_CheckpointState,
34+
)
3235
from modelopt.torch.utils import print_rank_0
3336
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3437
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
@@ -44,6 +47,7 @@
4447
is_quantized_column_parallel_linear,
4548
is_quantized_linear,
4649
is_quantized_row_parallel_linear,
50+
persistent_materialization,
4751
promote_nvfp4_static_quantizers,
4852
quantizer_attr_names,
4953
reduce_amax,
@@ -53,9 +57,9 @@
5357

5458
__all__ = [
5559
"awq",
60+
"layerwise_calibrate",
5661
"local_hessian_calibrate",
5762
"max_calibrate",
58-
"sequential_calibrate",
5963
"smoothquant",
6064
"svdquant",
6165
]
@@ -1552,53 +1556,85 @@ def postprocess(module, name):
15521556

15531557

15541558
@torch.no_grad()
1555-
def sequential_calibrate(
1559+
def layerwise_calibrate(
15561560
model: nn.Module,
15571561
forward_loop: ForwardLoop,
15581562
calib_func: Callable,
15591563
**calib_kwargs,
15601564
):
1561-
"""Sequential calibration - a sequential layer-by-layer calibration algorithm.
1565+
"""Layerwise calibration - a layer-by-layer calibration algorithm.
15621566
15631567
Runs the full model forward per layer but patches decoder layers with a
15641568
skip / run / capture strategy so that inter-layer logic in parent modules
15651569
(e.g. mask construction) executes naturally without model-specific hooks.
1570+
1571+
If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints
1572+
are saved after each layer completes. On restart, calibration resumes from
1573+
the last completed layer.
15661574
"""
1575+
checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None)
1576+
15671577
if forward_loop is None:
15681578
raise ValueError(
1569-
"forward_loop must not be None for sequential calibration. "
1579+
"forward_loop must not be None for layerwise calibration. "
15701580
"Please provide a valid forward_loop callable."
15711581
)
15721582

15731583
transformer_layers = LayerActivationCollector.get_decoder_layers(model)
15741584
if transformer_layers is None or len(transformer_layers) == 0:
15751585
raise ValueError(
15761586
"Could not find transformer layers in model. "
1577-
"Sequential calibration requires a model with identifiable transformer layers."
1587+
"Layerwise calibration requires a model with identifiable transformer layers."
15781588
)
15791589

1580-
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1590+
num_layers = len(transformer_layers)
1591+
print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers")
1592+
1593+
ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers)
1594+
start_layer = ckpt.start_layer if ckpt else 0
15811595

15821596
input_getter = LayerActivationCollector(model)
15831597
input_getter._patch_all_layers(decoder_layers=transformer_layers)
15841598

1599+
resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None
1600+
15851601
try:
1586-
for layer_idx, layer in enumerate(transformer_layers):
1587-
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
1588-
layer_inputs = input_getter.get_input_activations(layer, forward_loop)
1602+
# Bootstrap: get first layer's inputs (or use resumed inputs).
1603+
layer_inputs = input_getter.get_first_layer_inputs(
1604+
start_layer, resumed_inputs, forward_loop
1605+
)
1606+
1607+
for layer_idx in range(start_layer, num_layers):
1608+
layer = transformer_layers[layer_idx]
15891609

15901610
def _layer_forward_loop(m, _inputs=layer_inputs):
15911611
for args, kwargs_input in _inputs:
15921612
m(*args, **kwargs_input)
15931613

1594-
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1614+
with persistent_materialization(layer):
1615+
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1616+
1617+
# Run one more forward to get next layer's inputs and set
1618+
# output_meta on the just-calibrated layer (via "run" mode).
1619+
is_last = layer_idx + 1 >= num_layers
1620+
if not is_last:
1621+
next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop)
1622+
else:
1623+
next_inputs = None
1624+
1625+
if ckpt:
1626+
ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs)
15951627

15961628
del layer_inputs
15971629
torch.cuda.empty_cache()
1630+
layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure)
15981631
finally:
15991632
input_getter._unpatch_all_layers()
16001633

1601-
print_rank_0("Sequential calibration completed")
1634+
if ckpt:
1635+
ckpt.full_restore(transformer_layers, model)
1636+
1637+
print_rank_0("Layerwise calibration completed")
16021638

16031639

16041640
@torch.no_grad()
@@ -1610,12 +1646,12 @@ def gptq(
16101646
):
16111647
"""GPTQ quantization.
16121648
1613-
Works in two modes depending on ``use_sequential`` in the config:
1649+
Works in two modes depending on ``layerwise`` in the config:
16141650
1615-
* **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this
1651+
* **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this
16161652
function once per decoder layer with updated activations, producing more
16171653
accurate Hessian estimates.
1618-
* **Non-sequential** (``use_sequential=False``): called once on the full model.
1654+
* **Non-layerwise** (``layerwise=False``): called once on the full model.
16191655
All layers are quantized in parallel from the original activations.
16201656
16211657
Per-module steps:
@@ -1628,7 +1664,7 @@ def gptq(
16281664
16291665
Args:
16301666
model: The module to quantize — either the full model or a single decoder
1631-
layer when invoked by ``sequential_calibrate``.
1667+
layer when invoked by ``layerwise_calibrate``.
16321668
forward_loop: Callable that replays calibration inputs through *model*.
16331669
perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01).
16341670
block_size: Block size for GPTQ weight update.
@@ -1663,8 +1699,10 @@ def gptq(
16631699
handle.cleanup()
16641700

16651701
print_rank_0("Updating weights using GPTQ algorithm...")
1702+
name_to_module = dict(model.named_modules())
16661703
for handle in gptq_handles.values():
1667-
handle.update_weights(block_size, perc_damp)
1704+
with enable_weight_access_and_writeback(handle.module, model, name_to_module):
1705+
handle.update_weights(block_size, perc_damp)
16681706
handle.free()
16691707
del gptq_handles
16701708

0 commit comments

Comments
 (0)