Skip to content

Commit e705362

Browse files
committed
Add Triton INT4 dense kernels with dequant prefill path for Qwen3.5 MoE
Add three new Triton kernels for dense W4A16 linear projections that replace tinygemm's tiled INT4 format with simple [N, K//2] packed weights (same format as MoE experts): - int4_matmul: fused dequant+tl.dot GEMM for medium M (prefill crossover) - int4_matvec: bandwidth-optimized vec-mat for M=1 decode - dequant_w4_to_bf16: weight dequant for large-M prefill via Inductor mm W4DequantLinear wraps these with dual decode/prefill dispatch: - Decode (M=1): int4_matvec (73 tok/s, ~35% slower than tinygemm) - Prefill (M>1): dequant+F.linear via Inductor (3400 tok/s at 3K tokens, +67% over tinygemm baseline) Single 18GB weight blob (no duplication). Decode perf regression is a known trade-off for uniform weight format — to be revisited with a CUDA C++ matvec kernel. Also adds INT8 dynamic-activation MoE tests and comprehensive correctness tests (48 tests, all passing at rtol=0.01). Co-authored-by: Claude <noreplyanthropic.com> ghstack-source-id: f484044 Pull Request resolved: #19188
1 parent 32c49a3 commit e705362

9 files changed

Lines changed: 1996 additions & 0 deletions

