Skip to content

Commit 809ae2c

Browse files
committed
feat: add tilelang op test
1 parent 3b5a6f9 commit 809ae2c

20 files changed

Lines changed: 1267 additions & 50 deletions

add.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,27 @@
33

44
import ops.ninetoothed.torch
55
import ops.triton.torch
6+
import ops.tilelang.torch
67

78
if __name__ == "__main__":
89
torch.manual_seed(0)
910

10-
size = 98432
11+
shape = (1024, 1024)
1112
dtype = torch.float16
1213
device = "cuda"
1314

14-
input = torch.randn(size, dtype=dtype, device=device)
15-
other = torch.randn(size, dtype=dtype, device=device)
15+
input = torch.randn(shape, dtype=dtype, device=device)
16+
other = torch.randn(shape, dtype=dtype, device=device)
1617

1718
ninetoothed_output = ops.ninetoothed.torch.add(input, other)
1819
torch_output = input + other
1920
triton_output = ops.triton.torch.add(input, other)
21+
tilelang_output = ops.tilelang.torch.add(input, other)
2022

2123
print(ninetoothed_output)
2224
print(torch_output)
2325
print(triton_output)
26+
print(tilelang_output)
2427

2528
if torch.allclose(ninetoothed_output, torch_output):
2629
print("✅ NineToothed and PyTorch match.")
@@ -30,31 +33,37 @@
3033
print("✅ NineToothed and Triton match.")
3134
else:
3235
print("❌ NineToothed and Triton differ.")
36+
if torch.allclose(ninetoothed_output, tilelang_output, atol=0, rtol=0):
37+
print("✅ NineToothed and TileLang match.")
38+
else:
39+
print("❌ NineToothed and TileLang differ.")
3340

3441
@triton.testing.perf_report(
3542
triton.testing.Benchmark(
36-
x_names=["size"],
37-
x_vals=[2**i for i in range(18, 28)],
43+
x_names=["m", "n"],
44+
x_vals=[2**i for i in range(5, 15)],
3845
x_log=True,
3946
line_arg="provider",
40-
line_vals=["ninetoothed", "torch", "triton"],
41-
line_names=["NineToothed", "PyTorch", "Triton"],
42-
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
47+
line_vals=["ninetoothed", "torch", "triton", "tilelang"],
48+
line_names=["NineToothed", "PyTorch", "Triton", "TileLang"],
49+
styles=[("blue", "-"), ("green", "-"), ("orange", "-"), ("red", "--")],
4350
ylabel="ms",
4451
plot_name="add-performance",
4552
args={},
4653
)
4754
)
48-
def benchmark(size, provider):
49-
input = torch.randn(size, dtype=dtype, device=device)
50-
other = torch.randn(size, dtype=dtype, device=device)
55+
def benchmark(m, n, provider):
56+
input = torch.randn((m, n), dtype=dtype, device=device)
57+
other = torch.randn((m, n), dtype=dtype, device=device)
5158

5259
ninetoothed_output = ops.ninetoothed.torch.add(input, other)
5360
torch_output = torch.add(input, other)
5461
triton_output = ops.triton.torch.add(input, other)
62+
tilelang_output = ops.tilelang.torch.add(input, other)
5563

5664
assert torch.allclose(ninetoothed_output, torch_output)
5765
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
66+
assert torch.allclose(ninetoothed_output, tilelang_output, atol=0, rtol=0)
5867

5968
if provider == "ninetoothed":
6069
ms = triton.testing.do_bench(
@@ -64,6 +73,8 @@ def benchmark(size, provider):
6473
ms = triton.testing.do_bench(lambda: torch.add(input, other))
6574
elif provider == "triton":
6675
ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other))
76+
elif provider == "tilelang":
77+
ms = triton.testing.do_bench(lambda: ops.tilelang.torch.add(input, other))
6778

