Skip to content

Add layerwise calibration for large models#1251

Open
realAsma wants to merge 2 commits intomainfrom
asma/ptq-large-models
Open

Add layerwise calibration for large models#1251
realAsma wants to merge 2 commits intomainfrom
asma/ptq-large-models

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Apr 13, 2026

Summary

Adds layerwise (layer-by-layer) calibration for quantizing large models that don't fit entirely on GPU. The key changes:

  1. Rename sequential_calibratelayerwise_calibrate (use_sequentialuse_layerwise, _seq_calib_layerwise_calib) to better describe the algorithm.

  2. Performant layerwise calibration with skip/run/capture state machine: Each decoder layer is patched with a unified forward whose behavior is governed by a per-layer state:

    • skip — parameter-free _SkipLayer dummy replaces fully-calibrated layers so framework hooks (accelerate CPU-offload, FSDP2) skip materialization entirely
    • run — replays captured inputs through the just-calibrated layer with updated weights
    • capture — records (args, kwargs) and raises _EarlyStopForwardError to halt the forward pass early
    • persistent_materialization keeps the active layer on GPU for the entire calibration step
  3. Checkpoint save/resume — calibration of large models (hours-long for 100+ layer MoE models) can be interrupted and restarted from the last completed layer:

    • Save: after each layer is calibrated, its state_dict, quantizer state, output metadata, and next-layer inputs are saved to checkpoint_dir. Only one layer's checkpoint is written at a time (minimal I/O).
    • Resume: on restart, the checkpoint manifest is read to find the last completed layer. Layers 0..K-1 are set to skip mode using saved output_meta. Only their quantizer state and weights are restored once at the end of the calibration loop (not during), keeping the hot path fast.
    • This design means save is per-layer but restore is bulk — minimal overhead during calibration, full state recovery only when needed.

Example: NVFP4+GPTQ layerwise calibration on Nemotron-Nano-30B (52 layers, single NVIDIA RTX 6000 Ada 49GB — requires CPU offloading)

The model does not fit in GPU memory, so accelerate automatically offloads layers to CPU:

Initializing model from nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
Model does not fit to the GPU mem. We apply the following memory limit for calibration:
{0: 40327184384.0, 'cpu': 514466635776}
Some parameters are on the meta device because they were offloaded to the cpu.
Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM

Initial run (killed after layer 7):

Layerwise calibration: Found 52 transformer layers
Calibrating layer 1/52 | capture: [1]
Computing Hessians for 2 linear layers...
GPTQ time: 12.99s
Calibrating layer 2/52 | run: [1] | capture: [2]
Checkpoint: saved layer 0
Computing Hessians for 2 linear layers...
GPTQ time: 4.92s
Calibrating layer 3/52 | skip: 1 | run: [2] | capture: [3]
Checkpoint: saved layer 1
...
Calibrating layer 7/52 | skip: 5 | run: [6] | capture: [7]
Checkpoint: saved layer 5
Computing Hessians for 2 linear layers...
GPTQ time: 4.11s
Calibrating layer 8/52 | skip: 6 | run: [7] | capture: [8]
<killed>

Resumed run (picks up from layer 7, finishes all 52):

Layerwise calibration: Found 52 transformer layers
Checkpoint: resuming layerwise calibration from layer 6/52
Calibrating layer 7 (resumed)
Computing Hessians for 2 linear layers...
GPTQ time: 5.98s
Calibrating layer 8/52 | skip: 6 | run: [7] | capture: [8]
Checkpoint: saved layer 6
...
Calibrating layer 52/52 | skip: 50 | run: [51] | capture: [52]
Checkpoint: saved layer 50
GPTQ time: 4.78s
Checkpoint: saved layer 51 (final)
Checkpoint: restored 6 previously calibrated layers
Layerwise calibration completed
Quantized model exported to: output/nemotron_nano_30b_nvfp4_gptq_seq
Total time used 179.10s
GPU 0: Peak memory usage = 20.42 GB

TODO

  • Update CHANGELOG

Test plan

  • tests/unit/torch/quantization/test_layerwise_calibrate.py — unit tests for skip/swap/restore
  • tests/unit/torch/quantization/test_sequential_checkpoint.py — checkpoint save/resume correctness
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py — CPU-offloaded layerwise + GPTQ + checkpoint resume
  • tests/gpu/torch/quantization/test_fsdp2.py — FSDP2 layerwise calibration

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 13, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1251/

Built to branch gh-pages at 2026-04-15 22:59 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Rename calibration mode flag use_sequentialuse_layerwise, add optional checkpoint_dir, replace sequential calibration with a new layerwise calibrator (with per-layer checkpoints/resume), introduce a new layerwise activation collector, update accelerate/FSDP/device helpers, and add extensive tests and example helpers.

Changes

