|
| 1 | +# Copyright 2025 Tencent Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import math |
| 16 | +import time |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from .....utils import get_tensor_item, print_info |
| 21 | +from ...core import compute_scales_with_zero |
| 22 | + |
| 23 | +__all__ = ["GPTAQModule"] |
| 24 | + |
| 25 | + |
| 26 | +class GPTAQModule: |
| 27 | + def __init__(self, layer, quant_bits=4): |
| 28 | + """ |
| 29 | + GPTAQ quantization wrapper for neural network layers. |
| 30 | +
|
| 31 | + Args: |
| 32 | + layer: Full-precision torch.nn.Module to quantize (Linear) |
| 33 | + quant_bits: Quantization bitwidth (2-8 bits, default=4) |
| 34 | + """ |
| 35 | + super(GPTAQModule, self).__init__() |
| 36 | + self.layer = layer |
| 37 | + self.dev = self.layer.weight.device |
| 38 | + self.w = layer.weight.data.clone() |
| 39 | + self.rows = self.w.shape[0] |
| 40 | + self.columns = self.w.shape[1] |
| 41 | + self.h = torch.zeros((self.columns, self.columns), device=self.dev) |
| 42 | + self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev) |
| 43 | + self.nsamples = 0 |
| 44 | + self.quant_bits = quant_bits |
| 45 | + |
| 46 | + def add_batch(self, inp, out, native_inp): |
| 47 | + if len(inp.shape) == 4: |
| 48 | + inp = inp[0, 0, :, :] |
| 49 | + native_inp = native_inp[0, 0, :, :] |
| 50 | + inp = inp.squeeze() |
| 51 | + native_inp = native_inp.squeeze() |
| 52 | + if len(inp.shape) == 2: |
| 53 | + inp = inp.unsqueeze(0) |
| 54 | + native_inp = native_inp.unsqueeze(0) |
| 55 | + tmp = inp.shape[0] |
| 56 | + if len(inp.shape) == 3: |
| 57 | + inp = inp.reshape((-1, inp.shape[-1])) |
| 58 | + native_inp = native_inp.reshape((-1, native_inp.shape[-1])) |
| 59 | + inp = inp.t() |
| 60 | + native_inp = native_inp.t() |
| 61 | + self.h *= self.nsamples / (self.nsamples + tmp) |
| 62 | + self.dXXT *= self.nsamples / (self.nsamples + tmp) |
| 63 | + self.nsamples += tmp |
| 64 | + inp = math.sqrt(2 / self.nsamples) * inp.float() |
| 65 | + self.h += inp.matmul(inp.t()) |
| 66 | + native_inp = math.sqrt(2 / self.nsamples) * native_inp |
| 67 | + self.dXXT += (native_inp - inp).matmul(inp.t()) |
| 68 | + |
| 69 | + def fasterquant( |
| 70 | + self, |
| 71 | + blocksize=128, |
| 72 | + percdamp=0.01, |
| 73 | + group_size=-1, |
| 74 | + actorder=True, |
| 75 | + sym=True, |
| 76 | + ): |
| 77 | + w_weight = self.w.float() |
| 78 | + |
| 79 | + tick = time.time() |
| 80 | + |
| 81 | + hessian = self.h |
| 82 | + if torch.isnan(hessian).any(): |
| 83 | + print_info("[error] Hessian contains nan!") |
| 84 | + exit() |
| 85 | + self.h.detach().cpu() |
| 86 | + del self.h |
| 87 | + dead = torch.diag(hessian) == 0 |
| 88 | + hessian[dead, dead] = 1 |
| 89 | + w_weight[:, dead] = 0 |
| 90 | + self.dXXT[:, dead] = 0 |
| 91 | + |
| 92 | + g_idx = [] |
| 93 | + scale = [] |
| 94 | + zero = [] |
| 95 | + now_idx = 1 |
| 96 | + static_groups = True |
| 97 | + |
| 98 | + if static_groups: |
| 99 | + for i in range(0, self.columns, group_size): |
| 100 | + weight_scale, weight_zero = compute_scales_with_zero( |
| 101 | + w_weight[:, i : (i + group_size)], bits=self.quant_bits, sym=sym |
| 102 | + ) |
| 103 | + scale.append(weight_scale) |
| 104 | + zero.append(weight_zero) |
| 105 | + |
| 106 | + if actorder: |
| 107 | + perm = torch.argsort(torch.diag(hessian), descending=True) |
| 108 | + w_weight = w_weight[:, perm] |
| 109 | + hessian = hessian[perm][:, perm] |
| 110 | + self.dXXT = self.dXXT[perm][:, perm] |
| 111 | + invperm = torch.argsort(perm) |
| 112 | + |
| 113 | + losses = torch.zeros_like(w_weight) |
| 114 | + q_weight = torch.zeros_like(w_weight) |
| 115 | + |
| 116 | + while 1 > percdamp > 0: |
| 117 | + try: |
| 118 | + damp = percdamp * torch.mean(torch.diag(hessian)) |
| 119 | + diag = torch.arange(self.columns, device=self.dev) |
| 120 | + hessian[diag, diag] += damp |
| 121 | + hessian = torch.linalg.cholesky(hessian) |
| 122 | + hessian = torch.cholesky_inverse(hessian) |
| 123 | + hessian = torch.linalg.cholesky(hessian, upper=True) |
| 124 | + hinv = hessian |
| 125 | + break |
| 126 | + except torch._C._LinAlgError as e: |
| 127 | + print_info(e) |
| 128 | + print_info(f"Cholesky failed with percdamp={percdamp:.5f}") |
| 129 | + percdamp += 0.01 |
| 130 | + |
| 131 | + P = ((self.dXXT @ hinv.T).triu(diagonal=1)) @ hinv |
| 132 | + del self.dXXT |
| 133 | + |
| 134 | + for i1 in range(0, self.columns, blocksize): |
| 135 | + i2 = min(i1 + blocksize, self.columns) |
| 136 | + count = i2 - i1 |
| 137 | + |
| 138 | + w1 = w_weight[:, i1:i2].clone() |
| 139 | + q1 = torch.zeros_like(w1) |
| 140 | + err1 = torch.zeros_like(w1) |
| 141 | + losses1 = torch.zeros_like(w1) |
| 142 | + hinv1 = hinv[i1:i2, i1:i2] |
| 143 | + P1 = P[i1:i2, i1:i2] |
| 144 | + |
| 145 | + for i in range(count): |
| 146 | + w = w1[:, i] |
| 147 | + d = hinv1[i, i] |
| 148 | + |
| 149 | + if group_size != -1: |
| 150 | + if not static_groups: |
| 151 | + if (i1 + i) % group_size == 0: |
| 152 | + weight_scale, weight_zero = compute_scales_with_zero( |
| 153 | + w_weight[:, (i1 + i) : (i1 + i + group_size)], |
| 154 | + bits=self.quant_bits, |
| 155 | + sym=sym, |
| 156 | + ) |
| 157 | + |
| 158 | + if ((i1 + i) // group_size) - now_idx == -1: |
| 159 | + scale.append(weight_scale) |
| 160 | + zero.append(weight_zero) |
| 161 | + now_idx += 1 |
| 162 | + else: |
| 163 | + idx = i1 + i |
| 164 | + if actorder: |
| 165 | + idx = perm[idx] |
| 166 | + weight_scale = scale[idx // group_size] |
| 167 | + weight_zero = zero[idx // group_size] |
| 168 | + |
| 169 | + maxq = torch.tensor(2**self.quant_bits - 1) |
| 170 | + q = torch.clamp( |
| 171 | + torch.round(w.unsqueeze(1) / weight_scale) + weight_zero, 0, maxq |
| 172 | + ) |
| 173 | + q = weight_scale * (q - weight_zero) |
| 174 | + q = q.flatten() |
| 175 | + q1[:, i] = q |
| 176 | + losses1[:, i] = (w - q) ** 2 / d**2 |
| 177 | + |
| 178 | + err = (w - q) / d |
| 179 | + w1[:, i:] -= err.unsqueeze(1).matmul( |
| 180 | + hinv1[i, i:].unsqueeze(0) |
| 181 | + ) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0)) |
| 182 | + err1[:, i] = err |
| 183 | + |
| 184 | + q_weight[:, i1:i2] = q1 |
| 185 | + losses[:, i1:i2] = losses1 / 2 |
| 186 | + |
| 187 | + w_weight[:, i2:] -= err1.matmul(hinv[i1:i2, i2:]) - w1.matmul(P[i1:i2, i2:]) |
| 188 | + |
| 189 | + torch.cuda.synchronize() |
| 190 | + print_info(f" duration: {(time.time() - tick)}") |
| 191 | + print_info(f" avg loss: {torch.sum(losses).item() / self.nsamples}") |
| 192 | + |
| 193 | + group_size = group_size if group_size != -1 else self.columns |
| 194 | + if static_groups and actorder: |
| 195 | + g_idx = [perm[i] // group_size for i in range(self.columns)] |
| 196 | + else: |
| 197 | + g_idx = [i // group_size for i in range(self.columns)] |
| 198 | + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=q_weight.device) |
| 199 | + if actorder: |
| 200 | + q_weight = q_weight[:, invperm] |
| 201 | + g_idx = g_idx[invperm] |
| 202 | + |
| 203 | + norm_loss = torch.norm( |
| 204 | + q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data) |
| 205 | + - self.layer.weight.data |
| 206 | + ) |
| 207 | + all_norm_loss = [norm_loss] |
| 208 | + |
| 209 | + print_info(" self.layer.weight: {}, {}".format(q_weight.shape, q_weight.sum())) |
| 210 | + print_info(f" norm loss: {list(map(get_tensor_item, all_norm_loss))}") |
| 211 | + |
| 212 | + self.layer.weight.data.copy_( |
| 213 | + q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data) |
| 214 | + ) |
| 215 | + |
| 216 | + if scale == []: |
| 217 | + scale = weight_scale |
| 218 | + zero = torch.zeros_like(weight_scale) |
| 219 | + scale = torch.cat(scale, dim=1) |
| 220 | + zero = torch.cat(zero, dim=1) |
| 221 | + losses = losses.cpu() |
| 222 | + q_weight = q_weight.cpu() |
| 223 | + w_weight = w_weight.cpu() |
| 224 | + hessian = hessian.cpu() |
| 225 | + hinv = hinv.cpu() |
| 226 | + del losses, q_weight, w_weight, hessian, hinv, P |
| 227 | + self.w = self.w.cpu() |
| 228 | + del self.w |
| 229 | + torch.cuda.empty_cache() |
| 230 | + return scale, zero, g_idx |
| 231 | + |
| 232 | + def free(self): |
| 233 | + self.h = None |
| 234 | + self.w = None |
| 235 | + self.losses = None |
| 236 | + torch.cuda.empty_cache() |
0 commit comments