|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import time |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +import tilelang |
| 8 | +import tilelang.language as T |
| 9 | + |
| 10 | +logging.getLogger("tilelang").setLevel(logging.WARNING) |
| 11 | + |
| 12 | +BLOCK_CONFIGS = [ |
| 13 | + (16, 16, 16), |
| 14 | + (32, 32, 16), |
| 15 | + (32, 32, 32), |
| 16 | + (64, 64, 32), |
| 17 | +] |
| 18 | + |
| 19 | + |
| 20 | +@tilelang.jit |
| 21 | +def matmul_simdgroup(M, N, K, block_M=64, block_N=64, block_K=32, dtype=T.float16, accum_dtype=T.float32): |
| 22 | + |
| 23 | + @T.prim_func |
| 24 | + def gemm_kernel( |
| 25 | + A: T.Tensor((M, K), dtype), |
| 26 | + B: T.Tensor((K, N), dtype), |
| 27 | + C: T.Tensor((M, N), accum_dtype), |
| 28 | + ): |
| 29 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): |
| 30 | + A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") |
| 31 | + B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") |
| 32 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 33 | + T.clear(C_local) |
| 34 | + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): |
| 35 | + T.copy(A[by * block_M, ko * block_K], A_shared) |
| 36 | + T.copy(B[ko * block_K, bx * block_N], B_shared) |
| 37 | + T.gemm(A_shared, B_shared, C_local) |
| 38 | + T.copy(C_local, C[by * block_M, bx * block_N]) |
| 39 | + |
| 40 | + return gemm_kernel |
| 41 | + |
| 42 | + |
| 43 | +def _tflops(M, N, K, seconds): |
| 44 | + return 2.0 * M * N * K / seconds / 1e12 |
| 45 | + |
| 46 | + |
| 47 | +def _bench(fn, warmup, repeats): |
| 48 | + for _ in range(warmup): |
| 49 | + fn() |
| 50 | + torch.mps.synchronize() |
| 51 | + t0 = time.perf_counter() |
| 52 | + for _ in range(repeats): |
| 53 | + fn() |
| 54 | + torch.mps.synchronize() |
| 55 | + return (time.perf_counter() - t0) / repeats |
| 56 | + |
| 57 | + |
| 58 | +def bench_torch_mps(M, N, K, warmup, repeats): |
| 59 | + a = torch.randn(M, K, dtype=torch.float16, device="mps") |
| 60 | + b = torch.randn(K, N, dtype=torch.float16, device="mps") |
| 61 | + avg_s = _bench(lambda: torch.mm(a, b), warmup, repeats) |
| 62 | + return _tflops(M, N, K, avg_s) |
| 63 | + |
| 64 | + |
| 65 | +def bench_tilelang(M, N, K, block_M, block_N, block_K, warmup, repeats): |
| 66 | + kernel = matmul_simdgroup(M, N, K, block_M, block_N, block_K) |
| 67 | + a = torch.randn(M, K, dtype=torch.float16, device="mps") |
| 68 | + b = torch.randn(K, N, dtype=torch.float16, device="mps") |
| 69 | + c = torch.zeros(M, N, dtype=torch.float32, device="mps") |
| 70 | + avg_s = _bench(lambda: kernel(a, b, c), warmup, repeats) |
| 71 | + return _tflops(M, N, K, avg_s) |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + parser = argparse.ArgumentParser(description="Metal GEMM Benchmark (simdgroup)") |
| 76 | + parser.add_argument("--m", type=int, default=4096) |
| 77 | + parser.add_argument("--n", type=int, default=4096) |
| 78 | + parser.add_argument("--k", type=int, default=4096) |
| 79 | + parser.add_argument("--warmup", type=int, default=10) |
| 80 | + parser.add_argument("--repeats", type=int, default=100) |
| 81 | + parser.add_argument("--sweep", action="store_true", help="Sweep all block configs instead of using default (64,64,32)") |
| 82 | + args = parser.parse_args() |
| 83 | + |
| 84 | + M, N, K = args.m, args.n, args.k |
| 85 | + |
| 86 | + print(f"torch: {torch.__version__}") |
| 87 | + print(f"tilelang: {tilelang.__version__}") |
| 88 | + print(f"MPS: {torch.backends.mps.is_available()}") |
| 89 | + print(f"M={M}, N={N}, K={K}, warmup={args.warmup}, repeats={args.repeats}") |
| 90 | + print() |
| 91 | + |
| 92 | + ref_tflops = bench_torch_mps(M, N, K, args.warmup, args.repeats) |
| 93 | + print(f"PyTorch MPS (torch.mm fp16): {ref_tflops:.1f} TFLOPS") |
| 94 | + print() |
| 95 | + |
| 96 | + configs = BLOCK_CONFIGS if args.sweep else [(64, 64, 32)] |
| 97 | + |
| 98 | + print(f"{'block (M,N,K)':>16s} | {'TileLang':>14s} | {'Ratio':>6s}") |
| 99 | + print("-" * 44) |
| 100 | + |
| 101 | + best_tflops = 0.0 |
| 102 | + best_config = configs[0] |
| 103 | + for bM, bN, bK in configs: |
| 104 | + try: |
| 105 | + tl = bench_tilelang(M, N, K, bM, bN, bK, args.warmup, args.repeats) |
| 106 | + ratio = tl / ref_tflops * 100 |
| 107 | + tag = "" |
| 108 | + if tl > best_tflops: |
| 109 | + best_tflops = tl |
| 110 | + best_config = (bM, bN, bK) |
| 111 | + print(f"{f'({bM},{bN},{bK})':>16s} | {tl:>10.1f} TFLOPS | {ratio:>5.0f}%") |
| 112 | + except Exception as e: |
| 113 | + print(f"{f'({bM},{bN},{bK})':>16s} | {'FAILED':>14s} | {e}") |
| 114 | + |
| 115 | + if args.sweep: |
| 116 | + print() |
| 117 | + print(f"Best config: {best_config}") |
| 118 | + print(f"Best TFlops: {best_tflops:.1f}") |
| 119 | + print(f"Reference TFlops (PyTorch MPS): {ref_tflops:.1f}") |
0 commit comments