Skip to content

Commit 67a2b94

Browse files
committed
claude feedback
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
1 parent 6f0feae commit 67a2b94

6 files changed

Lines changed: 233 additions & 93 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,11 @@ def layerwise_calibrate(
17551755
ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers)
17561756
start_layer = ckpt.start_layer if ckpt else 0
17571757

1758+
if ckpt and start_layer >= num_layers:
1759+
ckpt.full_restore(transformer_layers, model)
1760+
print_rank_0("Layerwise calibration completed (restored from checkpoint)")
1761+
return
1762+
17581763
input_getter = LayerActivationCollector(model)
17591764
input_getter._patch_all_layers(decoder_layers=transformer_layers)
17601765

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -824,31 +824,6 @@ def _fake_quantize(self, inputs):
824824
getattr(self, "_onnx_quantizer_type", None),
825825
self._pass_through_bwd,
826826
)
827-
elif (
828-
self.block_sizes is not None
829-
and self._num_bits == (2, 1)
830-
and self.block_sizes.get("scale_bits") == (4, 3)
831-
):
832-
# Static NVFP4: plain TensorQuantizer should have been promoted to
833-
# NVFP4StaticQuantizer during MSE setup. For per-expert quantizers
834-
# in fused MoEs, promotion is gated on `_amax` having been set during
835-
# max_calibrate; experts not activated during max_calibrate stay
836-
# plain. MSE later sets a per-block `_amax`, so by the time forward
837-
# runs again the quantizer has a valid amax — dispatch to the static
838-
# NVFP4 fake-quant path here.
839-
if amax is not None:
840-
outputs = static_blockwise_fp4_fake_quant(
841-
inputs,
842-
amax,
843-
None, # global_amax — computed internally by the kernel
844-
True,
845-
inputs.dtype,
846-
self._pass_through_bwd,
847-
)
848-
else:
849-
# No amax at all (truly uncalibrated): pass through unchanged so
850-
# forward doesn't crash. Should not normally be reachable.
851-
outputs = inputs
852827
elif isinstance(self._num_bits, tuple):
853828
# Float-point quantization, e.g., FP8
854829
E, M = self._num_bits # noqa: N806
@@ -959,9 +934,7 @@ def set_quant_params(axis, block_reshape_size, padding, slices, amax_shape=None)
959934

960935
quant_axis = [i for i in range(len(quantize_axis)) if quantize_axis[i]]
961936

962-
slices = (
963-
None if all(s is None for s in slices) else [s if s else slice(None) for s in slices]
964-
)
937+
slices = None if all(s is None for s in slices) else [s or slice(None) for s in slices]
965938

966939
if all(p is None for p in paddings):
967940
paddings = None
@@ -970,7 +943,7 @@ def set_quant_params(axis, block_reshape_size, padding, slices, amax_shape=None)
970943
for padding in paddings:
971944
if not (new_paddings or padding):
972945
continue
973-
new_paddings.extend(padding if padding else (0, 0))
946+
new_paddings.extend(padding or (0, 0))
974947
paddings = tuple(reversed(new_paddings))
975948

976949
set_quant_params(quant_axis, reshape_size, paddings, slices)

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -945,21 +945,26 @@ def update_quant_cfg_with_kv_cache_quant(
945945
def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
946946
"""Convert eligible TensorQuantizers to NVFP4StaticQuantizer in-place.
947947
948-
After max calibration sets per-block amax values, NVFP4 static quantizers
949-
need to be promoted so they use the two-level scaling path (global amax +
950-
per-block amax) instead of the generic E4M3 path.
948+
Promotion is purely a class swap based on the static-NVFP4 *format*; it does
949+
not require ``_amax`` to be set. Quantizers without ``_amax`` (e.g. MoE
950+
experts that received no calibration tokens) still get promoted so that any
951+
later forward — once MSE or bootstrap populates ``_amax`` — dispatches via
952+
the subclass's two-level scaling path instead of the parent's generic E4M3.
951953
952954
Returns the number of quantizers converted.
953955
"""
954956
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer
955957

956958
converted = 0
957959
for _name, module in list(model.named_modules()):
958-
if isinstance(module, TensorQuantizer) and not module._disabled:
959-
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
960-
if module.is_nvfp4_static:
961-
initial_amax = module._amax.clone().detach()
962-
global_amax = reduce_amax(initial_amax, axis=None)
963-
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
964-
converted += 1
960+
if not isinstance(module, TensorQuantizer) or module._disabled:
961+
continue
962+
if module._calibrator is None or module._dynamic:
963+
continue
964+
if not module.is_nvfp4_static or isinstance(module, NVFP4StaticQuantizer):
965+
continue
966+
amax = getattr(module, "_amax", None)
967+
global_amax = reduce_amax(amax.detach(), axis=None) if amax is not None else None
968+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
969+
converted += 1
965970
return converted

modelopt/torch/quantization/utils/layerwise_calib.py

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -227,22 +227,20 @@ def _patched_forward(self, *args, **kwargs):
227227
f"Layer {info.name} is in 'run' mode but has no cached inputs to replay."
228228
)
229229
real_args, real_kwargs = info.cached_inputs.popleft()
230-
if (
231-
real_args
232-
and isinstance(real_args[0], torch.Tensor)
233-
and real_args[0].device.type == "cpu"
234-
):
235-
device = get_module_device(self)
236-
real_args = _move_to_device(real_args, device)
237-
real_kwargs = _move_to_device(real_kwargs, device)
230+
# Captured inputs are stored on CPU (see "capture" branch); move
231+
# back to the layer's device for replay. `_move_to_device` is a
232+
# no-op for tensors already on `device`.
233+
device = get_module_device(self)
234+
real_args = _move_to_device(real_args, device)
235+
real_kwargs = _move_to_device(real_kwargs, device)
238236
output = self._original_forward(*real_args, **real_kwargs)
239237
info.output_meta = LayerActivationCollector._extract_output_meta(output)
240238
return output
241239

