Skip to content

Commit 8eabe76

Browse files
realAsmaclaude
andcommitted
Add checkpoint save/resume for sequential calibration of large models
Per-layer checkpoints allow sequential calibration to resume from the last completed layer after a crash or preemption. Also extends sequential calibration to work with FSDP2 and accelerate CPU-offloaded models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 14b78ae commit 8eabe76

File tree

16 files changed

+1196
-102
lines changed

16 files changed

+1196
-102
lines changed

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12271227
),
12281228
)
12291229

1230+
checkpoint_dir: str | None = ModeloptField(
1231+
default=None,
1232+
title="Checkpoint directory for sequential calibration.",
1233+
description=(
1234+
"If set together with use_sequential=True, per-layer checkpoints are saved to this "
1235+
"directory during calibration. On restart, calibration resumes from the last "
1236+
"completed layer."
1237+
),
1238+
)
1239+
12301240

12311241
class MaxCalibConfig(QuantizeAlgorithmConfig):
12321242
"""The config for max calibration algorithm.

modelopt/torch/quantization/mode.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def wrapped_calib_func(
223223
kwargs = config.model_dump()
224224
method = kwargs.pop("method")
225225
sequential = kwargs.pop("use_sequential", False)
226+
checkpoint_dir = kwargs.pop("checkpoint_dir", None)
226227
if method is not None and "awq" in method:
227228
# For backward compatibility
228229
kwargs["algorithm"] = method
@@ -240,14 +241,12 @@ def wrapped_calib_func(
240241
if sequential:
241242
if forward_loop is None:
242243
raise ValueError("forward_loop is required for calibration but got None.")
243-
assert method in ["max", "gptq"], (
244-
f"Sequential calibration currently only supports max and gptq calibration, got {method}"
245-
)
246244
# Wrap with sequential processing
247245
sequential_calibrate(
248246
model,
249247
forward_loop=forward_loop,
250248
calib_func=func,
249+
checkpoint_dir=checkpoint_dir,
251250
**kwargs,
252251
)
253252
else:

modelopt/torch/quantization/model_calib.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
31+
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
3232
from modelopt.torch.utils import print_rank_0
3333
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3434
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
@@ -1563,7 +1563,15 @@ def sequential_calibrate(
15631563
Runs the full model forward per layer but patches decoder layers with a
15641564
skip / run / capture strategy so that inter-layer logic in parent modules
15651565
(e.g. mask construction) executes naturally without model-specific hooks.
1566+
1567+
If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints
1568+
are saved after each layer completes. On restart, calibration resumes from
1569+
the last completed layer.
15661570
"""
1571+
from modelopt.torch.quantization.utils.layerwise_calib import _CheckpointState
1572+
1573+
checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None)
1574+
15671575
if forward_loop is None:
15681576
raise ValueError(
15691577
"forward_loop must not be None for sequential calibration. "
@@ -1577,30 +1585,77 @@ def sequential_calibrate(
15771585
"Sequential calibration requires a model with identifiable transformer layers."
15781586
)
15791587

1580-
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1588+
num_layers = len(transformer_layers)
1589+
print_rank_0(f"Sequential calibration: Found {num_layers} transformer layers")
1590+
1591+
ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers)
1592+
start_layer = ckpt.start_layer if ckpt else 0
15811593

15821594
input_getter = LayerActivationCollector(model)
15831595
input_getter._patch_all_layers(decoder_layers=transformer_layers)
15841596

1597+
resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None
1598+
15851599
try:
1586-
for layer_idx, layer in enumerate(transformer_layers):
1587-
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
1588-
layer_inputs = input_getter.get_input_activations(layer, forward_loop)
1600+
for layer_idx, layer in enumerate(list(transformer_layers)):
1601+
if layer_idx < start_layer:
1602+
continue
1603+
1604+
layer_inputs = _get_layer_inputs(
1605+
layer_idx, start_layer, resumed_inputs, layer, input_getter, forward_loop
1606+
)
1607+
if ckpt:
1608+
ckpt.save_prev(transformer_layers, layer_inputs)
15891609

15901610
def _layer_forward_loop(m, _inputs=layer_inputs):
15911611
for args, kwargs_input in _inputs:
15921612
m(*args, **kwargs_input)
15931613

15941614
calib_func(layer, _layer_forward_loop, **calib_kwargs)
15951615

1616+
if ckpt:
1617+
ckpt.stash(layer_idx, layer, model)
1618+
15961619
del layer_inputs
15971620
torch.cuda.empty_cache()
1621+
1622+
if ckpt:
1623+
ckpt.save_last(transformer_layers)
15981624
finally:
15991625
input_getter._unpatch_all_layers()
16001626

1627+
if ckpt:
1628+
ckpt.full_restore(transformer_layers, model)
1629+
16011630
print_rank_0("Sequential calibration completed")
16021631

16031632

