Skip to content

fix: handle accelerate CPU-offloaded models in FakeQuant export#1194

Open
sungsooha wants to merge 3 commits intoNVIDIA:mainfrom
sungsooha:sungsooh/fix-offload-export
Open

fix: handle accelerate CPU-offloaded models in FakeQuant export#1194
sungsooha wants to merge 3 commits intoNVIDIA:mainfrom
sungsooha:sungsooh/fix-offload-export

Conversation

@sungsooha
Copy link
Copy Markdown
Contributor

@sungsooha sungsooha commented Apr 8, 2026

What does this PR do?

Type of change: Bug fix

When models are loaded with device_map="auto" and layers are offloaded to CPU via AlignDevicesHook, two issues arise during FakeQuant export:

  1. model.state_dict() returns meta tensors for offloaded layers (no actual data)
  2. model.save_pretrained(state_dict=clean_sd) is ignored by accelerate — it saves from internal state instead, leaking quantizer keys (input_quantizer._amax) into safetensors and preserving auto_map in config.json

This causes vLLM's weight loader to crash with KeyError on the unknown quantizer keys, and OSError when auto_map references custom Python files not present in the export directory.

Affected models: Any model large enough to trigger CPU offloading during PTQ (e.g., DeepSeek-R1 671B on 8xH100).

Four fixes:

  1. _materialize_offloaded_weights(): Walk accelerate's AlignDevicesHook.weights_map to resolve meta tensors to actual CPU data before export.

  2. GPU hop in weight processing: Move CPU tensors to the quantizer's device before calling quantizer kernels (e.g., fp4_fake_quant_block) which require CUDA. Uses quantizer buffers (_amax) for device detection since NVFP4 quantizers have no parameters, only buffers.

  3. _save_clean_checkpoint(): Bypass save_pretrained() entirely for weight saving. Write safetensors directly from clean_sd via safetensors.torch.save_file() + split_torch_state_dict_into_shards(). Also strips auto_map from config.json (custom code files are not present in the export directory).

  4. FakeQuantWorker.compile_or_warm_up_model(): Fix return type from None to float and add missing return before super() call. Without this, the multiproc executor gets [None, None, ...] from collective_rpc and crashes with TypeError in max(compilation_times).

Usage

No API changes. The fixes are internal to export_hf_vllm_fq_checkpoint() and FakeQuantWorker. Existing callers work without modification.

# Export works the same — offloaded models now produce clean checkpoints
from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint

export_hf_vllm_fq_checkpoint(model, export_dir="/path/to/export")
# Output: clean safetensors (no quantizer keys), config.json (no auto_map),
#         vllm_fq_modelopt_state.pth (quantizer state)

Testing

Tested on H100 (dlcluster) with forced CPU offloading:

  • Model: Qwen3-0.6B with max_memory={0: "500MiB", "cpu": "32GiB"}
  • 254/311 parameters offloaded (meta tensors in state_dict())
  • Verified: no quantizer keys in safetensor files, no auto_map in config.json, valid vllm_fq_modelopt_state.pth (257KB)

Also validated end-to-end on DeepSeek-R1 671B (8xH100, v7.1 image) — PTQ export + vLLM FakeQuant serving successful.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ (no API changes, internal fix only)
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ (huggingface_hub.split_torch_state_dict_into_shards and safetensors.torch.save_file are existing dependencies)
  • Did you write any new necessary tests?: ❌ (manual GPU test — automated test requires multi-GPU with forced offloading, hard to set up in CI)
  • Did you update Changelog?: N/A (bug fix, not a new feature or API change)

Additional Information

  • Related to PR feat: parallelize fakequant export across GPUs via ThreadPoolExecutor #1177 (parallel FakeQuant export) which addresses the performance side
  • Root cause discovered during P0 NVFP4 sweep on DeepSeek-R1 671B
  • The compile_or_warm_up_model fix (item 4) is needed for single-node multiproc executor (TP>1) — Ray executor has a different code path that doesn't hit this bug

Summary by CodeRabbit

  • Bug Fixes

    • Better handling and materialization of offloaded (meta) weights during fake‑quant export to avoid missing tensors.
    • Improved export process to write clean, sharded checkpoints reliably and preserve configuration.
    • Ensured tensors are placed on appropriate CUDA devices during quantized export.
  • New Features

    • Worker warm‑up now returns a numeric compile/warm‑up metric for improved reporting and propagation.