242240
if info.mode == "capture":
243241
# Offload captured inputs to CPU at append time. For early layers
244-
# on a single GPU (e.g. layer 02 on GPU 0 with seq_device_map),
245-
# accumulating thousands of batches' worth of (bs × seq × hidden)
242+
# on a single GPU (e.g. layer 0-2 on GPU 0 with seq_device_map),
243+
# accumulating thousands of batches' worth of (bs x seq x hidden)
246244
# activations on-device saturates that GPU during the capture loop
247245
# and OOMs before _set_layer_states gets a chance to move them.
248246
# The "run" branch already handles CPU-resident inputs (see the
@@ -333,11 +331,8 @@ def _set_layer_states(self, layer_idx: int):
333331
"was called for every preceding layer in order."
334332
)
335333
prev.mode = "run"
336-
cpu = torch.device("cpu")
337-
prev.cached_inputs = deque(
338-
(_move_to_device(args, cpu), _move_to_device(kwargs, cpu))
339-
for args, kwargs in prev.collected_inputs
340-
)
334+
# Inputs are already CPU-resident at capture time (see _patched_forward).
335+
prev.cached_inputs = deque(prev.collected_inputs)
341336
prev.collected_inputs = []
342337

343338
cur = self._decoder_layers[layer_idx]._layerwise_calib
@@ -534,9 +529,6 @@ def _save_layer(
534529
torch.save(output_meta, os.path.join(d, "output_meta.pt"))
535530
if next_inputs is not None:
536531
torch.save(next_inputs, os.path.join(d, "next_inputs.pt"))
537-
amax_state = {k: v for k, v in weights.items() if "_amax" in k}
538-
if amax_state:
539-
torch.save(amax_state, os.path.join(d, "quantizer_amaxes.pt"))
540532
_write_manifest(checkpoint_dir, idx, num_layers)
541533

542534

@@ -635,17 +627,8 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None:
635627
# Keep on CPU — _patched_forward's run mode moves each entry to device on pop.
636628
return next_inputs
637629

638-
def full_restore(
639-
self, layers: nn.ModuleList, model: nn.Module, restore_weights: bool = True
640-
) -> None:
641-
"""Restore weights and quantizer state for layers 0..K-1 after the calibration loop.
642-
643-
Args:
644-
restore_weights: If False, skip reloading ``weights.pt`` and load only the
645-
``_amax`` values (from ``quantizer_amaxes.pt`` or filtered from ``weights.pt``).
646-
Set to False for calibration algorithms (max, MSE) that never modify weights
647-
to avoid re-reading gigabytes of unchanged expert weights from disk.
648-
"""
630+
def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None:
631+
"""Restore weights and quantizer state for layers 0..K-1 after the calibration loop."""
649632
from modelopt.torch.quantization.config import QuantizeConfig
650633
from modelopt.torch.quantization.conversion import restore_quantizer_state
651634
from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback
@@ -671,31 +654,13 @@ def full_restore(
671654
map_location="cpu",
672655
weights_only=False,
673656
)
657+
weights = torch.load(
658+
os.path.join(d, "weights.pt"),
659+
map_location="cpu",
660+
weights_only=False,
661+
)
674662
restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate})
675-
if restore_weights:
676-
weights = torch.load(
677-
os.path.join(d, "weights.pt"),
678-
map_location="cpu",
679-
weights_only=False,
680-
)
681-
layer.load_state_dict(weights, strict=False, assign=False)
682-
else:
683-
# Load only _amax entries — skip gigabytes of unchanged expert weights.
684-
# Use map_location="cpu" to get fresh CPU tensors (no storage_offset).
685-
# _export_fused_experts moves _amax to the weight device on demand.
686-
amax_path = os.path.join(d, "quantizer_amaxes.pt")
687-
if os.path.exists(amax_path):
688-
amaxes = torch.load(amax_path, map_location="cpu", weights_only=False)
689-
else:
690-
# Legacy checkpoint: filter _amax entries from the full weights.pt.
691-
weights = torch.load(
692-
os.path.join(d, "weights.pt"),
693-
map_location="cpu",
694-
weights_only=False,
695-
)
696-
amaxes = {k: v for k, v in weights.items() if "_amax" in k}
697-
if amaxes:
698-
layer.load_state_dict(amaxes, strict=False, assign=True)
663+
layer.load_state_dict(weights, strict=False, assign=False)
699664

