Skip to content

Commit de55e8a

Browse files
Bug fix: disable weight quantizer rotation after weight fold during vLLM fakequant export (#1143)
### What does this PR do? Type of change: Bug fix When `export_hf_vllm_fq_checkpoint` folds fake-quantized weights into the HF checkpoint, it applies the weight quantizer's full forward pass — including any Hadamard rotation — to produce `fake_quant(rotate(W))`. The weight quantizer is then disabled in the saved ModelOpt state so vLLM reload skips re-quantization. However, if `fold_weight` is called after reload (e.g. in `fakequant_worker.py`), `QuantModule.fold_weight` checks `fake_quant` but not `is_enabled`, so it calls the disabled weight quantizer. With `rotate=True`, rotation runs before the `_disabled` early-return in `TensorQuantizer.forward`, corrupting the already-folded weight with a second rotation. This fix clears `_rotate` on weight quantizers alongside `disable()` before saving the ModelOpt state. Both are restored on the in-memory model after saving, so the export remains non-mutating. ### Usage ### Testing ``` cd <MODELOPT_PATH>/examples/vllm_serve Add following config to modelopt/torch/quantization/config.py NVFP4_ROTATE_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, "enable": True, "rotate": True, }, "*input_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, "enable": True, "rotate": True, }, **_default_disabled_quantizer_cfg, }, "algorithm": { "method": "max", }, } python ../llm_ptq/hf_ptq.py --pyt_ckpt_path Qwen/Qwen3-0.6B --quant_cfg NVFP4_ROTATE_CFG --dataset cnn_dailymail --export_path qwen3-rotate --trust_remote_code --batch_size 4 --calib_size 512 --vllm_fakequant_export MODELOPT_STATE_PATH=qwen3-rotate/vllm_fq_modelopt_state.pth python vllm_serve_fakequant.py qwen3-rotate/ -tp 1 --served-model-name Qwen3-0.6B --host 0.0.0.0 --port 8001 --trust-remote-code --disable-custom-all-reduce --enforce-eager --gpu-memory-utilization 0.8 --max-model-len 32768 ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Preserve and restore rotation settings for quantized weights during export so exported models keep correct quantization behavior. * **Documentation** * Updated changelog: moved vLLM fakequant reload entry to 0.44 (2026-05) with reference to the vLLM serve example. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 2ae407c commit de55e8a

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ NVIDIA Model Optimizer Changelog
1010
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1111
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1212
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
13+
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1314

1415
**Bug Fixes**
1516

@@ -48,7 +49,6 @@ NVIDIA Model Optimizer Changelog
4849
- Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
4950
- Add support for block-granular RHT for non-power-of-2 dimensions.
5051
- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes.
51-
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
5252

5353
**Deprecations**
5454

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn as nn
2121

2222
import modelopt.torch.opt as mto
23+
from modelopt.torch.quantization.config import RotateConfig
2324
from modelopt.torch.quantization.conversion import quantizer_state
2425
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
2526
from modelopt.torch.quantization.utils import get_quantizer_state_dict
@@ -28,6 +29,15 @@
2829
__all__ = ["export_hf_vllm_fq_checkpoint"]
2930

3031

32+
def disable_rotate(quantizer: TensorQuantizer):
33+
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
34+
if isinstance(quantizer._rotate, RotateConfig):
35+
return RotateConfig(enable=False)
36+
if isinstance(quantizer._rotate, dict): # backward compat: old checkpoints stored a dict
37+
return dict(quantizer._rotate, enable=False)
38+
return False
39+
40+
3141
def export_hf_vllm_fq_checkpoint(
3242
model: nn.Module,
3343
export_dir: Path | str,
@@ -104,6 +114,8 @@ def export_hf_vllm_fq_checkpoint(
104114
# dict, then re-enable. The _disabled=True flag is captured in modelopt_state
105115
# so that on vLLM reload weight quantizers stay off while input/output/
106116
# attention quantizers remain active.
117+
# Rotation is also cleared: the weight was already folded with rotation applied,
118+
# so if fold_weight is called on reload it must not re-rotate the exported weight.
107119
wqs_to_restore = []
108120
for _, module in model.named_modules():
109121
if isinstance(module, QuantModule):
@@ -114,7 +126,10 @@ def export_hf_vllm_fq_checkpoint(
114126
and quantizer.is_enabled
115127
):
116128
quantizer.disable()
117-
wqs_to_restore.append(quantizer)
129+
orig_rotate = quantizer._rotate
130+
if quantizer.rotate_is_enabled:
131+
quantizer._rotate = disable_rotate(quantizer)
132+
wqs_to_restore.append((quantizer, orig_rotate))
118133

119134
quantizer_state_dict = get_quantizer_state_dict(model)
120135
for key in list(quantizer_state_dict):
@@ -149,5 +164,6 @@ def export_hf_vllm_fq_checkpoint(
149164
# Step 3: Save HF weights using the pre-built folded state dict.
150165
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
151166

152-
for wq in wqs_to_restore:
167+
for wq, orig_rotate in wqs_to_restore:
153168
wq.enable()
169+
wq._rotate = orig_rotate

0 commit comments

Comments
 (0)