Skip to content

Commit 2e830e4

Browse files
authored
feat(spinquant): add permutation & fix fuse layer norm for rms_norm (#282)
1 parent 43a322a commit 2e830e4

12 files changed

Lines changed: 575 additions & 99 deletions

File tree

angelslim/compressor/quant/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
dequantize_gemm,
2020
pack_weight_to_int8,
2121
pack_weight_to_int8_gpu,
22+
unpack_weight_omni,
2223
)
2324
from .quant_func import * # noqa: F401 F403
2425
from .sample_func import EMASampler, MultiStepSampler # noqa: F401

angelslim/compressor/quant/core/packing_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
4040
return iweights, izeros
4141

4242

43+
def unpack_weight_omni(qweight: torch.Tensor, save_bit: int = 4, pack_bit: int = 8):
44+
assert pack_bit % save_bit == 0, "pack_bit must be divisible by save_bit"
45+
mask = (1 << save_bit) - 1 # e.g. 0x0F for 4-bit
46+
sign_bit = 1 << (save_bit - 1) # e.g. 0x08 for 4-bit
47+
shifts = torch.arange(0, pack_bit, save_bit, device=qweight.device)
48+
qweight = qweight.to(torch.int32)
49+
# Extract each sub-value and apply sign extension
50+
# bitwise_right_shift is arithmetic, so the highest slot (last shift) is already
51+
# sign-extended correctly; all other slots need masking + manual sign extension.
52+
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
53+
torch.int32
54+
)
55+
# Mask off upper bits and sign-extend for all slots except the topmost
56+
iweights = iweights & mask # isolate save_bit bits
57+
iweights = iweights - ((iweights & sign_bit) << 1) # sign extend
58+
iweights = iweights.reshape(iweights.shape[0], -1)
59+
return iweights
60+
61+
4362
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
4463
reverse_order_tensor = torch.arange(
4564
iweights.shape[-1],

angelslim/compressor/quant/modules/helper_layer.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
reduce_block_padding,
3535
tensor_quant_dequant_fp8,
3636
tensor_quant_dequant_int,
37+
unpack_weight_omni,
3738
)
3839

3940

@@ -547,19 +548,21 @@ def __init__(
547548
super().__init__()
548549
self.quant_algo = quant_algo
549550
weight_scale = weight_scale.to(weight.device)
551+
self.group_size = group_size
550552
if "fp8" in quant_algo:
551553
if "w4a8" in self.quant_algo:
552554
max_value_group_wise = weight_scale.clone()
555+
# weight(bf16) -> fp8 -> int4
556+
# dweight(int4) -> fp8 -> weight(bf16)
553557
tensor_wise_scale = max_value_group_wise.max() / 448.0
554-
quant_weight, _ = quantize_weight_per_tensor_fp8(weight, tensor_wise_scale)
555-
new_weight_bf16 = quant_weight.to(torch.bfloat16) * tensor_wise_scale
556-
558+
new_weight_bf16 = weight
557559
new_weight_bf16_qdq = fake_quant_dequant(
558560
new_weight_bf16, method="groupwise", bits=4, group_size=group_size
559561
)
560562
quant_weight, _ = quantize_weight_int(
561563
new_weight_bf16_qdq, max_value_group_wise, bits=4
562564
)
565+
563566
quant_weight = pack_weight_to_int8(quant_weight)
564567
del new_weight_bf16_qdq, new_weight_bf16
565568
self.weight_scale_int4 = torch.nn.Parameter(
@@ -600,11 +603,30 @@ def forward(self, x):
600603
raise ValueError(f"Unsupported quantization algorithm: {self.quant_algo}")
601604

602605
if "fp8" in self.quant_algo:
606+
if "w4a8" in self.quant_algo:
607+
# unpack, save as int32
608+
weight = self.qweight.to(qinput.device)
609+
weight = unpack_weight_omni(weight, save_bit=4, pack_bit=8)
610+
weight_scale = self.weight_scale.to(qinput.device)
611+
612+
scale = (
613+
self.weight_scale_int4.float()
614+
.repeat_interleave(self.group_size, dim=-1)
615+
.to(qinput.device)
616+
) # (out,in)
617+
# dequant to bf16
618+
weight = weight * scale
619+
# quant to fp8
620+
weight, _ = quantize_weight_per_tensor_fp8(weight, weight_scale)
621+
# to fp8
622+
else:
623+
weight = self.weight.to(qinput.device)
624+
weight_scale = self.weight_scale.to(qinput.device)
603625
output = gemm_fp8(
604626
act=qinput,
605627
act_scale=self.input_scale,
606-
weight=self.weight,
607-
weight_scale=self.weight_scale,
628+
weight=weight,
629+
weight_scale=weight_scale,
608630
bias=self.bias,
609631
out_dtype=x.dtype,
610632
)

angelslim/compressor/quant/ptq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def convert(self):
162162
if "smooth" in self.quant_helpers:
163163
self.smooth.convert()
164164
self._convert()
165+
166+
self.transform_runner.convert()
165167
print_info("convert model done.")
166168

167169
def save(self, save_path: str):

angelslim/compressor/transform/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def run(self):
3131
def save(self):
3232
pass
3333

34+
def convert(self):
35+
pass
36+
3437

3538
class TransformFactory:
3639
"""Factory for creating TransformBase instances from config.

angelslim/compressor/transform/rotation/hadamard_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def matmul_hadUt(X):
101101
def random_hadamard_matrix(size, device):
102102
# See https://cornell-relaxml.github.io/quip-sharp/ ,
103103
# Section "Randomized Hadamard Transformation"
104-
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
104+
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float32)
105105
Q = Q * 2 - 1
106106
Q = torch.diag(Q)
107107
return matmul_hadU(Q).to(device)

0 commit comments

Comments
 (0)