|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Benchmark INT4 matmul strategies for M=1 decode.""" |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | +from triton.testing import do_bench |
| 8 | + |
| 9 | + |
| 10 | +# Strategy 1: tl.dot with BLOCK_M=16 padding (current approach) |
| 11 | +@triton.jit |
| 12 | +def _int4_dot_kernel( |
| 13 | + A, |
| 14 | + B, |
| 15 | + C, |
| 16 | + B_scale, |
| 17 | + M, |
| 18 | + N: tl.constexpr, |
| 19 | + K: tl.constexpr, |
| 20 | + stride_am, |
| 21 | + stride_ak, |
| 22 | + stride_bn, |
| 23 | + stride_bk, |
| 24 | + stride_cm, |
| 25 | + stride_cn, |
| 26 | + stride_bsn, |
| 27 | + stride_bsk, |
| 28 | + group_size: tl.constexpr, |
| 29 | + BLOCK_SIZE_M: tl.constexpr, |
| 30 | + BLOCK_SIZE_N: tl.constexpr, |
| 31 | + BLOCK_SIZE_K: tl.constexpr, |
| 32 | +): |
| 33 | + pid = tl.program_id(0) |
| 34 | + num_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 35 | + mb = pid // num_n |
| 36 | + nb = pid % num_n |
| 37 | + offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 38 | + offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 39 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 40 | + mm = offs_m < M |
| 41 | + nm = offs_n < N |
| 42 | + a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak |
| 43 | + b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk |
| 44 | + b_shift = (offs_k[:, None] % 2) * 4 |
| 45 | + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 46 | + for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 47 | + kr = K - ks * BLOCK_SIZE_K |
| 48 | + km = offs_k < kr |
| 49 | + a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0) |
| 50 | + b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) |
| 51 | + b = (b >> b_shift) & 0xF |
| 52 | + gi = (BLOCK_SIZE_K * ks) // group_size |
| 53 | + sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk |
| 54 | + bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32) |
| 55 | + bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16) |
| 56 | + acc += tl.dot(a.to(tl.bfloat16), bd) |
| 57 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 58 | + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk |
| 59 | + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn |
| 60 | + tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) |
| 61 | + |
| 62 | + |
| 63 | +# Strategy 2: vec-mat with tl.sum (no tl.dot, no M padding waste) |
| 64 | +@triton.jit |
| 65 | +def _int4_vecmat_kernel( |
| 66 | + A, |
| 67 | + B, |
| 68 | + C, |
| 69 | + B_scale, |
| 70 | + N: tl.constexpr, |
| 71 | + K: tl.constexpr, |
| 72 | + stride_bn, |
| 73 | + stride_bk, |
| 74 | + stride_bsn, |
| 75 | + stride_bsk, |
| 76 | + group_size: tl.constexpr, |
| 77 | + BLOCK_SIZE_N: tl.constexpr, |
| 78 | + BLOCK_SIZE_K: tl.constexpr, |
| 79 | +): |
| 80 | + nb = tl.program_id(0) |
| 81 | + offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 82 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 83 | + nm = offs_n < N |
| 84 | + b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk |
| 85 | + b_shift = (offs_k[:, None] % 2) * 4 |
| 86 | + a_ptrs = A + offs_k |
| 87 | + acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) |
| 88 | + for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 89 | + kr = K - ks * BLOCK_SIZE_K |
| 90 | + km = offs_k < kr |
| 91 | + a = tl.load(a_ptrs, mask=km, other=0.0).to(tl.float32) # [BK] |
| 92 | + b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) |
| 93 | + b = (b >> b_shift) & 0xF |
| 94 | + gi = (BLOCK_SIZE_K * ks) // group_size |
| 95 | + sp = B_scale + offs_n * stride_bsn + gi * stride_bsk |
| 96 | + bs = tl.load(sp, mask=nm, other=0.0).to(tl.float32) # [BN] |
| 97 | + bd = (b.to(tl.float32) - 8.0) * bs[None, :] # [BK, BN] |
| 98 | + acc += tl.sum(a[:, None] * bd, axis=0) # [BN] |
| 99 | + a_ptrs += BLOCK_SIZE_K |
| 100 | + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk |
| 101 | + c_ptrs = C + offs_n |
| 102 | + tl.store(c_ptrs, acc.to(tl.bfloat16), mask=nm) |
| 103 | + |
| 104 | + |
| 105 | +# Strategy 3: split-K with tl.dot — more CTAs, then atomic reduce |
| 106 | +@triton.jit |
| 107 | +def _int4_splitk_kernel( |
| 108 | + A, |
| 109 | + B, |
| 110 | + C, |
| 111 | + B_scale, |
| 112 | + M, |
| 113 | + N: tl.constexpr, |
| 114 | + K: tl.constexpr, |
| 115 | + stride_am, |
| 116 | + stride_ak, |
| 117 | + stride_bn, |
| 118 | + stride_bk, |
| 119 | + stride_cm, |
| 120 | + stride_cn, |
| 121 | + stride_bsn, |
| 122 | + stride_bsk, |
| 123 | + group_size: tl.constexpr, |
| 124 | + K_SPLITS: tl.constexpr, |
| 125 | + BLOCK_SIZE_M: tl.constexpr, |
| 126 | + BLOCK_SIZE_N: tl.constexpr, |
| 127 | + BLOCK_SIZE_K: tl.constexpr, |
| 128 | +): |
| 129 | + pid = tl.program_id(0) |
| 130 | + num_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 131 | + num_nk = num_n * K_SPLITS |
| 132 | + mb = pid // num_nk |
| 133 | + nk = pid % num_nk |
| 134 | + nb = nk // K_SPLITS |
| 135 | + ks_id = nk % K_SPLITS |
| 136 | + |
| 137 | + offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 138 | + offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 139 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 140 | + mm = offs_m < M |
| 141 | + nm = offs_n < N |
| 142 | + |
| 143 | + k_per_split = tl.cdiv(K, K_SPLITS) |
| 144 | + k_start = ks_id * k_per_split |
| 145 | + k_end = tl.minimum(k_start + k_per_split, K) |
| 146 | + |
| 147 | + a_ptrs = A + offs_m[:, None] * stride_am + (k_start + offs_k[None, :]) * stride_ak |
| 148 | + b_ptrs = ( |
| 149 | + B + offs_n[None, :] * stride_bn + ((k_start + offs_k[:, None]) // 2) * stride_bk |
| 150 | + ) |
| 151 | + b_shift = ((k_start + offs_k[:, None]) % 2) * 4 |
| 152 | + |
| 153 | + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 154 | + num_steps = tl.cdiv(k_end - k_start, BLOCK_SIZE_K) |
| 155 | + for step in range(0, num_steps): |
| 156 | + abs_k = k_start + step * BLOCK_SIZE_K + offs_k |
| 157 | + km = abs_k < k_end |
| 158 | + a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0) |
| 159 | + b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) |
| 160 | + b = (b >> b_shift) & 0xF |
| 161 | + gi = (k_start + step * BLOCK_SIZE_K) // group_size |
| 162 | + sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk |
| 163 | + bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32) |
| 164 | + bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16) |
| 165 | + acc += tl.dot(a.to(tl.bfloat16), bd) |
| 166 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 167 | + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk |
| 168 | + b_shift = (offs_k[:, None] % 2) * 4 # reset shift after first step |
| 169 | + |
| 170 | + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn |
| 171 | + if K_SPLITS == 1: |
| 172 | + tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) |
| 173 | + else: |
| 174 | + tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) |
| 175 | + |
| 176 | + |
| 177 | +def main(): |
| 178 | + import torch.nn as nn |
| 179 | + from executorch.extension.llm.export.quantize import quantize_model_ |
| 180 | + from torchao.quantization.quant_primitives import ( |
| 181 | + choose_qparams_affine, |
| 182 | + MappingType, |
| 183 | + quantize_affine, |
| 184 | + ) |
| 185 | + |
| 186 | + gs = 128 |
| 187 | + shapes = [ |
| 188 | + (2048, 2048, "q/o_proj"), |
| 189 | + (12352, 2048, "shared_g+u"), |
| 190 | + (256, 2048, "k/v_proj"), |
| 191 | + ] |
| 192 | + |
| 193 | + for N, K, label in shapes: |
| 194 | + w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") |
| 195 | + sc, zp = choose_qparams_affine( |
| 196 | + w.float(), |
| 197 | + MappingType.SYMMETRIC, |
| 198 | + (1, gs), |
| 199 | + target_dtype=torch.int8, |
| 200 | + quant_min=-8, |
| 201 | + quant_max=7, |
| 202 | + ) |
| 203 | + idata = quantize_affine( |
| 204 | + w.float(), |
| 205 | + (1, gs), |
| 206 | + sc, |
| 207 | + zp, |
| 208 | + output_dtype=torch.int8, |
| 209 | + quant_min=-8, |
| 210 | + quant_max=7, |
| 211 | + ) |
| 212 | + u4 = (idata + 8).to(torch.int16) |
| 213 | + packed = (u4[:, 0::2] | (u4[:, 1::2] << 4)).to(torch.int8).cuda() |
| 214 | + w_scale = sc.reshape(N, -1).to(torch.bfloat16).cuda() |
| 215 | + |
| 216 | + linear = nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") |
| 217 | + wr = nn.ModuleDict({"linear": linear}) |
| 218 | + quantize_model_( |
| 219 | + wr, |
| 220 | + qlinear_config="4w", |
| 221 | + qlinear_group_size=gs, |
| 222 | + qlinear_packing_format="tile_packed_to_4d", |
| 223 | + ) |
| 224 | + tw = wr.linear.weight |
| 225 | + |
| 226 | + x = torch.randn(1, K, dtype=torch.bfloat16, device="cuda") |
| 227 | + t_tiny = ( |
| 228 | + do_bench( |
| 229 | + lambda: nn.functional.linear(x, tw), |
| 230 | + warmup=50, |
| 231 | + rep=200, |
| 232 | + return_mode="median", |
| 233 | + ) |
| 234 | + * 1000 |
| 235 | + ) |
| 236 | + |
| 237 | + print(f"\n{'='*70}") |
| 238 | + print(f"[{N}x{K}] {label} — M=1, tinygemm={t_tiny:.1f}us") |
| 239 | + print(f"{'='*70}") |
| 240 | + |
| 241 | + # Strategy 1: tl.dot with various configs |
| 242 | + print("\n--- Strategy 1: tl.dot (BLOCK_M=16 padding) ---") |
| 243 | + out = torch.empty(1, N, dtype=torch.bfloat16, device="cuda") |
| 244 | + for BN, BK, warps, stages in [ |
| 245 | + (16, 128, 4, 5), |
| 246 | + (32, 128, 4, 5), |
| 247 | + (32, 256, 4, 3), |
| 248 | + (16, 128, 2, 5), |
| 249 | + (32, 128, 2, 5), |
| 250 | + ]: |
| 251 | + grid = ((N + BN - 1) // BN,) |
| 252 | + |
| 253 | + def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid): |
| 254 | + _int4_dot_kernel[_g]( |
| 255 | + x, |
| 256 | + packed, |
| 257 | + out, |
| 258 | + w_scale, |
| 259 | + 1, |
| 260 | + N, |
| 261 | + K, |
| 262 | + x.stride(0), |
| 263 | + x.stride(1), |
| 264 | + packed.stride(0), |
| 265 | + packed.stride(1), |
| 266 | + out.stride(0), |
| 267 | + out.stride(1), |
| 268 | + w_scale.stride(0), |
| 269 | + w_scale.stride(1), |
| 270 | + gs, |
| 271 | + BLOCK_SIZE_M=16, |
| 272 | + BLOCK_SIZE_N=_BN, |
| 273 | + BLOCK_SIZE_K=_BK, |
| 274 | + num_warps=_w, |
| 275 | + num_stages=_s, |
| 276 | + ) |
| 277 | + |
| 278 | + try: |
| 279 | + run() |
| 280 | + t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 |
| 281 | + print( |
| 282 | + f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" |
| 283 | + ) |
| 284 | + except Exception as e: |
| 285 | + print( |
| 286 | + f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}" |
| 287 | + ) |
| 288 | + |
| 289 | + # Strategy 2: vec-mat with tl.sum (no padding waste) |
| 290 | + print("\n--- Strategy 2: vec-mat tl.sum (no M padding) ---") |
| 291 | + for BN, BK, warps, stages in [ |
| 292 | + (16, 128, 4, 5), |
| 293 | + (32, 128, 4, 5), |
| 294 | + (64, 128, 4, 5), |
| 295 | + (16, 256, 4, 3), |
| 296 | + (32, 256, 4, 3), |
| 297 | + (16, 128, 2, 5), |
| 298 | + (32, 128, 2, 5), |
| 299 | + (16, 64, 2, 5), |
| 300 | + (32, 64, 2, 5), |
| 301 | + ]: |
| 302 | + grid = ((N + BN - 1) // BN,) |
| 303 | + out1d = torch.empty(N, dtype=torch.bfloat16, device="cuda") |
| 304 | + |
| 305 | + def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid): |
| 306 | + _int4_vecmat_kernel[_g]( |
| 307 | + x, |
| 308 | + packed, |
| 309 | + out1d, |
| 310 | + w_scale, |
| 311 | + N, |
| 312 | + K, |
| 313 | + packed.stride(0), |
| 314 | + packed.stride(1), |
| 315 | + w_scale.stride(0), |
| 316 | + w_scale.stride(1), |
| 317 | + gs, |
| 318 | + BLOCK_SIZE_N=_BN, |
| 319 | + BLOCK_SIZE_K=_BK, |
| 320 | + num_warps=_w, |
| 321 | + num_stages=_s, |
| 322 | + ) |
| 323 | + |
| 324 | + try: |
| 325 | + run() |
| 326 | + t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 |
| 327 | + print( |
| 328 | + f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" |
| 329 | + ) |
| 330 | + except Exception as e: |
| 331 | + print( |
| 332 | + f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}" |
| 333 | + ) |
| 334 | + |
| 335 | + # Strategy 3: split-K with tl.dot |
| 336 | + print("\n--- Strategy 3: split-K tl.dot ---") |
| 337 | + for BN, BK, splits, warps, stages in [ |
| 338 | + (32, 128, 4, 4, 3), |
| 339 | + (32, 128, 8, 4, 3), |
| 340 | + (32, 128, 16, 4, 3), |
| 341 | + (16, 128, 4, 4, 3), |
| 342 | + (16, 128, 8, 4, 3), |
| 343 | + (16, 128, 16, 4, 3), |
| 344 | + (64, 128, 4, 4, 3), |
| 345 | + (64, 128, 8, 4, 3), |
| 346 | + ]: |
| 347 | + grid = (((N + BN - 1) // BN) * splits,) |
| 348 | + out_sk = torch.zeros(1, N, dtype=torch.bfloat16, device="cuda") |
| 349 | + |
| 350 | + def run(_BN=BN, _BK=BK, _sp=splits, _w=warps, _s=stages, _g=grid): |
| 351 | + out_sk.zero_() |
| 352 | + _int4_splitk_kernel[_g]( |
| 353 | + x, |
| 354 | + packed, |
| 355 | + out_sk, |
| 356 | + w_scale, |
| 357 | + 1, |
| 358 | + N, |
| 359 | + K, |
| 360 | + x.stride(0), |
| 361 | + x.stride(1), |
| 362 | + packed.stride(0), |
| 363 | + packed.stride(1), |
| 364 | + out_sk.stride(0), |
| 365 | + out_sk.stride(1), |
| 366 | + w_scale.stride(0), |
| 367 | + w_scale.stride(1), |
| 368 | + gs, |
| 369 | + K_SPLITS=_sp, |
| 370 | + BLOCK_SIZE_M=16, |
| 371 | + BLOCK_SIZE_N=_BN, |
| 372 | + BLOCK_SIZE_K=_BK, |
| 373 | + num_warps=_w, |
| 374 | + num_stages=_s, |
| 375 | + ) |
| 376 | + |
| 377 | + try: |
| 378 | + run() |
| 379 | + t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 |
| 380 | + print( |
| 381 | + f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" |
| 382 | + ) |
| 383 | + except Exception as e: |
| 384 | + print( |
| 385 | + f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: FAIL {str(e)[:50]}" |
| 386 | + ) |
| 387 | + |
| 388 | + del wr, tw, packed, w_scale |
| 389 | + torch.cuda.empty_cache() |
| 390 | + |
| 391 | + |
| 392 | +if __name__ == "__main__": |
| 393 | + main() |
0 commit comments