Skip to content

Commit 8fe1efd

Browse files
authored
add int4_gptaq (#49)
1 parent fa7304e commit 8fe1efd

14 files changed

Lines changed: 455 additions & 7 deletions

File tree

angelslim/compressor/quant/core/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self, config, global_config=None):
138138
self.hidden_size = global_config.hidden_size
139139
self.model_arch_type = global_config.model_arch_type
140140
self.low_memory = config.quantization.low_memory
141-
elif "int4_gptq" in self.quant_algo:
141+
elif "int4_gptq" in self.quant_algo or "int4_gptaq" in self.quant_algo:
142142
self.act_observer = None
143143
self.weight_observer = None
144144
self.kv_cache_observer = None

angelslim/compressor/quant/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .awq.awq import AWQ # noqa: F401
1616
from .fp8.fp8 import FP8 # noqa: F401
1717
from .fp8.lepto_fp8 import LeptoFP8 # noqa: F401
18+
from .gptq.gptaq_module import GPTAQModule # noqa: F401
1819
from .gptq.gptq import GPTQ # noqa: F401
1920
from .gptq.gptq_module import GPTQModule # noqa: F401
2021
from .helper_layer import GPTQQuantLinear # noqa: F401
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import time
17+
18+
import torch
19+
20+
from .....utils import get_tensor_item, print_info
21+
from ...core import compute_scales_with_zero
22+
23+
__all__ = ["GPTAQModule"]
24+
25+
26+
class GPTAQModule:
27+
def __init__(self, layer, quant_bits=4):
28+
"""
29+
GPTAQ quantization wrapper for neural network layers.
30+
31+
Args:
32+
layer: Full-precision torch.nn.Module to quantize (Linear)
33+
quant_bits: Quantization bitwidth (2-8 bits, default=4)
34+
"""
35+
super(GPTAQModule, self).__init__()
36+
self.layer = layer
37+
self.dev = self.layer.weight.device
38+
self.w = layer.weight.data.clone()
39+
self.rows = self.w.shape[0]
40+
self.columns = self.w.shape[1]
41+
self.h = torch.zeros((self.columns, self.columns), device=self.dev)
42+
self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
43+
self.nsamples = 0
44+
self.quant_bits = quant_bits
45+
46+
def add_batch(self, inp, out, native_inp):
47+
if len(inp.shape) == 4:
48+
inp = inp[0, 0, :, :]
49+
native_inp = native_inp[0, 0, :, :]
50+
inp = inp.squeeze()
51+
native_inp = native_inp.squeeze()
52+
if len(inp.shape) == 2:
53+
inp = inp.unsqueeze(0)
54+
native_inp = native_inp.unsqueeze(0)
55+
tmp = inp.shape[0]
56+
if len(inp.shape) == 3:
57+
inp = inp.reshape((-1, inp.shape[-1]))
58+
native_inp = native_inp.reshape((-1, native_inp.shape[-1]))
59+
inp = inp.t()
60+
native_inp = native_inp.t()
61+
self.h *= self.nsamples / (self.nsamples + tmp)
62+
self.dXXT *= self.nsamples / (self.nsamples + tmp)
63+
self.nsamples += tmp
64+
inp = math.sqrt(2 / self.nsamples) * inp.float()
65+
self.h += inp.matmul(inp.t())
66+
native_inp = math.sqrt(2 / self.nsamples) * native_inp
67+
self.dXXT += (native_inp - inp).matmul(inp.t())
68+
69+
def fasterquant(
70+
self,
71+
blocksize=128,
72+
percdamp=0.01,
73+
group_size=-1,
74+
actorder=True,
75+
sym=True,
76+
):
77+
w_weight = self.w.float()
78+
79+
tick = time.time()
80+
81+
hessian = self.h
82+
if torch.isnan(hessian).any():
83+
print_info("[error] Hessian contains nan!")
84+
exit()
85+
self.h.detach().cpu()
86+
del self.h
87+
dead = torch.diag(hessian) == 0
88+
hessian[dead, dead] = 1
89+
w_weight[:, dead] = 0
90+
self.dXXT[:, dead] = 0
91+
92+
g_idx = []
93+
scale = []
94+
zero = []
95+
now_idx = 1
96+
static_groups = True
97+
98+
if static_groups:
99+
for i in range(0, self.columns, group_size):
100+
weight_scale, weight_zero = compute_scales_with_zero(
101+
w_weight[:, i : (i + group_size)], bits=self.quant_bits, sym=sym
102+
)
103+
scale.append(weight_scale)
104+
zero.append(weight_zero)
105+
106+
if actorder:
107+
perm = torch.argsort(torch.diag(hessian), descending=True)
108+
w_weight = w_weight[:, perm]
109+
hessian = hessian[perm][:, perm]
110+
self.dXXT = self.dXXT[perm][:, perm]
111+
invperm = torch.argsort(perm)
112+
113+
losses = torch.zeros_like(w_weight)
114+
q_weight = torch.zeros_like(w_weight)
115+
116+
while 1 > percdamp > 0:
117+
try:
118+
damp = percdamp * torch.mean(torch.diag(hessian))
119+
diag = torch.arange(self.columns, device=self.dev)
120+
hessian[diag, diag] += damp
121+
hessian = torch.linalg.cholesky(hessian)
122+
hessian = torch.cholesky_inverse(hessian)
123+
hessian = torch.linalg.cholesky(hessian, upper=True)
124+
hinv = hessian
125+
break
126+
except torch._C._LinAlgError as e:
127+
print_info(e)
128+
print_info(f"Cholesky failed with percdamp={percdamp:.5f}")
129+
percdamp += 0.01
130+
131+
P = ((self.dXXT @ hinv.T).triu(diagonal=1)) @ hinv
132+
del self.dXXT
133+
134+
for i1 in range(0, self.columns, blocksize):
135+
i2 = min(i1 + blocksize, self.columns)
136+
count = i2 - i1
137+
138+
w1 = w_weight[:, i1:i2].clone()
139+
q1 = torch.zeros_like(w1)
140+
err1 = torch.zeros_like(w1)
141+
losses1 = torch.zeros_like(w1)
142+
hinv1 = hinv[i1:i2, i1:i2]
143+
P1 = P[i1:i2, i1:i2]
144+
145+
for i in range(count):
146+
w = w1[:, i]
147+
d = hinv1[i, i]
148+
149+
if group_size != -1:
150+
if not static_groups:
151+
if (i1 + i) % group_size == 0:
152+
weight_scale, weight_zero = compute_scales_with_zero(
153+
w_weight[:, (i1 + i) : (i1 + i + group_size)],
154+
bits=self.quant_bits,
155+
sym=sym,
156+
)
157+
158+
if ((i1 + i) // group_size) - now_idx == -1:
159+
scale.append(weight_scale)
160+
zero.append(weight_zero)
161+
now_idx += 1
162+
else:
163+
idx = i1 + i
164+
if actorder:
165+
idx = perm[idx]
166+
weight_scale = scale[idx // group_size]
167+
weight_zero = zero[idx // group_size]
168+
169+
maxq = torch.tensor(2**self.quant_bits - 1)
170+
q = torch.clamp(
171+
torch.round(w.unsqueeze(1) / weight_scale) + weight_zero, 0, maxq
172+
)
173+
q = weight_scale * (q - weight_zero)
174+
q = q.flatten()
175+
q1[:, i] = q
176+
losses1[:, i] = (w - q) ** 2 / d**2
177+
178+
err = (w - q) / d
179+
w1[:, i:] -= err.unsqueeze(1).matmul(
180+
hinv1[i, i:].unsqueeze(0)
181+
) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
182+
err1[:, i] = err
183+
184+
q_weight[:, i1:i2] = q1
185+
losses[:, i1:i2] = losses1 / 2
186+
187+
w_weight[:, i2:] -= err1.matmul(hinv[i1:i2, i2:]) - w1.matmul(P[i1:i2, i2:])
188+
189+
torch.cuda.synchronize()
190+
print_info(f" duration: {(time.time() - tick)}")
191+
print_info(f" avg loss: {torch.sum(losses).item() / self.nsamples}")
192+
193+
group_size = group_size if group_size != -1 else self.columns
194+
if static_groups and actorder:
195+
g_idx = [perm[i] // group_size for i in range(self.columns)]
196+
else:
197+
g_idx = [i // group_size for i in range(self.columns)]
198+
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=q_weight.device)
199+
if actorder:
200+
q_weight = q_weight[:, invperm]
201+
g_idx = g_idx[invperm]
202+
203+
norm_loss = torch.norm(
204+
q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
205+
- self.layer.weight.data
206+
)
207+
all_norm_loss = [norm_loss]
208+
209+
print_info(" self.layer.weight: {}, {}".format(q_weight.shape, q_weight.sum()))
210+
print_info(f" norm loss: {list(map(get_tensor_item, all_norm_loss))}")
211+
212+
self.layer.weight.data.copy_(
213+
q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
214+
)
215+
216+
if scale == []:
217+
scale = weight_scale
218+
zero = torch.zeros_like(weight_scale)
219+
scale = torch.cat(scale, dim=1)
220+
zero = torch.cat(zero, dim=1)
221+
losses = losses.cpu()
222+
q_weight = q_weight.cpu()
223+
w_weight = w_weight.cpu()
224+
hessian = hessian.cpu()
225+
hinv = hinv.cpu()
226+
del losses, q_weight, w_weight, hessian, hinv, P
227+
self.w = self.w.cpu()
228+
del self.w
229+
torch.cuda.empty_cache()
230+
return scale, zero, g_idx
231+
232+
def free(self):
233+
self.h = None
234+
self.w = None
235+
self.losses = None
236+
torch.cuda.empty_cache()

angelslim/compressor/quant/modules/gptq/gptq.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .....utils import print_info
2424
from ...modules.catcher import Catcher
2525
from ...modules.helper_layer import GPTQQuantLinear
26+
from .gptaq_module import GPTAQModule
2627
from .gptq_module import GPTQModule
2728

2829
__all__ = ["GPTQ"]
@@ -51,6 +52,8 @@ def __init__(
5152
self.dtype = next(iter(self.layers.parameters())).dtype
5253
self.quantizers = {}
5354
self.gptq = {}
55+
self.quant_algo = self.model.quant_config.quant_algo
56+
self.native_inp_caches = {}
5457

5558
@torch.no_grad()
5659
def run(self, dataloader):
@@ -86,6 +89,8 @@ def run(self, dataloader):
8689
torch.cuda.empty_cache()
8790

8891
outs = torch.zeros_like(inps)
92+
if "gptaq" in self.quant_algo:
93+
native_inps = inps.clone().detach()
8994
# begin the gptq process
9095
print_info("Ready.")
9196

@@ -96,18 +101,61 @@ def run(self, dataloader):
96101
subset = self._find_layers(layer)
97102
print_info("subset:{}".format(subset))
98103
self.gptq = {}
104+
if "gptaq" in self.quant_algo:
105+
self.native_inp_caches = {}
99106
print_info("GPTQMoe start layer {}".format(i))
100107
for name in subset:
101108
if name in self.ignore_layers:
102109
continue
103-
self.gptq[name] = GPTQModule(subset[name], quant_bits=self.quant_bits)
110+
if "gptaq" in self.quant_algo:
111+
self.native_inp_caches[name] = []
112+
self.gptq[name] = GPTAQModule(
113+
subset[name], quant_bits=self.quant_bits
114+
)
115+
else:
116+
self.gptq[name] = GPTQModule(
117+
subset[name], quant_bits=self.quant_bits
118+
)
119+
120+
def pre_process_fwd_hook(layer_name):
121+
def tmp(_, inp, out):
122+
self.native_inp_caches[layer_name] += [inp[0].data]
123+
del inp, out
124+
125+
return tmp
104126

105127
def add_batch(layer_name):
106128
def tmp(_, inp, out):
107-
self.gptq[layer_name].add_batch(inp[0].data, out.data)
129+
if "gptaq" in self.quant_algo:
130+
native_inp = self.native_inp_caches[layer_name].pop(0)
131+
self.gptq[layer_name].add_batch(
132+
inp[0].data, out.data, native_inp
133+
)
134+
else:
135+
self.gptq[layer_name].add_batch(inp[0].data, out.data)
108136

109137
return tmp
110138

139+
if "gptaq" in self.quant_algo:
140+
native_handles = []
141+
for name in self.native_inp_caches:
142+
native_handles.append(
143+
subset[name].register_forward_hook(pre_process_fwd_hook(name))
144+
)
145+
146+
# being native hook
147+
for j in range(nsamples):
148+
with torch.no_grad():
149+
outs[j, :, :] = layer(
150+
hidden_states=native_inps[j, :, :].unsqueeze(0),
151+
**layer_kwargs,
152+
)[0].squeeze(1)
153+
native_inps = outs
154+
155+
print_info("Native HOOK Step{}".format(j))
156+
for h in native_handles:
157+
h.remove()
158+
111159
handles = []
112160
for name in self.gptq:
113161
handles.append(subset[name].register_forward_hook(add_batch(name)))

angelslim/compressor/quant/ptq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, model, slim_config=None):
4545
self.ptq_hook = PTQHook(self.quant_model)
4646
self.ptq_hook.apply_hook()
4747

48-
if "gptq" in self.quant_algo:
48+
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
4949
max_seq_length = self.quant_model.quant_config.max_seq_length
5050
hidden_size = self.quant_model.quant_config.hidden_size
5151
self.gptq = GPTQ(
@@ -105,7 +105,7 @@ def __init__(self, model, slim_config=None):
105105
)
106106

107107
def calibrate(self, dataloader):
108-
if "gptq" in self.quant_algo:
108+
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
109109
self.gptq.run(dataloader)
110110
elif "awq" in self.quant_algo:
111111
self.awq.run(dataloader)
@@ -123,7 +123,7 @@ def convert(self):
123123
Saves scales and inserts QDQ modules.
124124
"""
125125
print_info("Start convert model...")
126-
if "gptq" in self.quant_algo:
126+
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
127127
self.gptq.convert()
128128
elif "awq" in self.quant_algo:
129129
self.awq.convert()
@@ -166,7 +166,7 @@ def save(self, save_path: str):
166166
)
167167

168168
print_info("Start save PTQ ckpt to: {}".format(save_path))
169-
if "gptq" in self.quant_algo:
169+
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
170170
self.gptq.save(save_path)
171171
elif "awq" in self.quant_algo:
172172
self.awq.save(save_path)

angelslim/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"int4_awq": default_compress_config.default_int4_awq_config(),
3333
"int4_gptq": default_compress_config.default_int4_gptq_config(),
3434
"w4a8_fp8": default_compress_config.default_w4a8_fp8_static_config(),
35+
"int4_gptaq": default_compress_config.default_int4_gptaq_config(),
3536
}
3637

3738

angelslim/models/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def model_forward(self, dataloader, **kwargs):
256256
if (
257257
"gptq" in self.quant_config.quant_algo
258258
or "awq" in self.quant_config.quant_algo
259+
or "gptaq" in self.quant_config.quant_algo
259260
):
260261
device = "cuda:0"
261262
else:

0 commit comments

Comments
 (0)