6879
return ms
6980

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import argparse
2+
import torch
3+
import tilelang
4+
import tilelang.language as T
5+
6+
7+
def ref_program(x, y):
8+
return x + y
9+
10+
11+
@tilelang.jit(out_idx=[-1])
12+
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
13+
@T.prim_func
14+
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
15+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
16+
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
17+
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
18+
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
19+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
20+
21+
T.copy(A[by * block_M, bx * block_N], A_shared)
22+
T.copy(B[by * block_M, bx * block_N], B_shared)
23+
for local_y, local_x in T.Parallel(block_M, block_N):
24+
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
25+
T.copy(C_local, C_shared)
26+
T.copy(C_shared, C[by * block_M, bx * block_N])
27+
28+
return elem_add
29+
30+
31+
def main(M=1024, N=1024, use_autotune=False):
32+
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
33+
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
34+
35+
kernel = elementwise_add(M, N, block_M=32, block_N=32, threads=128, in_dtype=T.float32, out_dtype=T.float32)
36+
37+
out = kernel(a, b)
38+
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
39+
40+
41+
def run_regression_perf():
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument("--m", type=int, default=4096)
44+
parser.add_argument("--n", type=int, default=4096)
45+
args, _ = parser.parse_known_args()
46+
M, N = args.m, args.n
47+
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
48+
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
49+
config = {"block_M": 32, "block_N": 32, "threads": 128}
50+
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
51+
from tilelang.profiler import do_bench
52+
53+
return do_bench(lambda: kernel(a, b), backend="cupti")
54+
55+
56+
if __name__ == "__main__":
57+
parser = argparse.ArgumentParser()
58+
parser.add_argument("--m", type=int, default=1024)
59+
parser.add_argument("--n", type=int, default=1024)
60+
args, _ = parser.parse_known_args()
61+
main(args.m, args.n)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import tilelang
4+
from tilelang.autotuner import *
5+
import tilelang.language as T
6+
import itertools
7+
import argparse
8+
from functools import partial
9+
10+
11+
def get_configs():
12+
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
13+
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
14+
15+
16+
@autotune(configs=get_configs(), warmup=10, rep=10)
17+
@tilelang.jit(
18+
out_idx=[3],
19+
pass_configs={
20+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
21+
},
22+
)
23+
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
24+
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
25+
q_shape = [batch, heads, seq_q, dim]
26+
kv_shape = [batch, heads, seq_kv, dim]
27+
dtype = T.float16
28+
accum_dtype = T.float32
29+
30+
past_len = seq_kv - seq_q
31+
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
32+
33+
@T.prim_func
34+
def main(
35+
Q: T.Tensor(q_shape, dtype),
36+
K: T.Tensor(kv_shape, dtype),
37+
V: T.Tensor(kv_shape, dtype),
38+
Output: T.Tensor(q_shape, dtype),
39+
):
40+
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
41+
Q_shared = T.alloc_shared([block_M, dim], dtype)
42+
K_shared = T.alloc_shared([block_N, dim], dtype)
43+
V_shared = T.alloc_shared([block_N, dim], dtype)
44+
O_shared = T.alloc_shared([block_M, dim], dtype)
45+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
46+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
47+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
48+
scores_max = T.alloc_fragment([block_M], accum_dtype)
49+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
50+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
51+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
52+
logsum = T.alloc_fragment([block_M], accum_dtype)
53+
54+
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
55+
T.fill(acc_o, 0)
56+
T.fill(logsum, 0)
57+
T.fill(scores_max, -T.infinity(accum_dtype))
58+
59+
loop_range = (
60+
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
61+
if is_causal
62+
else T.ceildiv(seq_kv, block_N)
63+
)
64+
65+
for k in T.Pipelined(loop_range, num_stages=num_stages):
66+
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
67+
if is_causal:
68+
for i, j in T.Parallel(block_M, block_N):
69+
q_idx = bx * block_M + i + past_len
70+
k_idx = k * block_N + j
71+
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
72+
else:
73+
for i, j in T.Parallel(block_M, block_N):
74+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
75+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
76+
77+
T.copy(scores_max, scores_max_prev)
78+
T.fill(scores_max, -T.infinity(accum_dtype))
79+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
80+
for i in T.Parallel(block_M):
81+
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
82+
for i in T.Parallel(block_M):
83+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
84+
for i, j in T.Parallel(block_M, block_N):
85+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
86+
T.reduce_sum(acc_s, scores_sum, dim=1)
87+
for i in T.Parallel(block_M):
88+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
89+
T.copy(acc_s, acc_s_cast)
90+
91+
for i, j in T.Parallel(block_M, dim):
92+
acc_o[i, j] *= scores_scale[i]
93+
94+
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
95+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
96+
97+
for i, j in T.Parallel(block_M, dim):
98+
acc_o[i, j] /= logsum[i]
99+
T.copy(acc_o, O_shared)
100+
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
101+
102+
return main
103+
104+
105+
def ref_program(Q, K, V, is_causal):
106+
dim = Q.size(-1)
107+
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
108+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
109+
if is_causal:
110+
seq_q = Q.size(2)
111+
seq_kv = K.size(2)
112+
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
113+
mask = mask.unsqueeze(0).unsqueeze(0)
114+
scores = scores.masked_fill(mask == 0, float("-inf"))
115+
attention_weights = F.softmax(scores, dim=-1)
116+
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
117+
return output
118+
119+
120+
def main(
121+
batch: int = 1,
122+
heads: int = 1,
123+
seq_q: int = 256,
124+
seq_kv: int = 256,
125+
dim: int = 64,
126+
is_causal: bool = False,
127+
tune: bool = False,
128+
):
129+
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
130+
total_flops = 2 * flops_per_matmul
131+
if is_causal:
132+
total_flops *= 0.5
133+
134+
if not tune:
135+
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128)
136+
ref_program_processed = partial(ref_program, is_causal=is_causal)
137+
138+
profiler = kernel.get_profiler()
139+
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
140+
print("All checks pass.")
141+
latency = profiler.do_bench(ref_program_processed, warmup=500)
142+
print("Ref: {:.2f} ms".format(latency))
143+
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
144+
latency = profiler.do_bench(warmup=500)
145+
print("Tile-lang: {:.2f} ms".format(latency))
146+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
147+
else:
148+
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
149+
best_latency = kernel.latency
150+
best_config = kernel.config
151+
ref_latency = kernel.ref_latency
152+
print(f"Best latency: {best_latency}")
153+
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
154+
print(f"Best config: {best_config}")
155+
print(f"Ref latency: {ref_latency}")
156+
157+
158+
def run_regression_perf(
159+
batch: int = 1,
160+
heads: int = 32,
161+
seq_q: int = 256,
162+
seq_kv: int = 256,
163+
dim: int = 64,
164+
is_causal: bool = False,
165+
tune: bool = False,
166+
):
167+
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
168+
profiler = kernel.get_profiler()
169+
return profiler.do_bench(backend="cupti")
170+
171+
172+
if __name__ == "__main__":
173+
parser = argparse.ArgumentParser()
174+
parser.add_argument("--batch", type=int, default=1, help="batch size")
175+
parser.add_argument("--heads", type=int, default=1, help="heads")
176+
parser.add_argument("--seq_q", type=int, default=256, help="query sequence length")
177+
parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length")
178+
parser.add_argument("--dim", type=int, default=64, help="dim")
179+
parser.add_argument("--is_causal", action="store_true", help="causal", default=False)
180+
parser.add_argument("--tune", action="store_true", help="tune configs")
181+
args = parser.parse_args()
182+
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)

0 commit comments

Comments
 (0)