Skip to content

Commit 7fa71bc

Browse files
authored
add diffusion fp8-per-token-sgl (#169)
1 parent c3b203e commit 7fa71bc

6 files changed

Lines changed: 138 additions & 7 deletions

File tree

angelslim/compressor/diffusion/quant/modules/linear.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
fp8_per_block_quant,
2020
fp8_per_tensor_quant,
2121
fp8_per_token_group_quant,
22+
fp8_per_token_quant_sgl,
2223
fp8_weight_only_gemm,
2324
)
2425

@@ -61,6 +62,10 @@ def forward(self, x):
6162
origin_shape = None
6263
x_2d = x.view(-1, x.shape[-1])
6364
qinput, x_scale = fp8_per_token_group_quant(x_2d, x_2d.shape[-1])
65+
elif self.quant_type == "fp8-per-token-sgl" and self.native_fp8_support:
66+
origin_shape = x.shape
67+
x_2d = x.view(-1, x.shape[-1])
68+
qinput, x_scale = fp8_per_token_quant_sgl(x_2d)
6469
elif self.quant_type == "fp8-per-block" and self.native_fp8_support:
6570
origin_shape = x.shape
6671
x = x.view(-1, x.shape[-1])
@@ -85,7 +90,11 @@ def forward(self, x):
8590
origin_shape=origin_shape,
8691
)
8792

88-
if self.quant_type == "fp8-per-token" and x.dim() == 3 and output.dim() == 2:
93+
if (
94+
self.quant_type in ["fp8-per-token", "fp8-per-token-sgl"]
95+
and x.dim() == 3
96+
and output.dim() == 2
97+
):
8998
output = output.unsqueeze(0)
9099

91100
return output

angelslim/compressor/diffusion/quant/ptq.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from .modules import FP8DynamicLinear, FP8WeightOnlyLinear
2424
from .quant_func import (
2525
fp8_per_block_quant,
26+
fp8_per_channel_quant,
2627
fp8_per_tensor_quant,
2728
fp8_per_token_group_quant,
2829
)
2930
from .utils import (
3031
QuantType,
3132
_ensure_deep_gemm,
33+
_ensure_sgl_kernel,
3234
cleanup_memory,
3335
load_fp8_scales,
3436
load_quantized_model,
@@ -106,6 +108,10 @@ def _quantize_linear_weight(
106108
linear.weight, linear.weight.shape[-1]
107109
)
108110
weight_scale = weight_scale.t()
111+
elif self.quant_type == QuantType.FP8_PER_TOKEN_SGL:
112+
if self.native_fp8_support:
113+
_ensure_sgl_kernel()
114+
quant_weight, weight_scale = fp8_per_channel_quant(linear.weight)
109115
elif self.quant_type == QuantType.FP8_PER_BLOCK:
110116
if self.native_fp8_support:
111117
_ensure_deep_gemm()

angelslim/compressor/diffusion/quant/quant_func.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
fp8_per_block_quant_triton,
2222
fp8_per_token_group_quant_triton,
2323
)
24-
from .utils import QuantType, _ensure_deep_gemm
24+
from .utils import QuantType, _ensure_deep_gemm, _ensure_sgl_kernel
2525

2626
FP8_MAX = float(torch.finfo(torch.float8_e4m3fn).max)
2727
FP8_MIN = float(torch.finfo(torch.float8_e4m3fn).min)
@@ -87,6 +87,41 @@ def fp8_per_token_group_quant(
8787
)
8888

8989