1633+
def _get_layer_inputs(
1634+
layer_idx: int,
1635+
start_layer: int,
1636+
resumed_inputs: list | None,
1637+
layer: nn.Module,
1638+
input_getter: LayerActivationCollector,
1639+
forward_loop: ForwardLoop,
1640+
) -> list:
1641+
"""Get inputs for a layer, using resumed_inputs for the first resumed layer."""
1642+
if layer_idx == start_layer and resumed_inputs is not None:
1643+
print_rank_0(f"Calibrating layer {layer_idx + 1} (resumed)")
1644+
# Manually set skip mode on all already-calibrated layers (output_meta
1645+
# was loaded by setup_resume). Don't call _set_layer_states which
1646+
# assumes the normal sequential progression with collected_inputs.
1647+
assert input_getter._decoder_layers is not None
1648+
for i in range(start_layer):
1649+
input_getter._swap_to_dummy(i)
1650+
# Seed collected_inputs so the next _set_layer_states call can
1651+
# transition this layer to "run" mode.
1652+
layer._seq_calib.collected_inputs = resumed_inputs
1653+
layer._seq_calib.mode = "original"
1654+
return resumed_inputs
1655+
1656+
return input_getter.get_input_activations(layer, forward_loop)
1657+
1658+
16041659
@torch.no_grad()
16051660
def gptq(
16061661
model: nn.Module,
@@ -1663,8 +1718,10 @@ def gptq(
16631718
handle.cleanup()
16641719

16651720
print_rank_0("Updating weights using GPTQ algorithm...")
1721+
name_to_module = dict(model.named_modules())
16661722
for handle in gptq_handles.values():
1667-
handle.update_weights(block_size, perc_damp)
1723+
with enable_weight_access_and_writeback(handle.module, model, name_to_module):
1724+
handle.update_weights(block_size, perc_damp)
16681725
handle.free()
16691726
del gptq_handles
16701727

modelopt/torch/quantization/plugins/accelerate.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def _get_cpu_offload_hook(hook):
3535
if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None:
36-
assert "weight" in hook.weights_map
36+
assert len(hook.weights_map) > 0
3737
if (
3838
isinstance(hook.weights_map, PrefixedDataset)
3939
and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict
@@ -50,32 +50,68 @@ def _get_cpu_offload_hook(hook):
5050
return None
5151

5252

53+
def _writeback_params_to_weights_map(module, align_hook):
54+
"""Write all non-meta parameters back to the hook's CPU weights_map."""
55+
for name, param in module.named_parameters():
56+
if param.device.type == "meta":
57+
continue
58+
if isinstance(align_hook.weights_map, PrefixedDataset):
59+
key = align_hook.weights_map.prefix + name
60+
w_map = align_hook.weights_map.dataset.state_dict
61+
else:
62+
w_map = align_hook.weights_map
63+
key = name
64+
if key in w_map:
65+
w_map[key] = param.data.to(w_map[key].device, dtype=w_map[key].dtype)
66+
67+
5368
@contextmanager
5469
def weight_access_and_writeback_context(module):
55-
"""Context manager for weight access and writeback for modules managed by accelerate."""
70+
"""Context manager for weight access and writeback for modules managed by accelerate.
71+
72+
Handles two cases:
73+
1. **Single-module**: the module's own ``_hf_hook`` is an offload hook.
74+
2. **Sub-module**: the module's hook is non-offloading, but its children have
75+
offload hooks (common with ``SequentialHook`` on sub-modules placed by
76+
``load_checkpoint_and_dispatch``).
77+
78+
For the sub-module case, ``pre_forward`` is skipped on sub-modules whose weights
79+
are already materialized (not on meta). This allows the context manager to be
80+
used as a pure writeback after weight-modifying algorithms.
81+
"""
5682
assert hasattr(module, "_hf_hook")
5783
align_hook = _get_cpu_offload_hook(module._hf_hook)
5884

5985
if align_hook:
60-
# Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass
61-
# The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader
62-
# See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141
63-
# TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority)
64-
65-
# This will load the weights from CPU state_dict and move it to the GPU from meta device
6686
align_hook.pre_forward(module)
87+
try:
88+
yield
89+
finally:
90+
_writeback_params_to_weights_map(module, align_hook)
91+
align_hook.post_forward(module, None)
92+
return
93+
94+
materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = []
95+
for mod in module.modules():
96+
if mod is module or not hasattr(mod, "_hf_hook"):
97+
continue
98+
hook = _get_cpu_offload_hook(mod._hf_hook)
99+
if hook is None:
100+
continue
101+
# Only call pre_forward if weights need materializing; already-materialized
102+
# weights would be overwritten with stale CPU state_dict values.
103+
needs_materialize = any(p.device.type == "meta" for p in mod.parameters())
104+
if needs_materialize:
105+
hook.pre_forward(mod)
106+
materialized.append((mod, hook, needs_materialize))
107+
67108
try:
68109
yield
69110
finally:
70-
if align_hook:
71-
# Update the weight in the CPU state_dict
72-
if isinstance(align_hook.weights_map, PrefixedDataset):
73-
key = align_hook.weights_map.prefix + "weight"
74-
w_map = align_hook.weights_map.dataset.state_dict
75-
else:
76-
key, w_map = "weight", align_hook.weights_map
77-
w_map[key] = module.weight.data.to(w_map[key].device, dtype=w_map[key].dtype)
78-
align_hook.post_forward(module, None)
111+
for mod, hook, was_materialized in materialized:
112+
_writeback_params_to_weights_map(mod, hook)
113+
if was_materialized:
114+
hook.post_forward(mod, None)
79115

80116

81117
@contextmanager

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ..nn.modules.quant_linear import _QuantLinear
4040
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
4141
from ..utils import replace_function, sync_moe_expert_amax
42-
from ..utils.activation_collector import LayerActivationCollector
42+
from ..utils.layerwise_calib import LayerActivationCollector
4343
from .attention import register_attention_for_kv_quant
4444
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
4545

modelopt/torch/quantization/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# ruff: noqa: F405
1717
"""Quantization utilities."""
1818

19-
from .activation_collector import LayerActivationCollector
2019
from .core_utils import *
20+
from .layerwise_calib import LayerActivationCollector
2121

2222
__all__ = [
2323
"EXPORT_MODE",

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, module, name, offload_to_cpu=False):
9696
self.name = name
9797
in_features = module.weight.shape[-1]
9898
device = module.weight.device
99-
if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65:
99+
if device.type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65):
100100
device = "cpu"
101101
self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device)
102102
self.n_samples = 0

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -423,47 +423,70 @@ def _get_enclosing_fsdp_module(
423423
return root_model
424424

425425

426+
def _set_parameter(module: nn.Module, name: str, value: nn.Parameter):
427+
"""Set a parameter on a module by dotted name (e.g. ``self_attn.q_proj.weight``)."""
428+
parts = name.rsplit(".", 1)
429+
if len(parts) == 2:
430+
parent = module.get_submodule(parts[0])
431+
attr = parts[1]
432+
else:
433+
parent = module
434+
attr = name
435+
parent._parameters[attr] = value
436+
437+
426438
@contextmanager
427439
def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module):
428440
"""Context manager for FSDP2 weight access and writeback.
429441
430-
Note this context will gather the weight across FSDP/HSDP shards. If TP is implemented with DTensor,
431-
the weight will be a local tensor of the TP DTensor under this context.
442+
Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be
443+
read or modified. Works for both leaf modules (single ``weight``) and
444+
composite modules like decoder layers (all ``named_parameters``).
445+
446+
If TP is implemented with DTensor, the weight will be a local tensor of the
447+
TP DTensor under this context.
432448
"""
433449
assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2"
434450

435451
assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks"
436-
assert isinstance(module.weight, torch.distributed.tensor.DTensor)
437452
fsdp_module = _get_enclosing_fsdp_module(module, root_model)
438453
assert fsdp_module is not None, "Module is not wrapped by FSDP"
439454
fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module)
440455
fsdp_dim = fsdp_device_mesh.ndim
441456

