Skip to content

Commit 43d1888

Browse files
realAsmaclaude
andcommitted
Address PR review feedback for layerwise calibration
- Add inline security comments for all torch.load(weights_only=False) calls - Replace bare assert with RuntimeError for unsupported offload hook layout - Write back buffers (not just parameters) in _writeback_params_to_weights_map - Add cross-field validator rejecting layerwise_checkpoint_dir without layerwise=True - Validate num_layers mismatch on checkpoint resume - Handle integer device ordinals in _get_execution_device_from_hook - Clean up stale layer artifacts in partial-checkpoint tests - Guard non-dict algorithm values in needs_checkpoint_path_update - Add comment explaining dummy output_meta for last layer Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 6f63f44 commit 43d1888

6 files changed

Lines changed: 63 additions & 18 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
860860
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
861861
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
862862
algorithm = quant_cfg.get("algorithm")
863-
if algorithm is None or isinstance(algorithm, str):
863+
if not isinstance(algorithm, dict):
864864
return False
865865
return algorithm.get("layerwise_checkpoint_dir") is not None
866866

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
12371237
),
12381238
)
12391239

1240+
@model_validator(mode="after")
1241+
def validate_layerwise_checkpoint_dir(self):
1242+
"""Raise if layerwise_checkpoint_dir is set but layerwise is False."""
1243+
if self.layerwise_checkpoint_dir is not None and not self.layerwise:
1244+
raise ValueError(
1245+
"layerwise_checkpoint_dir requires layerwise=True. "
1246+
"Set layerwise=True or remove layerwise_checkpoint_dir."
1247+
)
1248+
return self
1249+
12401250

12411251
class MaxCalibConfig(QuantizeAlgorithmConfig):
12421252
"""The config for max calibration algorithm.

modelopt/torch/quantization/plugins/accelerate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def _get_cpu_offload_hook(hook):
5050

5151

5252
def _writeback_params_to_weights_map(module, align_hook):
53-
"""Write all non-meta parameters back to the hook's CPU weights_map."""
54-
for name, param in module.named_parameters():
55-
if param.device.type == "meta":
53+
"""Write all non-meta parameters and buffers back to the hook's CPU weights_map."""
54+
for name, tensor in module.state_dict(keep_vars=True).items():
55+
if tensor.device.type == "meta":
5656
continue
5757
if isinstance(align_hook.weights_map, PrefixedDataset):
5858
key = align_hook.weights_map.prefix + name
@@ -61,7 +61,7 @@ def _writeback_params_to_weights_map(module, align_hook):
6161
w_map = align_hook.weights_map
6262
key = name
6363
if key in w_map:
64-
w_map[key] = param.data.to(w_map[key].device, dtype=w_map[key].dtype)
64+
w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype)
6565

6666

6767
@contextmanager
@@ -85,14 +85,15 @@ def weight_access_and_writeback_context(module):
8585
# Guard: the sub-module branch below is not reached when the parent has
8686
# an offload hook. Assert that no children also carry offload hooks,
8787
# which would require a combined writeback strategy.
88-
assert not any(
88+
if any(
8989
_get_cpu_offload_hook(mod._hf_hook)
9090
for mod in module.modules()
9191
if mod is not module and hasattr(mod, "_hf_hook")
92-
), (
93-
"Both the module and one of its sub-modules have CPU-offload hooks. "
94-
"weight_access_and_writeback_context does not support this layout yet."
95-
)
92+
):
93+
raise RuntimeError(
94+
"Both the module and one of its sub-modules have CPU-offload hooks. "
95+
"weight_access_and_writeback_context does not support this layout yet."
96+
)
9697
align_hook.pre_forward(module)
9798
align_hook.offload = False
9899
try:

modelopt/torch/quantization/utils/layerwise_calib.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,13 @@ def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _Checkpoint
555555
return None
556556
os.makedirs(checkpoint_dir, exist_ok=True)
557557
info = detect_resume_point(checkpoint_dir)
558+
if info is not None:
559+
manifest_num_layers = info[1].get("num_layers")
560+
if manifest_num_layers is not None and manifest_num_layers != num_layers:
561+
raise ValueError(
562+
f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} "
563+
f"but model has {num_layers}. Use a fresh checkpoint directory."
564+
)
558565
start = info[0] if info else 0
559566
if start > 0:
560567
print_rank_0(
@@ -575,6 +582,7 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None:
575582

576583
for i in range(self.start_layer):
577584
d = _layer_dir(self.checkpoint_dir, i)
585+
# weights_only=False is safe: file is internally generated by _save_layer, not user-supplied
578586
meta = torch.load(
579587
os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False
580588
)
@@ -586,6 +594,7 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None:
586594
next_inputs_path = os.path.join(d, "next_inputs.pt")
587595
if not os.path.isfile(next_inputs_path):
588596
raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}")
597+
# weights_only=False is safe: file is internally generated by _save_layer, not user-supplied
589598
next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False)
590599
resume_device = get_module_device(layers[self.start_layer])
591600
next_inputs = _move_to_device(next_inputs, resume_device)
@@ -610,14 +619,20 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None:
610619
# Restore quantizer state first: may promote TensorQuantizer to
611620
# NVFP4StaticQuantizer, changing module structure that load_state_dict
612621
# expects.
613-
qstate = torch.load(os.path.join(d, "quantizer_state.pt"), map_location=layer_device)
622+
# weights_only=False is safe: file is internally generated by _save_layer, not user-supplied
623+
qstate = torch.load(
624+
os.path.join(d, "quantizer_state.pt"), map_location=layer_device, weights_only=False
625+
)
614626
restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate})
615627

