diff --git a/backends/cuda/benchmarks/benchmark_int4_matmul.py b/backends/cuda/benchmarks/benchmark_int4_matmul.py deleted file mode 100644 index 875e9a3676e..00000000000 --- a/backends/cuda/benchmarks/benchmark_int4_matmul.py +++ /dev/null @@ -1,393 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark INT4 matmul strategies for M=1 decode.""" - -import torch -import triton -import triton.language as tl -from triton.testing import do_bench - - -# Strategy 1: tl.dot with BLOCK_M=16 padding (current approach) -@triton.jit -def _int4_dot_kernel( - A, - B, - C, - B_scale, - M, - N: tl.constexpr, - K: tl.constexpr, - stride_am, - stride_ak, - stride_bn, - stride_bk, - stride_cm, - stride_cn, - stride_bsn, - stride_bsk, - group_size: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid = tl.program_id(0) - num_n = tl.cdiv(N, BLOCK_SIZE_N) - mb = pid // num_n - nb = pid % num_n - offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - mm = offs_m < M - nm = offs_n < N - a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk - b_shift = (offs_k[:, None] % 2) * 4 - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - kr = K - ks * BLOCK_SIZE_K - km = offs_k < kr - a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0) - b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) - b = (b >> b_shift) & 0xF - gi = (BLOCK_SIZE_K * ks) // group_size - sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk - bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32) - bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16) - acc += tl.dot(a.to(tl.bfloat16), bd) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn - tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) - - -# Strategy 2: vec-mat with tl.sum (no tl.dot, no M padding waste) -@triton.jit -def _int4_vecmat_kernel( - A, - B, - C, - B_scale, - N: tl.constexpr, - K: tl.constexpr, - stride_bn, - stride_bk, - stride_bsn, - stride_bsk, - group_size: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - nb = tl.program_id(0) - offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - nm = offs_n < N - b_ptrs = B + offs_n[None, :] * stride_bn + (offs_k[:, None] // 2) * stride_bk - b_shift = (offs_k[:, None] % 2) * 4 - a_ptrs = A + offs_k - acc = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) - for ks in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - kr = K - ks * BLOCK_SIZE_K - km = offs_k < kr - a = tl.load(a_ptrs, mask=km, other=0.0).to(tl.float32) # [BK] - b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) - b = (b >> b_shift) & 0xF - gi = (BLOCK_SIZE_K * ks) // group_size - sp = B_scale + offs_n * stride_bsn + gi * stride_bsk - bs = tl.load(sp, mask=nm, other=0.0).to(tl.float32) # [BN] - bd = (b.to(tl.float32) - 8.0) * bs[None, :] # [BK, BN] - acc += tl.sum(a[:, None] * bd, axis=0) # [BN] - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - c_ptrs = C + offs_n - tl.store(c_ptrs, acc.to(tl.bfloat16), mask=nm) - - -# Strategy 3: split-K with tl.dot — more CTAs, then atomic reduce -@triton.jit -def _int4_splitk_kernel( - A, - B, - C, - B_scale, - M, - N: tl.constexpr, - K: tl.constexpr, - stride_am, - stride_ak, - stride_bn, - stride_bk, - stride_cm, - stride_cn, - stride_bsn, - stride_bsk, - group_size: tl.constexpr, - K_SPLITS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid = tl.program_id(0) - num_n = tl.cdiv(N, BLOCK_SIZE_N) - num_nk = num_n * K_SPLITS - mb = pid // num_nk - nk = pid % num_nk - nb = nk // K_SPLITS - ks_id = nk % K_SPLITS - - offs_m = mb * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = nb * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - mm = offs_m < M - nm = offs_n < N - - k_per_split = tl.cdiv(K, K_SPLITS) - k_start = ks_id * k_per_split - k_end = tl.minimum(k_start + k_per_split, K) - - a_ptrs = A + offs_m[:, None] * stride_am + (k_start + offs_k[None, :]) * stride_ak - b_ptrs = ( - B + offs_n[None, :] * stride_bn + ((k_start + offs_k[:, None]) // 2) * stride_bk - ) - b_shift = ((k_start + offs_k[:, None]) % 2) * 4 - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - num_steps = tl.cdiv(k_end - k_start, BLOCK_SIZE_K) - for step in range(0, num_steps): - abs_k = k_start + step * BLOCK_SIZE_K + offs_k - km = abs_k < k_end - a = tl.load(a_ptrs, mask=mm[:, None] & km[None, :], other=0.0) - b = tl.load(b_ptrs, mask=km[:, None] & nm[None, :], other=0) - b = (b >> b_shift) & 0xF - gi = (k_start + step * BLOCK_SIZE_K) // group_size - sp = B_scale + offs_n[None, :] * stride_bsn + gi * stride_bsk - bs = tl.load(sp, mask=nm[None, :], other=0.0).to(tl.float32) - bd = ((b.to(tl.float32) - 8.0) * bs).to(tl.bfloat16) - acc += tl.dot(a.to(tl.bfloat16), bd) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - b_shift = (offs_k[:, None] % 2) * 4 # reset shift after first step - - c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn - if K_SPLITS == 1: - tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) - else: - tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mm[:, None] & nm[None, :]) - - -def main(): - import torch.nn as nn - from executorch.extension.llm.export.quantize import quantize_model_ - from torchao.quantization.quant_primitives import ( - choose_qparams_affine, - MappingType, - quantize_affine, - ) - - gs = 128 - shapes = [ - (2048, 2048, "q/o_proj"), - (12352, 2048, "shared_g+u"), - (256, 2048, "k/v_proj"), - ] - - for N, K, label in shapes: - w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") - sc, zp = choose_qparams_affine( - w.float(), - MappingType.SYMMETRIC, - (1, gs), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - ) - idata = quantize_affine( - w.float(), - (1, gs), - sc, - zp, - output_dtype=torch.int8, - quant_min=-8, - quant_max=7, - ) - u4 = (idata + 8).to(torch.int16) - packed = (u4[:, 0::2] | (u4[:, 1::2] << 4)).to(torch.int8).cuda() - w_scale = sc.reshape(N, -1).to(torch.bfloat16).cuda() - - linear = nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") - wr = nn.ModuleDict({"linear": linear}) - quantize_model_( - wr, - qlinear_config="4w", - qlinear_group_size=gs, - qlinear_packing_format="tile_packed_to_4d", - ) - tw = wr.linear.weight - - x = torch.randn(1, K, dtype=torch.bfloat16, device="cuda") - t_tiny = ( - do_bench( - lambda: nn.functional.linear(x, tw), - warmup=50, - rep=200, - return_mode="median", - ) - * 1000 - ) - - print(f"\n{'='*70}") - print(f"[{N}x{K}] {label} — M=1, tinygemm={t_tiny:.1f}us") - print(f"{'='*70}") - - # Strategy 1: tl.dot with various configs - print("\n--- Strategy 1: tl.dot (BLOCK_M=16 padding) ---") - out = torch.empty(1, N, dtype=torch.bfloat16, device="cuda") - for BN, BK, warps, stages in [ - (16, 128, 4, 5), - (32, 128, 4, 5), - (32, 256, 4, 3), - (16, 128, 2, 5), - (32, 128, 2, 5), - ]: - grid = ((N + BN - 1) // BN,) - - def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid): - _int4_dot_kernel[_g]( - x, - packed, - out, - w_scale, - 1, - N, - K, - x.stride(0), - x.stride(1), - packed.stride(0), - packed.stride(1), - out.stride(0), - out.stride(1), - w_scale.stride(0), - w_scale.stride(1), - gs, - BLOCK_SIZE_M=16, - BLOCK_SIZE_N=_BN, - BLOCK_SIZE_K=_BK, - num_warps=_w, - num_stages=_s, - ) - - try: - run() - t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 - print( - f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" - ) - except Exception as e: - print( - f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}" - ) - - # Strategy 2: vec-mat with tl.sum (no padding waste) - print("\n--- Strategy 2: vec-mat tl.sum (no M padding) ---") - for BN, BK, warps, stages in [ - (16, 128, 4, 5), - (32, 128, 4, 5), - (64, 128, 4, 5), - (16, 256, 4, 3), - (32, 256, 4, 3), - (16, 128, 2, 5), - (32, 128, 2, 5), - (16, 64, 2, 5), - (32, 64, 2, 5), - ]: - grid = ((N + BN - 1) // BN,) - out1d = torch.empty(N, dtype=torch.bfloat16, device="cuda") - - def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid): - _int4_vecmat_kernel[_g]( - x, - packed, - out1d, - w_scale, - N, - K, - packed.stride(0), - packed.stride(1), - w_scale.stride(0), - w_scale.stride(1), - gs, - BLOCK_SIZE_N=_BN, - BLOCK_SIZE_K=_BK, - num_warps=_w, - num_stages=_s, - ) - - try: - run() - t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 - print( - f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" - ) - except Exception as e: - print( - f" BN={BN:3d} BK={BK:3d} w={warps} s={stages}: FAIL {str(e)[:50]}" - ) - - # Strategy 3: split-K with tl.dot - print("\n--- Strategy 3: split-K tl.dot ---") - for BN, BK, splits, warps, stages in [ - (32, 128, 4, 4, 3), - (32, 128, 8, 4, 3), - (32, 128, 16, 4, 3), - (16, 128, 4, 4, 3), - (16, 128, 8, 4, 3), - (16, 128, 16, 4, 3), - (64, 128, 4, 4, 3), - (64, 128, 8, 4, 3), - ]: - grid = (((N + BN - 1) // BN) * splits,) - out_sk = torch.zeros(1, N, dtype=torch.bfloat16, device="cuda") - - def run(_BN=BN, _BK=BK, _sp=splits, _w=warps, _s=stages, _g=grid): - out_sk.zero_() - _int4_splitk_kernel[_g]( - x, - packed, - out_sk, - w_scale, - 1, - N, - K, - x.stride(0), - x.stride(1), - packed.stride(0), - packed.stride(1), - out_sk.stride(0), - out_sk.stride(1), - w_scale.stride(0), - w_scale.stride(1), - gs, - K_SPLITS=_sp, - BLOCK_SIZE_M=16, - BLOCK_SIZE_N=_BN, - BLOCK_SIZE_K=_BK, - num_warps=_w, - num_stages=_s, - ) - - try: - run() - t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 - print( - f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]}" - ) - except Exception as e: - print( - f" BN={BN:3d} BK={BK:3d} sp={splits:2d} w={warps} s={stages}: FAIL {str(e)[:50]}" - ) - - del wr, tw, packed, w_scale - torch.cuda.empty_cache() - - -if __name__ == "__main__": - main() diff --git a/backends/cuda/benchmarks/benchmark_matvec.py b/backends/cuda/benchmarks/benchmark_matvec.py deleted file mode 100644 index 86a94eb2a97..00000000000 --- a/backends/cuda/benchmarks/benchmark_matvec.py +++ /dev/null @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark dedicated INT4 matvec kernels for M=1 decode.""" - -import torch -import triton -import triton.language as tl -from triton.testing import do_bench - - -@triton.jit -def _int4_matvec_v1( - X, - W, - Out, - W_scale, - N: tl.constexpr, - K: tl.constexpr, - stride_wn, - stride_wk, - stride_sn, - stride_sk, - group_size: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - """V1: Each CTA computes BLOCK_N outputs, loops over K.""" - nb = tl.program_id(0) - offs_n = nb * BLOCK_N + tl.arange(0, BLOCK_N) - nm = offs_n < N - offs_k = tl.arange(0, BLOCK_K) - - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - for ks in range(tl.cdiv(K, BLOCK_K)): - km = offs_k < (K - ks * BLOCK_K) - x_val = tl.load(X + ks * BLOCK_K + offs_k, mask=km, other=0.0).to(tl.float32) - w_ptrs = ( - W - + offs_n[:, None] * stride_wn - + ((ks * BLOCK_K + offs_k[None, :]) // 2) * stride_wk - ) - w_shift = ((ks * BLOCK_K + offs_k[None, :]) % 2) * 4 - w_raw = tl.load(w_ptrs, mask=nm[:, None] & km[None, :], other=0) - w_uint4 = (w_raw >> w_shift) & 0xF - - gi = (ks * BLOCK_K) // group_size - s_ptrs = W_scale + offs_n * stride_sn + gi * stride_sk - scale = tl.load(s_ptrs, mask=nm, other=0.0).to(tl.float32) - - w_dq = (w_uint4.to(tl.float32) - 8.0) * scale[:, None] # [BN, BK] - acc += tl.sum(w_dq * x_val[None, :], axis=1) # [BN] - - offs_k += BLOCK_K - - tl.store(Out + offs_n, acc.to(tl.bfloat16), mask=nm) - - -@triton.jit -def _int4_matvec_v2( - X, - W, - Out, - W_scale, - N: tl.constexpr, - K: tl.constexpr, - stride_wn, - stride_wk, - stride_sn, - stride_sk, - group_size: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - """V2: Transposed load — iterate N-first, K-second for better coalescing on W.""" - nb = tl.program_id(0) - offs_n = nb * BLOCK_N + tl.arange(0, BLOCK_N) - nm = offs_n < N - offs_k = tl.arange(0, BLOCK_K) - - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - x_base = X - w_base = W + offs_n[:, None] * stride_wn - s_base = W_scale + offs_n * stride_sn - - for ks in range(tl.cdiv(K, BLOCK_K)): - abs_k = ks * BLOCK_K + offs_k - km = abs_k < K - - x_val = tl.load(x_base + abs_k, mask=km, other=0.0).to(tl.float32) - - w_ptrs = w_base + (abs_k[None, :] // 2) * stride_wk - w_shift = (abs_k[None, :] % 2) * 4 - w_raw = tl.load(w_ptrs, mask=nm[:, None] & km[None, :], other=0) - w_uint4 = (w_raw >> w_shift) & 0xF - - gi = (ks * BLOCK_K) // group_size - scale = tl.load(s_base + gi * stride_sk, mask=nm, other=0.0).to(tl.float32) - - w_dq = (w_uint4.to(tl.float32) - 8.0) * scale[:, None] - acc += tl.sum(w_dq * x_val[None, :], axis=1) - - tl.store(Out + offs_n, acc.to(tl.bfloat16), mask=nm) - - -@triton.jit -def _int4_matvec_splitk( - X, - W, - Out, - W_scale, - N: tl.constexpr, - K: tl.constexpr, - stride_wn, - stride_wk, - stride_sn, - stride_sk, - group_size: tl.constexpr, - K_SPLITS: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - """V3: Split-K matvec — more CTAs, atomic accumulate.""" - pid = tl.program_id(0) - num_n = tl.cdiv(N, BLOCK_N) - nb = pid // K_SPLITS - kid = pid % K_SPLITS - - offs_n = nb * BLOCK_N + tl.arange(0, BLOCK_N) - nm = offs_n < N - - k_per_split = tl.cdiv(K, K_SPLITS) - k_start = kid * k_per_split - k_end = tl.minimum(k_start + k_per_split, K) - - offs_k = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_N,), dtype=tl.float32) - - num_steps = tl.cdiv(k_end - k_start, BLOCK_K) - for step in range(num_steps): - abs_k = k_start + step * BLOCK_K + offs_k - km = abs_k < k_end - - x_val = tl.load(X + abs_k, mask=km, other=0.0).to(tl.float32) - - w_ptrs = W + offs_n[:, None] * stride_wn + (abs_k[None, :] // 2) * stride_wk - w_shift = (abs_k[None, :] % 2) * 4 - w_raw = tl.load(w_ptrs, mask=nm[:, None] & km[None, :], other=0) - w_uint4 = (w_raw >> w_shift) & 0xF - - gi = (k_start + step * BLOCK_K) // group_size - scale = tl.load( - W_scale + offs_n * stride_sn + gi * stride_sk, mask=nm, other=0.0 - ).to(tl.float32) - - w_dq = (w_uint4.to(tl.float32) - 8.0) * scale[:, None] - acc += tl.sum(w_dq * x_val[None, :], axis=1) - - if K_SPLITS == 1: - tl.store(Out + offs_n, acc.to(tl.bfloat16), mask=nm) - else: - tl.atomic_add(Out + offs_n, acc.to(tl.bfloat16), mask=nm) - - -def main(): - import executorch.backends.cuda.triton.kernels # noqa: F401 — register ops - import torch.nn as nn - from executorch.extension.llm.export.quantize import quantize_model_ - from torchao.quantization.quant_primitives import ( - choose_qparams_affine, - MappingType, - quantize_affine, - ) - - gs = 128 - shapes = [ - (2048, 2048, "q/o_proj"), - (12352, 2048, "shared_g+u"), - (256, 2048, "k/v_proj"), - ] - - for N, K, label in shapes: - w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") - sc, zp = choose_qparams_affine( - w.float(), - MappingType.SYMMETRIC, - (1, gs), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - ) - idata = quantize_affine( - w.float(), - (1, gs), - sc, - zp, - output_dtype=torch.int8, - quant_min=-8, - quant_max=7, - ) - u4 = (idata + 8).to(torch.int16) - packed = (u4[:, 0::2] | (u4[:, 1::2] << 4)).to(torch.int8).cuda() - w_scale = sc.reshape(N, -1).to(torch.bfloat16).cuda() - - linear = nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") - wr = nn.ModuleDict({"linear": linear}) - quantize_model_( - wr, - qlinear_config="4w", - qlinear_group_size=gs, - qlinear_packing_format="tile_packed_to_4d", - ) - tw = wr.linear.weight - - x = torch.randn(1, K, dtype=torch.bfloat16, device="cuda") - x_flat = x.squeeze(0) - - t_tiny = ( - do_bench( - lambda: nn.functional.linear(x, tw), - warmup=50, - rep=200, - return_mode="median", - ) - * 1000 - ) - - t_i4mm = ( - do_bench( - lambda: torch.ops.triton.int4_matmul(x, packed, w_scale, gs), - warmup=50, - rep=200, - return_mode="median", - ) - * 1000 - ) - - print(f"\n{'='*70}") - print( - f"[{N}x{K}] {label} — tinygemm={t_tiny:.1f}us, int4_matmul={t_i4mm:.1f}us" - ) - print(f"{'='*70}") - - out = torch.empty(N, dtype=torch.bfloat16, device="cuda") - best_t, best_cfg = float("inf"), "" - - # V1: basic matvec - print("\n--- V1: basic matvec ---") - for BN, BK, warps, stages in [ - (16, 128, 4, 3), - (16, 256, 4, 3), - (32, 128, 4, 3), - (32, 256, 4, 3), - (8, 128, 2, 3), - (8, 256, 2, 3), - (16, 128, 2, 3), - (4, 128, 2, 3), - (4, 256, 2, 3), - ]: - grid = ((N + BN - 1) // BN,) - - def run(_BN=BN, _BK=BK, _w=warps, _s=stages, _g=grid): - _int4_matvec_v1[_g]( - x_flat, - packed, - out, - w_scale, - N, - K, - packed.stride(0), - packed.stride(1), - w_scale.stride(0), - w_scale.stride(1), - gs, - BLOCK_N=_BN, - BLOCK_K=_BK, - num_warps=_w, - num_stages=_s, - ) - - try: - run() - t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 - tag = " <<<" if t < best_t else "" - if t < best_t: - best_t, best_cfg = t, f"v1 BN={BN} BK={BK} w={warps}" - print( - f" BN={BN:2d} BK={BK:3d} w={warps}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]:5d}{tag}" - ) - except Exception as e: - print(f" BN={BN:2d} BK={BK:3d} w={warps}: FAIL {str(e)[:50]}") - - # V3: split-K matvec - print("\n--- V3: split-K matvec ---") - for BN, BK, splits, warps, stages in [ - (16, 128, 4, 4, 3), - (16, 128, 8, 4, 3), - (8, 128, 4, 2, 3), - (8, 128, 8, 2, 3), - (8, 128, 16, 2, 3), - (4, 128, 4, 2, 3), - (4, 128, 8, 2, 3), - (4, 128, 16, 2, 3), - (16, 64, 8, 4, 3), - (8, 64, 16, 2, 3), - ]: - grid = (((N + BN - 1) // BN) * splits,) - out_sk = torch.zeros(N, dtype=torch.bfloat16, device="cuda") - - def run(_BN=BN, _BK=BK, _sp=splits, _w=warps, _s=stages, _g=grid): - out_sk.zero_() - _int4_matvec_splitk[_g]( - x_flat, - packed, - out_sk, - w_scale, - N, - K, - packed.stride(0), - packed.stride(1), - w_scale.stride(0), - w_scale.stride(1), - gs, - K_SPLITS=_sp, - BLOCK_N=_BN, - BLOCK_K=_BK, - num_warps=_w, - num_stages=_s, - ) - - try: - run() - t = do_bench(run, warmup=50, rep=200, return_mode="median") * 1000 - tag = " <<<" if t < best_t else "" - if t < best_t: - best_t, best_cfg = t, f"v3 BN={BN} BK={BK} sp={splits} w={warps}" - print( - f" BN={BN:2d} BK={BK:3d} sp={splits:2d} w={warps}: {t:6.1f}us ({t/t_tiny:.2f}x) grid={grid[0]:5d}{tag}" - ) - except Exception as e: - print( - f" BN={BN:2d} BK={BK:3d} sp={splits:2d} w={warps}: FAIL {str(e)[:50]}" - ) - - print(f"\nBest: {best_t:.1f}us ({best_t/t_tiny:.2f}x tinygemm) — {best_cfg}") - - del wr, tw, packed, w_scale - torch.cuda.empty_cache() - - -if __name__ == "__main__": - main() diff --git a/backends/cuda/benchmarks/benchmark_moe.py b/backends/cuda/benchmarks/benchmark_moe.py deleted file mode 100644 index e64386dd50c..00000000000 --- a/backends/cuda/benchmarks/benchmark_moe.py +++ /dev/null @@ -1,460 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmark the Triton fused MoE kernel against eager and torch.compile baselines. - -Measures latency across prompt lengths matching the Qwen3.5-35B-A3B model -(hidden_size=2048, num_experts=256, top_k=8, intermediate_size=512, -INT4 weight-only quantization with group_size=128). - -Usage: - python benchmark_moe.py - python benchmark_moe.py --prompt-lengths 1,8,64,512 --num_iters 200 -""" - -import argparse -from functools import partial - -import executorch.backends.cuda.triton.kernels # noqa: F401 — registers triton ops - -import torch -from triton.testing import do_bench - - -# -- Qwen3.5-35B-A3B defaults ------------------------------------------------ - -DEFAULTS = { - "num_experts": 256, - "top_k": 8, - "hidden_size": 2048, - "intermediate_size": 512, - "group_size": 128, -} - -PROMPT_LENGTHS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4095] - - -# -- Weight / input generation ----------------------------------------------- - - -def _make_int4_weights(E, N, K, group_size, device="cuda"): - """Generate random packed INT4 weights and per-group scales. - - Returns: - w: [E, N, K//2] int8 — two INT4 values packed per byte - scale: [E, N, K//group_size] bf16 - """ - vals = torch.randint(0, 16, (E, N, K), dtype=torch.uint8, device=device) - low = vals[:, :, 0::2] - high = vals[:, :, 1::2] - packed = (high << 4) | low - w = packed.to(torch.int8) - - scale = ( - torch.randn(E, N, K // group_size, device=device, dtype=torch.bfloat16) * 0.01 - ) - return w, scale - - -# -- Dequantization ---------------------------------------------------------- - - -def _dequant_int4(w_packed, scale, group_size): - """Unpack INT4 weights and dequantize. - - w_packed: [E, N, K//2] int8 - scale: [E, N, K//group_size] bf16 - Returns: [E, N, K] bf16 - """ - w_uint8 = w_packed.to(torch.uint8) - low = (w_uint8 & 0xF).to(torch.float32) - high = ((w_uint8 >> 4) & 0xF).to(torch.float32) - E, N, Khalf = w_packed.shape - K = Khalf * 2 - vals = torch.empty(E, N, K, device=w_packed.device, dtype=torch.float32) - vals[:, :, 0::2] = low - vals[:, :, 1::2] = high - vals = vals - 8.0 - scale_expanded = scale.float().repeat_interleave(group_size, dim=2)[:, :, :K] - return (vals * scale_expanded).to(torch.bfloat16) - - -# -- Backends ----------------------------------------------------------------- - - -def _run_eager( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - """Loop-based eager MoE — correctness reference only (not benchmarked).""" - M, K = hidden_states.shape - inter = w2.shape[2] * 2 - - w1_deq = _dequant_int4(w1, w1_scale, group_size) - w2_deq = _dequant_int4(w2, w2_scale, group_size) - - output = torch.zeros(M, K, device=hidden_states.device, dtype=torch.bfloat16) - for i in range(M): - for j in range(top_k): - expert_id = topk_ids[i, j].item() - weight = topk_weights[i, j] - x = hidden_states[i : i + 1] @ w1_deq[expert_id].T - gate = x[:, :inter] - up = x[:, inter:] - x = torch.nn.functional.silu(gate) * up - x = x @ w2_deq[expert_id].T - output[i] += weight * x.squeeze(0) - return output - - -def _run_eager_vectorized( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - """Vectorized eager — gather + bmm, no Python loops.""" - M, K = hidden_states.shape - inter = w2.shape[2] * 2 - - w1_deq = _dequant_int4(w1, w1_scale, group_size) - w2_deq = _dequant_int4(w2, w2_scale, group_size) - - flat_ids = topk_ids.reshape(-1) - hs_rep = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(M * top_k, K) - gemm1_out = torch.bmm( - hs_rep.unsqueeze(1), w1_deq[flat_ids].transpose(1, 2) - ).squeeze(1) - - gate = gemm1_out[:, :inter] - up = gemm1_out[:, inter:] - act = torch.nn.functional.silu(gate) * up - - gemm2_out = torch.bmm(act.unsqueeze(1), w2_deq[flat_ids].transpose(1, 2)).squeeze(1) - - return (gemm2_out.view(M, top_k, K) * topk_weights.unsqueeze(-1)).sum(dim=1) - - -_compiled_fn = None - - -def _run_compiled( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - global _compiled_fn - if _compiled_fn is None: - _compiled_fn = torch.compile(_run_eager_vectorized) - return _compiled_fn( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, - ) - - -def _run_triton( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - return torch.ops.triton.fused_moe( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k=top_k, - num_experts=num_experts, - group_size=group_size, - ) - - -BACKENDS = { - "eager_vec": ("Eager (vec)", _run_eager_vectorized), - "compile": ("Compile", _run_compiled), - "triton": ("Triton fused", _run_triton), -} - -try: - from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe_batched - - def _run_triton_batched( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, - ): - return fused_moe_batched( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k=top_k, - num_experts=num_experts, - group_size=group_size, - ) - - BACKENDS["triton_batched"] = ("Triton batched", _run_triton_batched) - - def _run_triton_batched_int8( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, - ): - return fused_moe_batched( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k=top_k, - num_experts=num_experts, - group_size=group_size, - activation_dtype="int8", - ) - - BACKENDS["triton_batched_int8"] = ("Triton bat-i8", _run_triton_batched_int8) -except ImportError: - pass - - -# -- Helpers ------------------------------------------------------------------ - - -def _max_abs_error(out, ref): - return (out.float() - ref.float()).abs().max().item() - - -def _bench_ms(fn, num_warmup, num_iters): - return do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - - -def _try_bench(run_fn, args, num_warmup, num_iters): - fn = partial(run_fn, **args) - try: - fn() - return _bench_ms(fn, num_warmup, num_iters) - except torch.OutOfMemoryError: - torch.cuda.empty_cache() - return None - - -# -- Main --------------------------------------------------------------------- - - -@torch.inference_mode() -def run_benchmark( - prompt_lengths, - num_experts, - top_k, - hidden_size, - intermediate_size, - group_size, - num_warmup, - num_iters, -): - backends = [(name, *BACKENDS[name]) for name in BACKENDS] - - device_name = torch.cuda.get_device_name() - print() - print("=" * 100) - print("Fused MoE Benchmark — Qwen3.5-35B-A3B (W4A16)") - print(f" Device: {device_name}") - print( - f" Experts: {num_experts}, Top-K: {top_k}, Hidden: {hidden_size}, " - f"Intermediate: {intermediate_size}, Group: {group_size}" - ) - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Generate weights once (shared across prompt lengths) - w1, w1_scale = _make_int4_weights( - num_experts, 2 * intermediate_size, hidden_size, group_size - ) - w2, w2_scale = _make_int4_weights( - num_experts, hidden_size, intermediate_size, group_size - ) - - # Column layout: Shape | backend1 | backend2 | ... (dynamic widths) - col_specs = [("M (tokens)", "", 10)] - for _, label, _ in backends: - col_specs.append((label, "(ms)", max(8, len(label)))) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) - ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for M in prompt_lengths: - hidden_states = torch.randn(M, hidden_size, device="cuda", dtype=torch.bfloat16) - router_logits = torch.randn(M, num_experts, device="cuda", dtype=torch.float32) - topk_w, topk_i = torch.topk(router_logits, top_k, dim=-1) - topk_w = torch.softmax(topk_w, dim=-1) - topk_i = topk_i.to(torch.int64) - - common_args = { - "hidden_states": hidden_states, - "w1": w1, - "w1_scale": w1_scale, - "w2": w2, - "w2_scale": w2_scale, - "topk_weights": topk_w, - "topk_ids": topk_i, - "top_k": top_k, - "num_experts": num_experts, - "group_size": group_size, - } - - # Correctness: triton vs loop-based eager reference. - # Only check at small M to avoid slow eager loop + OOM on large M. - if M <= 64: - ref_out = _run_eager(**common_args) - tri_out = _run_triton(**common_args) - err = _max_abs_error(tri_out, ref_out) - assert err < 2.0e-1, ( - f"Triton vs eager mismatch at M={M}: " - f"max abs error {err:.3e} >= 2.0e-1" - ) - if "triton_batched_int8" in BACKENDS: - _, _, run_int8 = BACKENDS["triton_batched_int8"] - int8_out = run_int8(**common_args) - int8_err = _max_abs_error(int8_out, ref_out) - assert int8_err < 5.0e-1, ( - f"Triton INT8 vs eager mismatch at M={M}: " - f"max abs error {int8_err:.3e} >= 5.0e-1" - ) - del int8_out - del ref_out, tri_out - - # Benchmark - times = {} - for name, _label, run_fn in backends: - times[name] = _try_bench(run_fn, common_args, num_warmup, num_iters) - - ci = 0 - row_parts = [f"{f'M={M}':<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.3f}" if t is not None else f"{'OOM':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del hidden_states, topk_w, topk_i - torch.cuda.empty_cache() - - print("-" * len(header)) - print() - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark Triton fused MoE vs eager/compile baselines" - ) - parser.add_argument("--num-experts", type=int, default=DEFAULTS["num_experts"]) - parser.add_argument("--top-k", type=int, default=DEFAULTS["top_k"]) - parser.add_argument("--hidden-size", type=int, default=DEFAULTS["hidden_size"]) - parser.add_argument( - "--intermediate-size", type=int, default=DEFAULTS["intermediate_size"] - ) - parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"]) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) - parser.add_argument( - "--prompt-lengths", - type=str, - default=None, - help="Comma-separated list of prompt lengths (default: standard sweep)", - ) - args = parser.parse_args() - - prompt_lengths = PROMPT_LENGTHS - if args.prompt_lengths: - prompt_lengths = [int(x.strip()) for x in args.prompt_lengths.split(",")] - - run_benchmark( - prompt_lengths=prompt_lengths, - num_experts=args.num_experts, - top_k=args.top_k, - hidden_size=args.hidden_size, - intermediate_size=args.intermediate_size, - group_size=args.group_size, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) - - -if __name__ == "__main__": - main() diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py deleted file mode 100644 index 3c117f4574f..00000000000 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmark the Triton SDPA kernel against PyTorch SDPA backends. - -Measures latency across decode shapes matching the Qwen3.5 MoE model -(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA -(2 KV heads), while Flash/Efficient/Math require pre-expanded KV -(16 heads) since they lack native GQA support. - -""" - -import argparse -import warnings -from functools import partial - -import torch -import torch.nn.functional as F - -from executorch.backends.cuda.triton.kernels.sdpa import ( - sdpa as triton_sdpa, - sdpa_decode_splitk as triton_splitk, -) -from torch.nn.attention import sdpa_kernel, SDPBackend -from triton.testing import do_bench - - -# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. -# We expand KV heads via repeat_interleave so they can run, matching what -# the test reference does. This is fair: it measures the kernel itself, not -# the GQA dispatch overhead. - - -def _expand_kv(k, v, num_groups): - if num_groups > 1: - k = k.repeat_interleave(num_groups, dim=1) - v = v.repeat_interleave(num_groups, dim=1) - return k, v - - -def _expand_mask(mask, H_q): - if mask is not None and mask.shape[1] == 1 and H_q > 1: - mask = mask.expand(-1, H_q, -1, -1) - return mask - - -def _run_triton(q, k, v, attn_mask, enable_gqa): - return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_splitk(q, k, v, attn_mask, enable_gqa): - return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): - return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - enable_gqa=enable_gqa, - ) - - -def _make_pytorch_runner(backend: SDPBackend): - def run(q, k, v, attn_mask, enable_gqa): - with sdpa_kernel(backend): - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - - return run - - -# Flash doesn't support attn_mask at all, only is_causal. -# Our benchmark mask is all-ones, so no mask is equivalent. -def _run_flash(q, k, v, attn_mask, enable_gqa): - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - return F.scaled_dot_product_attention(q, k, v) - - -BACKENDS = { - "triton": ("ET Triton (GQA)", _run_triton), - "splitk": ("ET Split-K (GQA)", _run_splitk), - "pytorch": ("PyTorch", _run_pytorch_default), - "flash": ("Flash (expanded KV)", _run_flash), - "efficient": ( - "Efficient (expanded KV)", - _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), - ), - "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), -} - -# Backends that need KV heads expanded before calling (no native GQA support) -_NEEDS_KV_EXPAND = {"flash", "efficient", "math"} - -# -- Shapes ------------------------------------------------------------------ - -# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 -QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} - -DECODE_SHAPES = [ - dict(**QWEN35_BASE, Lq=1, Lk=64), - dict(**QWEN35_BASE, Lq=1, Lk=128), - dict(**QWEN35_BASE, Lq=1, Lk=256), - dict(**QWEN35_BASE, Lq=1, Lk=512), - dict(**QWEN35_BASE, Lq=1, Lk=1024), - dict(**QWEN35_BASE, Lq=1, Lk=2048), - dict(**QWEN35_BASE, Lq=1, Lk=4096), - dict(**QWEN35_BASE, Lq=1, Lk=8192), - dict(**QWEN35_BASE, Lq=1, Lk=16384), -] - -SCENARIOS = { - "decode": DECODE_SHAPES, -} - -# -- Helpers ----------------------------------------------------------------- - - -def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): - q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype) - k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) - v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) - mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) - enable_gqa = H_q != H_kv - num_groups = H_q // H_kv - # Pre-expanded versions for backends without native GQA - k_exp, v_exp = _expand_kv(k, v, num_groups) - mask_exp = _expand_mask(mask, H_q) - return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa - - -def _max_abs_error(out, ref): - return (out.float() - ref.float()).abs().max().item() - - -# Cross-backend validation tolerance (bf16 vs bf16). -MAX_ABS_TOL = 1e-2 - - -def _bench_us(fn, num_warmup, num_iters): - """Return median latency in microseconds using triton.testing.do_bench.""" - ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - return ms * 1000.0 - - -def _try_run(run_fn, q, k, v, mask, enable_gqa): - """Run a backend, returning output or None on failure.""" - try: - return run_fn(q, k, v, mask, enable_gqa) - except RuntimeError: - return None - - -def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): - """Benchmark a backend, returning median us or None on failure.""" - fn = partial(run_fn, q, k, v, mask, enable_gqa) - try: - run_fn(q, k, v, mask, enable_gqa) - return _bench_us(fn, num_warmup, num_iters) - except RuntimeError: - return None - - -# -- Main -------------------------------------------------------------------- - - -def _shape_label(shape): - return ( - f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " - f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" - ) - - -def _short_label(shape, scenario="decode"): - return f"Lq={shape['Lq']},Lk={shape['Lk']}" - - -@torch.inference_mode() -def run_benchmark( - scenario: str = "decode", - num_warmup: int = 25, - num_iters: int = 100, -): - shapes = SCENARIOS[scenario] - backends = [(name, *BACKENDS[name]) for name in BACKENDS] - - device_name = torch.cuda.get_device_name() - print() - print("=" * 100) - print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") - print(f" Device: {device_name}") - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Build column specs: (header_text, unit_text, min_width) - # Each column gets width = max(len(header), len(unit), min_width) - max_label = max(len(_short_label(s, scenario)) for s in shapes) - col_specs = [("Shape", "", max(8, max_label))] - for _, label, _ in backends: - col_specs.append((label, "(us)", 8)) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) - ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for shape in shapes: - q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Validate outputs across backends before benchmarking - outputs = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) - - # Use PyTorch F.sdpa as the trusted reference — never validate - # against our own Triton kernels. - ref_name, ref_out = None, None - if outputs.get("pytorch") is not None: - ref_name, ref_out = "pytorch", outputs["pytorch"] - - if ref_out is not None: - for name, label, _ in backends: - if name == ref_name or outputs[name] is None: - continue - err = _max_abs_error(outputs[name], ref_out) - assert err < MAX_ABS_TOL, ( - f"Output mismatch for {_shape_label(shape)}: " - f"{label} vs {BACKENDS[ref_name][0]}, " - f"max abs error {err:.3e} >= 1e-2" - ) - del outputs - - # Benchmark all backends - times = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - times[name] = _try_bench( - run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters - ) - - # Format row using col_widths - ci = 0 - row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del q, k, v, k_exp, v_exp, mask, mask_exp - torch.cuda.empty_cache() - - print("-" * len(header)) - print() - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark Triton SDPA vs PyTorch backends" - ) - parser.add_argument( - "--scenario", - choices=list(SCENARIOS.keys()) + ["all"], - default="all", - help="Which shape set to benchmark (default: all)", - ) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) - args = parser.parse_args() - - scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] - for s in scenarios: - run_benchmark( - scenario=s, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) - - -if __name__ == "__main__": - main() diff --git a/backends/cuda/benchmarks/sweep_e2e.sh b/backends/cuda/benchmarks/sweep_e2e.sh deleted file mode 100755 index 8ddfe7958ca..00000000000 --- a/backends/cuda/benchmarks/sweep_e2e.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash -# E2E sweep: baseline vs dequant across prompt lengths and generation lengths -set -euo pipefail - -RUNNER="${RUNNER:?Set RUNNER to the path of qwen3_5_moe_runner binary}" -TOKENIZER="${TOKENIZER:?Set TOKENIZER to the path of tokenizer.json}" -if [ -n "$LD_PRELOAD_OVERRIDE" ]; then - export LD_PRELOAD="$LD_PRELOAD_OVERRIDE" -fi - -BASELINE="${BASELINE:?Set BASELINE to the baseline model directory}" -DEQUANT="${DEQUANT:?Set DEQUANT to the dequant model directory}" -REPORT_DIR="${REPORT_DIR:-./report_baseline_vs_dequant}" -mkdir -p "$REPORT_DIR" - -# Generate prompts of various lengths -gen_prompt() { - local target=$1 - python3 -c " -base = 'The transformer architecture has revolutionized machine learning and natural language processing by enabling parallel computation across all positions in a sequence, eliminating the sequential bottleneck of recurrent models. ' -text = base * ($target // 10 + 5) -print(text[:$target * 6]) -" -} - -run_one() { - local label=$1 dir=$2 prompt=$3 max_tok=$4 outfile=$5 - $RUNNER \ - --model_path "$dir/model.pte" \ - --data_path "$dir/aoti_cuda_blob.ptd" \ - --tokenizer_path "$TOKENIZER" \ - --prompt "$prompt" \ - --max_new_tokens "$max_tok" \ - --temperature 0 2>&1 | tee "$outfile" -} - -extract_stats() { - local file=$1 - local ptok=$(grep -oP 'Prompt Tokens: \K\d+' "$file" | tail -1) - local prate=$(grep -oP 'Prompt evaluation:.*Rate:\s*\K[\d.]+' "$file" | tail -1) - local gtok=$(grep -oP 'Generated \K\d+' "$file" | tail -1) - local drate=$(grep -oP 'Generated \d+ tokens:.*Rate:\s*\K[\d.]+' "$file" | tail -1) - local ttft=$(grep -oP 'Time to first generated token:\s*\K[\d.]+' "$file" | tail -1) - echo "$ptok,$prate,$gtok,$drate,$ttft" -} - -# ============================================================ -# Part 1: Performance sweep -# ============================================================ -echo "prompt_tokens,gen_tokens,baseline_prefill,baseline_decode,baseline_ttft,dequant_prefill,dequant_decode,dequant_ttft" > "$REPORT_DIR/sweep.csv" - -for PTARGET in 128 256 512 1024 2048; do - PROMPT=$(gen_prompt $PTARGET) - for GENTOK in 128 256 512; do - echo "=== P~${PTARGET} G=${GENTOK} ===" - - # Baseline - BFILE="$REPORT_DIR/run_baseline_p${PTARGET}_g${GENTOK}.txt" - run_one baseline "$BASELINE" "$PROMPT" "$GENTOK" "$BFILE" > /dev/null 2>&1 - BSTATS=$(extract_stats "$BFILE") - - # Dequant - DFILE="$REPORT_DIR/run_dequant_p${PTARGET}_g${GENTOK}.txt" - run_one dequant "$DEQUANT" "$PROMPT" "$GENTOK" "$DFILE" > /dev/null 2>&1 - DSTATS=$(extract_stats "$DFILE") - - BPTOK=$(echo $BSTATS | cut -d, -f1) - BPRATE=$(echo $BSTATS | cut -d, -f2) - BDRATE=$(echo $BSTATS | cut -d, -f4) - BTTFT=$(echo $BSTATS | cut -d, -f5) - DPRATE=$(echo $DSTATS | cut -d, -f2) - DDRATE=$(echo $DSTATS | cut -d, -f4) - DTTFT=$(echo $DSTATS | cut -d, -f5) - - echo "$BPTOK,$GENTOK,$BPRATE,$BDRATE,$BTTFT,$DPRATE,$DDRATE,$DTTFT" >> "$REPORT_DIR/sweep.csv" - echo " P=$BPTOK: baseline prefill=${BPRATE} decode=${BDRATE} | dequant prefill=${DPRATE} decode=${DDRATE}" - done -done - -echo "" -echo "Sweep complete. Results in $REPORT_DIR/sweep.csv"