Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,11 @@ def layerwise_calibrate(
ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers)
start_layer = ckpt.start_layer if ckpt else 0

if ckpt and start_layer >= num_layers:
ckpt.full_restore(transformer_layers, model)
print_rank_0("Layerwise calibration completed (restored from checkpoint)")
return

input_getter = LayerActivationCollector(model)
input_getter._patch_all_layers(decoder_layers=transformer_layers)

Expand Down
41 changes: 25 additions & 16 deletions modelopt/torch/quantization/utils/layerwise_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,21 @@ def _patched_forward(self, *args, **kwargs):
f"Layer {info.name} is in 'run' mode but has no cached inputs to replay."
)
real_args, real_kwargs = info.cached_inputs.popleft()
# Move CPU-resident captured inputs back to the layer's device for replay.
device = get_module_device(self)
real_args = _move_to_device(real_args, device)
real_kwargs = _move_to_device(real_kwargs, device)
output = self._original_forward(*real_args, **real_kwargs)
info.output_meta = LayerActivationCollector._extract_output_meta(output)
return output

if info.mode == "capture":
info.collected_inputs.append((args, kwargs))
# Offload to CPU so the per-layer compute device doesn't OOM
# while accumulating thousands of batches; "run" moves back.
cpu = torch.device("cpu")
info.collected_inputs.append(
(_move_to_device(args, cpu), _move_to_device(kwargs, cpu))
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Captured inputs are now CPU-resident, but _patched_forward only moves them back to the layer's device in the "run" branch. The actual calibration forward in layerwise_calibrate is _layer_forward_loop, which calls m(*args, **kwargs_input) while the layer is in "original" mode — no device move happens there. With HF accelerate hooks the per-module pre_forward silently re-bounces CPU args onto the right GPU, which is presumably why the GLM5.1 e2e run was clean. On a non-accelerate single-GPU model (params on CUDA, args on CPU) _original_forward will hit a device mismatch.

Consider either (a) moving tensors back to get_module_device(layer) inside _layer_forward_loop mirroring the run-mode code on line ~232, or (b) adding an explicit comment + assertion that this path requires accelerate-style hooks, so the assumption is at least documented.

raise _EarlyStopForwardError()
Comment on lines +230 to 245

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Performance] The capture and run modes now do a synchronous host↔device transfer per batch, per layer. With N batches × L layers, that's 2·N·L blocking copies. Two cheap improvements:

  1. Pin the CPU buffer at capture time so the next .to(device) can be a true async DMA: t.pin_memory() on the CPU result (or call .to('cpu', non_blocking=False) and then .pin_memory()).
  2. Use non_blocking=True on the CPU→GPU side in the run branch so the copy can overlap with the previous batch's compute.

_move_to_device could grow a non_blocking flag and a pin-memory option (or a separate _offload_to_cpu helper), and the call site here would set it. Not blocking — the OOM fix is what matters — but worth doing if calibration time is a concern on the larger MoE models.


