Commit de55e8a
authored
Bug fix: disable weight quantizer rotation after weight fold during vLLM fakequant export (#1143)
### What does this PR do?
Type of change: Bug fix
When `export_hf_vllm_fq_checkpoint` folds fake-quantized weights into
the HF checkpoint, it applies the weight quantizer's full forward pass —
including any Hadamard rotation — to produce `fake_quant(rotate(W))`.
The weight quantizer is then disabled in the saved ModelOpt state so
vLLM reload skips re-quantization.
However, if `fold_weight` is called after reload (e.g. in
`fakequant_worker.py`), `QuantModule.fold_weight` checks `fake_quant`
but not `is_enabled`, so it calls the disabled weight quantizer. With
`rotate=True`, rotation runs before the
`_disabled` early-return in `TensorQuantizer.forward`, corrupting the
already-folded weight with a second rotation.
This fix clears `_rotate` on weight quantizers alongside `disable()`
before saving the ModelOpt state. Both are restored on the in-memory
model after saving, so the export remains non-mutating.
### Usage
### Testing
```
cd <MODELOPT_PATH>/examples/vllm_serve
Add following config to modelopt/torch/quantization/config.py
NVFP4_ROTATE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"rotate": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"enable": True,
"rotate": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "max",
},
}
python ../llm_ptq/hf_ptq.py --pyt_ckpt_path Qwen/Qwen3-0.6B --quant_cfg NVFP4_ROTATE_CFG --dataset cnn_dailymail --export_path qwen3-rotate --trust_remote_code --batch_size 4 --calib_size 512 --vllm_fakequant_export
MODELOPT_STATE_PATH=qwen3-rotate/vllm_fq_modelopt_state.pth python vllm_serve_fakequant.py qwen3-rotate/ -tp 1 --served-model-name Qwen3-0.6B --host 0.0.0.0 --port 8001 --trust-remote-code --disable-custom-all-reduce --enforce-eager --gpu-memory-utilization 0.8 --max-model-len 32768
```
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
- Did you write any new necessary tests?: N/A
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
N/A
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Preserve and restore rotation settings for quantized weights during
export so exported models keep correct quantization behavior.
* **Documentation**
* Updated changelog: moved vLLM fakequant reload entry to 0.44 (2026-05)
with reference to the vLLM serve example.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>1 parent 2ae407c commit de55e8a
File tree
2 files changed
+19
-3
lines changed- modelopt/torch/export/plugins
2 files changed
+19
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
48 | 49 | | |
49 | 50 | | |
50 | 51 | | |
51 | | - | |
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| |||
28 | 29 | | |
29 | 30 | | |
30 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
31 | 41 | | |
32 | 42 | | |
33 | 43 | | |
| |||
104 | 114 | | |
105 | 115 | | |
106 | 116 | | |
| 117 | + | |
| 118 | + | |
107 | 119 | | |
108 | 120 | | |
109 | 121 | | |
| |||
114 | 126 | | |
115 | 127 | | |
116 | 128 | | |
117 | | - | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
118 | 133 | | |
119 | 134 | | |
120 | 135 | | |
| |||
149 | 164 | | |
150 | 165 | | |
151 | 166 | | |
152 | | - | |
| 167 | + | |
153 | 168 | | |
| 169 | + | |
0 commit comments