|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +""" |
| 9 | +Functional correctness tests for INT4 matmul and dequant Triton kernels. |
| 10 | +
|
| 11 | +Tests both int4_matmul (fused W4A16 GEMM) and dequant_w4_to_bf16 (weight |
| 12 | +dequantization) against eager PyTorch references. Uses 0.01 absolute |
| 13 | +tolerance to account for INT4 quantization noise and bf16 rounding. |
| 14 | +
|
| 15 | +Usage: |
| 16 | + python -m pytest backends/cuda/tests/test_int4_matmul.py -v |
| 17 | +""" |
| 18 | + |
| 19 | +import unittest |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch.nn as nn |
| 23 | + |
| 24 | +from executorch.backends.cuda.triton.kernels.int4_matmul import ( |
| 25 | + dequant_w4_to_bf16, |
| 26 | + int4_matmul, |
| 27 | + int4_matvec, |
| 28 | +) |
| 29 | + |
| 30 | +ATOL = 0.01 |
| 31 | +DEVICE = "cuda" |
| 32 | + |
| 33 | + |
| 34 | +def _quantize_simple(w_bf16, group_size): |
| 35 | + """Quantize [N, K] bf16 weight to simple packed INT4 + per-group scales. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + w_packed: [N, K//2] int8 — two INT4 values per byte |
| 39 | + w_scale: [N, K//group_size] bf16 — symmetric scales |
| 40 | + w_ref: [N, K] bf16 — dequantized reference matching kernel's computation |
| 41 | + """ |
| 42 | + N, K = w_bf16.shape |
| 43 | + w = w_bf16.float() |
| 44 | + w_grouped = w.reshape(N, K // group_size, group_size) |
| 45 | + scale = w_grouped.abs().amax(dim=-1, keepdim=True) / 7.0 |
| 46 | + scale = scale.clamp(min=1e-10) |
| 47 | + int_data = (w_grouped / scale).round().clamp(-8, 7).to(torch.int8) |
| 48 | + # Kernel dequant: (uint4 - 8) * scale = int_data * scale |
| 49 | + scale_bf16 = scale.to(torch.bfloat16) |
| 50 | + w_ref = ((int_data.float()) * scale_bf16.float()).reshape(N, K).to(torch.bfloat16) |
| 51 | + scale_bf16 = scale_bf16.reshape(N, K // group_size) |
| 52 | + int_data = int_data.reshape(N, K) |
| 53 | + uint4 = (int_data + 8).to(torch.int16) |
| 54 | + packed = (uint4[:, 0::2] | (uint4[:, 1::2] << 4)).to(torch.int8) |
| 55 | + return packed.to(DEVICE), scale_bf16.to(DEVICE), w_ref.to(DEVICE) |
| 56 | + |
| 57 | + |
| 58 | +def _eager_int4_matmul(x, w_ref): |
| 59 | + """Reference matmul: x @ w_ref.T in float32, cast to bf16.""" |
| 60 | + return (x.float() @ w_ref.float().T).to(torch.bfloat16) |
| 61 | + |
| 62 | + |
| 63 | +class TestDequantW4ToBf16(unittest.TestCase): |
| 64 | + """Tests for dequant_w4_to_bf16 Triton kernel.""" |
| 65 | + |
| 66 | + def _run_dequant(self, N, K, group_size): |
| 67 | + torch.manual_seed(42) |
| 68 | + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) |
| 69 | + packed, scale, w_ref = _quantize_simple(w, group_size) |
| 70 | + |
| 71 | + out = dequant_w4_to_bf16(packed, scale, group_size) |
| 72 | + |
| 73 | + self.assertEqual(out.shape, (N, K)) |
| 74 | + self.assertEqual(out.dtype, torch.bfloat16) |
| 75 | + max_err = (out.float() - w_ref.float()).abs().max().item() |
| 76 | + self.assertLess( |
| 77 | + max_err, ATOL, f"dequant [{N}x{K}] gs={group_size}: max_err={max_err}" |
| 78 | + ) |
| 79 | + |
| 80 | + def test_square(self): |
| 81 | + self._run_dequant(256, 256, 32) |
| 82 | + |
| 83 | + def test_tall(self): |
| 84 | + self._run_dequant(2048, 256, 32) |
| 85 | + |
| 86 | + def test_wide(self): |
| 87 | + self._run_dequant(256, 2048, 128) |
| 88 | + |
| 89 | + def test_production_qkv(self): |
| 90 | + self._run_dequant(2048, 2048, 128) |
| 91 | + |
| 92 | + def test_production_shared_expert(self): |
| 93 | + self._run_dequant(1024, 2048, 128) |
| 94 | + |
| 95 | + def test_group_size_32(self): |
| 96 | + self._run_dequant(512, 512, 32) |
| 97 | + |
| 98 | + def test_group_size_128(self): |
| 99 | + self._run_dequant(512, 2048, 128) |
| 100 | + |
| 101 | + def test_non_power_of_two_N(self): |
| 102 | + self._run_dequant(12352, 2048, 128) |
| 103 | + |
| 104 | + def test_small(self): |
| 105 | + self._run_dequant(16, 64, 32) |
| 106 | + |
| 107 | + |
| 108 | +class TestInt4Matmul(unittest.TestCase): |
| 109 | + """Tests for int4_matmul Triton kernel (fused W4A16 GEMM).""" |
| 110 | + |
| 111 | + def _run_matmul(self, M, N, K, group_size): |
| 112 | + torch.manual_seed(42) |
| 113 | + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) |
| 114 | + packed, scale, w_ref = _quantize_simple(w, group_size) |
| 115 | + x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) |
| 116 | + |
| 117 | + out = int4_matmul(x, packed, scale, group_size) |
| 118 | + ref = _eager_int4_matmul(x, w_ref) |
| 119 | + |
| 120 | + self.assertEqual(out.shape, (M, N)) |
| 121 | + self.assertEqual(out.dtype, torch.bfloat16) |
| 122 | + self.assertTrue( |
| 123 | + torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01), |
| 124 | + f"int4_matmul M={M} [{N}x{K}] gs={group_size}: " |
| 125 | + f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, " |
| 126 | + f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}", |
| 127 | + ) |
| 128 | + |
| 129 | + # --- Decode (M=1) --- |
| 130 | + def test_decode_square(self): |
| 131 | + self._run_matmul(1, 256, 256, 32) |
| 132 | + |
| 133 | + def test_decode_qkv(self): |
| 134 | + self._run_matmul(1, 2048, 2048, 128) |
| 135 | + |
| 136 | + def test_decode_kv_proj(self): |
| 137 | + self._run_matmul(1, 256, 2048, 128) |
| 138 | + |
| 139 | + def test_decode_shared_expert(self): |
| 140 | + self._run_matmul(1, 1024, 2048, 128) |
| 141 | + |
| 142 | + def test_decode_large_N(self): |
| 143 | + self._run_matmul(1, 12352, 2048, 128) |
| 144 | + |
| 145 | + # --- Small prefill --- |
| 146 | + def test_prefill_4(self): |
| 147 | + self._run_matmul(4, 2048, 2048, 128) |
| 148 | + |
| 149 | + def test_prefill_16(self): |
| 150 | + self._run_matmul(16, 2048, 2048, 128) |
| 151 | + |
| 152 | + def test_prefill_64(self): |
| 153 | + self._run_matmul(64, 2048, 2048, 128) |
| 154 | + |
| 155 | + # --- Large prefill --- |
| 156 | + def test_prefill_256(self): |
| 157 | + self._run_matmul(256, 2048, 2048, 128) |
| 158 | + |
| 159 | + def test_prefill_1024(self): |
| 160 | + self._run_matmul(1024, 2048, 2048, 128) |
| 161 | + |
| 162 | + def test_prefill_4095(self): |
| 163 | + self._run_matmul(4095, 2048, 2048, 128) |
| 164 | + |
| 165 | + # --- Edge cases --- |
| 166 | + def test_group_size_32(self): |
| 167 | + self._run_matmul(4, 512, 512, 32) |
| 168 | + |
| 169 | + def test_non_power_of_two_M(self): |
| 170 | + self._run_matmul(7, 256, 256, 32) |
| 171 | + |
| 172 | + def test_non_power_of_two_N(self): |
| 173 | + self._run_matmul(4, 12352, 2048, 128) |
| 174 | + |
| 175 | + def test_small(self): |
| 176 | + self._run_matmul(1, 16, 64, 32) |
| 177 | + |
| 178 | + |
| 179 | +class TestInt4Matvec(unittest.TestCase): |
| 180 | + """Tests for int4_matvec Triton kernel (M=1 decode).""" |
| 181 | + |
| 182 | + def _run_matvec(self, N, K, group_size): |
| 183 | + torch.manual_seed(42) |
| 184 | + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) |
| 185 | + packed, scale, w_ref = _quantize_simple(w, group_size) |
| 186 | + x = torch.randn(K, dtype=torch.bfloat16, device=DEVICE) |
| 187 | + |
| 188 | + out = int4_matvec(x.unsqueeze(0), packed, scale, group_size) |
| 189 | + ref = int4_matmul(x.unsqueeze(0), packed, scale, group_size) |
| 190 | + |
| 191 | + self.assertEqual(out.shape, (1, N)) |
| 192 | + self.assertEqual(out.dtype, torch.bfloat16) |
| 193 | + # atol=1.0 for large accumulation across K, rtol=0.01 for relative |
| 194 | + self.assertTrue( |
| 195 | + torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01), |
| 196 | + f"int4_matvec [{N}x{K}] gs={group_size}: " |
| 197 | + f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, " |
| 198 | + f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}", |
| 199 | + ) |
| 200 | + |
| 201 | + def test_qkv_proj(self): |
| 202 | + self._run_matvec(2048, 2048, 128) |
| 203 | + |
| 204 | + def test_kv_proj(self): |
| 205 | + self._run_matvec(256, 2048, 128) |
| 206 | + |
| 207 | + def test_shared_expert(self): |
| 208 | + self._run_matvec(1024, 2048, 128) |
| 209 | + |
| 210 | + def test_large_N(self): |
| 211 | + self._run_matvec(12352, 2048, 128) |
| 212 | + |
| 213 | + def test_group_size_32(self): |
| 214 | + self._run_matvec(512, 512, 32) |
| 215 | + |
| 216 | + def test_small(self): |
| 217 | + self._run_matvec(16, 64, 32) |
| 218 | + |
| 219 | + def test_matches_int4_matmul(self): |
| 220 | + """Matvec output matches int4_matmul at M=1.""" |
| 221 | + torch.manual_seed(42) |
| 222 | + N, K, gs = 2048, 2048, 128 |
| 223 | + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) |
| 224 | + packed, scale, _ = _quantize_simple(w, gs) |
| 225 | + x = torch.randn(1, K, dtype=torch.bfloat16, device=DEVICE) |
| 226 | + |
| 227 | + out_mv = int4_matvec(x, packed, scale, gs) |
| 228 | + out_mm = int4_matmul(x, packed, scale, gs) |
| 229 | + |
| 230 | + self.assertTrue( |
| 231 | + torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01), |
| 232 | + f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}", |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +class TestDequantThenMatmul(unittest.TestCase): |
| 237 | + """Tests that dequant + F.linear matches int4_matmul (both paths should agree).""" |
| 238 | + |
| 239 | + def _run(self, M, N, K, group_size): |
| 240 | + torch.manual_seed(42) |
| 241 | + w = torch.randn(N, K, dtype=torch.bfloat16, device=DEVICE) |
| 242 | + packed, scale, w_ref = _quantize_simple(w, group_size) |
| 243 | + x = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) |
| 244 | + |
| 245 | + # Path A: fused int4_matmul |
| 246 | + out_fused = int4_matmul(x, packed, scale, group_size) |
| 247 | + |
| 248 | + # Path B: dequant + F.linear |
| 249 | + w_bf16 = dequant_w4_to_bf16(packed, scale, group_size) |
| 250 | + out_dequant = torch.nn.functional.linear(x, w_bf16) |
| 251 | + |
| 252 | + self.assertTrue( |
| 253 | + torch.allclose( |
| 254 | + out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01 |
| 255 | + ), |
| 256 | + f"fused vs dequant M={M} [{N}x{K}]: " |
| 257 | + f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}", |
| 258 | + ) |
| 259 | + |
| 260 | + def test_decode(self): |
| 261 | + self._run(1, 2048, 2048, 128) |
| 262 | + |
| 263 | + def test_prefill_short(self): |
| 264 | + self._run(64, 2048, 2048, 128) |
| 265 | + |
| 266 | + def test_prefill_long(self): |
| 267 | + self._run(1024, 2048, 2048, 128) |
| 268 | + |
| 269 | + def test_large_N(self): |
| 270 | + self._run(4, 12352, 2048, 128) |
| 271 | + |
| 272 | + |
| 273 | +if __name__ == "__main__": |
| 274 | + unittest.main() |
0 commit comments