|
| 1 | +"""Benchmark for kbit GEMM kernel. |
| 2 | +
|
| 3 | +Measures throughput (TFLOPS) and effective memory bandwidth (GB/s) for: |
| 4 | +1. kbit_gemm_prod (production kernel, fp16 and bf16) |
| 5 | +2. cuBLAS fp16 GEMM (baseline) |
| 6 | +3. Standalone dequant + cuBLAS (simulated fused baseline) |
| 7 | +""" |
| 8 | + |
| 9 | +import argparse |
| 10 | +import sys |
| 11 | +import time |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +# Ensure bitsandbytes is importable from the worktree |
| 16 | +sys.path.insert(0, ".") |
| 17 | +import bitsandbytes # noqa: E402 |
| 18 | +from bitsandbytes import _ops # noqa: E402, F401 |
| 19 | +from scipy.stats import norm # noqa: E402 |
| 20 | + |
| 21 | +BLOCKSIZE = 32 |
| 22 | + |
| 23 | + |
| 24 | +def create_normal_float_codebook(k: int) -> torch.Tensor: |
| 25 | + n_levels = 1 << k |
| 26 | + quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels) |
| 27 | + values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32) |
| 28 | + values = values / values.abs().max() |
| 29 | + return values |
| 30 | + |
| 31 | + |
| 32 | +def quantize_kbit_ref(A, codebook, blocksize=BLOCKSIZE): |
| 33 | + A_flat = A.float().reshape(-1) |
| 34 | + n = A_flat.numel() |
| 35 | + pad = (blocksize - n % blocksize) % blocksize |
| 36 | + if pad > 0: |
| 37 | + A_flat = torch.nn.functional.pad(A_flat, (0, pad)) |
| 38 | + n_padded = A_flat.numel() |
| 39 | + num_blocks = n_padded // blocksize |
| 40 | + blocks = A_flat.reshape(num_blocks, blocksize) |
| 41 | + absmax = blocks.abs().max(dim=1).values |
| 42 | + absmax_safe = absmax.clamp(min=1e-8) |
| 43 | + normalized = blocks / absmax_safe.unsqueeze(1) |
| 44 | + cb = codebook.float().unsqueeze(0).unsqueeze(0) |
| 45 | + norm_exp = normalized.unsqueeze(2) |
| 46 | + distances = (norm_exp - cb).abs() |
| 47 | + indices = distances.argmin(dim=2).to(torch.uint8) |
| 48 | + indices = indices.reshape(-1)[:n] |
| 49 | + return indices, absmax |
| 50 | + |
| 51 | + |
| 52 | +def pack_kbit_ref(indices, k, blocksize=BLOCKSIZE): |
| 53 | + n = indices.numel() |
| 54 | + pad = (blocksize - n % blocksize) % blocksize |
| 55 | + if pad > 0: |
| 56 | + indices = torch.nn.functional.pad(indices.int(), (0, pad)) |
| 57 | + n_padded = indices.numel() |
| 58 | + num_blocks = n_padded // blocksize |
| 59 | + blocks = indices.int().reshape(num_blocks, blocksize) |
| 60 | + packed_words = [] |
| 61 | + for b in range(num_blocks): |
| 62 | + for bit in range(k): |
| 63 | + word = 0 |
| 64 | + for i in range(blocksize): |
| 65 | + word |= ((int(blocks[b, i]) >> bit) & 1) << i |
| 66 | + if word >= (1 << 31): |
| 67 | + word -= 1 << 32 |
| 68 | + packed_words.append(word) |
| 69 | + return torch.tensor(packed_words, dtype=torch.int32) |
| 70 | + |
| 71 | + |
| 72 | +def prepare_weights(K_dim, N, k): |
| 73 | + """Quantize and repack random weights using CUDA kernels. Returns (packed_tiled, absmax_tiled, codebook).""" |
| 74 | + codebook = create_normal_float_codebook(k) |
| 75 | + W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda") |
| 76 | + # Use CUDA quantize kernel (fast) |
| 77 | + packed_flat, absmax = torch.ops.bitsandbytes.quantize_kbit(W.reshape(-1), codebook.cuda(), k) |
| 78 | + packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit( |
| 79 | + packed_flat, absmax.cuda(), K_dim, N, k |
| 80 | + ) |
| 81 | + return packed_tiled, absmax_tiled, codebook.cuda(), W |
| 82 | + |
| 83 | + |
| 84 | +def bench_kbit_gemm(M, K_dim, N, k, k_chunks, dtype, packed_tiled, absmax_tiled, codebook, |
| 85 | + warmup=10, iters=100): |
| 86 | + """Benchmark the production kbit GEMM kernel.""" |
| 87 | + A = torch.randn(M, K_dim, dtype=dtype, device="cuda") |
| 88 | + |
| 89 | + # Warmup |
| 90 | + for _ in range(warmup): |
| 91 | + torch.ops.bitsandbytes.kbit_gemm_prod(A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, k_chunks) |
| 92 | + torch.cuda.synchronize() |
| 93 | + |
| 94 | + start = time.perf_counter() |
| 95 | + for _ in range(iters): |
| 96 | + torch.ops.bitsandbytes.kbit_gemm_prod(A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, k_chunks) |
| 97 | + torch.cuda.synchronize() |
| 98 | + elapsed = time.perf_counter() - start |
| 99 | + |
| 100 | + return elapsed / iters |
| 101 | + |
| 102 | + |
| 103 | +def bench_cublas(M, K_dim, N, dtype, W_fp16, warmup=10, iters=100): |
| 104 | + """Benchmark cuBLAS fp16 GEMM as baseline.""" |
| 105 | + A = torch.randn(M, K_dim, dtype=dtype, device="cuda") |
| 106 | + W = W_fp16.to(dtype).cuda() |
| 107 | + |
| 108 | + # Warmup |
| 109 | + for _ in range(warmup): |
| 110 | + torch.mm(A, W.T) |
| 111 | + torch.cuda.synchronize() |
| 112 | + |
| 113 | + start = time.perf_counter() |
| 114 | + for _ in range(iters): |
| 115 | + torch.mm(A, W.T) |
| 116 | + torch.cuda.synchronize() |
| 117 | + elapsed = time.perf_counter() - start |
| 118 | + |
| 119 | + return elapsed / iters |
| 120 | + |
| 121 | + |
| 122 | +def main(): |
| 123 | + parser = argparse.ArgumentParser(description="Benchmark kbit GEMM kernel") |
| 124 | + parser.add_argument("--k", type=int, default=4, help="Bit width (2-5)") |
| 125 | + parser.add_argument("--dtype", choices=["fp16", "bf16"], default="fp16") |
| 126 | + parser.add_argument("--warmup", type=int, default=20) |
| 127 | + parser.add_argument("--iters", type=int, default=200) |
| 128 | + parser.add_argument("--k-chunks", type=int, default=1, help="Split-K chunks") |
| 129 | + args = parser.parse_args() |
| 130 | + |
| 131 | + dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 |
| 132 | + k = args.k |
| 133 | + |
| 134 | + # LLM-typical shapes |
| 135 | + configs = [ |
| 136 | + # (M, K_dim, N) |
| 137 | + (1, 4096, 4096), |
| 138 | + (1, 4096, 11008), |
| 139 | + (4, 4096, 4096), |
| 140 | + (4, 4096, 11008), |
| 141 | + (8, 4096, 4096), |
| 142 | + (16, 4096, 4096), |
| 143 | + (32, 4096, 4096), |
| 144 | + (64, 4096, 4096), |
| 145 | + (128, 4096, 4096), |
| 146 | + ] |
| 147 | + |
| 148 | + print(f"kbit GEMM Benchmark: K={k}, dtype={args.dtype}, k_chunks={args.k_chunks}") |
| 149 | + print(f"Warmup={args.warmup}, Iters={args.iters}") |
| 150 | + print() |
| 151 | + print(f"{'M':>5} {'K_dim':>6} {'N':>6} | {'kbit (us)':>10} {'kbit TFLOPS':>12} {'kbit GB/s':>10} | " |
| 152 | + f"{'cuBLAS (us)':>12} {'cuBLAS TFLOPS':>14} | {'Speedup':>8}") |
| 153 | + print("-" * 115) |
| 154 | + |
| 155 | + for M, K_dim, N in configs: |
| 156 | + # Pad N to multiple of 128 if needed |
| 157 | + N_padded = ((N + 127) // 128) * 128 |
| 158 | + |
| 159 | + # Prepare weights |
| 160 | + packed_tiled, absmax_tiled, codebook, W = prepare_weights(K_dim, N_padded, k) |
| 161 | + |
| 162 | + # Benchmark kbit GEMM |
| 163 | + t_kbit = bench_kbit_gemm(M, K_dim, N_padded, k, args.k_chunks, dtype, |
| 164 | + packed_tiled, absmax_tiled, codebook, |
| 165 | + warmup=args.warmup, iters=args.iters) |
| 166 | + |
| 167 | + # Benchmark cuBLAS |
| 168 | + t_cublas = bench_cublas(M, K_dim, N_padded, dtype, W.half(), |
| 169 | + warmup=args.warmup, iters=args.iters) |
| 170 | + |
| 171 | + # Compute metrics |
| 172 | + flops = 2 * M * K_dim * N_padded |
| 173 | + tflops_kbit = flops / t_kbit / 1e12 |
| 174 | + tflops_cublas = flops / t_cublas / 1e12 |
| 175 | + |
| 176 | + # Effective bandwidth for kbit: A (fp16) + B (compressed) + C (fp16) |
| 177 | + a_bytes = M * K_dim * 2 |
| 178 | + b_bytes = N_padded * K_dim * k / 8 + N_padded * (K_dim // 32) # packed + absmax |
| 179 | + c_bytes = M * N_padded * 2 |
| 180 | + total_bytes = a_bytes + b_bytes + c_bytes |
| 181 | + gbps_kbit = total_bytes / t_kbit / 1e9 |
| 182 | + |
| 183 | + speedup = t_cublas / t_kbit |
| 184 | + |
| 185 | + print(f"{M:5d} {K_dim:6d} {N_padded:6d} | {t_kbit*1e6:10.1f} {tflops_kbit:12.3f} {gbps_kbit:10.1f} | " |
| 186 | + f"{t_cublas*1e6:12.1f} {tflops_cublas:14.3f} | {speedup:8.2f}x") |
| 187 | + |
| 188 | + print() |
| 189 | + |
| 190 | + |
| 191 | +if __name__ == "__main__": |
| 192 | + main() |
0 commit comments