Skip to content

Commit 1957037

Browse files
committed
Add Triton INT4 dense kernels with dequant prefill path for Qwen3.5 MoE
Add three new Triton kernels for dense W4A16 linear projections that replace tinygemm's tiled INT4 format with simple [N, K//2] packed weights (same format as MoE experts): - int4_matmul: fused dequant+tl.dot GEMM for medium M (prefill crossover) - int4_matvec: bandwidth-optimized vec-mat for M=1 decode - dequant_w4_to_bf16: weight dequant for large-M prefill via Inductor mm W4DequantLinear wraps these with dual decode/prefill dispatch: - Decode (M=1): int4_matvec (73 tok/s, ~35% slower than tinygemm) - Prefill (M>1): dequant+F.linear via Inductor (3400 tok/s at 3K tokens, +67% over tinygemm baseline) Single 18GB weight blob (no duplication). Decode perf regression is a known trade-off for uniform weight format — to be revisited with a CUDA C++ matvec kernel. Also adds INT8 dynamic-activation MoE tests and comprehensive correctness tests (48 tests, all passing at rtol=0.01). Co-authored-by: Claude <noreplyanthropic.com> ghstack-source-id: 91b6a04 Pull Request resolved: #19188
1 parent 6204cf4 commit 1957037

5 files changed

Lines changed: 1023 additions & 0 deletions

