Skip to content

Commit 8af3655

Browse files
realAsmaclaude
andcommitted
Add checkpoint save/resume for sequential calibration of large models
Per-layer checkpoints allow sequential calibration to resume from the last completed layer after a crash or preemption. Also extends sequential calibration to work with FSDP2 and accelerate CPU-offloaded models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent b6c6ec3 commit 8af3655

File tree

19 files changed

+1499
-408
lines changed

19 files changed

+1499
-408
lines changed

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12271227
),
12281228
)
12291229

1230+
checkpoint_dir: str | None = ModeloptField(
1231+
default=None,
1232+
title="Checkpoint directory for sequential calibration.",
1233+
description=(
1234+
"If set together with use_sequential=True, per-layer checkpoints are saved to this "
1235+
"directory during calibration. On restart, calibration resumes from the last "
1236+
"completed layer."
1237+
),
1238+
)
1239+
12301240

12311241
class MaxCalibConfig(QuantizeAlgorithmConfig):
12321242
"""The config for max calibration algorithm.

modelopt/torch/quantization/mode.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def wrapped_calib_func(
223223
kwargs = config.model_dump()
224224
method = kwargs.pop("method")
225225
sequential = kwargs.pop("use_sequential", False)
226+
checkpoint_dir = kwargs.pop("checkpoint_dir", None)
226227
if method is not None and "awq" in method:
227228
# For backward compatibility
228229
kwargs["algorithm"] = method
@@ -240,14 +241,12 @@ def wrapped_calib_func(
240241
if sequential:
241242
if forward_loop is None:
242243
raise ValueError("forward_loop is required for calibration but got None.")
243-
assert method in ["max", "gptq"], (
244-
f"Sequential calibration currently only supports max and gptq calibration, got {method}"
245-
)
246244
# Wrap with sequential processing
247245
sequential_calibrate(
248246
model,
249247
forward_loop=forward_loop,
250248
calib_func=func,
249+
checkpoint_dir=checkpoint_dir,
251250
**kwargs,
252251
)
253252
else:

modelopt/torch/quantization/model_calib.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
31+
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
3232
from modelopt.torch.utils import print_rank_0
3333
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3434
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
@@ -1563,7 +1563,15 @@ def sequential_calibrate(
15631563
Runs the full model forward per layer but patches decoder layers with a
15641564
skip / run / capture strategy so that inter-layer logic in parent modules
15651565
(e.g. mask construction) executes naturally without model-specific hooks.
1566+
1567+
If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints
1568+
are saved after each layer completes. On restart, calibration resumes from
1569+
the last completed layer.
15661570
"""
1571+
from modelopt.torch.quantization.utils.layerwise_calib import _CheckpointState
1572+
1573+
checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None)
1574+
15671575
if forward_loop is None:
15681576
raise ValueError(
15691577
"forward_loop must not be None for sequential calibration. "
@@ -1577,27 +1585,52 @@ def sequential_calibrate(
15771585
"Sequential calibration requires a model with identifiable transformer layers."
15781586
)
15791587

1580-
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1588+
num_layers = len(transformer_layers)
1589+
print_rank_0(f"Sequential calibration: Found {num_layers} transformer layers")
1590+
1591+
ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers)
1592+
start_layer = ckpt.start_layer if ckpt else 0
15811593

15821594
input_getter = LayerActivationCollector(model)
15831595
input_getter._patch_all_layers(decoder_layers=transformer_layers)
15841596

1597+
resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None
1598+
15851599
try:
1586-
for layer_idx, layer in enumerate(transformer_layers):
1587-
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
1588-
layer_inputs = input_getter.get_input_activations(layer, forward_loop)
1600+
# Bootstrap: get first layer's inputs (or use resumed inputs).
1601+
layer_inputs = input_getter.get_first_layer_inputs(
1602+
start_layer, resumed_inputs, forward_loop
1603+
)
15891604

1590-
def _layer_forward_loop(m, _inputs=layer_inputs):
1591-
for args, kwargs_input in _inputs:
1605+
for layer_idx in range(start_layer, num_layers):
1606+
layer = transformer_layers[layer_idx]
1607+
1608+
def _layer_forward_loop(m):
1609+
for args, kwargs_input in layer_inputs:
15921610
m(*args, **kwargs_input)
15931611

15941612
calib_func(layer, _layer_forward_loop, **calib_kwargs)
15951613

1614+
# Run one more forward to get next layer's inputs and set
1615+
# output_meta on the just-calibrated layer (via "run" mode).
1616+
is_last = layer_idx + 1 >= num_layers
1617+
if not is_last:
1618+
next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop)
1619+
else:
1620+
next_inputs = None
1621+
1622+
if ckpt:
1623+
ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs)
1624+
15961625
del layer_inputs
15971626
torch.cuda.empty_cache()
1627+
layer_inputs = next_inputs
15981628
finally:
15991629
input_getter._unpatch_all_layers()
16001630

1631+
if ckpt:
1632+
ckpt.full_restore(transformer_layers, model)
1633+
16011634
print_rank_0("Sequential calibration completed")
16021635

16031636