File tree

Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
#!/usr/bin/env python3
2+
"""Benchmark INT4 matmul strategies for M=1 decode."""
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
from triton.testing import do_bench
8+
9+
10+
# Strategy 1: tl.dot with BLOCK_M=16 padding (current approach)
11+
@triton.jit
12+
def _int4_dot_kernel(
13+
A,
14+
B,
15+
C,
16+
B_scale,
17+
M,
18+
N: tl.constexpr,
19+
K: tl.constexpr,
20+
stride_am,
21+
stride_ak,
22+
stride_bn,
23+
stride_bk,
24+
stride_cm,
25+
stride_cn,
26+
stride_bsn,
27+
stride_bsk,
28+
group_size: tl.constexpr,
29+
BLOCK_SIZE_M: tl.constexpr,
30+
BLOCK_SIZE_N: tl.constexpr,
31+
BLOCK_SIZE_K: tl.constexpr,
32+
):
33+
pid = tl.program_id(0)
34+
num_n = tl.cdiv(N, BLOCK_SIZE_N)
35+
mb = pid // num_n
36+
nb = pid % num_n
37+
offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
38+
offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
39+
offs_k = tl.arange(0, BLOCK_SIZE_K)
40+
mm = offs_m < M
41+
nm = offs_n < N
42+
a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
43+
b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk
44+
b_shift = (offs_k[:, None] % 2) * 4
45+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
46+
for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
47+
kr = K - ks * BLOCK_SIZE_K
48+
km = offs_k < kr
49+
a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0)
50+
b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0)
51+
b = (b >> b_shift) & 0xF
52+
gi = (BLOCK_SIZE_K * ks) // group_size
53+
sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk
54+
bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32)
55+
bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16)
56+
acc += tl.dot(a.to(tl.bfloat16), bd)
57+
a_ptrs += BLOCK_SIZE_K * stride_ak
58+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
59+
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
60+
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :])
61+
62+
63+
# Strategy 2: vec-mat with tl.sum (no tl.dot, no M padding waste)
64+
@triton.jit
65+
def _int4_vecmat_kernel(
66+
A,
67+
B,
68+
C,
69+
B_scale,
70+
N: tl.constexpr,
71+
K: tl.constexpr,
72+
stride_bn,
73+
stride_bk,
74+
stride_bsn,
75+
stride_bsk,
76+
group_size: tl.constexpr,
77+
BLOCK_SIZE_N: tl.constexpr,
78+
BLOCK_SIZE_K: tl.constexpr,
79+
):
80+
nb = tl.program_id(0)
81+
offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82+
offs_k = tl.arange(0, BLOCK_SIZE_K)
83+
nm = offs_n < N
84+
b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk
85+
b_shift = (offs_k[:, None] % 2) * 4
86+
a_ptrs = A + offs_k
87+
acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
88+
for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
89+
kr = K - ks * BLOCK_SIZE_K
90+
km = offs_k < kr
91+
a = tl.load(a_ptrs, mask=km, other=0.0).to(tl.float32) # [BK]
92+
b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0)
93+
b = (b >> b_shift) & 0xF
94+
gi = (BLOCK_SIZE_K * ks) // group_size
95+
sp = B_scale + offs_n * stride_bsn + gi * stride_bsk
96+
bs = tl.load(sp, mask=nm, other=0.0).to(tl.float32) # [BN]
97+
bd = (b.to(tl.float32) - 8.0) * bs[None, :] # [BK, BN]
98+
acc += tl.sum(a[:, None] * bd, axis=0) # [BN]
99+
a_ptrs += BLOCK_SIZE_K
100+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
101+
c_ptrs = C + offs_n
102+
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=nm)
103+
104+
105+
# Strategy 3: split-K with tl.dot — more CTAs, then atomic reduce
106+
@triton.jit
107+
def _int4_splitk_kernel(
108+
A,
109+
B,
110+
C,
111+
B_scale,
112+
M,
113+
N: tl.constexpr,
114+
K: tl.constexpr,
115+
stride_am,
116+
stride_ak,
117+
stride_bn,
118+
stride_bk,
119+
stride_cm,
120+
stride_cn,
121+
stride_bsn,
122+
stride_bsk,
123+
group_size: tl.constexpr,
124+
K_SPLITS: tl.constexpr,
125+
BLOCK_SIZE_M: tl.constexpr,
126+
BLOCK_SIZE_N: tl.constexpr,
127+
BLOCK_SIZE_K: tl.constexpr,
128+
):
129+
pid = tl.program_id(0)
130+
num_n = tl.cdiv(N, BLOCK_SIZE_N)
131+
num_nk = num_n * K_SPLITS
132+
mb = pid // num_nk
133+
nk = pid % num_nk
134+
nb = nk // K_SPLITS
135+
ks_id = nk % K_SPLITS
136+
137+
offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
138+
offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
139+
offs_k = tl.arange(0, BLOCK_SIZE_K)
140+
mm = offs_m < M
141+
nm = offs_n < N
142+
143+
k_per_split = tl.cdiv(K, K_SPLITS)
144+
k_start = ks_id * k_per_split
145+
k_end = tl.minimum(k_start + k_per_split, K)
146+
147+
a_ptrs = A + offs_m[:, None] * stride_am + (k_start + offs_k[None, :]) * stride_ak
148+
b_ptrs = (
149+
B + offs_n[None, :] * stride_bn + ((k_start + offs_k[:, None]) // 2) * stride_bk
150+
)
151+
b_shift = ((k_start + offs_k[:, None]) % 2) * 4
152+
153+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
154+
num_steps = tl.cdiv(k_end - k_start, BLOCK_SIZE_K)
155+
for step in range(0, num_steps):
156+
abs_k = k_start + step * BLOCK_SIZE_K + offs_k
157+
km = abs_k < k_end
158+
a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0)
159+
b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0)
160+
b = (b >> b_shift) & 0xF
161+
gi = (k_start + step * BLOCK_SIZE_K) // group_size
162+
sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk
163+
bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32)
164+
bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16)
165+
acc += tl.dot(a.to(tl.bfloat16), bd)
166+
a_ptrs += BLOCK_SIZE_K * stride_ak
167+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
168+
b_shift = (offs_k[:, None] % 2) * 4 # reset shift after first step
169+
170+
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
171+
if K_SPLITS == 1:
172+
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :])
173+
else:
174+
tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :])
175+
176+
177+
def main():
178+
import torch.nn as nn
179+
from executorch.extension.llm.export.quantize import quantize_model_
180+
from torchao.quantization.quant_primitives import (
181+
choose_qparams_affine,
182+
MappingType,
183+
quantize_affine,
184+
)
185+
186+
gs = 128
187+
shapes = [
188+
(2048, 2048, "q/o_proj"),
189+
(12352, 2048, "shared_g+u"),
190+
(256, 2048, "k/v_proj"),
191+
]
192+
193+
for N, K, label in shapes:
194+
w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
195+
sc, zp = choose_qparams_affine(
196+
w.float(),
197+
MappingType.SYMMETRIC,
198+
(1, gs),
199+
target_dtype=torch.int8,
200+
quant_min=-8,
201+
quant_max=7,
202+
)
203+
idata = quantize_affine(
204+
w.float(),
205+
(1, gs),
206+
sc,
207+
zp,
208+
output_dtype=torch.int8,
209+
quant_min=-8,
210+
quant_max=7,
211+
)
212+
u4 = (idata + 8).to(torch.int16)
213+
packed = (u4[:, 0::2] | (u4[:, 1::2] << 4)).to(torch.int8).cuda()
214+
w_scale = sc.reshape(N, -1).to(torch.bfloat16).cuda()
215+
216+
linear = nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
217+
wr = nn.ModuleDict({"linear": linear})
218+
quantize_model_(
219+
wr,
220+
qlinear_config="4w",
221+
qlinear_group_size=gs,
222+
qlinear_packing_format="tile_packed_to_4d",
223+
)
224+
tw = wr.linear.weight
225+
226+
x = torch.randn(1, K, dtype=torch.bfloat16, device="cuda")
227+
t_tiny = (
228+
do_bench(
229+
lambda: nn.functional.linear(x, tw),
230+
warmup=50,
231+
rep=200,
232+
return_mode="median",
233+
)
234+
* 1000
235+
)
236+
237+
print(f"\n{'='*70}")
238+
print(f"[{N}x{K}] {label} — M=1, tinygemm={t_tiny:.1f}us")
239+
print(f"{'='*70}")
240+
241+
# Strategy 1: tl.dot with various configs
242+
print("\n--- Strategy 1: tl.dot (BLOCK_M=16 padding) ---")
243+
out = torch.empty(1, N, dtype=torch.bfloat16, device="cuda")
244+
for BN, BK, warps, stages in [
245+
(16, 128, 4, 5),
246+
(32, 128, 4, 5),
247+
(32, 256, 4, 3),
248+
(16, 128, 2, 5),
249+
(32, 128, 2, 5),
250+
]:
251+
grid = ((N + BN - 1) // BN,)
252+
253+
def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid):
254+
_int4_dot_kernel[_g](
255+
x,
256+
packed,
257+
out,
258+
w_scale,
259+
1,
260+
N,
261+
K,
262+
x.stride(0),
263+
x.stride(1),
264+
packed.stride(0),
265+
packed.stride(1),
266+
out.stride(0),
267+
out.stride(1),
268+
w_scale.stride(0),
269+
w_scale.stride(1),
270+
gs,
271+
BLOCK_SIZE_M=16,
272+
BLOCK_SIZE_N=_BN,
273+
BLOCK_SIZE_K=_BK,
274+
num_warps=_w,
275+
num_stages=_s,
276+
)
277+
278+
try:
279+
run()
280+
t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000
281+
print(
282+
f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}"
283+
)
284+
except Exception as e:
285+
print(
286+
f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}"
287+
)
288+
289+
# Strategy 2: vec-mat with tl.sum (no padding waste)
290+
print("\n--- Strategy 2: vec-mat tl.sum (no M padding) ---")
291+
for BN, BK, warps, stages in [
292+
(16, 128, 4, 5),
293+
(32, 128, 4, 5),
294+
(64, 128, 4, 5),
295+
(16, 256, 4, 3),
296+
(32, 256, 4, 3),
297+
(16, 128, 2, 5),
298+
(32, 128, 2, 5),
299+
(16, 64, 2, 5),
300+
(32, 64, 2, 5),
301+
]:
302+
grid = ((N + BN - 1) // BN,)
303+
out1d = torch.empty(N, dtype=torch.bfloat16, device="cuda")
304+
305+
def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid):
306+
_int4_vecmat_kernel[_g](
307+
x,
308+
packed,
309+
out1d,
310+
w_scale,
311+
N,
312+
K,
313+
packed.stride(0),
314+
packed.stride(1),
315+
w_scale.stride(0),
316+
w_scale.stride(1),
317+
gs,
318+
BLOCK_SIZE_N=_BN,
319+
BLOCK_SIZE_K=_BK,
320+
num_warps=_w,
321+
num_stages=_s,
322+
)
323+
324+
try:
325+
run()
326+
t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000
327+
print(
328+
f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}"
329+
)
330+
except Exception as e:
331+
print(
332+
f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}"
333+
)
334+
335+
# Strategy 3: split-K with tl.dot
336+
print("\n--- Strategy 3: split-K tl.dot ---")
337+
for BN, BK, splits, warps, stages in [
338+
(32, 128, 4, 4, 3),
339+
(32, 128, 8, 4, 3),
340+
(32, 128, 16, 4, 3),
341+
(16, 128, 4, 4, 3),
342+
(16, 128, 8, 4, 3),
343+
(16, 128, 16, 4, 3),
344+
(64, 128, 4, 4, 3),
345+
(64, 128, 8, 4, 3),
346+
]:
347+
grid = (((N + BN - 1) // BN) * splits,)
348+
out_sk = torch.zeros(1, N, dtype=torch.bfloat16, device="cuda")
349+
350+
def run(_BN=BN, _BK=BK, _sp=splits, _w=warps, _s=stages, _g=grid):
351+
out_sk.zero_()
352+
_int4_splitk_kernel[_g](
353+
x,
354+
packed,
355+
out_sk,
356+
w_scale,
357+
1,
358+
N,
359+
K,
360+
x.stride(0),
361+
x.stride(1),
362+
packed.stride(0),
363+
packed.stride(1),
364+
out_sk.stride(0),
365+
out_sk.stride(1),
366+
w_scale.stride(0),
367+
w_scale.stride(1),
368+
gs,
369+
K_SPLITS=_sp,
370+
BLOCK_SIZE_M=16,
371+
BLOCK_SIZE_N=_BN,
372+
BLOCK_SIZE_K=_BK,
373+
num_warps=_w,
374+
num_stages=_s,
375+
)
376+
377+
try:
378+
run()
379+
t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000
380+
print(
381+
f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}"
382+
)
383+
except Exception as e:
384+
print(
385+
f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: FAIL {str(e)[:50]}"
386+
)
387+
388+
del wr, tw, packed, w_scale
389+
torch.cuda.empty_cache()
390+
391+
392+
if __name__ == "__main__":
393+
main()

0 commit comments

Comments
 (0)