Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import glob
import hashlib
import inspect
import json
import logging
Expand Down Expand Up @@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
else:
print("No custom model files found to copy")


def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
algorithm = quant_cfg.get("algorithm")
if not isinstance(algorithm, dict):
return False
return algorithm.get("layerwise_checkpoint_dir") is not None


def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to layerwise_checkpoint_dir.

Allows a single recipe to be reused across models without checkpoint collisions.
Must only be called when :func:`needs_checkpoint_path_update` returns True.
"""
algorithm = quant_cfg["algorithm"]
base_dir = algorithm["layerwise_checkpoint_dir"]

name = model_path.rstrip("/")
if "/" in name and not os.path.isabs(name):
name = name.replace("/", "--")
else:
name = Path(name).name

config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8]

quant_cfg = copy.deepcopy(quant_cfg)
quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join(
base_dir, f"{name}_{config_hash}"
)
return quant_cfg
17 changes: 14 additions & 3 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
is_enc_dec,
is_nemotron_vl,
load_mtp_weights,
needs_checkpoint_path_update,
resolve_checkpoint_dir,
run_nemotron_vl_preview,
)
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
for i, entry in enumerate(quant_cfg):
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
assert isinstance(entry.get("cfg", {}), dict)
quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}}
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
break


Expand Down Expand Up @@ -759,7 +762,9 @@ def export_quantized(
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
# Store the MTP layer prefixes on the model for later exclusion from quantization
if args.vllm_fakequant_export:
export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path)
export_hf_vllm_fq_checkpoint(
full_model, export_dir=export_path, inplace_mem_efficient=True
)
else:
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(
full_model, args.pyt_ckpt_path
Expand Down Expand Up @@ -1104,6 +1109,12 @@ def quantize_main(
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if needs_checkpoint_path_update(quant_cfg):
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
print(
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
)

if args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
Expand Down
162 changes: 114 additions & 48 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from modelopt.torch.quantization.conversion import quantizer_state
from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer
from modelopt.torch.quantization.utils import get_quantizer_state_dict
from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback
from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector
from modelopt.torch.utils import get_unwrapped_name

__all__ = ["export_hf_vllm_fq_checkpoint"]
Expand All @@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer):
return False


def _fakequant_module_weights(
module: nn.Module,
module_name: str,
model: nn.Module,
state_dict: dict | None,
input_quantizers_folded_pqs: set,
fakequant_weights: set,
inplace: bool,
):
"""Apply fake-quant to a single QuantModule's weights.