@@ -1663,8 +1696,10 @@ def gptq(
16631696
handle.cleanup()
16641697

16651698
print_rank_0("Updating weights using GPTQ algorithm...")
1699+
name_to_module = dict(model.named_modules())
16661700
for handle in gptq_handles.values():
1667-
handle.update_weights(block_size, perc_damp)
1701+
with enable_weight_access_and_writeback(handle.module, model, name_to_module):
1702+
handle.update_weights(block_size, perc_damp)
16681703
handle.free()
16691704
del gptq_handles
16701705

modelopt/torch/quantization/plugins/accelerate.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def _get_cpu_offload_hook(hook):
3535
if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None:
36-
assert "weight" in hook.weights_map
36+
assert len(hook.weights_map) > 0
3737
if (
3838
isinstance(hook.weights_map, PrefixedDataset)
3939
and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict
@@ -50,32 +50,79 @@ def _get_cpu_offload_hook(hook):
5050
return None
5151

5252

53+
def _writeback_params_to_weights_map(module, align_hook):
54+
"""Write all non-meta parameters back to the hook's CPU weights_map."""
55+
for name, param in module.named_parameters():
56+
if param.device.type == "meta":
57+
continue
58+
if isinstance(align_hook.weights_map, PrefixedDataset):
59+
key = align_hook.weights_map.prefix + name
60+
w_map = align_hook.weights_map.dataset.state_dict
61+
else:
62+
w_map = align_hook.weights_map
63+
key = name
64+
if key in w_map:
65+
w_map[key] = param.data.to(w_map[key].device, dtype=w_map[key].dtype)
66+
67+
5368
@contextmanager
5469
def weight_access_and_writeback_context(module):
55-
"""Context manager for weight access and writeback for modules managed by accelerate."""
70+
"""Context manager for weight access and writeback for modules managed by accelerate.
71+
72+
Handles two cases:
73+
1. **Single-module**: the module's own ``_hf_hook`` is an offload hook.
74+
2. **Sub-module**: the module's hook is non-offloading, but its children have
75+
offload hooks (common with ``SequentialHook`` on sub-modules placed by
76+
``load_checkpoint_and_dispatch``).
77+
78+
For the sub-module case, ``pre_forward`` is skipped on sub-modules whose weights
79+
are already materialized (not on meta). This allows the context manager to be
80+
used as a pure writeback after weight-modifying algorithms.
81+
"""
5682
assert hasattr(module, "_hf_hook")
5783
align_hook = _get_cpu_offload_hook(module._hf_hook)
5884

5985
if align_hook:
60-
# Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass
61-
# The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader
62-
# See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141
63-
# TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority)
64-
65-
# This will load the weights from CPU state_dict and move it to the GPU from meta device
86+
# Guard: the sub-module branch below is not reached when the parent has
87+
# an offload hook. Assert that no children also carry offload hooks,
88+
# which would require a combined writeback strategy.
89+
assert not any(
90+
_get_cpu_offload_hook(mod._hf_hook)
91+
for mod in module.modules()
92+
if mod is not module and hasattr(mod, "_hf_hook")
93+
), (
94+
"Both the module and one of its sub-modules have CPU-offload hooks. "
95+
"weight_access_and_writeback_context does not support this layout yet."
96+
)
6697
align_hook.pre_forward(module)
98+
try:
99+
yield
100+
finally:
101+
_writeback_params_to_weights_map(module, align_hook)
102+
align_hook.post_forward(module, None)
103+
return
104+
105+
materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = []
106+
for mod in module.modules():
107+
if mod is module or not hasattr(mod, "_hf_hook"):
108+
continue
109+
hook = _get_cpu_offload_hook(mod._hf_hook)
110+
if hook is None:
111+
continue
112+
# Only call pre_forward if weights need materializing; already-materialized
113+
# weights would be overwritten with stale CPU state_dict values.
114+
needs_materialize = any(p.device.type == "meta" for p in mod.parameters())
115+
if needs_materialize:
116+
hook.pre_forward(mod)
117+
materialized.append((mod, hook, needs_materialize))
118+
67119
try:
68120
yield
69121
finally:
70-
if align_hook:
71-
# Update the weight in the CPU state_dict
72-
if isinstance(align_hook.weights_map, PrefixedDataset):
73-
key = align_hook.weights_map.prefix + "weight"
74-
w_map = align_hook.weights_map.dataset.state_dict
75-
else:
76-
key, w_map = "weight", align_hook.weights_map
77-
w_map[key] = module.weight.data.to(w_map[key].device, dtype=w_map[key].dtype)
78-
align_hook.post_forward(module, None)
122+
for mod, hook, was_materialized in materialized:
123+
_writeback_params_to_weights_map(mod, hook)
124+
if was_materialized:
125+
hook.post_forward(mod, None)
79126

80127

81128
@contextmanager

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ..nn.modules.quant_linear import _QuantLinear
4040
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
4141
from ..utils import replace_function, sync_moe_expert_amax
42-
from ..utils.activation_collector import LayerActivationCollector
42+
from ..utils.layerwise_calib import LayerActivationCollector
4343
from .attention import register_attention_for_kv_quant
4444
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
4545

modelopt/torch/quantization/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# ruff: noqa: F405
1717
"""Quantization utilities."""
1818

19-
from .activation_collector import LayerActivationCollector
2019
from .core_utils import *
20+
from .layerwise_calib import LayerActivationCollector
2121

2222
__all__ = [
2323
"EXPORT_MODE",

0 commit comments

Comments
 (0)