|
47 | 47 | from modelopt.torch.utils.perf import get_used_gpu_mem_fraction |
48 | 48 |
|
49 | 49 |
|
| 50 | +def load_vector_lut_codebook(quantizer): |
| 51 | + """Load vector LUT codebook and quantizer params from a weight_quantizer. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + Tuple of (codebook, quant_block_size, scale_type). |
| 55 | + """ |
| 56 | + from luts import encode as luts_encode |
| 57 | + |
| 58 | + extra_args = quantizer.backend_extra_args |
| 59 | + encode_format = quantizer.num_bits |
| 60 | + encode_path = extra_args.get("encode_path", "") |
| 61 | + if encode_path and not encode_path.endswith("/"): |
| 62 | + encode_path += "/" |
| 63 | + |
| 64 | + if "sorted" in encode_format: |
| 65 | + cb = torch.load(encode_path + encode_format + ".pt", map_location="cpu") |
| 66 | + codebook = cb["sorted_values"].cuda().float() |
| 67 | + else: |
| 68 | + codebook, _ = luts_encode(encode_format, path=encode_path, norm=False, cuda=True) |
| 69 | + codebook = codebook.float() |
| 70 | + |
| 71 | + return codebook, extra_args.get("block_sizes"), extra_args.get("scale_type") |
| 72 | + |
| 73 | + |
50 | 74 | def update_hessian(input, hessian, n_samples): |
51 | 75 | """Update hessian matrix with new input samples using incremental formula. |
52 | 76 |
|
@@ -140,11 +164,18 @@ def update_weights(self, block_size, perc_damp): |
140 | 164 | Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update, |
141 | 165 | logs MSE, and writes the result back to the module. |
142 | 166 | """ |
| 167 | + backend_extra_args = getattr(self.module.weight_quantizer, "backend_extra_args", None) |
| 168 | + is_vector_lut = bool( |
| 169 | + backend_extra_args and backend_extra_args.get("lut_type") == "vector_lut" |
| 170 | + ) |
143 | 171 | hessian = self.hessian.to(self.module.weight.device) |
144 | 172 | self.weight = self.module.weight.data.float().clone() |
145 | 173 | self._prepare_hessian_inverse(hessian, perc_damp) |
146 | 174 |
|
147 | | - self._blockwise_update(block_size) |
| 175 | + if is_vector_lut: |
| 176 | + self._blockwise_vector_update(block_size) |
| 177 | + else: |
| 178 | + self._blockwise_update(block_size) |
148 | 179 |
|
149 | 180 | self._print_mse_error(hessian) |
150 | 181 | self.module.weight.data = self.weight.reshape(self.module.weight.shape).to( |
@@ -224,6 +255,64 @@ def _blockwise_update(self, block_size): |
224 | 255 | errs, self.h_inv[block_start:block_end, block_end:], alpha=-1 |
225 | 256 | ) |
226 | 257 |
|
| 258 | + def _blockwise_vector_update(self, block_size): |
| 259 | + """GPTQ blockwise update for vector LUT quantizers. |
| 260 | +
|
| 261 | + Pre-computes scales once, then runs the standard GPTQ 3-loop |
| 262 | + with per-vector-group static quantization via clip_vector_prescaled. |
| 263 | + """ |
| 264 | + import torch.nn.functional as F |
| 265 | + from luts import clip_vector_prescaled, clip_vector_scalesign_fast |
| 266 | + |
| 267 | + codebook, quant_block_size, scale_type = load_vector_lut_codebook( |
| 268 | + self.module.weight_quantizer |
| 269 | + ) |
| 270 | + |
| 271 | + # Get vector size from codebook |
| 272 | + vector_size = codebook.shape[1] |
| 273 | + |
| 274 | + assert self.weight is not None and self.h_inv is not None |
| 275 | + num_cols = self.weight.shape[1] |
| 276 | + assert block_size % quant_block_size == 0 |
| 277 | + |
| 278 | + # Pre-compute scales once outside the GPTQ loop |
| 279 | + _, scales = clip_vector_scalesign_fast( |
| 280 | + self.weight, |
| 281 | + codebook, |
| 282 | + quant_block_size, |
| 283 | + scale_type, |
| 284 | + scale_algo="max", |
| 285 | + sign_scale=True, |
| 286 | + return_scales=True, |
| 287 | + ) |
| 288 | + scales_2d = scales.reshape(self.weight.shape[0], -1) |
| 289 | + |
| 290 | + w = self.weight.clone() |
| 291 | + h_inv = self.h_inv |
| 292 | + |
| 293 | + for blk_start in range(0, num_cols, block_size): |
| 294 | + blk_end = min(blk_start + block_size, num_cols) |
| 295 | + errs = torch.zeros_like(w[:, blk_start:blk_end]) |
| 296 | + |
| 297 | + for j in range(blk_start, blk_end, vector_size): |
| 298 | + d = min(vector_size, blk_end - j) |
| 299 | + s = scales_2d[:, j // quant_block_size].contiguous() |
| 300 | + |
| 301 | + sub = w[:, j : j + d].contiguous() |
| 302 | + if d < vector_size: |
| 303 | + sub = F.pad(sub, (0, vector_size - d)) |
| 304 | + q_sub = clip_vector_prescaled(sub, codebook, s) |
| 305 | + |
| 306 | + for k in range(d): |
| 307 | + col = j + k |
| 308 | + self.weight[:, col] = q_sub[:, k] |
| 309 | + err = (w[:, col] - q_sub[:, k]) / h_inv[col, col] |
| 310 | + errs[:, col - blk_start] = err |
| 311 | + w[:, col:blk_end].addr_(err, h_inv[col, col:blk_end], alpha=-1) |
| 312 | + |
| 313 | + if blk_end < num_cols: |
| 314 | + w[:, blk_end:] -= errs @ h_inv[blk_start:blk_end, blk_end:] |
| 315 | + |
227 | 316 | def _print_mse_error(self, hessian): |
228 | 317 | """Log Hessian-weighted relative MSE between ``self.weight`` and original weights.""" |
229 | 318 | w_orig = self.module.weight.float() |
|
0 commit comments