Cohort / File(s) Summary
Config
modelopt/torch/quantization/config.py
Renamed use_sequentialuse_layerwise on QuantizeAlgorithmConfig; added optional `checkpoint_dir: str
Calibration entrypoint & mode
modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/mode.py
Replaced sequential_calibrate with layerwise_calibrate; wrapper now respects use_layerwise and checkpoint_dir, routes to layerwise_calibrate, and forwards checkpointing kwargs.
Layerwise implementation (new)
modelopt/torch/quantization/utils/layerwise_calib.py
New module implementing LayerActivationCollector, per-layer modes (capture/run/skip/original), persistent-materialization helpers, checkpoint manifest and per-layer snapshot/save/restore, and resume detection.
Removed legacy collector
modelopt/torch/quantization/utils/activation_collector.py
Deleted the old sequential LayerActivationCollector implementation.
Utils exports & imports
modelopt/torch/quantization/utils/__init__.py, modelopt/torch/quantization/plugins/huggingface.py, tests...
Switched imports to layerwise_calib collector across utils, plugins, and tests.
Accelerate CPU-offload integration
modelopt/torch/quantization/plugins/accelerate.py
Relaxed weights_map validation, added _writeback_params_to_weights_map, and reworked weight_access_and_writeback_context to handle single-module and child-hook layouts with multi-param writeback and correct pre/post hooks.
FSDP2 / core utils
modelopt/torch/quantization/utils/core_utils.py, modelopt/torch/utils/network.py
Added _set_parameter, persistent_materialization, _disable_fsdp_unshard_reshard; generalized FSDP2 parameter access/writeback across all named parameters; get_module_device now considers accelerate hook execution_device.
Hessian/device tweak
modelopt/torch/quantization/utils/calib_utils.py
Force Hessian allocation on CPU when module weight device is meta.
Examples & CLI helpers
examples/llm_ptq/hf_ptq.py, examples/llm_ptq/example_utils.py
Add needs_checkpoint_path_update and resolve_checkpoint_dir; normalize KV cfg; auto-resolve and print checkpoint dir before quantization when applicable.
Dataset loop tweak
modelopt/torch/utils/dataset_utils.py
Temporarily disable model.config.use_cache during _forward_loop and restore it afterwards.
Tests
tests/...
Extensive test additions and updates: replace sequential→layerwise in tests, add many layerwise/checkpoint/resume/FSDP/accelerate integration tests, and get_module_device unit tests.

Sequence Diagram(s)

sequenceDiagram
  participant Entrypoint as Calibration Entrypoint
  participant Model as Model
  participant Collector as LayerActivationCollector
  participant Forward as ForwardLoop
  participant Checkpoint as CheckpointStore
  participant GPTQ as GPTQ Updater

  Entrypoint->>Collector: attach/discover layers
  Entrypoint->>Checkpoint: detect_resume_point(checkpoint_dir)
  alt resume available
    Checkpoint-->>Collector: restore output_meta + next_inputs
  end
  loop for layer in start_layer..N
    Entrypoint->>Collector: set mode -> capture(layer)
    Entrypoint->>Forward: run forward (captures inputs / EarlyStop)
    Collector-->>Entrypoint: captured inputs
    Entrypoint->>Collector: set mode -> run(layer)
    Entrypoint->>Forward: replay captured inputs -> outputs
    Entrypoint->>Checkpoint: save(layer_weights, quantizer_state, output_meta, next_inputs)
    alt GPTQ enabled
      Entrypoint->>GPTQ: update_weights_for_layer(...)
    end
  end
  Entrypoint->>Checkpoint: full_restore(all_layers)
  Entrypoint->>Collector: unpatch and cleanup
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error Four torch.load() calls in layerwise_calib.py lack required inline security justification comments per SECURITY.md requirements. Add inline security comments to torch.load() calls at lines 578, 589, 613, 620 explaining files are internally-generated and trusted.
Docstring Coverage ⚠️ Warning Docstring coverage is 60.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add layerwise calibration for large models' clearly summarizes the main change, referring to the rename of sequential_calibrate to layerwise_calibrate and the addition of checkpoint save/resume support for large model calibration.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/ptq-large-models

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma force-pushed the asma/ptq-large-models branch 2 times, most recently from 8eabe76 to 6ec3721 Compare April 14, 2026 16:49
Comment thread modelopt/torch/quantization/plugins/accelerate.py
Comment thread modelopt/torch/quantization/plugins/accelerate.py
Comment thread modelopt/torch/quantization/utils/activation_collector.py
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 14, 2026

Codecov Report

❌ Patch coverage is 80.41237% with 95 lines in your changes missing coverage. Please review.
✅ Project coverage is 66.75%. Comparing base (361f7e3) to head (d14ccbb).

Files with missing lines Patch % Lines
modelopt/torch/quantization/plugins/accelerate.py 2.63% 37 Missing ⚠️
modelopt/torch/quantization/utils/core_utils.py 23.07% 30 Missing ⚠️
...delopt/torch/quantization/utils/layerwise_calib.py 93.07% 23 Missing ⚠️
modelopt/torch/quantization/model_calib.py 88.46% 3 Missing ⚠️
modelopt/torch/quantization/mode.py 75.00% 1 Missing ⚠️
modelopt/torch/quantization/utils/calib_utils.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1251       +/-   ##
===========================================
+ Coverage   55.67%   66.75%   +11.08%     
===========================================
  Files         458      459        +1     
  Lines       48464    48808      +344     
===========================================
+ Hits        26982    32583     +5601     
+ Misses      21482    16225     -5257     
Flag Coverage Δ
examples 41.25% <22.26%> (+17.14%) ⬆️
gpu 28.17% <19.58%> (+7.78%) ⬆️
unit 52.17% <77.73%> (+0.13%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/quantization/utils/activation_collector.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/utils/network.py Outdated
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 6ec3721 to 8af3655 Compare April 14, 2026 18:48
Comment thread modelopt/torch/utils/network.py Outdated
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 8af3655 to 6280846 Compare April 14, 2026 19:21
Comment thread tests/unit/torch/quantization/test_sequential_calibrate.py Outdated
@realAsma realAsma marked this pull request as ready for review April 14, 2026 19:25
@realAsma realAsma requested review from a team as code owners April 14, 2026 19:25
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a substantial PR (~1500 lines) that adds checkpoint save/resume for sequential calibration, extends support to FSDP2 and accelerate-offloaded models, and renames activation_collector.pylayerwise_calib.py. The changes are cohesive and well-tested (unit + GPU tests for checkpoint, resume, offload, FSDP2 scenarios).

Key issues found:

  1. Removed guard on sequential calibration methods — The assertion restricting sequential calibration to max and gptq was removed without replacement. Methods like awq, smoothquant, and svdquant operate on the full model (not per-layer) and will break silently or produce incorrect results when used with use_sequential=True.

  2. weights_only=False security concerntorch.load(..., weights_only=False) is used for loading checkpoints, which can execute arbitrary code. While the checkpoints are locally generated, this is flagged by security scanners and should use weights_only=True where possible.

Minor observations:

  • PR size is above ~1000 lines but the changes are cohesive and hard to split
  • Good test coverage for the new functionality
  • The temporarily_remove_accelerate_hook rewrite is a nice improvement avoiding the init_hook pitfall
  • _writeback_params_to_weights_map properly handles all parameters (not just weight)
  • FSDP2 context manager correctly generalized to handle all DTensor parameters

Comment thread modelopt/torch/quantization/mode.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/plugins/accelerate.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)

1566-1632: ⚠️ Potential issue | 🔴 Critical

Add inline comments to torch.load(..., weights_only=False) calls in layerwise_calib.py.

Per SECURITY.md and the coding guidelines, torch.load(..., weights_only=False) must include an inline comment documenting why the file is internally-generated/trusted and safe to deserialize. Lines 545 and 555 in modelopt/torch/quantization/utils/layerwise_calib.py need this justification:

  • Line 545: Loading output_meta.pt
  • Line 555: Loading next_inputs.pt

Add a comment before each call explaining these checkpoint files are generated and managed internally by the sequential calibration process, confirming they are trusted sources.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 1566 - 1632, Add
inline comments immediately before the two torch.load(..., weights_only=False)
calls in modelopt.torch.quantization.utils.layerwise_calib (look around the
_CheckpointState usage and the methods that load "output_meta.pt" and
"next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and
"next_inputs.pt") are generated and managed internally by the sequential
calibration process, are not user-supplied, and therefore are trusted for safe
deserialization; locate the calls near methods that restore checkpoint state
(e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load
calls inside setup_resume/save) and add the short justification comment directly
above each torch.load call.
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

591-630: LGTM!

The save method correctly:

  1. Uses enable_weight_access_and_writeback context for managed-weight frameworks
  2. Moves all data to CPU before storage
  3. Has a defensive fallback for missing output_meta (line 617-618)

The fallback creates dummy metadata if output_meta is None, which could mask state-machine bugs. Consider logging a warning in this case.

Optional: Add warning for missing output_meta
         output_meta = getattr(layer._seq_calib, "output_meta", None)
         if output_meta is None:
+            print_rank_0(
+                f"Warning: layer {layer_idx} has no output_meta; using fallback. "
+                "This may indicate the layer was not run in 'run' mode."
+            )
             output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 591 - 630,
In save (method save in layerwise_calib.py) add a warning when output_meta is
missing before calling LayerActivationCollector._extract_output_meta: detect if
getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g.,
logger = logging.getLogger(__name__); logger.warning(...)) that includes
layer_idx and the layer identifier and states that dummy metadata is being
created, then proceed to call LayerActivationCollector._extract_output_meta;
this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Line 555: Add an inline comment next to the torch.load call that sets
weights_only=False (the line loading next_inputs from next_inputs_path)
explaining that this file is produced internally by _save_layer and therefore
may contain non-tensor objects from the model's forward pass which require
pickle; explicitly state that the file source is trusted and why using
weights_only=False is safe in this context to satisfy the security guideline.
- Around line 544-546: Add an inline comment immediately above the
torch.load(...) call that sets weights_only=False (the line assigning meta =
torch.load(...)) explaining that this is safe because output_meta.pt is produced
internally by this module's _save_layer function (so it is not user-supplied and
controlled), that the file may contain arbitrary Python objects under the
("other", output) metadata path and therefore requires pickle deserialization,
and that this trusted-origin justification satisfies the SECURITY.md requirement
for using weights_only=False.

---

Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1566-1632: Add inline comments immediately before the two
torch.load(..., weights_only=False) calls in
modelopt.torch.quantization.utils.layerwise_calib (look around the
_CheckpointState usage and the methods that load "output_meta.pt" and
"next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and
"next_inputs.pt") are generated and managed internally by the sequential
calibration process, are not user-supplied, and therefore are trusted for safe
deserialization; locate the calls near methods that restore checkpoint state
(e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load
calls inside setup_resume/save) and add the short justification comment directly
above each torch.load call.

---

Nitpick comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 591-630: In save (method save in layerwise_calib.py) add a warning
when output_meta is missing before calling
LayerActivationCollector._extract_output_meta: detect if
getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g.,
logger = logging.getLogger(__name__); logger.warning(...)) that includes
layer_idx and the layer identifier and states that dummy metadata is being
created, then proceed to call LayerActivationCollector._extract_output_meta;
this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 6e6e083b-89b1-4ebe-9b11-a051411fcf87

📥 Commits

Reviewing files that changed from the base of the PR and between b6c6ec3 and 6280846.

📒 Files selected for processing (19)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/accelerate.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/activation_collector.py
  • modelopt/torch/quantization/utils/calib_utils.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
  • modelopt/torch/utils/network.py
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
  • tests/gpu/torch/quantization/test_fsdp2.py
  • tests/gpu/torch/quantization/test_sequential_calibrate.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • tests/unit/torch/quantization/test_calib.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
  • tests/unit/torch/quantization/test_sequential_checkpoint.py
  • tests/unit/torch/quantization/test_utils.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/utils/activation_collector.py

Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_sequential_calibrate.py (1)

585-590: Optional: guarantee cleanup with try/finally in restore test.

Use the same cleanup pattern as other tests so unpatch always runs if collection fails mid-test.

♻️ Suggested change
     collector = LayerActivationCollector(model)
     collector._patch_all_layers()
-    for layer in originals:
-        collector.get_input_activations(layer, forward_loop)
-    collector._unpatch_all_layers()
+    try:
+        for layer in originals:
+            collector.get_input_activations(layer, forward_loop)
+    finally:
+        collector._unpatch_all_layers()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_sequential_calibrate.py` around lines 585
- 590, The test currently calls collector._patch_all_layers(), runs collection,
and then calls collector._unpatch_all_layers() but does not guarantee cleanup if
collection fails; wrap the collection calls in a try/finally so that
_unpatch_all_layers() is always invoked even on exceptions. Specifically, after
calling LayerActivationCollector(model) and collector._patch_all_layers(),
perform the loop that calls collector.get_input_activations(layer, forward_loop)
over originals inside a try block and call collector._unpatch_all_layers() in
the finally block to ensure restoration.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit/torch/quantization/test_sequential_calibrate.py`:
- Around line 585-590: The test currently calls collector._patch_all_layers(),
runs collection, and then calls collector._unpatch_all_layers() but does not
guarantee cleanup if collection fails; wrap the collection calls in a
try/finally so that _unpatch_all_layers() is always invoked even on exceptions.
Specifically, after calling LayerActivationCollector(model) and
collector._patch_all_layers(), perform the loop that calls
collector.get_input_activations(layer, forward_loop) over originals inside a try
block and call collector._unpatch_all_layers() in the finally block to ensure
restoration.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: b3aad426-8516-47cb-93d8-486cadc7717d

📥 Commits

Reviewing files that changed from the base of the PR and between 6280846 and 6515d4d.

📒 Files selected for processing (2)
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py

@realAsma realAsma force-pushed the asma/ptq-large-models branch from 6515d4d to 6a25fc2 Compare April 15, 2026 14:18
@realAsma realAsma changed the title Add checkpoint save/resume for sequential calibration Add layerwise calibration for large models Apr 15, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

547-560: ⚠️ Potential issue | 🔴 Critical

Document the trusted-source exception for these unsafe loads.

Both torch.load(..., weights_only=False) calls still need an inline comment explaining why they are safe in this specific path. Without that justification, this will keep failing the repo’s checkpointing/security rule.

Suggested fix
         for i in range(self.start_layer):
             d = _layer_dir(self.checkpoint_dir, i)
+            # weights_only=False is required here because output_meta.pt may contain
+            # non-tensor metadata written by _save_layer; this file is produced by
+            # this checkpointing flow and is not user-supplied.
             meta = torch.load(
                 os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False
             )
@@
         next_inputs_path = os.path.join(d, "next_inputs.pt")
         if not os.path.isfile(next_inputs_path):
             raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}")
+        # weights_only=False is required here because next_inputs.pt may contain
+        # non-tensor captured inputs written by _save_layer; this file is produced
+        # by this checkpointing flow and is not user-supplied.
         next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False)

As per coding guidelines: Flag torch.load(..., weights_only=False) as CRITICAL security issue if no inline comment justifies why it is safe.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 547 - 560,
Add an inline “trusted-source” justification comment next to the unsafe
torch.load(..., weights_only=False) calls in layerwise_calib resume logic:
explain that the files loaded by _layer_dir(...) (output_meta.pt and
next_inputs.pt) are generated internally by this process, validated earlier, and
therefore safe to deserialize; reference the specific calls where torch.load is
used (the loop that assigns layers[i]._layerwise_calib.output_meta after
invoking _remap_output_metadata_device and the subsequent load of next_inputs
from next_inputs_path) and include a short note about any prior validation or
provenance guarantees that ensure these checkpoint files are trusted.
🧹 Nitpick comments (2)
tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py (2)

150-157: Extract repeated offload+dispatch setup to one helper

The same device_map construction + load_checkpoint_and_dispatch(...) flow is duplicated in several tests. Centralizing this will reduce drift when offload policy changes.

♻️ Refactor sketch
+def _dispatch_with_offloaded_layers(config, tiny_llama_dir, cpu_layers=(0,)):
+    with init_empty_weights():
+        model = AutoModelForCausalLM.from_config(config)
+    device_map = {
+        n: 0
+        for n, _ in model.named_modules()
+        if "layers" not in n or n.split("layers.")[-1].isdigit()
+    }
+    for idx in cpu_layers:
+        device_map[f"model.layers.{idx}"] = "cpu"
+    return load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map)

Then replace repeated blocks with calls like:

-model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map)
+model_ref = _dispatch_with_offloaded_layers(config, tiny_llama_dir, cpu_layers=(0,))

Also applies to: 197-204, 215-223, 245-252, 338-345, 356-364

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py` around lines 150
- 157, Extract the repeated offload+dispatch setup into a single helper (e.g.,
create a function like build_and_dispatch_offload) that encapsulates
constructing the device_map (using model.named_modules() with the current
filtering logic and setting "model.layers.0" to "cpu") and calling
load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map); then
replace each duplicated block (the instances around lines with device_map +
load_checkpoint_and_dispatch) with a call to that helper, passing the model (and
tiny_llama_dir if not global) so all tests share one canonical offload/dispatch
implementation.

