Skip to content

Commit 2a0d11c

Browse files
author
gavinlee
committed
feat:add SpinQuant offline rotation and integrate with PTQ pipeline
- Add angelslim/compressor/transform/ package: - TransformBase abstract class and TransformFactory with @register decorator - SpinQuant implementation: R1/R2/R4 offline Hadamard rotation fused into weights - SpinQuantMapping for LLaMA/Qwen layer name resolution - fuse_ln_linear, center_embeddings utilities; hadamard_utils - Integrate transform into PTQ: TransformFactory.create() + run() is called before quantization in PTQ.__init__() - Extend config_parser: add TransformConfig, FullConfig.transform_config, SlimConfigParser support for optional transform: YAML section - Add Engine.prepare_compressor(transform_config=) passthrough and lm_eval() - Add tools/run_transform_offline.py for standalone transform + save - Add configs/qwen3/spinquant/ with SpinQuant + fp8_static / int4_awq examples
1 parent 6c91c69 commit 2a0d11c

19 files changed

Lines changed: 98374 additions & 68 deletions

angelslim/compressor/quant/ptq.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ...utils import find_parent_layer_and_sub_name, print_info
2424
from ..compressor_factory import CompressorFactory
25+
from ..transform import TransformFactory
2526
from .core import PTQHook
2627
from .modules import AWQ, FP8, GPTQ, INT8, NVFP4, W4A8INT8, LeptoFP8, SmoothQuant
2728

