Skip to content

Commit f238d93

Browse files
vLLM fakequant fold weight_quantizer for megatron export (#1246)
### What does this PR do? Type of change: Bug fix During Megatron→vLLM fakequant export (`export_mcore_gpt_to_hf_vllm_fq`), the `weight_quantizer` is now applied as fake-quantization (quantize + dequantize) directly into the exported weight tensor, and its amax is no longer saved to `quantizer_state.pth`. On reload, if `weight_quantizer` keys are absent from the checkpoint (because they were folded at export time), the corresponding quantizer modules are disabled. This change is useful especially when amax across experts are not synced for `weight_quantizer`, this allows the `weight_quantizer` to keep them different for better accuracy. ### Usage ```python # Unchanged — export API is the same export_mcore_gpt_to_hf_vllm_fq(model, pretrained_model_name_or_path=..., export_dir=...) ``` ### Testing Step 1 — Quantize (run from Megatron-LM `examples/post_training/modelopt`): ```bash HF_MODEL_CKPT=<path/to/hf/weights> MLM_MODEL_SAVE=<quant-ckpt-name> \ bash quantize.sh <hf-model-id> NVFP4_DEFAULT_CFG ``` Step 2 — Export for vLLM fakequant: ```bash MLM_EXTRA_ARGS=--export-vllm-fq \ HF_MODEL_CKPT=<path/to/hf/weights> \ MLM_MODEL_CKPT=<quant-ckpt-name> \ EXPORT_DIR=<export-dir> \ bash export.sh <hf-model-id> ``` Step 3 — Serve (run from examples/vllm_serve): ```bash QUANT_CFG=NVFP4_DEFAULT_CFG \ QUANT_FILE_PATH=<export-dir>/quantizer_state.pth \ python3 vllm_serve_fakequant.py <export-dir> \ -tp 1 --served-model-name <model-name> \ --host 0.0.0.0 --port 8000 \ --trust-remote-code --enforce-eager \ --disable-custom-all-reduce \ --gpu-memory-utilization 0.8 ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Better handling when loading checkpoints: missing weight-quantizer entries are validated and corresponding modules are disabled to avoid load failures. * **Improvements** * Export now folds enabled weight quantizers into exported weights when present and omits internal weight-quantizer tensors from the exported state to produce cleaner exports. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 9f8188d commit f238d93

File tree

3 files changed

+106
-17
lines changed

3 files changed

+106
-17
lines changed

examples/vllm_serve/fakequant_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
def _fakequant_run_prolog_worker(self) -> None:
5151
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"
52+
5253
tokenizer = AutoTokenizer.from_pretrained(
5354
self.model_runner.model_config.tokenizer, trust_remote_code=trust_remote_code
5455
)

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,32 @@
3131
convert_to_quantized_model,
3232
restore_quantizer_state,
3333
)
34+
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
3435
from modelopt.torch.quantization.utils import is_quantized
3536

3637