451-454: Re-resolve linear in the final verification context

linear is captured in one enable_weight_access_and_writeback(...) context and reused in another. If materialization/writeback internals change object/parameter binding, this can become brittle.

🔧 Safer assertion pattern
-    with enable_weight_access_and_writeback(offloaded_layer, model):
-        assert torch.allclose(linear.weight, ref_weight + 1.0)
+    with enable_weight_access_and_writeback(offloaded_layer, model):
+        linear_after = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear))
+        assert torch.allclose(linear_after.weight, ref_weight + 1.0)

Also applies to: 480-481

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py` around lines 451
- 454, The test captures the Linear module as `linear` inside one
enable_weight_access_and_writeback(offloaded_layer, model) context and then
reuses that object later, which can be brittle if materialization/writeback
changes bindings; in the final verification block re-resolve the
module/parameter instead of reusing the old reference by calling the same
resolver (e.g., next(m for m in offloaded_layer.modules() if isinstance(m,
nn.Linear)) ) inside the second enable_weight_access_and_writeback context and
then compare its .weight to the previously saved `ref_weight` to ensure the
assertion checks the current materialized parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 1230-1238: This accepts checkpoint_dir even when use_layerwise is
false; add a cross-field validator on the same Pydantic/dataclass model that
contains checkpoint_dir and use_layerwise (the class that uses ModeloptField for
checkpoint_dir) to raise a clear ValueError if checkpoint_dir is set while
use_layerwise is False. Implement the validator (root_validator or equivalent)
to inspect use_layerwise and checkpoint_dir and fail-fast with a descriptive
message like "checkpoint_dir is only valid when use_layerwise=True".

In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 53-66: The helper _writeback_params_to_weights_map only iterates
named_parameters and misses registered buffers; change it to iterate the
module.state_dict(keep_vars=True).items() so both parameters and buffers (e.g.,
_amax, _bias_value) get written back to align_hook.weights_map (handling
PrefixedDataset the same way as before). For each (name, tensor_var) skip
entries where tensor_var.device.type == "meta", build the key using
align_hook.weights_map.prefix when weights_map is a PrefixedDataset, and then
assign w_map[key] = tensor_var.data.to(w_map[key].device,
dtype=w_map[key].dtype) to preserve device/dtype.

In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 523-534: from_folder currently trusts the checkpoint manifest and
can resume a checkpoint created for a different model depth; update from_folder
(and the resume detection logic around detect_resume_point) to read the saved
num_layers from the checkpoint manifest in checkpoint_dir and compare it to the
provided num_layers, and if they differ fail fast (raise a clear exception or
log an error and return None) instead of proceeding to create cls(...,
start_layer=start); ensure you reference the manifest field (saved num_layers)
obtained by detect_resume_point or by loading the manifest file and perform the
comparison before calling cls(checkpoint_dir, num_layers, start_layer=start).

In `@modelopt/torch/utils/network.py`:
- Around line 104-111: In _get_execution_device_from_hook(), the code assumes
hook.execution_device can be passed directly to torch.device but accelerate may
supply an integer GPU ordinal; update the logic where dev is read (both for the
top-level hook and inside the for h in getattr(hook, "hooks", ()) loop) to check
isinstance(dev, int) and when true return torch.device("cuda", dev); otherwise
keep the existing behavior (e.g., return torch.device(dev) for strings or
torch.device objects). Ensure you handle None as before and reference the
variable name dev and the function _get_execution_device_from_hook to locate
where to change it.

In `@tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py`:
- Around line 207-211: The test only rewrites manifest.json (manifest_path) to
simulate a crash, leaving checkpoint artifacts in ckpt_dir for layers >
last_completed_layer which can hide resume regressions; update the test in
tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py to remove or
truncate all layer artifacts whose layer indices are greater than the simulated
last_completed_layer (e.g., delete files/dirs in ckpt_dir corresponding to
layers 1..num_layers-1) after writing manifest.json, or instead generate the
partial checkpoint by interrupting the initial run so only artifacts up to layer
0 are created; apply the same change for the other occurrences around the blocks
at lines noted (the blocks using manifest_path, last_completed_layer, and
num_layers at the other two spots).

---

Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 547-560: Add an inline “trusted-source” justification comment next
to the unsafe torch.load(..., weights_only=False) calls in layerwise_calib
resume logic: explain that the files loaded by _layer_dir(...) (output_meta.pt
and next_inputs.pt) are generated internally by this process, validated earlier,
and therefore safe to deserialize; reference the specific calls where torch.load
is used (the loop that assigns layers[i]._layerwise_calib.output_meta after
invoking _remap_output_metadata_device and the subsequent load of next_inputs
from next_inputs_path) and include a short note about any prior validation or
provenance guarantees that ensure these checkpoint files are trusted.

---

Nitpick comments:
In `@tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py`:
- Around line 150-157: Extract the repeated offload+dispatch setup into a single
helper (e.g., create a function like build_and_dispatch_offload) that
encapsulates constructing the device_map (using model.named_modules() with the
current filtering logic and setting "model.layers.0" to "cpu") and calling
load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map); then
replace each duplicated block (the instances around lines with device_map +
load_checkpoint_and_dispatch) with a call to that helper, passing the model (and
tiny_llama_dir if not global) so all tests share one canonical offload/dispatch
implementation.
- Around line 451-454: The test captures the Linear module as `linear` inside
one enable_weight_access_and_writeback(offloaded_layer, model) context and then
reuses that object later, which can be brittle if materialization/writeback
changes bindings; in the final verification block re-resolve the
module/parameter instead of reusing the old reference by calling the same
resolver (e.g., next(m for m in offloaded_layer.modules() if isinstance(m,
nn.Linear)) ) inside the second enable_weight_access_and_writeback context and
then compare its .weight to the previously saved `ref_weight` to ensure the
assertion checks the current materialized parameter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 1470960a-24d2-4f8a-9d07-833e322404a8

📥 Commits

Reviewing files that changed from the base of the PR and between 6515d4d and 6a25fc2.

📒 Files selected for processing (20)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/accelerate.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/activation_collector.py
  • modelopt/torch/quantization/utils/calib_utils.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
  • modelopt/torch/utils/network.py
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
  • tests/gpu/torch/quantization/test_fsdp2.py
  • tests/gpu/torch/quantization/test_gptq.py
  • tests/gpu/torch/quantization/test_layerwise_calibrate.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • tests/unit/torch/quantization/test_calib.py
  • tests/unit/torch/quantization/test_layerwise_calibrate.py
  • tests/unit/torch/quantization/test_sequential_checkpoint.py
  • tests/unit/torch/quantization/test_utils.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/utils/activation_collector.py
🚧 Files skipped from review as they are similar to previous changes (10)
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/test_utils.py
  • tests/unit/torch/quantization/test_calib.py
  • modelopt/torch/quantization/mode.py
  • tests/gpu/torch/quantization/test_fsdp2.py
  • tests/unit/torch/quantization/test_sequential_checkpoint.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/init.py

Comment thread modelopt/torch/quantization/config.py Outdated
Comment thread modelopt/torch/quantization/plugins/accelerate.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/utils/network.py Outdated
Comment thread tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py Outdated
@realAsma realAsma requested a review from a team as a code owner April 15, 2026 18:14
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/accelerate.py (1)

52-64: ⚠️ Potential issue | 🟠 Major

Write back buffers too, not just parameters.

This helper only syncs named_parameters(). Quantization updates on offloaded modules also live in registered buffers, so those changes can be lost on the next rematerialization/offload cycle.

Suggested fix
-def _writeback_params_to_weights_map(module, align_hook):
-    """Write all non-meta parameters back to the hook's CPU weights_map."""
-    for name, param in module.named_parameters():
-        if param.device.type == "meta":
+def _writeback_params_to_weights_map(module, align_hook):
+    """Write all non-meta state back to the hook's CPU weights_map."""
+    for name, tensor in module.state_dict(keep_vars=True).items():
+        if tensor.device.type == "meta":
             continue
         if isinstance(align_hook.weights_map, PrefixedDataset):
             key = align_hook.weights_map.prefix + name
             w_map = align_hook.weights_map.dataset.state_dict
         else:
             w_map = align_hook.weights_map
             key = name
         if key in w_map:
-            w_map[key] = param.data.to(w_map[key].device, dtype=w_map[key].dtype)
+            w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/accelerate.py` around lines 52 - 64, The
helper _writeback_params_to_weights_map only writes named_parameters; update it
to also iterate named_buffers() and write non-meta buffers back to the same
weights_map using the same key logic (handle PrefixedDataset via
align_hook.weights_map.prefix and .dataset.state_dict), skip buffers on meta
device, and use buffer.data.to(w_map[key].device, dtype=w_map[key].dtype) when
assigning so buffer updates (e.g., quantization state) are preserved across
rematerialization/offload.
modelopt/torch/quantization/utils/layerwise_calib.py (2)