@@ -36,14 +37,24 @@ def __init__(self, model, slim_config=None):
3637
model(nn.Moudle, required): the model to be quant.
3738
slim_config(dict, required): the configuration for quantization.
3839
- compress_config: the configuration for compression.
40+
- transform_config: the configuration for transform.
3941
- global_config: the global configuration for the model.
4042
"""
4143
self.quant_model = model
4244
# init ptq config of model
4345
self.quant_model.init_ptq(slim_config)
4446
self.absolute_model_path = slim_config["global_config"].absolute_model_path
4547
self.quant_algo = self.quant_model.quant_config.quant_algo
48+
49+
# init transform
50+
# TODO(gavinlee) will be deprecated, and move to transform, now only for smoothquant
4651
self.quant_helpers = self.quant_model.quant_config.quant_helpers
52+
53+
# create transform, for example, smoothquant
54+
self.trasform_runner = TransformFactory.create(self.quant_model, slim_config)
55+
# trasform first, then run quantization
56+
self.trasform_runner.run()
57+
4758
if "fp8" in self.quant_algo or "int8" in self.quant_algo or "nvfp4" in self.quant_algo:
4859
# Add ptq observer hook
4960
self.ptq_hook = PTQHook(self.quant_model)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base import TransformBase
2+
from .factory import TransformFactory
3+
from .rotation.spin import SpinQuant
4+
5+
__all__ = ["TransformBase", "TransformFactory", "SpinQuant"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
17+
__all__ = ["TransformBase"]
18+
19+
20+
class TransformBase(ABC):
21+
"""Abstract base class for model weight transforms (e.g. SpinQuant).
22+
23+
Subclasses must implement `run()`. The lifecycle is:
24+
1. TransformFactory.create(quant_model, quant_config) -> TransformBase
25+
2. transform.run() - apply transform (PTQ: fuse into weights)
26+
3. transform.convert() - fuse hooks into weights after QAT training (optional)
27+
4. transform.save() - save transformed model (optional)
28+
"""
29+
30+
def __init__(self, quant_model, quant_config):
31+
self.quant_model = quant_model
32+
self.config = quant_config
33+
34+
@abstractmethod
35+
def run(self):
36+
"""Apply the transform to the model weights."""
37+
38+
def convert(self, **kwargs):
39+
"""Fuse online rotation hooks into weights after QAT training.
40+
41+
Override in subclasses that support QAT mode.
42+
"""
43+
raise NotImplementedError(f"{type(self).__name__} does not implement convert()")
44+
45+
def save(self, save_path: str = None):
46+
"""Save the transformed model.
47+
48+
Override in subclasses to implement actual saving logic.
49+
"""
50+
raise NotImplementedError(f"{type(self).__name__} does not implement save()")
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import TransformBase
16+
17+
__all__ = ["TransformFactory"]
18+
19+
20+
class _NoOpTransform(TransformBase):
21+
"""No-op transform returned when slim_config has no transform_config."""
22+
23+
def __init__(self, quant_model, slim_config=None):
24+
# slim_config may be a dict (PTQ path), skip TransformBase.__init__ attribute assignment
25+
self.quant_model = quant_model
26+
self.config = slim_config
27+
28+
def run(self):
29+
pass
30+
31+
32+
class TransformFactory:
33+
"""Factory for creating TransformBase instances from config.
34+
35+
Usage
36+
-----
37+
transform = TransformFactory.create(slim_model, slim_config)
38+
transform.run()
39+
40+
The transform name is read from ``slim_config.transform_config["name"]``,
41+
which corresponds to the ``transform.name`` field in the YAML config:
42+
43+
transform:
44+
name: SpinQuant
45+
spin_config: ...
46+
47+
Registering a new transform
48+
---------------------------
49+
@TransformFactory.register("MyTransform")
50+
class MyTransform(TransformBase):
51+
...
52+
"""
53+
54+
_registry: dict[str, type[TransformBase]] = {}
55+
56+
@classmethod
57+
def create(cls, quant_model, slim_config) -> TransformBase:
58+
"""Instantiate a transform from slim_config.
59+
60+
Args:
61+
quant_model: The wrapped slim model.
62+
slim_config: Config object with a ``transform_config`` dict containing ``"name"``.
63+
64+
Returns:
65+
An unrun TransformBase instance. Call ``.run()`` to apply the transform.
66+
67+
Raises:
68+
ValueError: If transform name is missing or not registered.
69+
"""
70+
# slim_config may be a dict (PTQ path) or an object with attributes (transform path)
71+
if isinstance(slim_config, dict):
72+
transform_config = slim_config.get("transform_config")
73+
else:
74+
transform_config = getattr(slim_config, "transform_config", None)
75+
76+
if not transform_config:
77+
return _NoOpTransform(quant_model, slim_config)
78+
79+
name = (
80+
transform_config.get("name")
81+
if isinstance(transform_config, dict)
82+
else getattr(transform_config, "name", None)
83+
)
84+
if not name:
85+
return _NoOpTransform(quant_model, slim_config)
86+
87+
if name not in cls._registry:
88+
available = list(cls._registry.keys())
89+
raise ValueError(f"Unknown transform '{name}'. Available: {available}")
90+
91+
return cls._registry[name](quant_model, slim_config)
92+
93+
@classmethod
94+
def register(cls, name: str):
95+
"""Decorator to register a TransformBase subclass under the given name.
96+
97+
Args:
98+
name: The string key used in YAML ``transform.name``.
99+
100+
Example:
101+
@TransformFactory.register("MyTransform")
102+
class MyTransform(TransformBase):
103+
...
104+
"""
105+
106+
def decorator(cls_):
107+
if not issubclass(cls_, TransformBase):
108+
raise TypeError(f"{cls_.__name__} must be a subclass of TransformBase")
109+
cls._registry[name] = cls_
110+
return cls_
111+
112+
return decorator
113+
114+
@classmethod
115+
def list_transforms(cls) -> list[str]:
116+
"""Return names of all registered transforms."""
117+
return list(cls._registry.keys())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .spin import SpinQuant
2+
3+
__all__ = ["SpinQuant"]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# coding=utf-8
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
9+
# Licensed under Apache License 2.0.
10+
11+
import typing
12+
13+
import torch
14+
15+
16+
@torch.no_grad()
17+
def center_embeddings(embedding: torch.nn.Module):
18+
"""
19+
Shift each embedding to have a mean of zero
20+
21+
:param embedding: embedding module containing embeddings to center
22+
"""
23+
if not hasattr(embedding, "weight"):
24+
raise ValueError(f"Cannot fuse norm of type {type(embedding)}")
25+
26+
weight_dtype = embedding.weight.dtype
27+
weight = embedding.weight.to(torch.float64)
28+
new_weight = weight - weight.mean(dim=-1, keepdim=True)
29+
new_weight = new_weight.to(weight_dtype)
30+
embedding.weight.data = new_weight
31+
32+
33+
# [TODO] check this function correct or not
34+
@torch.no_grad()
35+
def bake_mean_into_linear(linear: torch.nn.Linear) -> None:
36+
"""
37+
This function takes a linear layer and subtracts the means from the
38+
weights and biases. This will result in the linear layer performing
39+
the mean substitution which is usually done inside layernorm.
40+
"""
41+
linear_dtype = linear.weight.dtype
42+
W_ = linear.weight.data.double()
43+
linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
44+
linear.weight.data = linear.weight.data.to(linear_dtype)
45+
if linear.bias is not None:
46+
b_ = linear.bias.data.double()
47+
linear.bias.data = b_ - b_.mean()
48+
linear.bias.data = linear.bias.data.to(linear_dtype)
49+
50+
51+
@torch.no_grad()
52+
def fuse_ln_linear(
53+
layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]
54+
) -> None:
55+
"""
56+
fuse the linear operations in Layernorm into the adjacent linear blocks.
57+
"""
58+
for linear in linear_layers:
59+
linear_dtype = linear.weight.dtype
60+
61+
# Calculating new weight and bias
62+
W_ = linear.weight.data.double()
63+
linear.weight.data = (W_ * layernorm.weight.double()).to(linear_dtype)
64+
65+
if hasattr(layernorm, "bias"):
66+
if linear.bias is None:
67+
linear.bias = torch.nn.Parameter(
68+
torch.zeros(linear.out_features, dtype=torch.float64)
69+
)
70+
linear.bias.data = linear.bias.data.double() + torch.matmul(
71+
W_, layernorm.bias.double()
72+
)
73+
linear.bias.data = linear.bias.data.to(linear_dtype)
74+
75+
if hasattr(layernorm, "bias"):
76+
layernorm.bias.data = torch.zeros_like(layernorm.bias.data)
77+
layernorm.weight.data = torch.ones_like(layernorm.weight.data)

0 commit comments

Comments
 (0)