When models are loaded with device_map="auto" and layers are offloaded
to CPU via AlignDevicesHook, model.state_dict() returns meta tensors
and model.save_pretrained(state_dict=clean_sd) is ignored by accelerate.

Three fixes:
1. _materialize_offloaded_weights(): resolve meta tensors from
   accelerate's AlignDevicesHook.weights_map before export.
2. GPU hop in weight processing: move CPU tensors to quantizer's
   device (quantizer kernels like fp4_fake_quant_block require CUDA).
   Uses quantizer buffers (amax) for device detection.
3. _save_clean_checkpoint(): bypass save_pretrained entirely, write
   safetensors directly via save_file() + split_torch_state_dict_into_shards().
   Also strips auto_map from config.json (custom code files not in export).
4. FakeQuantWorker.compile_or_warm_up_model: return float (not None)
   to fix multiproc executor TypeError in max(compilation_times).

Tested: Qwen3-0.6B on H100 with forced CPU offloading (500MiB GPU limit,
254/311 meta tensors). All checks passed — no quantizer keys in safetensors,
no auto_map in config.json.

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
@sungsooha sungsooha requested review from a team as code owners April 8, 2026 05:16
@sungsooha sungsooha requested a review from meenchen April 8, 2026 05:16
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2026

📝 Walkthrough

Walkthrough

Updated FakeQuantWorker to return a float from its compile/warm-up flow; added helpers to materialize accelerate-offloaded (meta) tensors and to write clean sharded safetensors/config during VLLM fake-quant checkpoint export.

Changes

Cohort / File(s) Summary
FakeQuant Worker
examples/vllm_serve/fakequant_worker.py
Changed FakeQuantWorker.compile_or_warm_up_model() return type to float; now optionally runs _fakequant_run_prolog_worker() then returns the numeric result from BaseWorker.compile_or_warm_up_model().
Weight Materialization & Checkpoint Export
modelopt/torch/export/plugins/vllm_fakequant_hf.py
Added module logger, _materialize_offloaded_weights() to replace meta tensors using Accelerate hooks, _save_clean_checkpoint() to write sharded safetensors and config.json, and updated export_hf_vllm_fq_checkpoint() to materialize offloaded weights, move target tensors to CUDA if needed, and save via the new clean-save flow.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Export as export_hf_vllm_fq_checkpoint
    participant Detect as MetaTensorDetector
    participant Materialize as _materialize_offloaded_weights
    participant Quantizer as Quantizer/DeviceManager
    participant Save as _save_clean_checkpoint
    participant FS as FileSystem

    Caller->>Export: invoke export_hf_vllm_fq_checkpoint(model, ...)
    Export->>Export: state_dict = model.state_dict()
    Export->>Detect: scan state_dict for meta tensors
    alt meta tensors found
        Export->>Materialize: request materialization via accelerate weights_map
        Materialize->>Materialize: replace meta tensors with real tensors
        Materialize->>Export: return materialized state_dict
    end
    Export->>Quantizer: fold quantizers (may require CUDA)
    alt non-CUDA tensors need CUDA
        Quantizer->>Export: choose CUDA device and move tensors
    end
    Export->>Save: write clean_sd via _save_clean_checkpoint (shard & safetensors)
    Save->>FS: write shards and model.config (strip auto_map)
    Save->>Export: confirm saved
    Export->>Caller: return/export complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: handle accelerate CPU-offloaded models in FakeQuant export' directly and specifically describes the main bug fix addressing CPU-offloaded model handling in FakeQuant export, which is the primary objective of this PR.
Security Anti-Patterns ✅ Passed PR adheres to all SECURITY.md requirements: torch.load uses weights_only=True, trust_remote_code defaults to False, no unsafe deserialization or code execution patterns detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)

248-252: ⚠️ Potential issue | 🟠 Major

Restore the quantizer state in a finally.

Once this block disables the weight quantizers, any exception from torch.save() or _save_clean_checkpoint() leaves the in-memory model partially mutated. Wrapping the disable/save/restore path in try/finally keeps a failed export from poisoning subsequent use of the model.

