Skip to content

Commit e8243e9

Browse files
committed
update
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 78f0d8a commit e8243e9

2 files changed

Lines changed: 21 additions & 77 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
reduce_amax,
5050
weight_attr_names,
5151
)
52-
from .utils.calib_utils import GPTQHelper
52+
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5353

5454
__all__ = [
5555
"awq",
@@ -1663,7 +1663,12 @@ def gptq(
16631663
print_rank_0("No quantized linear layers found, skipping GPTQ")
16641664
return
16651665

1666-
gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers}
1666+
def _make_gptq_handle(name, m):
1667+
backend = getattr(m.weight_quantizer, "backend", None)
1668+
cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper)
1669+
return cls(m, name, offload_to_cpu=True)
1670+
1671+
gptq_handles = {name: _make_gptq_handle(name, m) for name, m in quantized_layers}
16671672
for handle in gptq_handles.values():
16681673
handle.setup()
16691674

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 14 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
"""GPTQ helper and Hessian utilities for calibration."""
4040

4141
import math
42-
from contextlib import contextmanager
4342

4443
import torch
4544

@@ -135,22 +134,6 @@ def free(self):
135134
self.weight = None
136135
self.h_inv = None
137136

138-
def is_vector_lut(self) -> bool:
139-
"""Check if this module's weight quantizer is configured for vector LUT quantization."""
140-
extra_args = getattr(self.module.weight_quantizer, "backend_extra_args", None)
141-
return bool(extra_args and extra_args.get("lut_type") == "vector_lut")
142-
143-
@contextmanager
144-
def weight_slice_scales(self, scales):
145-
"""Temporarily replace _psx_scales with per-slice scales, then restore."""
146-
quantizer = self.module.weight_quantizer
147-
original = quantizer._psx_scales
148-
quantizer._psx_scales = scales
149-
try:
150-
yield
151-
finally:
152-
quantizer._psx_scales = original
153-
154137
def update_weights(self, block_size, perc_damp):
155138
"""Run GPTQ blockwise weight update on this module.
156139
@@ -160,12 +143,7 @@ def update_weights(self, block_size, perc_damp):
160143
hessian = self.hessian.to(self.module.weight.device)
161144
self.weight = self.module.weight.data.float().clone()
162145
self._prepare_hessian_inverse(hessian, perc_damp)
163-
164-
if self.is_vector_lut():
165-
self._blockwise_vector_update(block_size)
166-
else:
167-
self._blockwise_update(block_size)
168-
146+
self._blockwise_update(block_size)
169147
self._print_mse_error(hessian)
170148
self.module.weight.data = self.weight.reshape(self.module.weight.shape).to(
171149
self.module.weight.data.dtype
@@ -244,62 +222,23 @@ def _blockwise_update(self, block_size):
244222
errs, self.h_inv[block_start:block_end, block_end:], alpha=-1
245223
)
246224

247-
def _blockwise_vector_update(self, block_size):
248-
"""GPTQ blockwise update for vector quantizers.
249-
250-
A single ``quantizer(weight)`` call computes and caches per-block
251-
scales (``_psx_scales``) on the quantizer via the backend's
252-
``static_scales`` path. The GPTQ loop then slices per-vector-group
253-
scales from ``_psx_scales`` for each sub-vector quantization call.
254-
"""
255-
import torch.nn.functional as F
256-
257-
quantizer = self.module.weight_quantizer
258-
assert quantizer.backend_extra_args.get("static_scales", False), (
259-
"GPTQ vector update requires static_scales=True in backend_extra_args."
260-
)
261-
vector_size = quantizer.backend_extra_args["vector_size"]
262-
quant_block_size = quantizer.backend_extra_args["block_sizes"]
263-
264-
assert self.weight is not None and self.h_inv is not None
265-
num_cols = self.weight.shape[1]
266-
assert block_size % quant_block_size == 0
267-
268-
# Compute and cache _psx_scales on the quantizer via the backend.
269-
quantizer(self.weight)
270-
271-
w = self.weight.clone()
272-
h_inv = self.h_inv
273-
274-
for blk_start in range(0, num_cols, block_size):
275-
blk_end = min(blk_start + block_size, num_cols)
276-
errs = torch.zeros_like(w[:, blk_start:blk_end])
277-
278-
for j in range(blk_start, blk_end, vector_size):
279-
d = min(vector_size, blk_end - j)
280-
s = quantizer._psx_scales[:, j // quant_block_size].contiguous()
281-
282-
sub = w[:, j : j + d].contiguous()
283-
if d < vector_size:
284-
sub = F.pad(sub, (0, vector_size - d))
285-
286-
with self.weight_slice_scales(s):
287-
q_sub = quantizer(sub)
288-
289-
for k in range(d):
290-
col = j + k
291-
self.weight[:, col] = q_sub[:, k]
292-
err = (w[:, col] - q_sub[:, k]) / h_inv[col, col]
293-
errs[:, col - blk_start] = err
294-
w[:, col:blk_end].addr_(err, h_inv[col, col:blk_end], alpha=-1)
295-
296-
if blk_end < num_cols:
297-
w[:, blk_end:] -= errs @ h_inv[blk_start:blk_end, blk_end:]
298-
299225
def _print_mse_error(self, hessian):
300226
"""Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
301227
w_orig = self.module.weight.float()
302228
delta = self.weight - w_orig
303229
mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6)
304230
suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else ""
305231
print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}")
232+
233+
234+
_GPTQ_HELPER_REGISTRY: dict[str, type[GPTQHelper]] = {}
235+
236+
237+
def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
238+
"""Register a :class:`GPTQHelper` subclass for a quantizer backend.
239+
240+
When :func:`modelopt.torch.quantization.model_calib.gptq` encounters a
241+
module whose ``weight_quantizer.backend`` matches ``backend``, it will
242+
construct ``factory`` instead of the default ``GPTQHelper``.
243+
"""
244+
_GPTQ_HELPER_REGISTRY[backend] = factory

0 commit comments

Comments
 (0)