Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion angelslim/compressor/quant/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, config, global_config=None):
self.hidden_size = global_config.hidden_size
self.model_arch_type = global_config.model_arch_type
self.low_memory = config.quantization.low_memory
elif "int4_gptq" in self.quant_algo:
elif "int4_gptq" in self.quant_algo or "int4_gptaq" in self.quant_algo:
self.act_observer = None
self.weight_observer = None
self.kv_cache_observer = None
Expand Down
1 change: 1 addition & 0 deletions angelslim/compressor/quant/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .awq.awq import AWQ # noqa: F401
from .fp8.fp8 import FP8 # noqa: F401
from .fp8.lepto_fp8 import LeptoFP8 # noqa: F401
from .gptq.gptaq_module import GPTAQModule # noqa: F401
from .gptq.gptq import GPTQ # noqa: F401
from .gptq.gptq_module import GPTQModule # noqa: F401
from .helper_layer import GPTQQuantLinear # noqa: F401
Expand Down
236 changes: 236 additions & 0 deletions angelslim/compressor/quant/modules/gptq/gptaq_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import time

import torch

from .....utils import get_tensor_item, print_info
from ...core import compute_scales_with_zero

__all__ = ["GPTAQModule"]


class GPTAQModule:
def __init__(self, layer, quant_bits=4):
"""
GPTAQ quantization wrapper for neural network layers.

Args:
layer: Full-precision torch.nn.Module to quantize (Linear)
quant_bits: Quantization bitwidth (2-8 bits, default=4)
"""
super(GPTAQModule, self).__init__()
self.layer = layer
self.dev = self.layer.weight.device
self.w = layer.weight.data.clone()
self.rows = self.w.shape[0]
self.columns = self.w.shape[1]
self.h = torch.zeros((self.columns, self.columns), device=self.dev)
self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quant_bits = quant_bits

def add_batch(self, inp, out, native_inp):
if len(inp.shape) == 4:
inp = inp[0, 0, :, :]
native_inp = native_inp[0, 0, :, :]
inp = inp.squeeze()
native_inp = native_inp.squeeze()
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
native_inp = native_inp.unsqueeze(0)
tmp = inp.shape[0]
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
native_inp = native_inp.reshape((-1, native_inp.shape[-1]))
inp = inp.t()
native_inp = native_inp.t()
self.h *= self.nsamples / (self.nsamples + tmp)
self.dXXT *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
inp = math.sqrt(2 / self.nsamples) * inp.float()
self.h += inp.matmul(inp.t())
native_inp = math.sqrt(2 / self.nsamples) * native_inp
self.dXXT += (native_inp - inp).matmul(inp.t())

def fasterquant(
self,
blocksize=128,
percdamp=0.01,
group_size=-1,
actorder=True,
sym=True,
):
w_weight = self.w.float()

tick = time.time()

hessian = self.h
if torch.isnan(hessian).any():
print_info("[error] Hessian contains nan!")
exit()
self.h.detach().cpu()
del self.h
dead = torch.diag(hessian) == 0
hessian[dead, dead] = 1
w_weight[:, dead] = 0
self.dXXT[:, dead] = 0

g_idx = []
scale = []
zero = []
now_idx = 1
static_groups = True

if static_groups:
for i in range(0, self.columns, group_size):
weight_scale, weight_zero = compute_scales_with_zero(
w_weight[:, i : (i + group_size)], bits=self.quant_bits, sym=sym
)
scale.append(weight_scale)
zero.append(weight_zero)

if actorder:
perm = torch.argsort(torch.diag(hessian), descending=True)
w_weight = w_weight[:, perm]
hessian = hessian[perm][:, perm]
self.dXXT = self.dXXT[perm][:, perm]
invperm = torch.argsort(perm)

losses = torch.zeros_like(w_weight)
q_weight = torch.zeros_like(w_weight)

