Skip to content

Commit f2ab54e

Browse files
authored
add triton-fp8w8a8g128 quant type. (#1214)
1 parent a27dfc8 commit f2ab54e

10 files changed

Lines changed: 664 additions & 60 deletions

File tree

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ PD 分离模式参数
384384
* ``vllm-fp8w8a8-b128``
385385
* ``deepgemm-fp8w8a8-b128``
386386
* ``triton-fp8w8a8-block128``
387+
* ``triton-fp8w8a8g128``: 权重 per-channel 量化和激活 per-group 128 量化
388+
* ``triton-fp8w8a8g64``: 权重 per-channel 量化, group size 64
387389
* ``awq``
388390
* ``awq_marlin``
389391
* ``none`` (默认)

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ Quantization Parameters
376376
* ``vllm-fp8w8a8-b128``
377377
* ``deepgemm-fp8w8a8-b128``
378378
* ``triton-fp8w8a8-block128``
379+
* ``triton-fp8w8a8g128``: weight per-channel quant and activation per-group 128 quant
380+
* ``triton-fp8w8a8g64``: weight per-channel quantization with group size 64
379381
* ``awq``
380382
* ``awq_marlin``
381383
* ``none`` (default)

lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _per_token_group_quant_fp8(
2929
xs_n,
3030
xs_row_major: tl.constexpr,
3131
BLOCK: tl.constexpr,
32+
NEED_MASK: tl.constexpr,
3233
):
3334
g_id = tl.program_id(0)
3435
y_ptr += g_id * y_stride
@@ -41,9 +42,15 @@ def _per_token_group_quant_fp8(
4142
y_s_ptr += col_id * xs_m + row_id # col major
4243

4344
cols = tl.arange(0, BLOCK) # N <= BLOCK
44-
mask = cols < N
4545

46-
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
46+
if NEED_MASK:
47+
mask = cols < N
48+
other = 0.0
49+
else:
50+
mask = None
51+
other = None
52+
53+
y = tl.load(y_ptr + cols, mask=mask, other=other).to(tl.float32)
4754
# Quant
4855
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
4956
y_s = _absmax / fp8_max
@@ -99,6 +106,7 @@ def lightllm_per_token_group_quant_fp8(
99106
xs_n=xs_n,
100107
xs_row_major=xs_row_major,
101108
BLOCK=BLOCK,
109+
NEED_MASK=BLOCK != group_size,
102110
num_warps=num_warps,
103111
num_stages=num_stages,
104112
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from lightllm.utils.dist_utils import get_current_device_id
5+
6+
7+
@triton.jit
8+
def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_N: tl.constexpr):
9+
m_index = tl.program_id(axis=0)
10+
11+
offs_n = tl.arange(0, BLOCK_N)
12+
mask = offs_n < N
13+
14+
x = tl.load(x_ptr + m_index * N + offs_n, mask=mask, other=0.0).to(tl.float32)
15+
16+
amax = tl.max(tl.abs(x))
17+
18+
max_fp8e4m3_val = 448.0
19+
scale = amax / max_fp8e4m3_val
20+
y = (x / (scale + 1e-6)).to(y_ptr.dtype.element_ty)
21+
22+
tl.store(y_ptr + m_index * N + offs_n, y, mask=mask)
23+
tl.store(s_ptr + m_index, scale)
24+
25+
26+
def mm_weight_quant(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
27+
assert x.is_contiguous(), "Input tensor must be contiguous"
28+
M, N = x.size()
29+
30+
y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device)
31+
s_scales = torch.empty((M, 1), dtype=torch.float32, device=x.device)
32+
33+
grid = (M,)
34+
weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_N=triton.next_power_of_2(N), num_warps=16)
35+
return y_quant, s_scales
36+
37+
38+
def weight_quant(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
39+
assert x.is_contiguous(), "Input tensor must be contiguous"
40+
x = x.cuda(get_current_device_id())
41+
if x.dim() == 3:
42+
y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device)
43+
s_scales = torch.empty((x.shape[0], x.shape[1], 1), dtype=torch.float32, device=x.device)
44+
for i in range(x.shape[0]):
45+
y_quant[i], s_scales[i] = mm_weight_quant(x[i])
46+
return y_quant, s_scales
47+
else:
48+
y_quant, s_scales = mm_weight_quant(x)
49+
return y_quant, s_scales

0 commit comments

Comments
 (0)