614-625: ⚠️ Potential issue | 🔴 Critical

Add inline justification for both weights_only=False loads.

These two loads are security-sensitive and will fail the repo rule without an inline comment explaining why pickle deserialization is required and why the files are trusted.

Suggested fix
         for i in range(self.start_layer):
             d = _layer_dir(self.checkpoint_dir, i)
+            # weights_only=False is required here because output_meta.pt is
+            # internally generated by _save_layer() in this module and may
+            # contain non-tensor Python values in the ("other", ...) metadata path.
             meta = torch.load(
                 os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False
             )
             layer_device = get_module_device(layers[i])
             meta = _remap_output_metadata_device(meta, layer_device)
             layers[i]._layerwise_calib.output_meta = meta

         d = _layer_dir(self.checkpoint_dir, last_ckpt)
         next_inputs_path = os.path.join(d, "next_inputs.pt")
         if not os.path.isfile(next_inputs_path):
             raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}")
+        # weights_only=False is required here because next_inputs.pt is
+        # internally generated by _save_layer() and may include non-tensor
+        # forward inputs that require pickle deserialization.
         next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False)

As per coding guidelines: "Flag torch.load(..., weights_only=False) as CRITICAL security issue if no inline comment justifies why it is safe (e.g. confirming the file is internally-generated and not user-supplied)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 614 - 625,
Add inline comments next to both torch.load(...) calls explaining why
weights_only=False is required and why deserialization is safe: state these
files ("output_meta.pt" loaded into meta before _remap_output_metadata_device
and assigned to layers[i]._layerwise_calib.output_meta, and "next_inputs.pt"
loaded into next_inputs) are internally-generated by this checkpointing routine,
not user-supplied, and thus trusted for pickle deserialization; mention any
integrity guarantees (e.g., written by the same code/version and stored in a
controlled checkpoint_dir) and that removing weights_only=False would lose
needed non-parameter metadata—place the comments immediately above each
torch.load invocation.