When ``inplace=False``, reads/writes weights from/to ``state_dict``.
When ``inplace=True``, modifies the module's weight parameters directly.
"""
if not isinstance(module, QuantModule):
return
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized"

if inplace:
w = getattr(module, weight_name)
w_quant = quantizer(w.float()).to(w.dtype)
else:
assert state_dict is not None
if sd_key not in state_dict:
continue
w = state_dict[sd_key]
w_quant = quantizer(w.float()).to(w.dtype)

# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
# fake_quant(x*s), the non-linearity prevents folding s into W.
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
inp_q = getattr(module, inp_attr)
if (
hasattr(inp_q, "_pre_quant_scale")
and inp_q._pre_quant_scale is not None
and inp_q._disabled
):
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
input_quantizers_folded_pqs.add(inp_q_key)

if inplace:
w.data.copy_(w_quant)
else:
assert state_dict is not None
state_dict[sd_key] = w_quant.cpu()
fakequant_weights.add(sd_key)


def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
inplace_mem_efficient: bool = False,
):
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.

Expand All @@ -53,59 +121,56 @@ def export_hf_vllm_fq_checkpoint(
Args:
model: In-memory quantized model.
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at
a time using ``enable_weight_access_and_writeback``, avoiding full state
dict materialization. This is destructive — model weights are permanently
modified and weight quantizers are not re-enabled after export.
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

# Step 1: Build the folded HF state dict.
# model.state_dict() returns detached copies of all tensors, so model
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()
fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
) # keys for input_quantizers where pre_quant_scale was folded
input_quantizers_folded_pqs = set()
with torch.inference_mode():
for module_name, module in model.named_modules():
if not isinstance(module, QuantModule):
continue
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
if inplace_mem_efficient:
# Inplace path: iterate decoder layers, one offload<->onload per layer.
decoder_layers = LayerActivationCollector.get_decoder_layers(model)
assert decoder_layers is not None, (
"inplace_mem_efficient=True requires a model with discoverable decoder layers"
)
for name, module in model.named_modules():
if module not in decoder_layers:
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in fakequant_weights, (
f"Weight {sd_key} has already been fakequantized"
)
if sd_key in state_dict:
w = state_dict[sd_key]
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
# fake_quant(x*s), the non-linearity prevents folding s into W.
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
inp_q = getattr(module, inp_attr)
if (
hasattr(inp_q, "_pre_quant_scale")
and inp_q._pre_quant_scale is not None
and inp_q._disabled
):
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
input_quantizers_folded_pqs.add(inp_q_key)
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
with enable_weight_access_and_writeback(module, module):
for sub_name, sub_mod in module.named_modules():
full_name = f"{name}.{sub_name}" if sub_name else name
_fakequant_module_weights(
sub_mod,
full_name,
model,
None,
input_quantizers_folded_pqs,
fakequant_weights,
inplace=True,
)
# Meta tensors for offloaded weights (free); offload maps now have
# fakequanted values via writeback.
state_dict = model.state_dict()
else:
# Default path: full state_dict copy, fakequant into the copy.
state_dict = model.state_dict()
for module_name, module in model.named_modules():
with enable_weight_access_and_writeback(module, model):
_fakequant_module_weights(
module,
module_name,
model,
state_dict,
input_quantizers_folded_pqs,
fakequant_weights,
inplace=False,
)

# Filter quantizer tensors out for a clean HF checkpoint.
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
Expand Down Expand Up @@ -164,6 +229,7 @@ def export_hf_vllm_fq_checkpoint(
# Step 3: Save HF weights using the pre-built folded state dict.
model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)

for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
if not inplace_mem_efficient:
for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
28 changes: 24 additions & 4 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

use_sequential: bool = ModeloptField(
layerwise: bool = ModeloptField(
default=False,
title="Enable sequential layer-by-layer calibration.",
title="Enable layerwise (layer-by-layer) calibration.",
description=(
"If True, the calibration algorithm is applied sequentially to each decoder block. "
"Each layer's inputs are captured via a single forward pass that reflects the "
"If True, the calibration algorithm is applied layer by layer. "
"Each layer's inputs are captured via a forward pass that reflects the "
"quantization of all preceding layers, incurring O(N) forward passes for N layers."
),
)

layerwise_checkpoint_dir: str | None = ModeloptField(
default=None,
title="Checkpoint directory for layerwise calibration.",
description=(
"If set together with layerwise=True, per-layer checkpoints are saved to this "
"directory during calibration. On restart, calibration resumes from the last "
"completed layer."
),
)

@model_validator(mode="after")
def validate_layerwise_checkpoint_dir(self):
"""Raise if layerwise_checkpoint_dir is set but layerwise is False."""
if self.layerwise_checkpoint_dir is not None and not self.layerwise:
raise ValueError(
"layerwise_checkpoint_dir requires layerwise=True. "
"Set layerwise=True or remove layerwise_checkpoint_dir."
)
return self


class MaxCalibConfig(QuantizeAlgorithmConfig):
"""The config for max calibration algorithm.
Expand Down
17 changes: 9 additions & 8 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
from .model_calib import (
awq,
gptq,
layerwise_calibrate,
local_hessian_calibrate,
max_calibrate,
mse_calibrate,
sequential_calibrate,
smoothquant,
svdquant,
)
Expand Down Expand Up @@ -222,7 +222,8 @@ def wrapped_calib_func(
"""
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)
layerwise = kwargs.pop("layerwise", False)
checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand All @@ -237,17 +238,17 @@ def wrapped_calib_func(
module._moe_calib_experts_ratio = moe_calib_experts_ratio

if func is not None:
if sequential:
if layerwise:
# All currently implemented PTQ algorithms support layerwise calibration;
# future algorithms that need full-model context must add a guard here.
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
Comment thread
realAsma marked this conversation as resolved.
assert method in ["max", "gptq"], (
f"Sequential calibration currently only supports max and gptq calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
# Wrap with layerwise processing
layerwise_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
checkpoint_dir=checkpoint_dir,
**kwargs,
)
else:
Expand Down
Loading
Loading