Skip to content

Commit 010b220

Browse files
vLLM fakequant export update for AWQ checkpoint (#1242)
### What does this PR do? Type of change: Bug Enables end-to-end AWQ checkpoint export and reload in the vLLM fake-quant serving path (`MODELOPT_STATE_PATH`). Previously, the `input_quantizer` was using incorrect `pre_quant_scale` especially with grouped quantizers like `qkv_proj`, using simply the first `input_quantizer.pre_quant_scale`. This MR adds `_resmooth_experts_for_export` that non-mutatively averages `pre_quant_scale` across MoE experts and unifies input `_amax`, required because vLLM uses a single input quantizer per expert group. Adds `merge_amax_tensors_for_group` (element-wise max for same-shape, `cat` for GQA, scalar-max fallback) replacing the scalar-collapsing `torch.stack().max()` that dropped per-channel `_amax` structure. ### Usage ```python # Export AWQ checkpoint from HF model from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint export_hf_vllm_fq_checkpoint(model, export_dir="./awq_vllm_checkpoint") ``` ### Testing **Step 1 — Export the quantized checkpoint:** ```bash python examples/llm_ptq/hf_ptq.py \ --pyt_ckpt_path <MODEL_PATH> \ --recipe <AWQ_RECIPE> \ --calib_size 512 \ --export_path <EXPORT_DIR> \ --vllm_fakequant_export ``` This produces `<EXPORT_DIR>/vllm_fq_modelopt_state.pth` with the averaged per-expert pre_quant_scale and unified _amax now included. Step 2 — Serve via vLLM fakequant worker: ```bash MODELOPT_STATE_PATH=<EXPORT_DIR>/vllm_fq_modelopt_state.pth \ python examples/vllm_serve/vllm_serve_fakequant.py \ <EXPORT_DIR> --tensor-parallel-size <TP> ``` Tested for quantization configurations: ``` FP8_DEFAULT_CFG FP8_DEFAULT_CFG (input_q disabled) INT8_SMOOTHQUANT_CFG INT8_WEIGHT_ONLY_CFG NVFP4_DEFAULT_CFG NVFP4_AWQ_LITE_CFG INT4_AWQ_CFG NVFP4_AWQ_CFG NVFP4_DEFAULT_CFG (input_q disabled) ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Nemotron-style MoE export support and group-aware AWQ resmoothing with optional requantization during export. * Improved handling for shared-input / expert groups and tensor-parallel sharding of pre-quantization scales. * **Bug Fixes** * Removed AWQ reload limitation from known issues; improved checkpoint validation and safer save/load behavior. * Better detection and handling of enabled weight-quantizers and clearer warnings for mismatched checkpoint keys. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 26ae8da commit 010b220

File tree

8 files changed

+911
-161
lines changed

8 files changed

+911
-161
lines changed

examples/vllm_serve/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,5 @@ QUANT_CFG=<quant_cfg> QUANT_FILE_PATH=<quantizer_state.pth> python vllm_serve_fa
9898
## Known Problems
9999

100100
1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align).
101-
2. AWQ reload is not supported yet
102-
3. KV cache quantization export and reload is not supported in MCore yet.
103-
4. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs.
101+
2. KV cache quantization export and reload is not supported in MCore yet.
102+
3. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs.

examples/vllm_serve/fakequant_worker.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
import os
18+
import warnings
1819
from typing import Any
1920

2021
import torch
@@ -26,13 +27,16 @@
2627
convert_modelopt_state_to_vllm,
2728
load_state_dict_from_path,
2829
restore_from_modelopt_state_vllm,
30+
shard_pre_quant_scale_for_tp,
2931
)
3032

3133
import modelopt.torch.quantization as mtq
34+
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
3235
from modelopt.torch.quantization.plugins.vllm import (
3336
disable_compilation,
3437
post_restore_vllm_parallel_linears,
3538
)
39+
from modelopt.torch.utils import safe_load
3640
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
3741