700665
print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")
701666

tests/unit/torch/quantization/test_layerwise_calibrate.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
import modelopt.torch.quantization as mtq
2626
from modelopt.torch.quantization.model_calib import layerwise_calibrate
2727
from modelopt.torch.quantization.nn import TensorQuantizer
28-
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer
28+
from modelopt.torch.quantization.utils.layerwise_calib import (
29+
LayerActivationCollector,
30+
_CheckpointState,
31+
_SkipLayer,
32+
detect_resume_point,
33+
)
2934

3035

3136
class _DecoderBlock(nn.Module):
@@ -719,3 +724,98 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm():
719724
config,
720725
forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))),
721726
)
727+
728+
729+
# Checkpoint resume + capture-time CPU offload
730+
731+
732+
def test_collected_inputs_are_cpu_at_capture(monkeypatch):
733+
"""Capture-time CPU offload: collected_inputs must be on CPU even if data starts elsewhere.
734+
735+
This is the OOM-prevention invariant — without it, accumulating thousands of
736+
batches' worth of activations on the layer's compute device saturates GPU
737+
memory before the run-mode transition gets a chance to move them.
738+
"""
739+
_register_test_discoverer(monkeypatch)
740+
model = _SimpleTwoLayerModel(dim=8)
741+
collector = LayerActivationCollector(model)
742+
743+
def forward_loop(m):
744+
m(torch.randn(2, 8))
745+
746+
collector._patch_all_layers()
747+
try:
748+
inputs = collector.get_input_activations(model.layers[0], forward_loop)
749+
finally:
750+
collector._unpatch_all_layers()
751+
752+
args, _ = inputs[0]
753+
assert args[0].device.type == "cpu", "captured tensor must be CPU-resident"
754+
755+
756+
def test_detect_resume_point_returns_num_layers_when_complete(tmp_path):
757+
"""Completed checkpoint reports ``start = num_layers`` (not None)."""
758+
ckpt_dir = str(tmp_path / "ckpt")
759+
state = _CheckpointState(ckpt_dir, num_layers=3)
760+
import os
761+
762+
os.makedirs(ckpt_dir, exist_ok=True)
763+
from modelopt.torch.quantization.utils.layerwise_calib import _write_manifest
764+
765+
_write_manifest(ckpt_dir, last_completed_layer=2, num_layers=3)
766+
767+
result = detect_resume_point(ckpt_dir)
768+
assert result is not None
769+
start, _ = result
770+
assert start == state.num_layers == 3
771+
772+
773+
def test_layerwise_calibrate_early_returns_on_completed_checkpoint(monkeypatch, tmp_path):
774+
"""Fully-completed checkpoint must short-circuit calibration: no forward_loop calls."""
775+
_register_test_discoverer(monkeypatch)
776+
torch.manual_seed(0)
777+
778+
# Set up a model and run one round of layerwise calibration to write a complete checkpoint.
779+
model = _SimpleTransformerModel(n_layers=2, dim=16)
780+
calib_data = [torch.randint(0, 32, (2, 8))]
781+
ckpt_dir = str(tmp_path / "ckpt")
782+
783+
config = _int8_layerwise_config(
784+
{"method": "max", "layerwise": True, "layerwise_checkpoint_dir": ckpt_dir}
785+
)
786+
mtq.quantize(model, config, forward_loop=lambda m: [m(b) for b in calib_data])
787+
788+
# Second invocation against the same dir should never call forward_loop again.
789+
fresh = _SimpleTransformerModel(n_layers=2, dim=16)
790+
config2 = _int8_layerwise_config(
791+
{"method": "max", "layerwise": True, "layerwise_checkpoint_dir": ckpt_dir}
792+
)
793+
794+
call_count = {"n": 0}
795+
796+
def counting_forward(m):
797+
call_count["n"] += 1
798+
m(calib_data[0])
799+
800+
mtq.quantize(fresh, config2, forward_loop=counting_forward)
801+
assert call_count["n"] == 0, "completed checkpoint must skip the calibration forward loop"
802+
803+
804+
def test_layerwise_calibrate_resumes_from_partial_checkpoint(monkeypatch, tmp_path):
805+
"""Partial checkpoint runs only the remaining layers."""
806+
_register_test_discoverer(monkeypatch)
807+
808+
# Hand-write a manifest claiming layer 0 of 2 is complete, but with a dummy
809+
# layer_0000 directory that won't actually load. The test only checks that
810+
# detect_resume_point returns start=1 (not None) — verifying the partial-resume
811+
# branch and the "all done" branch are distinct.
812+
import os
813+
814+
ckpt_dir = str(tmp_path / "ckpt")
815+
os.makedirs(ckpt_dir, exist_ok=True)
816+
from modelopt.torch.quantization.utils.layerwise_calib import _write_manifest
817+
818+
_write_manifest(ckpt_dir, last_completed_layer=0, num_layers=2)
819+
820+
result = detect_resume_point(ckpt_dir)
821+
assert result == (1, {"last_completed_layer": 0, "num_layers": 2})

0 commit comments

Comments
 (0)