588-599: ⚠️ Potential issue | 🟠 Major

Validate checkpoint depth before resuming.

from_folder() trusts manifest.json blindly. Reusing the same checkpoint_dir for a model with a different decoder-layer count can resume at the wrong layer and later restore incompatible state.

Suggested fix
     def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None:
         """Create from folder. Detects resume point. Returns None if no checkpoint_dir."""
         if not checkpoint_dir:
             return None
         os.makedirs(checkpoint_dir, exist_ok=True)
         info = detect_resume_point(checkpoint_dir)
+        if info and info[1].get("num_layers") != num_layers:
+            raise ValueError(
+                f"Checkpoint in {checkpoint_dir!r} was created for "
+                f"{info[1].get('num_layers')} layers, but the current model has {num_layers}."
+            )
         start = info[0] if info else 0
         if start > 0:
             print_rank_0(
                 f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}"
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 588 - 599,
from_folder currently resumes purely from detect_resume_point and can continue
with an incompatible layer count; modify from_folder to validate checkpoint
depth by reading the checkpoint manifest (e.g., manifest.json under
checkpoint_dir) and extract the saved layer count, then compare it to the
provided num_layers: if they mismatch, print_rank_0 a clear error/warning
referencing checkpoint_dir and the manifest values and either refuse to resume
(return None) or clamp start to min(saved_start, num_layers) and set start_layer
accordingly; ensure you use the existing symbols from_folder,
detect_resume_point, start_layer, checkpoint_dir and cls(...) so the validation
happens before cls(...) is returned.
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