Also applies to: 282-291

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 248 - 252,
The disable/modify sequence for weight quantizers (calls to quantizer.disable(),
setting quantizer._rotate via disable_rotate(quantizer), and appending to
wqs_to_restore) must be wrapped with a try/finally so the original quantizer
state (orig_rotate and enabled state) is always restored even if torch.save() or
_save_clean_checkpoint() throws; update both the block around
quantizer.disable()/orig_rotate/quantizer._rotate and the similar block at
282-291 to perform the disable and then run save inside try, and restore
quantizer._rotate and re-enable the quantizer in the finally regardless of
exceptions, referencing quantizer, disable_rotate, wqs_to_restore, torch.save
and _save_clean_checkpoint to locate the code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 203-208: The code currently only moves weight w to a CUDA device
if quantizer.parameters()/buffers() contains a CUDA tensor; when TensorQuantizer
has no qtensors this misses a CUDA-only quantizer path and can still call
quantizer(w.float()) on CPU. Update the block around w and quantizer (symbols:
w, quantizer, TensorQuantizer) to, if qtensors is empty, find a sibling tensor
on the same module (e.g., iterate module.parameters()/buffers() for the first
is_cuda tensor) and use its device, or fall back to an explicit export device
variable (e.g., export_device or torch.device('cuda')) before calling quantizer;
ensure w = w.to(found_device) is applied only when needed.
- Around line 112-115: The code calls split_torch_state_dict_into_shards() and
save_file() on cpu_sd without resolving tied/shared tensors, which can cause
failures or duplicated large tensors; before sharding, detect aliasing in cpu_sd
(e.g., by comparing tensor.data_ptr()/storage()/is) and replace aliased entries
with a single canonical tensor reference (or remove duplicate keys mapping to
the same storage) so split_torch_state_dict_into_shards() and subsequent
save_file() operate on a deduplicated state dict; alternatively, delegate to the
HF helper used by save_pretrained() that already performs this alias
cleanup—perform this deduplication on cpu_sd prior to calling
split_torch_state_dict_into_shards() and use the resulting filename_to_tensors
mapping to build shard dicts for save_file().

---

Outside diff comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 248-252: The disable/modify sequence for weight quantizers (calls
to quantizer.disable(), setting quantizer._rotate via disable_rotate(quantizer),
and appending to wqs_to_restore) must be wrapped with a try/finally so the
original quantizer state (orig_rotate and enabled state) is always restored even
if torch.save() or _save_clean_checkpoint() throws; update both the block around
quantizer.disable()/orig_rotate/quantizer._rotate and the similar block at
282-291 to perform the disable and then run save inside try, and restore
quantizer._rotate and re-enable the quantizer in the finally regardless of
exceptions, referencing quantizer, disable_rotate, wqs_to_restore, torch.save
and _save_clean_checkpoint to locate the code to change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 22c6dd83-d33b-430f-8cbf-6526445753cf

📥 Commits

Reviewing files that changed from the base of the PR and between 82d96a6 and 3895d21.

📒 Files selected for processing (2)
  • examples/vllm_serve/fakequant_worker.py
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 5.40541% with 70 lines in your changes missing coverage. Please review.
✅ Project coverage is 59.34%. Comparing base (82d96a6) to head (8c189ac).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/plugins/vllm_fakequant_hf.py 5.40% 70 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1194      +/-   ##
==========================================
- Coverage   59.45%   59.34%   -0.11%     
==========================================
  Files         352      352              
  Lines       40343    40416      +73     
==========================================
- Hits        23987    23986       -1     
- Misses      16356    16430      +74     
Flag Coverage Δ
unit 54.92% <5.40%> (-0.11%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

1. Fix ruff import ordering: move logger after all imports
2. Clone tensors in _save_clean_checkpoint to handle tied weights
   (safetensors rejects shared storage)
3. Robust GPU device fallback: check quantizer params/buffers,
   then parent module params (handles uninitialized quantizers)

Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 167-178: After calling _materialize_offloaded_weights(model,
state_dict, meta_keys) recompute unresolved_meta = [k for k,v in
state_dict.items() if v.is_meta]; if unresolved_meta is non-empty and contains
any keys that are not quantizer-related (e.g. not containing "quant" or
"quantizer"), raise a RuntimeError listing unresolved_meta and a short message
mentioning that materialization failed and will break subsequent fake-quant
folding or _save_clean_checkpoint; reference the symbols meta_keys,
_materialize_offloaded_weights, state_dict, and _save_clean_checkpoint so the
error helps locate the problem.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fccfaa7b-0635-4d84-8e1d-3888a156a198

📥 Commits

Reviewing files that changed from the base of the PR and between b85c4e0 and 8c189ac.

📒 Files selected for processing (1)
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py

Comment on lines +110 to +118
# Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
# safetensors rejects tensors that share underlying storage.
cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}