38+
def _union_quantizer_keys_across_ranks(local_quantizer_keys: list[str]) -> set[str]:
39+
"""Union of quantizer key strings from every rank (same file on all ranks → identical to local)."""
40+
local = set(local_quantizer_keys)
41+
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
42+
return local
43+
if torch.distributed.get_world_size() <= 1:
44+
return local
45+
try:
46+
world_size = torch.distributed.get_world_size()
47+
gathered: list[list[str]] = [[] for _ in range(world_size)]
48+
torch.distributed.all_gather_object(gathered, list(local_quantizer_keys))
49+
out: set[str] = set()
50+
for g in gathered:
51+
out.update(g)
52+
return out
53+
except Exception as e:
54+
warnings.warn(
55+
f"Could not all_gather quantizer key lists across ranks ({e}); using this rank's keys only."
56+
)
57+
return local
58+
59+
3760
def _values_equal(v1: Any, v2: Any) -> bool:
3861
"""Compare values, handling dicts with tensors."""
3962
if isinstance(v1, dict) and isinstance(v2, dict):
@@ -285,7 +308,7 @@ def filter_modelopt_state_quantizer_state_for_model(
285308
model: Model with quantizers (must already be converted)
286309
"""
287310
from modelopt.torch.quantization.conversion import quantizer_state
288-
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
311+
from modelopt.torch.quantization.nn import TensorQuantizer
289312
from modelopt.torch.utils import get_unwrapped_name
290313

291314
model_qstate = quantizer_state(model)
@@ -435,24 +458,51 @@ def load_state_dict_from_path(
435458
# Count quant keys in checkpoint and model
436459
checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key]
437460
model_quant_keys = [key for key in current_state_dict if "quantizer" in key]
438-
for key in checkpoint_quant_keys:
439-
if key not in model_quant_keys:
440-
print(f"Key {key} not found in model state dict, but exists in checkpoint")
461+
ckpt_key_set = set(checkpoint_quant_keys)
462+
global_ckpt_key_set = _union_quantizer_keys_across_ranks(checkpoint_quant_keys)
463+
# For weight quantizers absent from the checkpoint the weights were already fake-quantized
464+
# at export time (amax folded into weights). Disable those quantizers so that fold_weight
465+
# is a no-op for them. Non-weight keys missing on this rank but present on another rank's
466+
# shard are omitted from global_missing (all_gather union of key strings).
467+
missing_wq_module_paths: set[str] = set()
468+
global_missing_non_wq: list[str] = []
441469
for key in model_quant_keys:
442-
if key not in checkpoint_quant_keys:
443-
raise ValueError(f"Key {key} not found in checkpoint state dict, but exists in model")
444-
445-
checkpoint_quant_count = len(checkpoint_quant_keys)
446-
model_quant_count = len(model_quant_keys)
447-
448-
# Ensure counts match
449-
if checkpoint_quant_count != model_quant_count:
470+
if key in ckpt_key_set:
471+
continue
472+
if "weight_quantizer" in key:
473+
# Per-rank shard: only disable using this rank's checkpoint contents.
474+
parts = key.split(".")
475+
weight_quantizer_index = next(
476+
(i for i, p in enumerate(parts) if p.endswith("weight_quantizer")),
477+
None,
478+
)
479+
if weight_quantizer_index is not None:
480+
missing_wq_module_paths.add(".".join(parts[: weight_quantizer_index + 1]))
481+
else:
482+
raise ValueError(
483+
f"Missing checkpoint key {key!r} looks like a weight quantizer, but no path "
484+
"component ends with 'weight_quantizer'; cannot map to a module to disable."
485+
)
486+
elif key not in global_ckpt_key_set:
487+
global_missing_non_wq.append(key)
488+
489+
if global_missing_non_wq:
490+
keys = sorted(global_missing_non_wq)
491+
n = len(keys)
492+
sample, rest = keys[:8], n - 8
450493
warnings.warn(
451-
f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} "
452-
f"quant keys but model has {model_quant_count} quantizer state keys. "
453-
f"This can happen if the model is using PP."
494+
f"{n} quantizer key(s) missing from every rank's checkpoint (after all_gather):"
495+
f"{sample}{' ... (+{rest} more)' if rest > 0 else ''}"
454496
)
455497

498+
for name, module in model.named_modules():
499+
if (
500+
name in missing_wq_module_paths
501+
and isinstance(module, TensorQuantizer)
502+
and hasattr(module, "disable")
503+
):
504+
module.disable()
505+
456506
# Update quant values
457507
saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict)
458508
for key, value in saved_quant_dict.items():

modelopt/torch/export/plugins/vllm_fakequant_megatron.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ def _get_quantized_state(
117117
) -> tuple[dict[str, torch.Tensor], str, int]:
118118
"""Return a state_dict, quantization format, and block_size of the module.
119119
120+
The weight_quantizer is folded into the weight via fake-quantization
121+
(quantize + dequantize), and its amax is not exported. The vLLM fakequant
122+
reload path is expected to disable the weight quantizer when the amax is absent.
123+
120124
Args:
121125
module: The target module to perform real quantization.
122126
dtype: The default data type.
@@ -133,14 +137,48 @@ def _get_quantized_state(
133137
block_size = 0
134138

135139
if hasattr(module, "weight") and module.weight is not None:
136-
weight = module.weight.to(dtype).cpu()
137-
name_to_value["weight"] = weight
140+
weight = module.weight.to(dtype)
141+
# Fold the weight_quantizer into the weight by applying fake-quantization
142+
# (quantize then dequantize). The weight_quantizer amax is not exported;
143+
# the vLLM fakequant reload path disables the weight quantizer when absent.
144+
weight_quantizer = getattr(module, "weight_quantizer", None)
145+
if weight_quantizer is not None and weight_quantizer.is_enabled:
146+
with torch.no_grad():
147+
# NVFP4-like kernels may need CUDA; if weights are CPU after gather, run on
148+
# CUDA then ``weight_quantizer.to`` back (full module round-trip).
149+
quant_device = (
150+
torch.device("cuda", torch.cuda.current_device())
151+
if weight.device.type == "cpu" and torch.cuda.is_available()
152+
else weight.device
153+
)
154+
# TensorQuantizer does not expose nn.Module.device (custom __getattr__).
155+
param_device = next(weight_quantizer.parameters(), None)
156+
buf_device = next(weight_quantizer.buffers(), None)
157+
wq_dev = (
158+
param_device.device
159+
if param_device is not None
160+
else (buf_device.device if buf_device is not None else torch.device("cpu"))
161+
)
162+
need_move = wq_dev != quant_device
163+
if need_move:
164+
weight_quantizer.to(quant_device)
165+
try:
166+
weight = weight_quantizer(weight.to(quant_device)).to(dtype)
167+
finally:
168+
if need_move:
169+
weight_quantizer.to(wq_dev)
170+
name_to_value["weight"] = weight.cpu()
138171
else:
139172
return name_to_value, qformat, block_size
140173

141174
if hasattr(module, "bias") and module.bias is not None:
142175
name_to_value["bias"] = module.bias.to(dtype).cpu()
176+
177+
# Only save input/output quantizer state; weight_quantizer amax is not exported
178+
# since it has been folded into the weight above.
143179
for name, param in get_quantizer_state_dict(module).items():
180+
if "weight_quantizer" in name:
181+
continue
144182
for key, value in param.items():
145183
name_to_value[name + "." + key] = value.to(dtype).cpu()
146184
return name_to_value, qformat, block_size

0 commit comments

Comments
 (0)