File tree

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Functional correctness tests for INT4 matmul and dequant Triton kernels.
10+
11+
Tests both int4_matmul (fused W4A16 GEMM) and dequant_w4_to_bf16 (weight
12+
dequantization) against eager PyTorch references. Uses 0.01 absolute
13+
tolerance to account for INT4 quantization noise and bf16 rounding.
14+
15+
Usage:
16+
python -m pytest backends/cuda/tests/test_int4_matmul.py -v
17+
"""
18+
19+
import unittest
20+
21+
import torch
22+
import torch.nn as nn
23+
24+
from executorch.backends.cuda.triton.kernels.int4_matmul import (
25+
dequant_w4_to_bf16,
26+
int4_matmul,
27+
int4_matvec,
28+
)
29+
30+
ATOL = 0.01
31+
DEVICE = "cuda"
32+
33+
34+
def _quantize_simple(w_bf16, group_size):
35+
"""Quantize [N, K] bf16 weight to simple packed INT4 + per-group scales.
36+
37+
Returns:
38+
w_packed: [N, K//2] int8 — two INT4 values per byte
39+
w_scale: [N, K//group_size] bf16 — symmetric scales
40+
w_ref: [N, K] bf16 — dequantized reference matching kernel's computation
41+
"""
42+
N, K = w_bf16.shape
43+
w = w_bf16.float()
44+
w_grouped = w.reshape(N, K // group_size, group_size)
45+
scale = w_grouped.abs().amax(dim=-1, keepdim=True) / 7.0
46+
scale = scale.clamp(min=1e-10)
47+
int_data = (w_grouped / scale).round().clamp(-8, 7).to(torch.int8)
48+
# Kernel dequant: (uint4 - 8) * scale = int_data * scale
49+
scale_bf16 = scale.to(torch.bfloat16)
50+
w_ref = ((int_data.float()) * scale_bf16.float()).reshape(N, K).to(torch.bfloat16)
51+
scale_bf16 = scale_bf16.reshape(N, K // group_size)
52+
int_data = int_data.reshape(N, K)
53+
uint4 = (int_data + 8).to(torch.int16)
54+
packed = (uint4[:, 0::2] | (uint4[:, 1::2] << 4)).to(torch.int8)
55+
return packed.to(DEVICE), scale_bf16.to(DEVICE), w_ref.to(DEVICE)
56+
57+
58+
def _eager_int4_matmul(x, w_ref):
59+
"""Reference matmul: x @ w_ref.T in float32, cast to bf16."""
60+
return (x.float() @ w_ref.float().T).to(torch.bfloat16)
61+
62+
63+
class TestDequantW4ToBf16(unittest.TestCase):
64+
"""Tests for dequant_w4_to_bf16 Triton kernel."""
65+
66+
def _run_dequant(self, N, K, group_size):
67+
torch.manual_seed(42)
68+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
69+
packed, scale, w_ref = _quantize_simple(w, group_size)
70+
71+
out = dequant_w4_to_bf16(packed, scale, group_size)
72+
73+
self.assertEqual(out.shape, (N, K))
74+
self.assertEqual(out.dtype, torch.bfloat16)
75+
max_err = (out.float() - w_ref.float()).abs().max().item()
76+
self.assertLess(
77+
max_err, ATOL, f"dequant [{N}x{K}] gs={group_size}: max_err={max_err}"
78+
)
79+
80+
def test_square(self):
81+
self._run_dequant(256, 256, 32)
82+
83+
def test_tall(self):
84+
self._run_dequant(2048, 256, 32)
85+
86+
def test_wide(self):
87+
self._run_dequant(256, 2048, 128)
88+
89+
def test_production_qkv(self):
90+
self._run_dequant(2048, 2048, 128)
91+
92+
def test_production_shared_expert(self):
93+
self._run_dequant(1024, 2048, 128)
94+
95+
def test_group_size_32(self):
96+
self._run_dequant(512, 512, 32)
97+
98+
def test_group_size_128(self):
99+
self._run_dequant(512, 2048, 128)
100+
101+
def test_non_power_of_two_N(self):
102+
self._run_dequant(12352, 2048, 128)
103+
104+
def test_small(self):
105+
self._run_dequant(16, 64, 32)
106+
107+
108+
class TestInt4Matmul(unittest.TestCase):
109+
"""Tests for int4_matmul Triton kernel (fused W4A16 GEMM)."""
110+
111+
def _run_matmul(self, M, N, K, group_size):
112+
torch.manual_seed(42)
113+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
114+
packed, scale, w_ref = _quantize_simple(w, group_size)
115+
x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
116+
117+
out = int4_matmul(x, packed, scale, group_size)
118+
ref = _eager_int4_matmul(x, w_ref)
119+
120+
self.assertEqual(out.shape, (M, N))
121+
self.assertEqual(out.dtype, torch.bfloat16)
122+
self.assertTrue(
123+
torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01),
124+
f"int4_matmul M={M} [{N}x{K}] gs={group_size}: "
125+
f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, "
126+
f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}",
127+
)
128+
129+
# --- Decode (M=1) ---
130+
def test_decode_square(self):
131+
self._run_matmul(1, 256, 256, 32)
132+
133+
def test_decode_qkv(self):
134+
self._run_matmul(1, 2048, 2048, 128)
135+
136+
def test_decode_kv_proj(self):
137+
self._run_matmul(1, 256, 2048, 128)
138+
139+
def test_decode_shared_expert(self):
140+
self._run_matmul(1, 1024, 2048, 128)
141+
142+
def test_decode_large_N(self):
143+
self._run_matmul(1, 12352, 2048, 128)
144+
145+
# --- Small prefill ---
146+
def test_prefill_4(self):
147+
self._run_matmul(4, 2048, 2048, 128)
148+
149+
def test_prefill_16(self):
150+
self._run_matmul(16, 2048, 2048, 128)
151+
152+
def test_prefill_64(self):
153+
self._run_matmul(64, 2048, 2048, 128)
154+
155+
# --- Large prefill ---
156+
def test_prefill_256(self):
157+
self._run_matmul(256, 2048, 2048, 128)
158+
159+
def test_prefill_1024(self):
160+
self._run_matmul(1024, 2048, 2048, 128)
161+
162+
def test_prefill_4095(self):
163+
self._run_matmul(4095, 2048, 2048, 128)
164+
165+
# --- Edge cases ---
166+
def test_group_size_32(self):
167+
self._run_matmul(4, 512, 512, 32)
168+
169+
def test_non_power_of_two_M(self):
170+
self._run_matmul(7, 256, 256, 32)
171+
172+
def test_non_power_of_two_N(self):
173+
self._run_matmul(4, 12352, 2048, 128)
174+
175+
def test_small(self):
176+
self._run_matmul(1, 16, 64, 32)
177+
178+
179+
class TestInt4Matvec(unittest.TestCase):
180+
"""Tests for int4_matvec Triton kernel (M=1 decode)."""
181+
182+
def _run_matvec(self, N, K, group_size):
183+
torch.manual_seed(42)
184+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
185+
packed, scale, w_ref = _quantize_simple(w, group_size)
186+
x = torch.randn(K, dtype=torch.bfloat16, device=DEVICE)
187+
188+
out = int4_matvec(x.unsqueeze(0), packed, scale, group_size)
189+
ref = int4_matmul(x.unsqueeze(0), packed, scale, group_size)
190+
191+
self.assertEqual(out.shape, (1, N))
192+
self.assertEqual(out.dtype, torch.bfloat16)
193+
# atol=1.0 for large accumulation across K, rtol=0.01 for relative
194+
self.assertTrue(
195+
torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01),
196+
f"int4_matvec [{N}x{K}] gs={group_size}: "
197+
f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, "
198+
f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}",
199+
)
200+
201+
def test_qkv_proj(self):
202+
self._run_matvec(2048, 2048, 128)
203+
204+
def test_kv_proj(self):
205+
self._run_matvec(256, 2048, 128)
206+
207+
def test_shared_expert(self):
208+
self._run_matvec(1024, 2048, 128)
209+
210+
def test_large_N(self):
211+
self._run_matvec(12352, 2048, 128)
212+
213+
def test_group_size_32(self):
214+
self._run_matvec(512, 512, 32)
215+
216+
def test_small(self):
217+
self._run_matvec(16, 64, 32)
218+
219+
def test_matches_int4_matmul(self):
220+
"""Matvec output matches int4_matmul at M=1."""
221+
torch.manual_seed(42)
222+
N, K, gs = 2048, 2048, 128
223+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
224+
packed, scale, _ = _quantize_simple(w, gs)
225+
x = torch.randn(1, K, dtype=torch.bfloat16, device=DEVICE)
226+
227+
out_mv = int4_matvec(x, packed, scale, gs)
228+
out_mm = int4_matmul(x, packed, scale, gs)
229+
230+
self.assertTrue(
231+
torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01),
232+
f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}",
233+
)
234+
235+
236+
class TestDequantThenMatmul(unittest.TestCase):
237+
"""Tests that dequant + F.linear matches int4_matmul (both paths should agree)."""
238+
239+
def _run(self, M, N, K, group_size):
240+
torch.manual_seed(42)
241+
w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE)
242+
packed, scale, w_ref = _quantize_simple(w, group_size)
243+
x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
244+
245+
# Path A: fused int4_matmul
246+
out_fused = int4_matmul(x, packed, scale, group_size)
247+
248+
# Path B: dequant + F.linear
249+
w_bf16 = dequant_w4_to_bf16(packed, scale, group_size)
250+
out_dequant = torch.nn.functional.linear(x, w_bf16)
251+
252+
self.assertTrue(
253+
torch.allclose(
254+
out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01
255+
),
256+
f"fused vs dequant M={M} [{N}x{K}]: "
257+
f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}",
258+
)
259+
260+
def test_decode(self):
261+
self._run(1, 2048, 2048, 128)
262+
263+
def test_prefill_short(self):
264+
self._run(64, 2048, 2048, 128)
265+
266+
def test_prefill_long(self):
267+
self._run(1024, 2048, 2048, 128)
268+
269+
def test_large_N(self):
270+
self._run(4, 12352, 2048, 128)
271+
272+
273+
if __name__ == "__main__":
274+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010
fused_moe_batched_gemm,
1111
moe_align_block_size,
1212
)
13+
14+
from executorch.backends.cuda.triton.kernels.int4_matmul import (
15+
dequant_w4_to_bf16,
16+
int4_matvec,
17+
)
1318
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
1419
from executorch.backends.cuda.triton.kernels.topk import topk
1520

1621
__all__ = [
22+
"dequant_w4_to_bf16",
1723
"fused_moe",
1824
"fused_moe_batched",
1925
"fused_moe_batched_gemm",

0 commit comments

Comments
 (0)