|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +import tilelang |
| 4 | +from tilelang.autotuner import * |
| 5 | +import tilelang.language as T |
| 6 | +import itertools |
| 7 | +import argparse |
| 8 | +from functools import partial |
| 9 | + |
| 10 | + |
| 11 | +def get_configs(): |
| 12 | + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) |
| 13 | + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] |
| 14 | + |
| 15 | + |
| 16 | +@autotune(configs=get_configs(), warmup=10, rep=10) |
| 17 | +@tilelang.jit( |
| 18 | + out_idx=[3], |
| 19 | + pass_configs={ |
| 20 | + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, |
| 21 | + }, |
| 22 | +) |
| 23 | +def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): |
| 24 | + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) |
| 25 | + q_shape = [batch, heads, seq_q, dim] |
| 26 | + kv_shape = [batch, heads, seq_kv, dim] |
| 27 | + dtype = T.float16 |
| 28 | + accum_dtype = T.float32 |
| 29 | + |
| 30 | + past_len = seq_kv - seq_q |
| 31 | + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" |
| 32 | + |
| 33 | + @T.prim_func |
| 34 | + def main( |
| 35 | + Q: T.Tensor(q_shape, dtype), |
| 36 | + K: T.Tensor(kv_shape, dtype), |
| 37 | + V: T.Tensor(kv_shape, dtype), |
| 38 | + Output: T.Tensor(q_shape, dtype), |
| 39 | + ): |
| 40 | + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): |
| 41 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 42 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 43 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 44 | + O_shared = T.alloc_shared([block_M, dim], dtype) |
| 45 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 46 | + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 47 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 48 | + scores_max = T.alloc_fragment([block_M], accum_dtype) |
| 49 | + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
| 50 | + scores_scale = T.alloc_fragment([block_M], accum_dtype) |
| 51 | + scores_sum = T.alloc_fragment([block_M], accum_dtype) |
| 52 | + logsum = T.alloc_fragment([block_M], accum_dtype) |
| 53 | + |
| 54 | + T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared) |
| 55 | + T.fill(acc_o, 0) |
| 56 | + T.fill(logsum, 0) |
| 57 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 58 | + |
| 59 | + loop_range = ( |
| 60 | + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) |
| 61 | + if is_causal |
| 62 | + else T.ceildiv(seq_kv, block_N) |
| 63 | + ) |
| 64 | + |
| 65 | + for k in T.Pipelined(loop_range, num_stages=num_stages): |
| 66 | + T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared) |
| 67 | + if is_causal: |
| 68 | + for i, j in T.Parallel(block_M, block_N): |
| 69 | + q_idx = bx * block_M + i + past_len |
| 70 | + k_idx = k * block_N + j |
| 71 | + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) |
| 72 | + else: |
| 73 | + for i, j in T.Parallel(block_M, block_N): |
| 74 | + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) |
| 75 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 76 | + |
| 77 | + T.copy(scores_max, scores_max_prev) |
| 78 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 79 | + T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
| 80 | + for i in T.Parallel(block_M): |
| 81 | + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) |
| 82 | + for i in T.Parallel(block_M): |
| 83 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 84 | + for i, j in T.Parallel(block_M, block_N): |
| 85 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 86 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 87 | + for i in T.Parallel(block_M): |
| 88 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 89 | + T.copy(acc_s, acc_s_cast) |
| 90 | + |
| 91 | + for i, j in T.Parallel(block_M, dim): |
| 92 | + acc_o[i, j] *= scores_scale[i] |
| 93 | + |
| 94 | + T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared) |
| 95 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 96 | + |
| 97 | + for i, j in T.Parallel(block_M, dim): |
| 98 | + acc_o[i, j] /= logsum[i] |
| 99 | + T.copy(acc_o, O_shared) |
| 100 | + T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :]) |
| 101 | + |
| 102 | + return main |
| 103 | + |
| 104 | + |
| 105 | +def ref_program(Q, K, V, is_causal): |
| 106 | + dim = Q.size(-1) |
| 107 | + scores = torch.einsum("bhqd,bhkd->bhqk", Q, K) |
| 108 | + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) |
| 109 | + if is_causal: |
| 110 | + seq_q = Q.size(2) |
| 111 | + seq_kv = K.size(2) |
| 112 | + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) |
| 113 | + mask = mask.unsqueeze(0).unsqueeze(0) |
| 114 | + scores = scores.masked_fill(mask == 0, float("-inf")) |
| 115 | + attention_weights = F.softmax(scores, dim=-1) |
| 116 | + output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V) |
| 117 | + return output |
| 118 | + |
| 119 | + |
| 120 | +def main( |
| 121 | + batch: int = 1, |
| 122 | + heads: int = 1, |
| 123 | + seq_q: int = 256, |
| 124 | + seq_kv: int = 256, |
| 125 | + dim: int = 64, |
| 126 | + is_causal: bool = False, |
| 127 | + tune: bool = False, |
| 128 | +): |
| 129 | + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim |
| 130 | + total_flops = 2 * flops_per_matmul |
| 131 | + if is_causal: |
| 132 | + total_flops *= 0.5 |
| 133 | + |
| 134 | + if not tune: |
| 135 | + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) |
| 136 | + ref_program_processed = partial(ref_program, is_causal=is_causal) |
| 137 | + |
| 138 | + profiler = kernel.get_profiler() |
| 139 | + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) |
| 140 | + print("All checks pass.") |
| 141 | + latency = profiler.do_bench(ref_program_processed, warmup=500) |
| 142 | + print("Ref: {:.2f} ms".format(latency)) |
| 143 | + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 144 | + latency = profiler.do_bench(warmup=500) |
| 145 | + print("Tile-lang: {:.2f} ms".format(latency)) |
| 146 | + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 147 | + else: |
| 148 | + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) |
| 149 | + best_latency = kernel.latency |
| 150 | + best_config = kernel.config |
| 151 | + ref_latency = kernel.ref_latency |
| 152 | + print(f"Best latency: {best_latency}") |
| 153 | + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") |
| 154 | + print(f"Best config: {best_config}") |
| 155 | + print(f"Ref latency: {ref_latency}") |
| 156 | + |
| 157 | + |
| 158 | +def run_regression_perf( |
| 159 | + batch: int = 1, |
| 160 | + heads: int = 32, |
| 161 | + seq_q: int = 256, |
| 162 | + seq_kv: int = 256, |
| 163 | + dim: int = 64, |
| 164 | + is_causal: bool = False, |
| 165 | + tune: bool = False, |
| 166 | +): |
| 167 | + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) |
| 168 | + profiler = kernel.get_profiler() |
| 169 | + return profiler.do_bench(backend="cupti") |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + parser = argparse.ArgumentParser() |
| 174 | + parser.add_argument("--batch", type=int, default=1, help="batch size") |
| 175 | + parser.add_argument("--heads", type=int, default=1, help="heads") |
| 176 | + parser.add_argument("--seq_q", type=int, default=256, help="query sequence length") |
| 177 | + parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length") |
| 178 | + parser.add_argument("--dim", type=int, default=64, help="dim") |
| 179 | + parser.add_argument("--is_causal", action="store_true", help="causal", default=False) |
| 180 | + parser.add_argument("--tune", action="store_true", help="tune configs") |
| 181 | + args = parser.parse_args() |
| 182 | + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) |
0 commit comments