616628
# Load weights inside the framework's access context so that
617629
# managed-weight frameworks (accelerate CPU offload, FSDP2) sync
618630
# their internal state with the restored parameters.
619631
with enable_weight_access_and_writeback(layer, model, name_to_module):
620-
weights = torch.load(os.path.join(d, "weights.pt"), map_location=layer_device)
632+
# weights_only=False is safe: file is internally generated by _save_layer, not user-supplied
633+
weights = torch.load(
634+
os.path.join(d, "weights.pt"), map_location=layer_device, weights_only=False
635+
)
621636
layer.load_state_dict(weights, strict=False)
622637

623638
print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")
@@ -649,6 +664,8 @@ def save(
649664

650665
output_meta = getattr(layer._layerwise_calib, "output_meta", None)
651666
if output_meta is None:
667+
# Placeholder for the last layer: output_meta is never used for skip mode
668+
# since there is no subsequent layer that needs a correctly shaped dummy output.
652669
output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1))
653670

654671
_save_layer(

modelopt/torch/utils/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ def _get_execution_device_from_hook(module: nn.Module) -> torch.device | None:
103103

104104
dev = getattr(hook, "execution_device", None)
105105
if dev is not None:
106-
return torch.device(dev)
106+
return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev)
107107

108108
for h in getattr(hook, "hooks", ()):
109109
dev = getattr(h, "execution_device", None)
110110
if dev is not None:
111-
return torch.device(dev)
111+
return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev)
112112

113113
return None
114114

tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import copy
1717
import json
1818
import os
19+
import shutil
1920

2021
import pytest
2122
import torch
@@ -29,6 +30,7 @@
2930
enable_weight_access_and_writeback,
3031
is_quantized_linear,
3132
)
33+
from modelopt.torch.quantization.utils.layerwise_calib import _layer_dir
3234

3335

3436
@pytest.mark.parametrize(
@@ -204,10 +206,15 @@ def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path, quant_cfg):
204206
mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs))
205207
output_ref = model_ref(inputs)
206208

207-
# Simulate crash after layer 0 by truncating the manifest
209+
# Simulate crash after layer 0 by truncating the manifest and removing later layers
210+
last_completed_layer = 0
208211
manifest_path = os.path.join(ckpt_dir, "manifest.json")
209212
with open(manifest_path, "w") as f:
210-
json.dump({"last_completed_layer": 0, "num_layers": num_layers}, f)
213+
json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f)
214+
for i in range(last_completed_layer + 1, num_layers):
215+
d = _layer_dir(ckpt_dir, i)
216+
if os.path.isdir(d):
217+
shutil.rmtree(d)
211218

212219
# Resume from a fresh CPU-offloaded model
213220
with init_empty_weights():
@@ -257,9 +264,14 @@ def _make_multi_offload_model():
257264
output_ref = model_ref(inputs)
258265

259266
# Simulate crash after layer 0
267+
last_completed_layer = 0
260268
manifest_path = os.path.join(ckpt_dir, "manifest.json")
261269
with open(manifest_path, "w") as f:
262-
json.dump({"last_completed_layer": 0, "num_layers": num_layers}, f)
270+
json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f)
271+
for i in range(last_completed_layer + 1, num_layers):
272+
d = _layer_dir(ckpt_dir, i)
273+
if os.path.isdir(d):
274+
shutil.rmtree(d)
263275

264276
# Resume from fresh model with same offload layout
265277
model_resumed = _make_multi_offload_model()
@@ -346,9 +358,14 @@ def test_sequential_gptq_checkpoint_resume_cpu_offloaded(tmp_path):
346358
output_ref = model_ref(inputs)
347359

348360
# Simulate crash after layer 0
361+
last_completed_layer = 0
349362
manifest_path = os.path.join(ckpt_dir, "manifest.json")
350363
with open(manifest_path, "w") as f:
351-
json.dump({"last_completed_layer": 0, "num_layers": num_layers}, f)
364+
json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f)
365+
for i in range(last_completed_layer + 1, num_layers):
366+
d = _layer_dir(ckpt_dir, i)
367+
if os.path.isdir(d):
368+
shutil.rmtree(d)
352369

353370
# Resume from fresh CPU-offloaded model
354371
with init_empty_weights():

0 commit comments

Comments
 (0)