while 1 > percdamp > 0:
try:
damp = percdamp * torch.mean(torch.diag(hessian))
diag = torch.arange(self.columns, device=self.dev)
hessian[diag, diag] += damp
hessian = torch.linalg.cholesky(hessian)
hessian = torch.cholesky_inverse(hessian)
hessian = torch.linalg.cholesky(hessian, upper=True)
hinv = hessian
break
except torch._C._LinAlgError as e:
print_info(e)
print_info(f"Cholesky failed with percdamp={percdamp:.5f}")
percdamp += 0.01

P = ((self.dXXT @ hinv.T).triu(diagonal=1)) @ hinv
del self.dXXT

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

w1 = w_weight[:, i1:i2].clone()
q1 = torch.zeros_like(w1)
err1 = torch.zeros_like(w1)
losses1 = torch.zeros_like(w1)
hinv1 = hinv[i1:i2, i1:i2]
P1 = P[i1:i2, i1:i2]

for i in range(count):
w = w1[:, i]
d = hinv1[i, i]

if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
weight_scale, weight_zero = compute_scales_with_zero(
w_weight[:, (i1 + i) : (i1 + i + group_size)],
bits=self.quant_bits,
sym=sym,
)

if ((i1 + i) // group_size) - now_idx == -1:
scale.append(weight_scale)
zero.append(weight_zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
weight_scale = scale[idx // group_size]
weight_zero = zero[idx // group_size]

maxq = torch.tensor(2**self.quant_bits - 1)
q = torch.clamp(
torch.round(w.unsqueeze(1) / weight_scale) + weight_zero, 0, maxq
)
q = weight_scale * (q - weight_zero)
q = q.flatten()
q1[:, i] = q
losses1[:, i] = (w - q) ** 2 / d**2

err = (w - q) / d
w1[:, i:] -= err.unsqueeze(1).matmul(
hinv1[i, i:].unsqueeze(0)
) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
err1[:, i] = err

q_weight[:, i1:i2] = q1
losses[:, i1:i2] = losses1 / 2

w_weight[:, i2:] -= err1.matmul(hinv[i1:i2, i2:]) - w1.matmul(P[i1:i2, i2:])

torch.cuda.synchronize()
print_info(f" duration: {(time.time() - tick)}")
print_info(f" avg loss: {torch.sum(losses).item() / self.nsamples}")

group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=q_weight.device)
if actorder:
q_weight = q_weight[:, invperm]
g_idx = g_idx[invperm]

norm_loss = torch.norm(
q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
- self.layer.weight.data
)
all_norm_loss = [norm_loss]

print_info(" self.layer.weight: {}, {}".format(q_weight.shape, q_weight.sum()))
print_info(f" norm loss: {list(map(get_tensor_item, all_norm_loss))}")

self.layer.weight.data.copy_(
q_weight.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
)

if scale == []:
scale = weight_scale
zero = torch.zeros_like(weight_scale)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
losses = losses.cpu()
q_weight = q_weight.cpu()
w_weight = w_weight.cpu()
hessian = hessian.cpu()
hinv = hinv.cpu()
del losses, q_weight, w_weight, hessian, hinv, P
self.w = self.w.cpu()
del self.w
torch.cuda.empty_cache()
return scale, zero, g_idx

def free(self):
self.h = None
self.w = None
self.losses = None
torch.cuda.empty_cache()
52 changes: 50 additions & 2 deletions angelslim/compressor/quant/modules/gptq/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .....utils import print_info
from ...modules.catcher import Catcher
from ...modules.helper_layer import GPTQQuantLinear
from .gptaq_module import GPTAQModule
from .gptq_module import GPTQModule

__all__ = ["GPTQ"]
Expand Down Expand Up @@ -51,6 +52,8 @@ def __init__(
self.dtype = next(iter(self.layers.parameters())).dtype
self.quantizers = {}
self.gptq = {}
self.quant_algo = self.model.quant_config.quant_algo
self.native_inp_caches = {}

@torch.no_grad()
def run(self, dataloader):
Expand Down Expand Up @@ -86,6 +89,8 @@ def run(self, dataloader):
torch.cuda.empty_cache()

outs = torch.zeros_like(inps)
if "gptaq" in self.quant_algo:
native_inps = inps.clone().detach()
# begin the gptq process
print_info("Ready.")

Expand All @@ -96,18 +101,61 @@ def run(self, dataloader):
subset = self._find_layers(layer)
print_info("subset:{}".format(subset))
self.gptq = {}
if "gptaq" in self.quant_algo:
self.native_inp_caches = {}
print_info("GPTQMoe start layer {}".format(i))
for name in subset:
if name in self.ignore_layers:
continue
self.gptq[name] = GPTQModule(subset[name], quant_bits=self.quant_bits)
if "gptaq" in self.quant_algo:
self.native_inp_caches[name] = []
self.gptq[name] = GPTAQModule(
subset[name], quant_bits=self.quant_bits
)
else:
self.gptq[name] = GPTQModule(
subset[name], quant_bits=self.quant_bits
)

def pre_process_fwd_hook(layer_name):
def tmp(_, inp, out):
self.native_inp_caches[layer_name] += [inp[0].data]
del inp, out

return tmp

def add_batch(layer_name):
def tmp(_, inp, out):
self.gptq[layer_name].add_batch(inp[0].data, out.data)
if "gptaq" in self.quant_algo:
native_inp = self.native_inp_caches[layer_name].pop(0)
self.gptq[layer_name].add_batch(
inp[0].data, out.data, native_inp
)
else:
self.gptq[layer_name].add_batch(inp[0].data, out.data)

return tmp

if "gptaq" in self.quant_algo:
native_handles = []
for name in self.native_inp_caches:
native_handles.append(
subset[name].register_forward_hook(pre_process_fwd_hook(name))
)

# being native hook
for j in range(nsamples):
with torch.no_grad():
outs[j, :, :] = layer(
hidden_states=native_inps[j, :, :].unsqueeze(0),
**layer_kwargs,
)[0].squeeze(1)
native_inps = outs

print_info("Native HOOK Step{}".format(j))
for h in native_handles:
h.remove()

handles = []
for name in self.gptq:
handles.append(subset[name].register_forward_hook(add_batch(name)))
Expand Down
8 changes: 4 additions & 4 deletions angelslim/compressor/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, model, slim_config=None):
self.ptq_hook = PTQHook(self.quant_model)
self.ptq_hook.apply_hook()

if "gptq" in self.quant_algo:
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
max_seq_length = self.quant_model.quant_config.max_seq_length
hidden_size = self.quant_model.quant_config.hidden_size
self.gptq = GPTQ(
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, model, slim_config=None):
)

def calibrate(self, dataloader):
if "gptq" in self.quant_algo:
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
self.gptq.run(dataloader)
elif "awq" in self.quant_algo:
self.awq.run(dataloader)
Expand All @@ -123,7 +123,7 @@ def convert(self):
Saves scales and inserts QDQ modules.
"""
print_info("Start convert model...")
if "gptq" in self.quant_algo:
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
self.gptq.convert()
elif "awq" in self.quant_algo:
self.awq.convert()
Expand Down Expand Up @@ -166,7 +166,7 @@ def save(self, save_path: str):
)

print_info("Start save PTQ ckpt to: {}".format(save_path))
if "gptq" in self.quant_algo:
if "gptq" in self.quant_algo or "gptaq" in self.quant_algo:
self.gptq.save(save_path)
elif "awq" in self.quant_algo:
self.awq.save(save_path)
Expand Down
1 change: 1 addition & 0 deletions angelslim/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"int4_awq": default_compress_config.default_int4_awq_config(),
"int4_gptq": default_compress_config.default_int4_gptq_config(),
"w4a8_fp8": default_compress_config.default_w4a8_fp8_static_config(),
"int4_gptaq": default_compress_config.default_int4_gptaq_config(),
}


Expand Down
1 change: 1 addition & 0 deletions angelslim/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def model_forward(self, dataloader, **kwargs):
if (
"gptq" in self.quant_config.quant_algo
or "awq" in self.quant_config.quant_algo
or "gptaq" in self.quant_config.quant_algo
):
device = "cuda:0"
else:
Expand Down
Loading