Skip to content

Commit d101aba

Browse files
committed
GPTQ vector and unit test
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 73be810 commit d101aba

2 files changed

Lines changed: 634 additions & 1 deletion

File tree

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@
4747
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
4848

4949

50+
def load_vector_lut_codebook(quantizer):
51+
"""Load vector LUT codebook and quantizer params from a weight_quantizer.
52+
53+
Returns:
54+
Tuple of (codebook, quant_block_size, scale_type).
55+
"""
56+
from luts import encode as luts_encode
57+
58+
extra_args = quantizer.backend_extra_args
59+
encode_format = quantizer.num_bits
60+
encode_path = extra_args.get("encode_path", "")
61+
if encode_path and not encode_path.endswith("/"):
62+
encode_path += "/"
63+
64+
if "sorted" in encode_format:
65+
cb = torch.load(encode_path + encode_format + ".pt", map_location="cpu")
66+
codebook = cb["sorted_values"].cuda().float()
67+
else:
68+
codebook, _ = luts_encode(encode_format, path=encode_path, norm=False, cuda=True)
69+
codebook = codebook.float()
70+
71+
return codebook, extra_args.get("block_sizes"), extra_args.get("scale_type")
72+
73+
5074
def update_hessian(input, hessian, n_samples):
5175
"""Update hessian matrix with new input samples using incremental formula.
5276
@@ -140,11 +164,18 @@ def update_weights(self, block_size, perc_damp):
140164
Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update,
141165
logs MSE, and writes the result back to the module.
142166
"""
167+
backend_extra_args = getattr(self.module.weight_quantizer, "backend_extra_args", None)
168+
is_vector_lut = bool(
169+
backend_extra_args and backend_extra_args.get("lut_type") == "vector_lut"
170+
)
143171
hessian = self.hessian.to(self.module.weight.device)
144172
self.weight = self.module.weight.data.float().clone()
145173
self._prepare_hessian_inverse(hessian, perc_damp)
146174

147-
self._blockwise_update(block_size)
175+
if is_vector_lut:
176+
self._blockwise_vector_update(block_size)
177+
else:
178+
self._blockwise_update(block_size)
148179

149180
self._print_mse_error(hessian)
150181
self.module.weight.data = self.weight.reshape(self.module.weight.shape).to(
@@ -224,6 +255,64 @@ def _blockwise_update(self, block_size):
224255
errs, self.h_inv[block_start:block_end, block_end:], alpha=-1
225256
)
226257

258+
def _blockwise_vector_update(self, block_size):
259+
"""GPTQ blockwise update for vector LUT quantizers.
260+
261+
Pre-computes scales once, then runs the standard GPTQ 3-loop
262+
with per-vector-group static quantization via clip_vector_prescaled.
263+
"""
264+
import torch.nn.functional as F
265+
from luts import clip_vector_prescaled, clip_vector_scalesign_fast
266+
267+
codebook, quant_block_size, scale_type = load_vector_lut_codebook(
268+
self.module.weight_quantizer
269+
)
270+
271+
# Get vector size from codebook
272+
vector_size = codebook.shape[1]
273+
274+
assert self.weight is not None and self.h_inv is not None
275+
num_cols = self.weight.shape[1]
276+
assert block_size % quant_block_size == 0
277+
278+
# Pre-compute scales once outside the GPTQ loop
279+
_, scales = clip_vector_scalesign_fast(
280+
self.weight,
281+
codebook,
282+
quant_block_size,
283+
scale_type,
284+
scale_algo="max",
285+
sign_scale=True,
286+
return_scales=True,
287+
)
288+
scales_2d = scales.reshape(self.weight.shape[0], -1)
289+
290+
w = self.weight.clone()
291+
h_inv = self.h_inv
292+
293+
for blk_start in range(0, num_cols, block_size):
294+
blk_end = min(blk_start + block_size, num_cols)
295+
errs = torch.zeros_like(w[:, blk_start:blk_end])
296+
297+
for j in range(blk_start, blk_end, vector_size):
298+
d = min(vector_size, blk_end - j)
299+
s = scales_2d[:, j // quant_block_size].contiguous()
300+
301+
sub = w[:, j : j + d].contiguous()
302+
if d < vector_size:
303+
sub = F.pad(sub, (0, vector_size - d))
304+
q_sub = clip_vector_prescaled(sub, codebook, s)
305+
306+
for k in range(d):
307+
col = j + k
308+
self.weight[:, col] = q_sub[:, k]
309+
err = (w[:, col] - q_sub[:, k]) / h_inv[col, col]
310+
errs[:, col - blk_start] = err
311+
w[:, col:blk_end].addr_(err, h_inv[col, col:blk_end], alpha=-1)
312+
313+
if blk_end < num_cols:
314+
w[:, blk_end:] -= errs @ h_inv[blk_start:blk_end, blk_end:]
315+
227316
def _print_mse_error(self, hessian):
228317
"""Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
229318
w_orig = self.module.weight.float()

0 commit comments

Comments
 (0)