Skip to content

Commit c50c4a7

Browse files
realAsmaclaude
andcommitted
Add layerwise calibration for large models
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 6a25fc2 commit c50c4a7

6 files changed

Lines changed: 106 additions & 18 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
9191
for i, entry in enumerate(quant_cfg):
9292
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
9393
continue
94-
assert isinstance(entry.get("cfg", {}), dict)
95-
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
94+
cfg = entry.get("cfg") or {}
95+
assert isinstance(cfg, dict)
96+
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
9697
break
9798

9899

@@ -1104,6 +1105,15 @@ def quantize_main(
11041105
quant_cfg = copy.deepcopy(quant_cfg)
11051106
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11061107

1108+
from modelopt.torch.quantization.utils.layerwise_calib import (
1109+
needs_checkpoint_path_update,
1110+
resolve_checkpoint_dir,
1111+
)
1112+
1113+
if needs_checkpoint_path_update(quant_cfg):
1114+
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
1115+
print(f"Auto-resolved checkpoint_dir: {quant_cfg['algorithm']['checkpoint_dir']}")
1116+
11071117
if args.qformat in QUANT_CFG_CHOICES:
11081118
mono_quantize(
11091119
args,

modelopt/torch/quantization/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def wrapped_calib_func(
239239

240240
if func is not None:
241241
if layerwise:
242+
# TODO: add a method guard here — not all calib methods support per-layer invocation
242243
if forward_loop is None:
243244
raise ValueError("forward_loop is required for calibration but got None.")
244245
# Wrap with layerwise processing

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
31+
from modelopt.torch.quantization.utils.layerwise_calib import (
32+
LayerActivationCollector,
33+
_CheckpointState,
34+
)
3235
from modelopt.torch.utils import print_rank_0
3336
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3437
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
@@ -1569,8 +1572,6 @@ def layerwise_calibrate(
15691572
are saved after each layer completes. On restart, calibration resumes from
15701573
the last completed layer.
15711574
"""
1572-
from modelopt.torch.quantization.utils.layerwise_calib import _CheckpointState
1573-
15741575
checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None)
15751576

15761577
if forward_loop is None:

modelopt/torch/quantization/plugins/accelerate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434
def _get_cpu_offload_hook(hook):
3535
if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None:
3636
assert len(hook.weights_map) > 0
37-
if (
38-
isinstance(hook.weights_map, PrefixedDataset)
39-
and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict
37+
if isinstance(hook.weights_map, PrefixedDataset) and not any(
38+
k.startswith(hook.weights_map.prefix) for k in hook.weights_map.dataset.state_dict
4039
):
4140
raise NotImplementedError(
4241
"This layer could be offloaded to disk. We don't support this yet."

modelopt/torch/quantization/utils/layerwise_calib.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import torch
3434
import torch.nn as nn
3535

36+
from modelopt.torch.utils import distributed as dist
3637
from modelopt.torch.utils import print_rank_0
3738
from modelopt.torch.utils.network import (
3839
bind_forward_method,
@@ -77,6 +78,16 @@ def __init__(self, original: nn.Module):
7778
object.__setattr__(self, "_original", original)
7879
self._layerwise_calib = _LayerCalibState(mode="skip")
7980

81+
def __getattr__(self, name: str):
82+
# Proxy non-special attribute lookups to the original layer so that
83+
# parent-model code that accesses layer-level attributes (e.g.,
84+
# NemotronH's ``block_type``) still works when the layer is replaced
85+
# with a _SkipLayer.
86+
try:
87+
return super().__getattr__(name)
88+
except AttributeError:
89+
return getattr(object.__getattribute__(self, "_original"), name)
90+
8091
def forward(self, *args, **kwargs):
8192
return LayerActivationCollector._zeros_from_meta(
8293
self._original._layerwise_calib.output_meta
@@ -315,7 +326,13 @@ def _log_layer_summary(self, layer_idx: int):
315326
mode = layer._layerwise_calib.mode
316327
if mode in ("skip", "run", "capture"):
317328
groups.setdefault(mode, []).append(i + 1)
318-
parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups]
329+
330+
parts = []
331+
for mode in ("skip", "run", "capture"):
332+
if mode not in groups:
333+
continue
334+
ids = groups[mode]
335+
parts.append(f"{mode}: {len(ids)}" if mode == "skip" else f"{mode}: {ids}")
319336
print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}")
320337

321338
@torch.no_grad()
@@ -489,6 +506,42 @@ def _save_layer(
489506
_write_manifest(checkpoint_dir, idx, num_layers)
490507

491508

509+
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
510+
"""Check if quant_cfg has a checkpoint_dir that should be auto-resolved to a unique subpath."""
511+
algorithm = quant_cfg.get("algorithm")
512+
if algorithm is None or isinstance(algorithm, str):
513+
return False
514+
return algorithm.get("checkpoint_dir") is not None
515+
516+
517+
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
518+
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to checkpoint_dir.
519+
520+
Allows a single recipe to be reused across models without checkpoint collisions.
521+
Must only be called when :func:`needs_checkpoint_path_update` returns True.
522+
"""
523+
import copy
524+
import hashlib
525+
from pathlib import Path
526+
527+
algorithm = quant_cfg["algorithm"]
528+
base_dir = algorithm["checkpoint_dir"]
529+
530+
name = model_path.rstrip("/")
531+
if "/" in name and not os.path.isabs(name):
532+
name = name.replace("/", "--")
533+
else:
534+
name = Path(name).name
535+
536+
config_hash = hashlib.sha256(
537+
json.dumps(quant_cfg, sort_keys=True, default=str).encode()
538+
).hexdigest()[:8]
539+
540+
quant_cfg = copy.deepcopy(quant_cfg)
541+
quant_cfg["algorithm"]["checkpoint_dir"] = os.path.join(base_dir, f"{name}_{config_hash}")
542+
return quant_cfg
543+
544+
492545
def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None:
493546
"""Detect where to resume from an existing checkpoint directory.
494547
@@ -512,9 +565,21 @@ class _CheckpointState:
512565
513566
Handles both saving per-layer checkpoints during calibration and
514567
restoring from a previous partial run.
568+
569+
.. todo::
570+
Support distributed checkpoint save/restore for FSDP2:
571+
use ``torch.distributed.checkpoint`` (or save only from rank 0 + barrier)
572+
and broadcast restored state to all ranks during resume.
515573
"""
516574

517575
def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0):
576+
if dist.is_initialized() and dist.size() > 1:
577+
raise RuntimeError(
578+
"Layerwise calibration checkpointing is not supported in "
579+
"multi-process distributed jobs (e.g. FSDP2). "
580+
"Use single-process calibration or disable checkpointing."
581+
)
582+
518583
self.checkpoint_dir = checkpoint_dir
519584
self.num_layers = num_layers
520585
self.start_layer = start_layer

modelopt/torch/utils/dataset_utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -594,16 +594,28 @@ def _forward_loop(
594594
dataloader: DataLoader containing the batched input data
595595
allowed_non_tensor_keys: Set of key names whose values may be non-tensor types
596596
"""
597-
with torch.no_grad():
598-
is_enc_dec = model_type_is_enc_dec(model)
599-
infer_method = model.generate if is_enc_dec else model.forward
600-
max_working_batch_size = None # Initialize max working batch size as None
597+
# Disable KV caching during calibration — it is unnecessary overhead and causes
598+
# correctness issues with hybrid Mamba/attention models whose cache state is mutated
599+
# in-place (e.g., NemotronH).
600+
config = getattr(model, "config", None)
601+
prev_use_cache = getattr(config, "use_cache", None)
602+
if config is not None and prev_use_cache is not None:
603+
config.use_cache = False
601604

602-
for _, data in enumerate(tqdm(dataloader)):
603-
# Process batch and update max working batch size
604-
max_working_batch_size = _process_batch(
605-
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
606-
)
605+
try:
606+
with torch.no_grad():
607+
is_enc_dec = model_type_is_enc_dec(model)
608+
infer_method = model.generate if is_enc_dec else model.forward
609+
max_working_batch_size = None # Initialize max working batch size as None
610+
611+
for _, data in enumerate(tqdm(dataloader)):
612+
# Process batch and update max working batch size
613+
max_working_batch_size = _process_batch(
614+
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
615+
)
616+
finally:
617+
if config is not None and prev_use_cache is not None:
618+
config.use_cache = prev_use_cache
607619

608620

609621
def create_forward_loop(

0 commit comments

Comments
 (0)