vLLM fakequant fold weight_quantizer for megatron export#1246
vLLM fakequant fold weight_quantizer for megatron export#1246kinjalpatel27 merged 10 commits intomainfrom
Conversation
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughChanges modify quantizer weight handling in vLLM export and reload workflows. The export plugin now folds enabled weight quantizers into exported weights and skips exporting weight-quantizer tensors. The reload utility validates missing quantizer keys and disables corresponding modules for weight quantizers to prevent folding during checkpoint loading. One whitespace-only change was added. Changes
Sequence Diagram(s)sequenceDiagram
participant Exporter
participant Module
participant Quantizer
participant CPU
participant CUDA
Exporter->>Module: inspect weight, check weight_quantizer
alt weight_quantizer exists and enabled
Exporter->>Quantizer: move quantizer buffers -> CUDA (if weight on CPU and CUDA available)
Exporter->>Quantizer: apply quantizer(weight) under no_grad
Quantizer-->>Exporter: quantized weight
Exporter->>Quantizer: restore quantizer buffers to original devices
Exporter->>CPU: move quantized weight -> CPU and store for export
else no enabled weight_quantizer
Exporter->>CPU: move module.weight -> CPU and store for export
end
Exporter->>Exporter: emit quantizer state excluding items with "weight_quantizer" in name
sequenceDiagram
participant Loader
participant Checkpoint
participant Model
participant Module
Loader->>Checkpoint: read checkpoint quantizer keys
Loader->>Model: iterate model.named_modules()
alt missing checkpoint key contains "weight_quantizer" and maps to module path
Loader->>Module: call module.disable() if available
else missing key unrelated or cannot map
Loader-->>Loader: raise ValueError
end
Loader->>Model: continue loading remaining state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1246 +/- ##
==========================================
+ Coverage 75.61% 76.56% +0.94%
==========================================
Files 459 459
Lines 48597 48613 +16
==========================================
+ Hits 36747 37220 +473
+ Misses 11850 11393 -457
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
examples/vllm_serve/vllm_reload_utils.py (2)
462-472: Consider filtering checkpoint count symmetrically for clearer comparison.The comparison at line 467 compares
checkpoint_quant_count(all quantizer keys) againstmodel_non_wq_quant_count(excluding weight_quantizer). This works correctly when the checkpoint has no weight_quantizer keys (expected behavior), but the asymmetry could cause a spurious warning if a checkpoint unexpectedly contains weight_quantizer entries.For symmetry and clarity, consider filtering the checkpoint count the same way:
♻️ Suggested change for symmetric comparison
- checkpoint_quant_count = len(checkpoint_quant_keys) - - # Ensure counts match (excluding weight quantizer keys, which may be absent when weights - # were pre-folded at export) - model_non_wq_quant_count = sum(1 for k in model_quant_keys if "weight_quantizer" not in k) - if checkpoint_quant_count != model_non_wq_quant_count: + # Ensure counts match (excluding weight quantizer keys, which may be absent when weights + # were pre-folded at export) + checkpoint_non_wq_quant_count = sum( + 1 for k in checkpoint_quant_keys if "weight_quantizer" not in k + ) + model_non_wq_quant_count = sum(1 for k in model_quant_keys if "weight_quantizer" not in k) + if checkpoint_non_wq_quant_count != model_non_wq_quant_count: warnings.warn( - f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " + f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_non_wq_quant_count} " f"quant keys but model has {model_non_wq_quant_count} non-weight quantizer state keys. " f"This can happen if the model is using PP." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` around lines 462 - 472, The comparison currently uses checkpoint_quant_count (len(checkpoint_quant_keys)) vs model_non_wq_quant_count (sum over model_quant_keys excluding "weight_quantizer"), which is asymmetric; change the checkpoint side to filter out weight quantizer keys the same way (e.g., compute checkpoint_non_wq_quant_count = sum(1 for k in checkpoint_quant_keys if "weight_quantizer" not in k) and compare that to model_non_wq_quant_count) so the warning only fires when non-weight quantizer key counts truly differ; update references to checkpoint_quant_count in the surrounding logic to use the new filtered count.
463-463: Remove trailing whitespace.Line 463 appears to have trailing whitespace. Ruff will flag this.
🧹 Fix
- +🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` at line 463, Remove the trailing whitespace flagged by Ruff in the vllm_reload_utils.py module by deleting the extra spaces at the end of the blank line around the reload utilities area (the stray whitespace on that empty line); ensure the line contains no trailing spaces so the file passes linting.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 462-472: The comparison currently uses checkpoint_quant_count
(len(checkpoint_quant_keys)) vs model_non_wq_quant_count (sum over
model_quant_keys excluding "weight_quantizer"), which is asymmetric; change the
checkpoint side to filter out weight quantizer keys the same way (e.g., compute
checkpoint_non_wq_quant_count = sum(1 for k in checkpoint_quant_keys if
"weight_quantizer" not in k) and compare that to model_non_wq_quant_count) so
the warning only fires when non-weight quantizer key counts truly differ; update
references to checkpoint_quant_count in the surrounding logic to use the new
filtered count.
- Line 463: Remove the trailing whitespace flagged by Ruff in the
vllm_reload_utils.py module by deleting the extra spaces at the end of the blank
line around the reload utilities area (the stray whitespace on that empty line);
ensure the line contains no trailing spaces so the file passes linting.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 4d77ebdd-19dd-4f19-b343-71baeb996154
📒 Files selected for processing (1)
examples/vllm_serve/vllm_reload_utils.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 440-463: The loop that looks for weight_quantizer module suffixes
can silently ignore keys like "model.weight_quantizer_v2._amax" because wq_i may
be None; update the logic in the block processing model_quant_keys (the code
that builds missing_wq_module_paths using wq_i) to not silently skip when wq_i
is None: either raise a ValueError naming the offending key or emit a clear
warning via logging (include the key and explain it didn’t match the expected
"weight_quantizer" module suffix), then proceed to add valid module paths to
missing_wq_module_paths and continue to the later loop that calls
module.disable() on modules found by model.named_modules().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 4625d7a6-f374-4790-ab93-d358d6aa3a27
📒 Files selected for processing (2)
examples/vllm_serve/vllm_reload_utils.pymodelopt/torch/export/plugins/vllm_fakequant_megatron.py
meenchen
left a comment
There was a problem hiding this comment.
1. Correctness — Stricter Missing Key Handling May Break PP
[QUESTION] Removed count-mismatch warning may regress pipeline parallelism
The old code warned when checkpoint and model quantizer key counts didn't match (e.g., PP splits) and proceeded. The new code raises ValueError for any missing non-weight-quantizer key. If PP causes missing input_quantizer keys in the checkpoint (model has more layers than this rank's checkpoint), this would now error where it previously warned. Has this been tested with PP > 1?
2. Correctness — module.disable() without type check
[SUGGESTION] No isinstance guard before calling disable()
for name, module in model.named_modules():
if name in missing_wq_module_paths and hasattr(module, "disable"):
module.disable()The hasattr(module, "disable") check is a reasonable duck-type guard, but if any non-TensorQuantizer module at that path happens to have a disable() method, it would be called unintentionally. Consider adding isinstance(module, TensorQuantizer) for safety.
3. Design — Device juggling in _get_quantized_state
[SUGGESTION] Temporary CUDA lift for quantizer buffers is fragile
When weight is on CPU, buffers are moved to CUDA, the quantizer runs, then buffers are restored. If the quantizer forward internally creates new buffers or registers state, those won't be moved back. The finally block only restores buffers captured before the forward. Consider a comment noting this assumption, or use quantizer.to(quant_device) / quantizer.to(orig_device) for a cleaner round-trip.
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Addressed it in 3e2ec17. Verified it works with PP>1
Addressed it in 3e2ec17
Addressed it in 3e2ec17 |
|
The loading logic changes seem mainly for disabling weight quantizers, disabling weight quantizers seem an often used case, a cleaner way may be directly add an env |
|
What does this PR do?
Type of change: Bug fix
During Megatron→vLLM fakequant export (
export_mcore_gpt_to_hf_vllm_fq), theweight_quantizeris now applied as fake-quantization (quantize + dequantize) directly into the exported weight tensor, and its amax is no longer saved toquantizer_state.pth. On reload, ifweight_quantizerkeys are absent from the checkpoint (because they were folded at export time), the corresponding quantizer modules are disabled.This change is useful especially when amax across experts are not synced for
weight_quantizer, this allows theweight_quantizerto keep them different for better accuracy.Usage
Testing
Step 1 — Quantize (run from Megatron-LM
examples/post_training/modelopt):Step 2 — Export for vLLM fakequant:
Step 3 — Serve (run from examples/vllm_serve):
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
Bug Fixes
Improvements