return self._original_forward(*args, **kwargs)
Expand Down Expand Up @@ -433,6 +442,8 @@ def cache_outputs_for_next_layer_calib(

next_layer = self._decoder_layers[next_idx]
with persistent_materialization(layer):
# Free cached-but-unused GPU memory left over from the previous layer's calibration.
torch.cuda.empty_cache()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.empty_cache() usually does not make your own PyTorch program able to use more memory, because PyTorch could already reuse that cached memory internally.

you could trigger a manual garbage collection to delete any unused tensors

Suggested change
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

return self.get_input_activations(next_layer, forward_loop)


Expand Down Expand Up @@ -512,10 +523,9 @@ def _save_layer(


def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None:
"""Detect where to resume from an existing checkpoint directory.
"""Return ``(start_layer, manifest)`` for an existing checkpoint, else ``None``.

Returns ``(start_layer, manifest)`` if there is work to resume,
or ``None`` if the directory is empty, corrupt, or calibration was already complete.
``start_layer == num_layers`` signals a fully-completed run.
"""
manifest = _read_manifest(checkpoint_dir)
if manifest is None:
Expand All @@ -524,8 +534,6 @@ def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None:
total = manifest.get("num_layers")
if last is None or total is None:
return None
if last + 1 >= total:
return None
return (last + 1, manifest)
Comment on lines 525 to 537

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Now that last + 1 == num_layers is a valid (completed) state, the only remaining defensive guard is malformed values. If last_completed_layer is ever corrupt (e.g., negative or > num_layers), the manifest passes from_folder's num_layers-match check and the early-return path (model_calib.py:1758) calls full_restore, which iterates range(start_layer) and trips a FileNotFoundError on a missing layer_XXXX/ directory — confusing failure mode for the user.

A two-line bound check here would fail fast with a clear "checkpoint corrupt" message:

if not isinstance(last, int) or not isinstance(total, int):
    return None
if last < -1 or last >= total:
    return None
return (last + 1, manifest)

(CodeRabbit raised the same point — flagging here because with this PR's "completed" semantics it becomes more relevant: prior to this PR, the last + 1 >= total short-circuit hid the malformed case behind "no resume"; now it routes into restore.)



Expand Down Expand Up @@ -568,7 +576,9 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint
f"but model has {num_layers}. Use a fresh checkpoint directory."
)
start = info[0] if info else 0
if start > 0:
if start >= num_layers:
print_rank_0(f"Checkpoint: all {num_layers} layers already calibrated")
elif start > 0:
print_rank_0(
f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}"
)
Expand Down Expand Up @@ -601,8 +611,7 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None:
raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}")
# weights_only=False is safe: file is internally generated by _save_layer, not user-supplied
next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False)
resume_device = get_module_device(layers[self.start_layer])
next_inputs = _move_to_device(next_inputs, resume_device)
# Keep on CPU — _patched_forward's run mode moves each entry to device on pop.
return next_inputs

def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None:
Expand All @@ -620,23 +629,23 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None:
layer = layers[i]
d = _layer_dir(self.checkpoint_dir, i)

# Resolve layer_device and load inside the context so params are
# materialized — otherwise get_module_device can return meta.
# Load inside the context so params are materialized — otherwise
# get_module_device can return meta.
with enable_weight_access_and_writeback(layer, model, name_to_module):
layer_device = get_module_device(layer)
# weights_only=False is safe: files are internally generated by _save_layer
# Load to CPU to avoid serialized-view storage_offset hazards on later clone/deepcopy.
# weights_only=False is safe: files are internally generated by _save_layer.
qstate = torch.load(
os.path.join(d, "quantizer_state.pt"),
map_location=layer_device,
map_location="cpu",
weights_only=False,
)
weights = torch.load(
os.path.join(d, "weights.pt"),
map_location=layer_device,
map_location="cpu",
weights_only=False,
)
restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate})
layer.load_state_dict(weights, strict=False, assign=True)
layer.load_state_dict(weights, strict=False, assign=False)
Comment thread
realAsma marked this conversation as resolved.

print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")

Expand Down
53 changes: 53 additions & 0 deletions tests/unit/torch/quantization/test_layerwise_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,56 @@ def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm():
config,
forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))),
)


# Checkpoint resume + capture-time CPU offload


def test_collected_inputs_are_cpu_at_capture(monkeypatch):
"""Captured inputs must be CPU-resident — the OOM-prevention invariant."""
_register_test_discoverer(monkeypatch)
model = _SimpleTwoLayerModel(dim=8)
collector = LayerActivationCollector(model)

def forward_loop(m):
m(torch.randn(2, 8))

collector._patch_all_layers()
try:
inputs = collector.get_input_activations(model.layers[0], forward_loop)
finally:
collector._unpatch_all_layers()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

This asserts the capture-side invariant but doesn't exercise the calibration replay. The OOM fix's correctness for the calib_func path (CPU-stashed inputs eventually reaching a GPU layer's forward) isn't covered — both layerwise e2e tests run on CPU-only models so a CPU-arg / GPU-weight mismatch would never surface. A small GPU smoke test (or even a CPU test that asserts inputs are device-correct at the moment _layer_forward_loop invokes the layer) would lock in the invariant the PR description claims.

args, _ = inputs[0]
assert args[0].device.type == "cpu", "captured tensor must be CPU-resident"


def test_layerwise_calibrate_early_returns_on_completed_checkpoint(monkeypatch, tmp_path):
"""Fully-completed checkpoint must short-circuit calibration: no forward_loop calls.

Indirectly covers ``detect_resume_point`` returning ``(num_layers, manifest)``
for a completed run — if it returned ``None``, the loop would re-run and
forward_loop would be invoked.
"""
_register_test_discoverer(monkeypatch)
torch.manual_seed(0)
ckpt_dir = str(tmp_path / "ckpt")
config = _int8_layerwise_config(
{"method": "max", "layerwise": True, "layerwise_checkpoint_dir": ckpt_dir}
)
calib_data = [torch.randint(0, 32, (2, 8))]

# First run writes a complete checkpoint.
model = _SimpleTransformerModel(n_layers=2, dim=16)
mtq.quantize(model, config, forward_loop=lambda m: [m(b) for b in calib_data])

# Second run against the same dir must skip the calibration forward loop.
call_count = {"n": 0}

def counting_forward(m):
call_count["n"] += 1
m(calib_data[0])

fresh = _SimpleTransformerModel(n_layers=2, dim=16)
mtq.quantize(fresh, config, forward_loop=counting_forward)
assert call_count["n"] == 0
Loading