681-684: Avoid repeated module scans in the per-layer save path.

enable_weight_access_and_writeback(layer, model) recomputes module lookup state every time this method runs. In a checkpointed layerwise loop that turns save into an avoidable O(N²) path on large models. Please cache dict(model.named_modules()) once and pass it here, like the restore path already does.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 681 - 684,
The per-layer save path calls enable_weight_access_and_writeback(layer, model)
which rescans model modules each time and causes O(N²) behavior; cache
dict(model.named_modules()) once (as done in the restore path) and pass that
cached mapping into enable_weight_access_and_writeback to avoid repeated module
lookups when saving layers (update the call sites around
enable_weight_access_and_writeback in the block that moves weights via
_move_to_device and computes qstate via quantizer_state to accept and use the
precomputed named_modules map).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/quantization/plugins/accelerate.py`:
- Around line 52-64: The helper _writeback_params_to_weights_map only writes
named_parameters; update it to also iterate named_buffers() and write non-meta
buffers back to the same weights_map using the same key logic (handle
PrefixedDataset via align_hook.weights_map.prefix and .dataset.state_dict), skip
buffers on meta device, and use buffer.data.to(w_map[key].device,
dtype=w_map[key].dtype) when assigning so buffer updates (e.g., quantization
state) are preserved across rematerialization/offload.

In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 614-625: Add inline comments next to both torch.load(...) calls
explaining why weights_only=False is required and why deserialization is safe:
state these files ("output_meta.pt" loaded into meta before
_remap_output_metadata_device and assigned to
layers[i]._layerwise_calib.output_meta, and "next_inputs.pt" loaded into
next_inputs) are internally-generated by this checkpointing routine, not
user-supplied, and thus trusted for pickle deserialization; mention any
integrity guarantees (e.g., written by the same code/version and stored in a
controlled checkpoint_dir) and that removing weights_only=False would lose
needed non-parameter metadata—place the comments immediately above each
torch.load invocation.
- Around line 588-599: from_folder currently resumes purely from
detect_resume_point and can continue with an incompatible layer count; modify
from_folder to validate checkpoint depth by reading the checkpoint manifest
(e.g., manifest.json under checkpoint_dir) and extract the saved layer count,
then compare it to the provided num_layers: if they mismatch, print_rank_0 a
clear error/warning referencing checkpoint_dir and the manifest values and
either refuse to resume (return None) or clamp start to min(saved_start,
num_layers) and set start_layer accordingly; ensure you use the existing symbols
from_folder, detect_resume_point, start_layer, checkpoint_dir and cls(...) so
the validation happens before cls(...) is returned.

---

Nitpick comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 681-684: The per-layer save path calls
enable_weight_access_and_writeback(layer, model) which rescans model modules
each time and causes O(N²) behavior; cache dict(model.named_modules()) once (as
done in the restore path) and pass that cached mapping into
enable_weight_access_and_writeback to avoid repeated module lookups when saving
layers (update the call sites around enable_weight_access_and_writeback in the
block that moves weights via _move_to_device and computes qstate via
quantizer_state to accept and use the precomputed named_modules map).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 5c4cda75-b053-47ea-aa4d-59f64d16d881

📥 Commits

Reviewing files that changed from the base of the PR and between 6a25fc2 and c50c4a7.

📒 Files selected for processing (6)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/accelerate.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
  • modelopt/torch/utils/dataset_utils.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
modelopt/torch/quantization/utils/layerwise_calib.py (2)

557-563: ⚠️ Potential issue | 🟠 Major

Reject resumes from checkpoints with a different layer count.

from_folder() trusts manifest.json blindly. Reusing an old checkpoint directory against a model with different decoder depth can resume at the wrong layer and restore incompatible per-layer state.

Suggested fix
         os.makedirs(checkpoint_dir, exist_ok=True)
         info = detect_resume_point(checkpoint_dir)
+        if info and info[1].get("num_layers") != num_layers:
+            raise ValueError(
+                f"Checkpoint in {checkpoint_dir!r} was created for "
+                f"{info[1].get('num_layers')} layers, but the current model has {num_layers}."
+            )
         start = info[0] if info else 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 557 - 563,
from_folder() currently trusts manifest.json and may resume a checkpoint with a
different layer count (via detect_resume_point / start_layer), causing
incompatible per-layer state; update from_folder() to read the saved num_layers
from manifest.json (or stored metadata) and compare it to the provided
num_layers and only allow resuming when they match, otherwise log/raise an error
and force a fresh start (i.e., ignore the checkpoint_dir or set start_layer=0)
before returning cls(checkpoint_dir, num_layers, start_layer=start).

576-590: ⚠️ Potential issue | 🔴 Critical

Document both trusted pickle loads inline.

Both torch.load(..., weights_only=False) calls are missing the required trusted-origin justification. In this repo that is treated as a blocker for checkpoint deserialization code.

Suggested fix
         for i in range(self.start_layer):
             d = _layer_dir(self.checkpoint_dir, i)
+            # weights_only=False is required here: output_meta.pt is written by
+            # _save_layer in this module, is not user-supplied, and may contain
+            # non-tensor metadata under the ("other", ...) path.
             meta = torch.load(
                 os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False
             )
@@
         next_inputs_path = os.path.join(d, "next_inputs.pt")
         if not os.path.isfile(next_inputs_path):
             raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}")
+        # weights_only=False is required here: next_inputs.pt is written by
+        # _save_layer in this module, is not user-supplied, and may contain
+        # non-tensor forward inputs that need pickle deserialization.
         next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False)
#!/bin/bash
rg -n -C2 'weights_only=False' modelopt/torch/quantization/utils/layerwise_calib.py

As per coding guidelines: "Flag torch.load(..., weights_only=False) as CRITICAL security issue if no inline comment justifies why it is safe."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 576 - 590,
Both torch.load(..., weights_only=False) calls must include an inline
trusted-origin justification; update the two calls that load "output_meta.pt"
(inside the loop using _layer_dir(...), meta = torch.load(...)) and
"next_inputs.pt" (after locating d = _layer_dir(...), next_inputs =
torch.load(...)) to add short comments explaining why it's safe to deserialize
with weights_only=False (e.g., files are produced by our controlled
checkpointing process, not user-supplied, and validated elsewhere), referencing
the surrounding symbols layers, start_layer, last_ckpt, checkpoint_dir,
_layer_dir and _remap_output_metadata_device so reviewers can verify the trust
boundary. Ensure each comment is immediately adjacent to the torch.load call and
concise.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 860-865: The helper needs_checkpoint_path_update currently calls
algorithm.get("checkpoint_dir") and will crash if quant_cfg["algorithm"] is not
a mapping; update it to first retrieve algorithm = quant_cfg.get("algorithm"),
then guard that algorithm is a dict (e.g., isinstance(algorithm, dict)) before
accessing .get, returning False for any non-dict/None values so the function
never raises on unexpected config shapes while still checking for
"checkpoint_dir".

In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 304-314: When setting up neighboring layer state in
_set_layer_states() (the block that accesses self._decoder_layers[layer_idx -
1]._layerwise_calib and mutates prev.mode, prev.cached_inputs, and
prev.collected_inputs), ensure you restore prev.collected_inputs,
prev.cached_inputs and prev.mode if forward_loop() (or any subsequent operation)
raises so the previous layer isn't left with emptied collected_inputs; wrap the
state changes in a try/except that reassigns prev.collected_inputs back from the
deque/cached state (or clears cached_inputs) and resets prev.mode to its
original value on exception so retries won't hit the empty-input guard.

---

Duplicate comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 557-563: from_folder() currently trusts manifest.json and may
resume a checkpoint with a different layer count (via detect_resume_point /
start_layer), causing incompatible per-layer state; update from_folder() to read
the saved num_layers from manifest.json (or stored metadata) and compare it to
the provided num_layers and only allow resuming when they match, otherwise
log/raise an error and force a fresh start (i.e., ignore the checkpoint_dir or
set start_layer=0) before returning cls(checkpoint_dir, num_layers,
start_layer=start).
- Around line 576-590: Both torch.load(..., weights_only=False) calls must
include an inline trusted-origin justification; update the two calls that load
"output_meta.pt" (inside the loop using _layer_dir(...), meta = torch.load(...))
and "next_inputs.pt" (after locating d = _layer_dir(...), next_inputs =
torch.load(...)) to add short comments explaining why it's safe to deserialize
with weights_only=False (e.g., files are produced by our controlled
checkpointing process, not user-supplied, and validated elsewhere), referencing
the surrounding symbols layers, start_layer, last_ckpt, checkpoint_dir,
_layer_dir and _remap_output_metadata_device so reviewers can verify the trust
boundary. Ensure each comment is immediately adjacent to the torch.load call and
concise.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: b9a52a60-07ea-4cc6-8e22-c5373f4a0cbc

📥 Commits

Reviewing files that changed from the base of the PR and between c50c4a7 and d2cd03c.

📒 Files selected for processing (3)
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_ptq/hf_ptq.py

Comment thread examples/llm_ptq/example_utils.py Outdated
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
@realAsma realAsma requested a review from a team as a code owner April 15, 2026 18:53
@realAsma realAsma requested a review from shengliangxu April 15, 2026 18:53
This PR does three things:

1. Rename sequential_calibrate to layerwise_calibrate to better describe
   the layer-by-layer algorithm (use_sequential -> use_layerwise,
   _seq_calib -> _layerwise_calib).

2. Make layerwise calibration performant: persistent_materialization
   keeps the active layer on GPU for the entire calibration step,
   and _SkipLayer replaces fully-calibrated layers with parameter-free
   dummies so framework hooks (accelerate, FSDP2) skip materialization.

3. Add checkpoint save/resume so calibration of large models can be
   interrupted and restarted from the last completed layer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>

Add layerwise calibration for large models

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>

Move checkpoint_dir helpers from library to examples/llm_ptq

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>

Rename layerwise config fields and enable layerwise on experts-only recipe

- use_layerwise -> layerwise, checkpoint_dir -> layerwise_checkpoint_dir
- Enable layerwise calibration + checkpointing on nvfp4_experts_only-fp8_kv recipe
- Add layerwise_checkpoint_dir to nvfp4_default-none_kv_gptq recipe

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>

Address PR review feedback for layerwise calibration

- Add inline security comments for all torch.load(weights_only=False) calls
- Replace bare assert with RuntimeError for unsupported offload hook layout
- Write back buffers (not just parameters) in _writeback_params_to_weights_map
- Add cross-field validator rejecting layerwise_checkpoint_dir without layerwise=True
- Validate num_layers mismatch on checkpoint resume
- Handle integer device ordinals in _get_execution_device_from_hook
- Clean up stale layer artifacts in partial-checkpoint tests
- Guard non-dict algorithm values in needs_checkpoint_path_update
- Add comment explaining dummy output_meta for last layer

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 43d1888 to e0cda1b Compare April 15, 2026 19:58
Copy link
Copy Markdown
Contributor Author

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 Note on layerwise_calib.py: Most of this file is moved from modelopt/torch/quantization/utils/activation_collector.py (deleted in this PR). Git does not detect the rename because the file nearly doubled in size with new checkpoint/resume logic. To see the actual diff against the original, use:

git diff origin/main...HEAD -M10% -- modelopt/torch/quantization/utils/

This shows it as a rename with ~393 insertions and ~47 deletions, rather than a full 681-line new file + 335-line deletion.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Design — Method Guard Removed

[QUESTION] Calibration method guard removed with TODO
The old code asserted that only max and gptq methods could be used with sequential calibration. This PR removes that assertion with a TODO: "add a method guard here." This means any calibration method (AWQ, SmoothQuant, etc.) can now be silently used with layerwise, even if it doesn't support per-layer invocation. What's the plan for re-adding validation?

2. Correctness — Checkpoint Corruption Recovery

[SUGGESTION] Manifest corruption restarts silently from layer 0
_read_manifest() returns None on corrupt/missing JSON, causing from_folder() to treat it as a fresh run. If a user's 50-layer calibration checkpoint is partially corrupted, they'd restart from scratch without knowing why. Add a warning when the manifest exists but can't be parsed — "Checkpoint manifest found but unreadable; starting from layer 0."

3. Correctness — _SkipLayer Proxy Masks AttributeErrors

[SUGGESTION] _SkipLayer __getattr__ catches all AttributeErrors

def __getattr__(self, name):
    try:
        return super().__getattr__(name)
    except AttributeError:
        return getattr(object.__getattribute__(self, "_original"), name)

If the original layer also doesn't have the attribute, the error message will reference _original instead of the skip layer — confusing for debugging. Consider catching only the first super().__getattr__ and letting the second propagate naturally (which it does), but add context to the error.

4. Design — Distributed Checkpoint Blocked

[SUGGESTION] Multi-rank checkpointing blocked at runtime, should fail at config time
_CheckpointState.__init__ raises RuntimeError if dist.size() > 1. Users discover this only after model loading and initial setup. Validate this in the config validator alongside layerwise_checkpoint_dir — if running distributed, reject the config early.

The block_sizes config has mixed int and str keys which causes TypeError
when sort_keys=True is used in json.dumps for checkpoint dir hashing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All critical comments from previous reviews have been addressed:

Critical issues — all resolved:

  1. weights_only=False security: All 6 torch.load calls now have inline justification comments explaining the files are internally generated.
  2. _writeback_params_to_weights_map now uses state_dict(keep_vars=True) to write back both parameters and buffers.
  3. from_folder() validates num_layers mismatch between checkpoint manifest and current model.
  4. ✅ Config validator validate_layerwise_checkpoint_dir rejects layerwise_checkpoint_dir when layerwise=False.
  5. _get_execution_device_from_hook handles integer GPU ordinals with isinstance(dev, int).
  6. needs_checkpoint_path_update guards against non-dict algorithm values.
  7. ✅ Tests now clean up stale layer directories above last_completed_layer during crash simulation.
  8. weight_access_and_writeback_context now raises RuntimeError (instead of bare assert) for unsupported dual-hook layouts.
  9. ✅ Last layer's dummy output_meta has a clear comment explaining it's a placeholder.
  10. _layer_forward_loop uses default argument capture (_inputs=layer_inputs) for explicit binding.

Design decisions accepted by reviewers:

  • Author pushed back on restoring previous-layer state on exception (no retry mechanism exists), and CodeRabbit agreed.
  • Method guard removal: The layerwise field is on the base QuantizeAlgorithmConfig, so technically any method could set layerwise=True. However, the TODO is clearly marked, and in practice the typed config dispatch system means standard recipes won't hit this path for unsupported methods. This is acceptable risk with the TODO in place.

Code quality:

  • Excellent test coverage: unit tests for skip/swap/restore, checkpoint save/resume, GPU integration tests for CPU-offloaded models, FSDP2, GPTQ combinations.
  • The _SkipLayer pattern for replacing calibrated layers is clean and avoids framework hook overhead.
  • The persistent_materialization context manager is well-designed for both FSDP2 and accelerate.
  • The _forward_loop change to disable use_cache during calibration is a good correctness fix.
  • The calib_utils.py change to handle meta device in GPTQHelper.__init__ correctly supports CPU-offloaded models.

PR size: ~1900 lines is above the soft 1000-line guideline, but the changes are cohesive (rename + layerwise calibration + checkpoint + framework integration) and the author notes most of layerwise_calib.py is moved from the deleted activation_collector.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants