Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

def _fakequant_run_prolog_worker(self) -> None:
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"

tokenizer = AutoTokenizer.from_pretrained(
self.model_runner.model_config.tokenizer,
trust_remote_code=trust_remote_code,
Expand Down
33 changes: 21 additions & 12 deletions examples/vllm_serve/vllm_reload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import re
import warnings
from collections import defaultdict
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -438,20 +437,30 @@ def load_state_dict_from_path(
for key in checkpoint_quant_keys:
if key not in model_quant_keys:
print(f"Key {key} not found in model state dict, but exists in checkpoint")
# For weight quantizers absent from the checkpoint the weights were already fake-quantized
# at export time (amax folded into weights). Disable those quantizers so that fold_weight
# is a no-op for them. Any other missing quantizer key is still an error.
missing_wq_module_paths: set[str] = set()
for key in model_quant_keys:
if key not in checkpoint_quant_keys:
raise ValueError(f"Key {key} not found in checkpoint state dict, but exists in model")

checkpoint_quant_count = len(checkpoint_quant_keys)
model_quant_count = len(model_quant_keys)
if "weight_quantizer" in key:
# State dict keys continue past the submodule (e.g. ...weight_quantizer._amax).
# named_modules() names stop at the weight_quantizer module; strip the suffix.
parts = key.split(".")
wq_i = next(
(i for i, p in enumerate(parts) if p.endswith("weight_quantizer")),
None,
)
if wq_i is not None:
missing_wq_module_paths.add(".".join(parts[: wq_i + 1]))
else:
raise ValueError(
f"Key {key} not found in checkpoint state dict, but exists in model"
)

# Ensure counts match
if checkpoint_quant_count != model_quant_count:
warnings.warn(
f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} "
f"quant keys but model has {model_quant_count} quantizer state keys. "
f"This can happen if the model is using PP."
)
for name, module in model.named_modules():
if name in missing_wq_module_paths and hasattr(module, "disable"):
module.disable()
Comment thread
kinjalpatel27 marked this conversation as resolved.

# Update quant values
saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict)
Expand Down
35 changes: 33 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def _get_quantized_state(
) -> tuple[dict[str, torch.Tensor], str, int]:
"""Return a state_dict, quantization format, and block_size of the module.

The weight_quantizer is folded into the weight via fake-quantization
(quantize + dequantize), and its amax is not exported. The vLLM fakequant
reload path is expected to disable the weight quantizer when the amax is absent.

Args:
module: The target module to perform real quantization.
dtype: The default data type.
Expand All @@ -133,14 +137,41 @@ def _get_quantized_state(
block_size = 0

if hasattr(module, "weight") and module.weight is not None:
weight = module.weight.to(dtype).cpu()
name_to_value["weight"] = weight
weight = module.weight.to(dtype)
# Fold the weight_quantizer into the weight by applying fake-quantization
# (quantize then dequantize). The weight_quantizer amax is not exported;
Comment thread
kinjalpatel27 marked this conversation as resolved.
# the vLLM fakequant reload path disables the weight quantizer when absent.
weight_quantizer = getattr(module, "weight_quantizer", None)
if weight_quantizer is not None and weight_quantizer.is_enabled:
with torch.no_grad():
# Some quantizers (e.g. NVFP4) require CUDA. If the model landed on
# CPU after TP gather, lift to the current CUDA device for the forward
# pass, then restore buffer devices.
quant_device = (
torch.device("cuda", torch.cuda.current_device())
if weight.device.type == "cpu" and torch.cuda.is_available()
else weight.device
)
buf_devices = [(buf, buf.device) for buf in weight_quantizer.buffers()]
for buf, _ in buf_devices:
buf.data = buf.data.to(quant_device)
try:
weight = weight_quantizer(weight.to(quant_device)).to(dtype)
finally:
for buf, orig_device in buf_devices:
buf.data = buf.data.to(orig_device)
name_to_value["weight"] = weight.cpu()
else:
return name_to_value, qformat, block_size

if hasattr(module, "bias") and module.bias is not None:
name_to_value["bias"] = module.bias.to(dtype).cpu()

# Only save input/output quantizer state; weight_quantizer amax is not exported
# since it has been folded into the weight above.
for name, param in get_quantizer_state_dict(module).items():
if "weight_quantizer" in name:
continue
for key, value in param.items():
name_to_value[name + "." + key] = value.to(dtype).cpu()
return name_to_value, qformat, block_size
Expand Down
Loading