Skip to content

Commit 1dff054

Browse files
committed
Add demo_tilelang_to_ninetoothed.py
1 parent d9b4a82 commit 1dff054

1 file changed

Lines changed: 194 additions & 0 deletions

File tree

demo_tilelang_to_ninetoothed.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import torch
2+
import triton
3+
4+
import ops.ninetoothed.kernels.mm
5+
import ops.tilelang.kernels.mm
6+
import ops.triton.kernels.mm
7+
import tilelang_to_ninetoothed
8+
9+
BLOCK_SIZE_M = 128
10+
BLOCK_SIZE_N = 128
11+
BLOCK_SIZE_K = 32
12+
13+
ninetoothed_mm_kernel = ops.ninetoothed.kernels.mm.kernel
14+
15+
triton_mm_kernel = ops.triton.kernels.mm.kernel
16+
17+
tilelang_mm_kernel = ops.tilelang.kernels.mm.mm(
18+
ops.tilelang.kernels.mm.M,
19+
ops.tilelang.kernels.mm.N,
20+
ops.tilelang.kernels.mm.K,
21+
BLOCK_SIZE_M,
22+
BLOCK_SIZE_N,
23+
BLOCK_SIZE_K,
24+
)
25+
26+
ninetoothed_mm_kernel_from_tilelang = (
27+
tilelang_to_ninetoothed.transform_tilelang_to_ninetoothed(
28+
ops.tilelang.kernels.mm.mm
29+
)
30+
)
31+
32+
33+
def ninetoothed_mm(input, other):
34+
output_shape = (input.shape[0], other.shape[1])
35+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
36+
37+
ninetoothed_mm_kernel(input, other, output)
38+
39+
return output
40+
41+
42+
def triton_mm(input, other):
43+
output_shape = (input.shape[0], other.shape[1])
44+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
45+
46+
def grid(meta):
47+
return (
48+
triton.cdiv(input.shape[0], meta["BLOCK_SIZE_M"])
49+
* triton.cdiv(other.shape[1], meta["BLOCK_SIZE_N"]),
50+
)
51+
52+
triton_mm_kernel[grid](
53+
input,
54+
other,
55+
output,
56+
input.shape[0],
57+
other.shape[1],
58+
input.shape[1],
59+
input.stride(0),
60+
input.stride(1),
61+
other.stride(0),
62+
other.stride(1),
63+
output.stride(0),
64+
output.stride(1),
65+
)
66+
67+
return output
68+
69+
70+
def tilelang_mm(input, other):
71+
output_shape = (input.shape[0], other.shape[1])
72+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
73+
74+
tilelang_mm_kernel(input, other, output)
75+
76+
return output
77+
78+
79+
def ninetoothed_from_tilelang_mm(input, other):
80+
m, k = input.shape
81+
_, n = other.shape
82+
83+
output = torch.empty((m, n), dtype=input.dtype, device=input.device)
84+
85+
ninetoothed_mm_kernel_from_tilelang(
86+
input,
87+
other,
88+
output,
89+
M=m,
90+
N=n,
91+
K=k,
92+
block_M=BLOCK_SIZE_M,
93+
block_N=BLOCK_SIZE_N,
94+
block_K=BLOCK_SIZE_K,
95+
)
96+
97+
return output
98+
99+
100+
def torch_mm(input, other):
101+
return torch.mm(input, other)
102+
103+
104+
if __name__ == "__main__":
105+
torch.manual_seed(0)
106+
107+
shape = (512, 512)
108+
dtype = torch.float16
109+
device = "cuda"
110+
111+
input = torch.randn(shape, dtype=dtype, device=device)
112+
other = torch.randn(shape, dtype=dtype, device=device)
113+
114+
ninetoothed_output = ninetoothed_mm(input, other)
115+
torch_output = torch_mm(input, other)
116+
triton_output = triton_mm(input, other)
117+
tilelang_output = tilelang_mm(input, other)
118+
ninetoothed_from_tilelang_output = ninetoothed_from_tilelang_mm(input, other)
119+
120+
print(ninetoothed_output)
121+
print(torch_output)
122+
print(triton_output)
123+
print(tilelang_output)
124+
print(ninetoothed_from_tilelang_output)
125+
126+
if torch.allclose(ninetoothed_output, torch_output):
127+
print("✅ NineToothed and PyTorch match.")
128+
else:
129+
print("❌ NineToothed and PyTorch differ.")
130+
if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0):
131+
print("✅ NineToothed and Triton match.")
132+
else:
133+
print("❌ NineToothed and Triton differ.")
134+
if torch.allclose(ninetoothed_output, tilelang_output):
135+
print("✅ NineToothed and TileLang match.")
136+
else:
137+
print("❌ NineToothed and TileLang differ.")
138+
if torch.allclose(ninetoothed_output, ninetoothed_from_tilelang_output):
139+
print("✅ NineToothed and NineToothed from TileLang match.")
140+
else:
141+
print("❌ NineToothed and NineToothed from TileLang differ.")
142+
143+
@triton.testing.perf_report(
144+
triton.testing.Benchmark(
145+
x_names=["m", "n", "k"],
146+
x_vals=[2**i for i in range(8, 13)],
147+
x_log=True,
148+
line_arg="provider",
149+
line_vals=[
150+
"ninetoothed",
151+
"torch",
152+
"triton",
153+
"tilelang",
154+
"ninetoothed_from_tilelang",
155+
],
156+
line_names=[
157+
"NineToothed",
158+
"PyTorch",
159+
"Triton",
160+
"TileLang",
161+
"NineToothed from TileLang",
162+
],
163+
styles=[
164+
("blue", "-"),
165+
("green", "-"),
166+
("orange", "-"),
167+
("cyan", "-"),
168+
("purple", "-"),
169+
],
170+
ylabel="ms",
171+
plot_name="mm-performance",
172+
args={},
173+
)
174+
)
175+
def benchmark(m, n, k, provider):
176+
input = torch.randn((m, k), dtype=dtype, device=device)
177+
other = torch.randn((k, n), dtype=dtype, device=device)
178+
179+
if provider == "ninetoothed":
180+
ms = triton.testing.do_bench(lambda: ninetoothed_mm(input, other))
181+
elif provider == "torch":
182+
ms = triton.testing.do_bench(lambda: torch_mm(input, other))
183+
elif provider == "triton":
184+
ms = triton.testing.do_bench(lambda: triton_mm(input, other))
185+
elif provider == "tilelang":
186+
ms = triton.testing.do_bench(lambda: tilelang_mm(input, other))
187+
elif provider == "ninetoothed_from_tilelang":
188+
ms = triton.testing.do_bench(
189+
lambda: ninetoothed_from_tilelang_mm(input, other)
190+
)
191+
192+
return ms
193+
194+
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 commit comments

Comments
 (0)