3842
quant_config: dict[str, Any] = {
@@ -61,28 +65,48 @@ def _fakequant_run_prolog_worker(self) -> None:
6165
model = model.unwrap()
6266
if quant_config["modelopt_state_path"]:
6367
print(f"Loading modelopt state from {quant_config['modelopt_state_path']}")
64-
# Load on CPU to avoid failures when the checkpoint was saved from a different
65-
# GPU mapping
66-
modelopt_state = torch.load(
67-
quant_config["modelopt_state_path"], weights_only=True, map_location="cpu"
68-
)
68+
# Load on CPU to avoid failures when the checkpoint was saved from a different GPU mapping.
69+
modelopt_state = safe_load(quant_config["modelopt_state_path"], map_location="cpu")
6970
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
7071
map_fun = (
7172
self.model_runner.model.hf_to_vllm_mapper.apply_dict
7273
if hasattr(self.model_runner.model, "hf_to_vllm_mapper")
7374
else None
7475
)
75-
# convert modelopt state to vllm format
7676
modelopt_state = convert_modelopt_state_to_vllm(modelopt_state, map_fun=map_fun)
77-
# restore model from modelopt state
7877
restore_from_modelopt_state_vllm(model, modelopt_state)
7978

8079
if modelopt_weights is not None:
81-
# convert quantizer state values to vllm format
8280
modelopt_weights = convert_dict_to_vllm(modelopt_weights, map_fun=map_fun)
8381
mtq.utils.set_quantizer_state_dict(model, modelopt_weights)
84-
# set_quantizer_state_dict does not invoke modelopt_post_restore (unlike restore_quantizer_state).
82+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
83+
from modelopt.torch.quantization.nn import TensorQuantizer
84+
from modelopt.torch.utils import get_unwrapped_name
85+
86+
loaded_keys = {
87+
get_unwrapped_name(n, model)
88+
for n, m in model.named_modules()
89+
if isinstance(m, TensorQuantizer)
90+
}
91+
# Same namespace as ``loaded_keys``: checkpoint keys may include DDP/FSDP
92+
# prefixes that ``convert_dict_to_vllm`` does not strip.
93+
pqs_in_weights = {
94+
get_unwrapped_name(k, model)
95+
for k, v in modelopt_weights.items()
96+
if isinstance(v, dict) and "_pre_quant_scale" in v
97+
}
98+
unmatched_pqs = pqs_in_weights - loaded_keys
99+
if unmatched_pqs:
100+
sample = sorted(unmatched_pqs)[:20]
101+
warnings.warn(
102+
f"{len(unmatched_pqs)} checkpoint pre_quant_scale key(s) have no "
103+
f"matching TensorQuantizer in the model (showing up to 20): {sample}",
104+
stacklevel=2,
105+
)
106+
# set_quantizer_state_dict does not run modelopt_post_restore (unlike restore_quantizer_state).
85107
post_restore_vllm_parallel_linears(model)
108+
# Must follow post_restore: shard_pre_quant_scale_for_tp uses weight H_in vs pqs length.
109+
shard_pre_quant_scale_for_tp(model)
86110

87111
else:
88112
if quant_config["quant_file_path"]:
@@ -101,15 +125,13 @@ def _fakequant_run_prolog_worker(self) -> None:
101125

102126
quant_cfg = get_quant_config(quant_config, model)
103127

104-
# quantize model
105128
with disable_compilation(model):
106129
print("Quantizing model...")
107130
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
108131

109132
quantizer_file_path = quant_config["quant_file_path"]
110133
if quantizer_file_path:
111-
# Get amax and other quantizer state from the quantizer file
112-
# this can be used with Megatron-LM exported model using export_mcore_gpt_to_hf_vllm_fq
134+
self.model_runner._dummy_run(1)
113135
current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model)
114136
model.load_state_dict(current_state_dict)
115137

@@ -122,8 +144,11 @@ def _fakequant_run_prolog_worker(self) -> None:
122144

123145
mtq.fold_weight(model)
124146
for name, module in model.named_modules():
125-
if name.endswith("weight_quantizer"):
126-
assert not module.is_enabled, f"quantizer {name} is still enabled"
147+
if is_weight_quantizer_state_key(name) and module.is_enabled:
148+
raise RuntimeError(
149+
f"Weight quantizer {name!r} is still enabled after fold_weight — "
150+
"double-quantization would corrupt activations."
151+
)
127152

128153

129154
class FakeQuantWorker(BaseWorker):

0 commit comments

Comments
 (0)