90+
def fp8_per_channel_quant(weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91+
"""
92+
Per-channel FP8 weight quantization (E4M3 format)
93+
94+
Args:
95+
weight: Original weight tensor with shape [out_features, in_features]
96+
97+
Returns:
98+
weight_quant: Quantized weight [out_features, in_features], dtype=float8_e4m3fn
99+
weight_scale: Scale factors [out_features, 1], dtype=float32
100+
"""
101+
abs_max = torch.abs(weight).amax(dim=1, keepdim=True) # [out_features, 1]
102+
103+
weight_scale = abs_max / FP8_MAX
104+
weight_scale = torch.clamp(weight_scale, min=1e-12)
105+
106+
weight_scaled = (weight / weight_scale).clamp(min=FP8_MIN, max=FP8_MAX)
107+
weight_quant = weight_scaled.to(torch.float8_e4m3fn)
108+
109+
return weight_quant, weight_scale.float()
110+
111+
112+
def fp8_per_token_quant_sgl(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
113+
m, k = x.shape
114+
input_tensor_quant = torch.empty(
115+
(m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False
116+
)
117+
input_tensor_scale = torch.empty(
118+
(m, 1), dtype=torch.float32, device="cuda", requires_grad=False
119+
)
120+
_sgl_kernel = _ensure_sgl_kernel()
121+
_sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
122+
return input_tensor_quant, input_tensor_scale
123+
124+
90125
# pure torch implementation of block-wise FP8 quantization on cpu
91126
def fp8_per_block_quant_torch(
92127
x: torch.Tensor, block_size: int = 128
@@ -260,6 +295,35 @@ def fp8_weight_only_gemm(A, B, B_scale, bias, out_dtype):
260295
return output
261296

262297

298+
def fp8_gemm_sgl_token(A, A_scale, B, B_scale, out_dtype, bias):
299+
"""GEMM function for FP8 per-token-sgl quantization using sgl-kernel.
300+
301+
Args:
302+
A: Input activation tensor
303+
A_scale: Scale tensor for input activations
304+
B: Weight tensor
305+
B_scale: Scale tensor for weights
306+
out_dtype: Output data type.
307+
bias: Optional bias tensor
308+
309+
Returns:
310+
torch.Tensor: Result of the GEMM operation.
311+
"""
312+
_sgl_kernel = _ensure_sgl_kernel()
313+
shape = (A.shape[0], B.shape[0])
314+
output = torch.empty(shape, dtype=out_dtype, device=A.device, requires_grad=False)
315+
output = _sgl_kernel.fp8_scaled_mm(
316+
A,
317+
B.t(),
318+
A_scale,
319+
B_scale.float(),
320+
out_dtype,
321+
bias=bias,
322+
)
323+
324+
return output
325+
326+
263327
def fp8_gemm(
264328
A: torch.Tensor,
265329
A_scale: torch.Tensor,
@@ -300,6 +364,9 @@ def fp8_gemm(
300364
if quant_type in (QuantType.FP8_PER_TENSOR, QuantType.FP8_PER_TOKEN):
301365
# Use torch native fp8 GEMM for per-tensor and per-token fp8 quantization
302366
return fp8_gemm_torch_tensor_token(A, A_scale, B, B_scale, out_dtype, bias)
367+
elif quant_type == QuantType.FP8_PER_TOKEN_SGL:
368+
# Use sgl-kernel for per-token-sgl fp8 quantization
369+
return fp8_gemm_sgl_token(A, A_scale, B, B_scale, out_dtype, bias)
303370
elif quant_type == QuantType.FP8_PER_BLOCK:
304371
# Use deepgemm accelerated blockwise fp8 GEMM
305372
return fp8_gemm_deepgemm_block(
@@ -324,7 +391,8 @@ def fp8_gemm(
324391
f"\n native_fp8_support={native_fp8_support}.\n"
325392
"Supported combinations:\n"
326393
" - native_fp8_support=True, "
327-
"quant_type in [fp8-per-tensor, fp8-per-token, fp8-per-block]\n"
394+
"quant_type in [fp8-per-tensor, fp8-per-token,"
395+
" fp8-per-block, fp8-per-token-sgl]\n"
328396
" - native_fp8_support=False, "
329397
"quant_type in [fp8-per-tensor, fp8-per-block]"
330398
)

angelslim/compressor/diffusion/quant/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
QuantType,
1818
_compile_pattern,
1919
_ensure_deep_gemm,
20+
_ensure_sgl_kernel,
2021
cleanup_memory,
2122
replace_module,
2223
should_quantize_layer,
@@ -32,4 +33,5 @@
3233
"cleanup_memory",
3334
"replace_module",
3435
"should_quantize_layer",
36+
"_ensure_sgl_kernel",
3537
]

angelslim/compressor/diffusion/quant/utils/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"should_quantize_layer",
2626
"_compile_pattern",
2727
"_ensure_deep_gemm",
28+
"_ensure_sgl_kernel",
2829
"QuantType",
2930
]
3031

@@ -34,11 +35,13 @@ class QuantType:
3435
FP8_PER_TOKEN = "fp8-per-token"
3536
FP8_PER_BLOCK = "fp8-per-block"
3637
FP8_PER_TENSOR_WEIGHT_ONLY = "fp8-per-tensor-weight-only"
38+
FP8_PER_TOKEN_SGL = "fp8-per-token-sgl"
3739
VALID_TYPES = [
3840
FP8_PER_TENSOR,
3941
FP8_PER_TOKEN,
4042
FP8_PER_BLOCK,
4143
FP8_PER_TENSOR_WEIGHT_ONLY,
44+
FP8_PER_TOKEN_SGL,
4245
]
4346

4447
@classmethod
@@ -171,3 +174,28 @@ def _ensure_deep_gemm():
171174
"native_fp8_support, but was not found. Please install deep_gemm first."
172175
)
173176
) from e
177+
178+
179+
_sgl_kernel_cached = None
180+
181+
182+
def _ensure_sgl_kernel():
183+
"""
184+
Lazy, safe import of sgl_kernel with process-level caching. Returns the module
185+
if available, otherwise raises a clear error.
186+
"""
187+
global _sgl_kernel_cached
188+
if _sgl_kernel_cached is not None:
189+
return _sgl_kernel_cached
190+
try:
191+
import sgl_kernel
192+
193+
_sgl_kernel_cached = sgl_kernel
194+
return _sgl_kernel_cached
195+
except ImportError as e:
196+
raise ImportError(
197+
(
198+
"sgl_kernel is required for 'fp8-per-token-sgl' quantization with "
199+
"native_fp8_support, but was not found. Please install sgl_kernel first"
200+
)
201+
) from e

docs/source/features/diffusion/quantization.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,35 @@ AngelSlim 支持以下四种 FP8 量化策略:
1010
- **fp8-per-tensor-weight-only**:仅对权重量化(权重:FP8,激活仍为 BF16/FP16),适合对精度有更高要求的场景
1111
- **fp8-per-block**:支持 per-block 量化,适用于 NVIDIA Hopper (SM90+) 架构,block_size目前只支持128
1212
- **fp8-per-token**:精细的 per-token 量化,对多样输入有更强适应性
13+
- **fp8-per-token-sgl**:基于 SGL kernel 的 per-token 量化,使用优化的 CUDA kernel 实现更高效的 per-token 量化和矩阵乘法运算
14+
15+
## 可选依赖安装
16+
### deep_gemm(用于 fp8-per-block)
17+
18+
`fp8-per-block` 量化在启用 `native_fp8_support` 时需要安装 `deep_gemm`
19+
20+
```shell
21+
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git
22+
cd DeepGEMM
23+
./develop.sh
24+
./install.sh
25+
```
26+
27+
### sgl_kernel(用于 fp8-per-token-sgl)
28+
29+
`fp8-per-token-sgl` 量化需要安装 `sgl_kernel`
30+
31+
```shell
32+
pip install sgl-kernel==0.3.18
33+
```
1334

1435
## 配置
1536

1637
DynamicDiTQuantizer 类提供灵活的配置选项,您可以通过以下参数自定义量化行为:
1738

1839
### 构造函数参数
1940

20-
- `quant_type`(str):量化类型,可选值 "fp8-per-tensor"、"fp8-per-tensor-weight-only"、"fp8-per-block"、"fp8-per-token"
41+
- `quant_type`(str):量化类型,可选值 "fp8-per-tensor"、"fp8-per-tensor-weight-only"、"fp8-per-block"、"fp8-per-token"、"fp8-per-token-sgl"
2142
- `include_patterns`(List[str|re.Pattern], 可选):指定需要量化的层名称模式,支持字符串或正则表达式
2243
- `exclude_patterns`(List[str|re.Pattern], 可选):指定需要排除的层名称模式,支持字符串或正则表达式
2344
- `layer_filter`(Callable, 可选):自定义层筛选函数(高级自定义场景专用)
@@ -166,6 +187,3 @@ quantizer.export_quantized_weight(pipe.transformer, save_path="/path/to/save/qua
166187
- `fp8_scales.safetensors`:FP8 缩放因子文件
167188

168189
导出后可通过上述"加载预量化模型和缩放因子"的方式加载使用。
169-
170-
171-

0 commit comments

Comments
 (0)