Skip to content

Commit 1a5efa0

Browse files
GGgary666garygugong
andauthored
init commit to refactor dit quant (#90)
Co-authored-by: garygugong <garygugong@tencent.com>
1 parent 06a687d commit 1a5efa0

15 files changed

Lines changed: 1259 additions & 0 deletions

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# AngelSlim Diffusion Model Compression
2+
3+
AngelSlim offers flexible and efficient tools for compressing Diffusion Transformer (DiT) diffusion models. The quantization utilities are modular and easy to integrate into custom inference pipelines.
4+
5+
## Quick Start: FP8 Quantization for Diffusion Models
6+
7+
```python
8+
import torch
9+
from diffusers import FluxPipeline
10+
from angelslim.compressor.diffusion import DynamicDiTQuantizer
11+
12+
# Load DiT pipeline with bfloat16 to reduce memory usage
13+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
14+
15+
# Supported quantization types: "fp8-per-tensor", "fp8-per-block", "fp8-per-token"
16+
# If you want to use "fp8-per-block" + DeepGEMM on NVIDIA Hopper (SM90+) devices,
17+
# please refer to https://github.com/deepseek-ai/DeepGEMM for installation instructions.
18+
quantizer = DynamicDiTQuantizer(quant_type="fp8-per-tensor")
19+
quantizer.quantize(pipe.transformer)
20+
21+
pipe.to("cuda")
22+
23+
# Run pipeline with FP8-quantized transformer
24+
image = pipe(
25+
"A cat holding a sign that says hello world",
26+
guidance_scale=0.0,
27+
num_inference_steps=4,
28+
max_sequence_length=256,
29+
generator=torch.Generator("cuda").manual_seed(0)
30+
).images[0]
31+
image.save("flux-schnell_fp8_per_tensor.png")
32+
```
33+
34+
## Customizable Quantization Layer Selection
35+
36+
AngelSlim provides fine-grained control over which layers are quantized. You can specify inclusion and exclusion patterns as substrings or regular expressions.
37+
38+
```python
39+
from angelslim.compressor.diffusion import DynamicDiTQuantizer
40+
41+
# Option 1: Default filtering (quantizes common linear layers)
42+
quantizer = DynamicDiTQuantizer(quant_type="fp8-per-tensor")
43+
44+
# Option 2: String-based include/exclude patterns
45+
quantizer = DynamicDiTQuantizer(
46+
quant_type="fp8-per-tensor",
47+
include_patterns=["linear", "attention"],
48+
exclude_patterns=["embed", "norm"]
49+
)
50+
51+
# Option 3: Regex pattern matching (auto-detected)
52+
quantizer = DynamicDiTQuantizer(
53+
quant_type="fp8-per-tensor",
54+
include_patterns=[r".*\.linear\d+", r".*\.attn.*"],
55+
exclude_patterns=[r".*embed.*"]
56+
)
57+
58+
# Option 4: Mix of strings and regex for flexible rules
59+
quantizer = DynamicDiTQuantizer(
60+
quant_type="fp8-per-tensor",
61+
include_patterns=["linear", r".*\.attn.*"],
62+
exclude_patterns=["embed", r".*norm.*"]
63+
)
64+
```
65+
66+
For more details on customizing quantization behavior, see the API documentation.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .quant import * # noqa: F401 F403
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .fp8_gemm import fp8_gemm_triton_block
16+
17+
__all__ = ["fp8_gemm_triton_block"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
19+
# modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
20+
fp8_gemm_configs = [
21+
triton.Config(
22+
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
23+
num_stages=num_stages,
24+
num_warps=8,
25+
)
26+
for block_m in [16, 32, 64]
27+
for block_n in [32, 64, 128]
28+
for num_stages in [3, 4, 5, 6]
29+
]
30+
31+
32+
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
33+
@triton.jit
34+
def _fp8_gemm_triton_block_kernel(
35+
a_ptr,
36+
b_ptr,
37+
c_ptr,
38+
a_s_ptr,
39+
b_s_ptr,
40+
M,
41+
N: tl.constexpr,
42+
K: tl.constexpr,
43+
BLOCK_SIZE_M: tl.constexpr,
44+
BLOCK_SIZE_N: tl.constexpr,
45+
BLOCK_SIZE_K: tl.constexpr,
46+
):
47+
"""
48+
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
49+
"""
50+
pid_m = tl.program_id(axis=0)
51+
pid_n = tl.program_id(axis=1)
52+
k = tl.cdiv(K, BLOCK_SIZE_K)
53+
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
54+
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
55+
offs_k = tl.arange(0, BLOCK_SIZE_K)
56+
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
57+
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
58+
a_s_ptrs = a_s_ptr + offs_m * k
59+
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
60+
61+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
62+
for i in range(k):
63+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
64+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
65+
a_s = tl.load(a_s_ptrs)
66+
b_s = tl.load(b_s_ptrs)
67+
68+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
69+
a_ptrs += BLOCK_SIZE_K
70+
b_ptrs += BLOCK_SIZE_K
71+
a_s_ptrs += 1
72+
b_s_ptrs += 1
73+
c = accumulator.to(c_ptr.dtype.element_ty)
74+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
75+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
76+
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
77+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
78+
tl.store(c_ptrs, c, mask=mask)
79+
80+
81+
# triton fp8 gemm for fp8 per-block weight & fp8 per-group activation
82+
# modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
83+
def fp8_gemm_triton_block(
84+
a: torch.Tensor,
85+
a_s: torch.Tensor,
86+
b: torch.Tensor,
87+
b_s: torch.Tensor,
88+
out_dtype=torch.bfloat16,
89+
bias=None,
90+
) -> torch.Tensor:
91+
"""
92+
Perform a matrix multiplication using FP8 precision.
93+
"""
94+
assert a.is_contiguous() and b.is_contiguous()
95+
assert a_s.is_contiguous() and b_s.is_contiguous()
96+
K = a.size(-1)
97+
M = a.numel() // K
98+
N = b.size(0)
99+
c = a.new_empty(*a.size()[:-1], N, dtype=out_dtype)
100+
101+
def grid(meta):
102+
return (
103+
triton.cdiv(M, meta["BLOCK_SIZE_M"]),
104+
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
105+
)
106+
107+
_fp8_gemm_triton_block_kernel[grid](a, b, c, a_s, b_s, M, N, K)
108+
109+
if bias is not None:
110+
c += bias
111+
112+
return c
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .fp8_per_block import fp8_per_block_quant_triton
16+
from .fp8_per_token_group import fp8_per_token_group_quant_triton
17+
18+
__all__ = ["fp8_per_token_group_quant_triton", "fp8_per_block_quant_triton"]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Tuple
16+
17+
import torch
18+
import triton
19+
import triton.language as tl
20+
21+
22+
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
23+
@triton.jit
24+
def _fp8_per_block_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
25+
"""Quantizes FP32 tensor to FP8 format using block-wise quantization."""
26+
pid_m = tl.program_id(axis=0)
27+
pid_n = tl.program_id(axis=1)
28+
n = tl.cdiv(N, BLOCK_SIZE)
29+
30+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
32+
offs = offs_m[:, None] * N + offs_n[None, :]
33+
34+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
35+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
36+
max_val = tl.max(tl.abs(x))
37+
scale = max_val / 448.0
38+
scale = tl.where(max_val == 0.0, 1.0, scale)
39+
y = x / scale
40+
y = y.to(y_ptr.dtype.element_ty)
41+
42+
tl.store(y_ptr + offs, y, mask=mask)
43+
tl.store(s_ptr + pid_m * n + pid_n, scale)
44+
45+
46+
# triton implementation
47+
# for weight quantization on gpu
48+
def fp8_per_block_quant_triton(
49+
x: torch.Tensor, block_size: int = 128
50+
) -> Tuple[torch.Tensor, torch.Tensor]:
51+
"""
52+
Quantizes a FP32 2D tensor to FP8 (E4M3FN) using block-wise quantization.
53+
For each (block_size x block_size) block:
54+
- scale = max(abs(block)) / 448.0 (FP8 E4M3FN max magnitude)
55+
- if block is all zeros, use scale = 1.0 to avoid div-by-zero
56+
- scale, clamp and cast to FP8
57+
Returns:
58+
y: Quantized FP8 tensor, same shape as input
59+
s: Per-block scales, shape (num_blocks_M, num_blocks_N)
60+
"""
61+
assert x.is_contiguous()
62+
assert x.dim() == 2
63+
64+
M, N = x.size()
65+
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
66+
m_blocks = triton.cdiv(M, block_size)
67+
n_blocks = triton.cdiv(N, block_size)
68+
s = torch.empty((m_blocks, n_blocks), dtype=torch.float32, device=x.device)
69+
70+
def grid(meta):
71+
return (
72+
triton.cdiv(M, meta["BLOCK_SIZE"]),
73+
triton.cdiv(N, meta["BLOCK_SIZE"]),
74+
)
75+
76+
_fp8_per_block_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size)
77+
78+
return y, s

0 commit comments

Comments
 (0)