Skip to content

Commit 78f0d8a

Browse files
committed
update
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 7a29dc0 commit 78f0d8a

1 file changed

Lines changed: 34 additions & 51 deletions

File tree

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 34 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"""GPTQ helper and Hessian utilities for calibration."""
4040

4141
import math
42+
from contextlib import contextmanager
4243

4344
import torch
4445

@@ -47,30 +48,6 @@
4748
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
4849

4950

50-
def load_vector_lut_codebook(quantizer):
51-
"""Load vector LUT codebook and quantizer params from a weight_quantizer.
52-
53-
Returns:
54-
Tuple of (codebook, quant_block_size, scale_type).
55-
"""
56-
from luts import encode as luts_encode
57-
58-
extra_args = quantizer.backend_extra_args
59-
encode_format = quantizer.num_bits
60-
encode_path = extra_args.get("encode_path", "")
61-
if encode_path and not encode_path.endswith("/"):
62-
encode_path += "/"
63-
64-
if "sorted" in encode_format:
65-
cb = torch.load(encode_path + encode_format + ".pt", map_location="cpu")
66-
codebook = cb["sorted_values"].cuda().float()
67-
else:
68-
codebook, _ = luts_encode(encode_format, path=encode_path, norm=False, cuda=True)
69-
codebook = codebook.float()
70-
71-
return codebook, extra_args.get("block_sizes"), extra_args.get("scale_type")
72-
73-
7451
def update_hessian(input, hessian, n_samples):
7552
"""Update hessian matrix with new input samples using incremental formula.
7653
@@ -158,21 +135,33 @@ def free(self):
158135
self.weight = None
159136
self.h_inv = None
160137

138+
def is_vector_lut(self) -> bool:
139+
"""Check if this module's weight quantizer is configured for vector LUT quantization."""
140+
extra_args = getattr(self.module.weight_quantizer, "backend_extra_args", None)
141+
return bool(extra_args and extra_args.get("lut_type") == "vector_lut")
142+
143+
@contextmanager
144+
def weight_slice_scales(self, scales):
145+
"""Temporarily replace _psx_scales with per-slice scales, then restore."""
146+
quantizer = self.module.weight_quantizer
147+
original = quantizer._psx_scales
148+
quantizer._psx_scales = scales
149+
try:
150+
yield
151+
finally:
152+
quantizer._psx_scales = original
153+
161154
def update_weights(self, block_size, perc_damp):
162155
"""Run GPTQ blockwise weight update on this module.
163156
164157
Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update,
165158
logs MSE, and writes the result back to the module.
166159
"""
167-
backend_extra_args = getattr(self.module.weight_quantizer, "backend_extra_args", None)
168-
is_vector_lut = bool(
169-
backend_extra_args and backend_extra_args.get("lut_type") == "vector_lut"
170-
)
171160
hessian = self.hessian.to(self.module.weight.device)
172161
self.weight = self.module.weight.data.float().clone()
173162
self._prepare_hessian_inverse(hessian, perc_damp)
174163

175-
if is_vector_lut:
164+
if self.is_vector_lut():
176165
self._blockwise_vector_update(block_size)
177166
else:
178167
self._blockwise_update(block_size)
@@ -256,36 +245,28 @@ def _blockwise_update(self, block_size):
256245
)
257246

258247
def _blockwise_vector_update(self, block_size):
259-
"""GPTQ blockwise update for vector LUT quantizers.
248+
"""GPTQ blockwise update for vector quantizers.
260249
261-
Pre-computes scales once, then runs the standard GPTQ 3-loop
262-
with per-vector-group static quantization via clip_vector_prescaled.
250+
A single ``quantizer(weight)`` call computes and caches per-block
251+
scales (``_psx_scales``) on the quantizer via the backend's
252+
``static_scales`` path. The GPTQ loop then slices per-vector-group
253+
scales from ``_psx_scales`` for each sub-vector quantization call.
263254
"""
264255
import torch.nn.functional as F
265-
from luts import clip_vector_prescaled, clip_vector_scalesign_fast
266256

267-
codebook, quant_block_size, scale_type = load_vector_lut_codebook(
268-
self.module.weight_quantizer
257+
quantizer = self.module.weight_quantizer
258+
assert quantizer.backend_extra_args.get("static_scales", False), (
259+
"GPTQ vector update requires static_scales=True in backend_extra_args."
269260
)
270-
271-
# Get vector size from codebook
272-
vector_size = codebook.shape[1]
261+
vector_size = quantizer.backend_extra_args["vector_size"]
262+
quant_block_size = quantizer.backend_extra_args["block_sizes"]
273263

274264
assert self.weight is not None and self.h_inv is not None
275265
num_cols = self.weight.shape[1]
276266
assert block_size % quant_block_size == 0
277267

278-
# Pre-compute scales once outside the GPTQ loop
279-
_, scales = clip_vector_scalesign_fast(
280-
self.weight,
281-
codebook,
282-
quant_block_size,
283-
scale_type,
284-
scale_algo="max",
285-
sign_scale=True,
286-
return_scales=True,
287-
)
288-
scales_2d = scales.reshape(self.weight.shape[0], -1)
268+
# Compute and cache _psx_scales on the quantizer via the backend.
269+
quantizer(self.weight)
289270

290271
w = self.weight.clone()
291272
h_inv = self.h_inv
@@ -296,12 +277,14 @@ def _blockwise_vector_update(self, block_size):
296277

297278
for j in range(blk_start, blk_end, vector_size):
298279
d = min(vector_size, blk_end - j)
299-
s = scales_2d[:, j // quant_block_size].contiguous()
280+
s = quantizer._psx_scales[:, j // quant_block_size].contiguous()
300281

301282
sub = w[:, j : j + d].contiguous()
302283
if d < vector_size:
303284
sub = F.pad(sub, (0, vector_size - d))
304-
q_sub = clip_vector_prescaled(sub, codebook, s)
285+
286+
with self.weight_slice_scales(s):
287+
q_sub = quantizer(sub)
305288

306289
for k in range(d):
307290
col = j + k

0 commit comments

Comments
 (0)