442-
original_placements = module.weight.placements
443-
original_device_mesh = module.weight.device_mesh
444-
original_weight = module.weight
445-
# Assuming the first fsdp_dim dimensions are for FSDP/HSDP, we only collect the tensor over FSDP/HSDP dimension,
446-
# the TP will be handled by the TP reduction.
447-
if fsdp_dim != original_device_mesh.ndim:
448-
assert fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim], (
449-
"FSDP2 mesh should be a slice of DTesnor's device mesh."
457+
# Collect all DTensor parameters, replacing them with local replicated copies.
458+
originals: dict[str, tuple] = {}
459+
for name, param in module.named_parameters():
460+
if not isinstance(param, torch.distributed.tensor.DTensor):
461+
continue
462+
original_placements = param.placements
463+
original_device_mesh = param.device_mesh
464+
if fsdp_dim != original_device_mesh.ndim:
465+
assert (
466+
fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim]
467+
), "FSDP2 mesh should be a slice of DTensor's device mesh."
468+
collected = param.redistribute(
469+
placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]),
470+
device_mesh=original_device_mesh,
450471
)
451-
452-
weight_collected = original_weight.redistribute(
453-
placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]),
454-
device_mesh=original_device_mesh,
455-
)
456-
new_weight = nn.Parameter(weight_collected.to_local())
457-
module._parameters["weight"] = new_weight
472+
originals[name] = (param, collected, original_placements, original_device_mesh)
473+
_set_parameter(module, name, nn.Parameter(collected.to_local()))
458474

459475
yield
460476

461-
original_weight.to_local().data.copy_(
462-
weight_collected.redistribute(
463-
placements=original_placements, device_mesh=original_device_mesh
464-
).to_local()
465-
)
466-
module._parameters["weight"] = original_weight
477+
# Write back and restore original DTensor parameters.
478+
for name, (
479+
original_param,
480+
collected,
481+
original_placements,
482+
original_device_mesh,
483+
) in originals.items():
484+
original_param.to_local().data.copy_(
485+
collected.redistribute(
486+
placements=original_placements, device_mesh=original_device_mesh
487+
).to_local()
488+
)
489+
_set_parameter(module, name, original_param)
467490

468491

469492
@contextmanager

0 commit comments

Comments
 (0)