state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
shard = {k: cpu_sd[k] for k in tensor_keys}
save_file(shard, str(export_dir / shard_file))
logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shard before cloning tensors to CPU.

cpu_sd = {k: v.cpu().clone() ...} creates a second full copy of the checkpoint in host RAM before sharding. On the offload path, that doubles peak memory and can OOM the exact large-model exports this change is trying to unblock.

💡 Proposed fix
-    # Move to CPU and clone to break shared storage (tied weights like lm_head/embed_tokens).
-    # safetensors rejects tensors that share underlying storage.
-    cpu_sd = {k: v.cpu().clone() for k, v in clean_sd.items()}
-
-    state_dict_split = split_torch_state_dict_into_shards(cpu_sd, max_shard_size="5GB")
+    state_dict_split = split_torch_state_dict_into_shards(clean_sd, max_shard_size="5GB")
     for shard_file, tensor_keys in state_dict_split.filename_to_tensors.items():
-        shard = {k: cpu_sd[k] for k in tensor_keys}
+        # Move only the current shard to CPU to keep peak memory bounded.
+        shard = {k: clean_sd[k].cpu().clone() for k in tensor_keys}
         save_file(shard, str(export_dir / shard_file))
         logger.info("Saved shard: %s (%d tensors)", shard_file, len(shard))
@@
     logger.info(
         "Checkpoint saved: %d weights in %d shard(s)",
-        len(cpu_sd),
+        len(clean_sd),
         len(state_dict_split.filename_to_tensors),
     )

Also applies to: 136-140

Comment on lines +167 to +178

# Handle accelerate-offloaded models: state_dict() returns meta tensors
# for CPU/disk-offloaded layers. Materialize them from the offload hooks.
meta_keys = [k for k, v in state_dict.items() if v.is_meta]
if meta_keys:
logger.info(
"Found %d meta tensors in state_dict (accelerate offloading). "
"Materializing from offload hooks...",
len(meta_keys),
)
_materialize_offloaded_weights(model, state_dict, meta_keys)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast if any offloaded tensors stay on meta.

_materialize_offloaded_weights() only logs misses. If a non-quantizer key is still meta here, the later fake-quant fold or _save_clean_checkpoint() will blow up with a much less actionable error. Please re-check the state dict immediately after materialization and raise with the unresolved keys.

💡 Proposed fix
     if meta_keys:
         logger.info(
             "Found %d meta tensors in state_dict (accelerate offloading). "
             "Materializing from offload hooks...",
             len(meta_keys),
         )
         _materialize_offloaded_weights(model, state_dict, meta_keys)
+        unresolved_meta_keys = [
+            k for k, v in state_dict.items() if v.is_meta and "quantizer" not in k
+        ]
+        if unresolved_meta_keys:
+            shown = ", ".join(unresolved_meta_keys[:10])
+            suffix = " ..." if len(unresolved_meta_keys) > 10 else ""
+            raise RuntimeError(f"Failed to materialize offloaded tensors: {shown}{suffix}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 167 - 178,
After calling _materialize_offloaded_weights(model, state_dict, meta_keys)
recompute unresolved_meta = [k for k,v in state_dict.items() if v.is_meta]; if
unresolved_meta is non-empty and contains any keys that are not
quantizer-related (e.g. not containing "quant" or "quantizer"), raise a
RuntimeError listing unresolved_meta and a short message mentioning that
materialization failed and will break subsequent fake-quant folding or
_save_clean_checkpoint; reference the symbols meta_keys,
_materialize_offloaded_weights, state_dict, and _save_clean_checkpoint so the
error helps locate the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant