Skip to content

Commit ed92844

Browse files
committed
Merge branch 'master' of github.com:InfiniTensor/ninetoothed-examples into experiment
2 parents bec9334 + c6a5a62 commit ed92844

19 files changed

Lines changed: 1654 additions & 82 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ These approaches allow you to obtain results in seconds. However, selecting opti
3333

3434
This project includes code modified or inspired from the following open-source repositories:
3535

36+
* [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers)
3637
* [https://github.com/triton-lang/triton](https://github.com/triton-lang/triton)
3738
* [https://github.com/ROCm/triton](https://github.com/ROCm/triton)
3839
* [https://github.com/l1351868270/implicit_gemm.triton](https://github.com/l1351868270/implicit_gemm.triton)

add.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton.language as tl
55
from ninetoothed import Symbol, Tensor
66

7-
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
7+
BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)
88

99

1010
@ninetoothed.jit
@@ -19,18 +19,14 @@ def add_kernel(
1919
def add(lhs, rhs):
2020
output = torch.empty_like(lhs)
2121

22-
add_kernel(lhs, rhs, output)
22+
add_kernel(lhs, rhs, output, BLOCK_SIZE=1024)
2323

2424
return output
2525

2626

2727
@triton.jit
2828
def triton_add_kernel(
29-
lhs_ptr,
30-
rhs_ptr,
31-
output_ptr,
32-
n_elements,
33-
BLOCK_SIZE: tl.constexpr,
29+
lhs_ptr, rhs_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
3430
):
3531
pid = tl.program_id(0)
3632

@@ -59,16 +55,22 @@ def grid(meta):
5955

6056
if __name__ == "__main__":
6157
torch.manual_seed(0)
58+
6259
size = 98432
6360
dtype = torch.float16
64-
lhs = torch.rand(size, dtype=dtype, device="cuda")
65-
rhs = torch.rand(size, dtype=dtype, device="cuda")
61+
device = "cuda"
62+
63+
lhs = torch.rand(size, dtype=dtype, device=device)
64+
rhs = torch.rand(size, dtype=dtype, device=device)
65+
6666
ninetoothed_output = add(lhs, rhs)
6767
torch_output = lhs + rhs
6868
triton_output = triton_add(lhs, rhs)
69+
6970
print(ninetoothed_output)
7071
print(torch_output)
7172
print(triton_output)
73+
7274
if torch.allclose(ninetoothed_output, torch_output):
7375
print("✅ NineToothed and PyTorch match.")
7476
else:
@@ -93,19 +95,20 @@ def grid(meta):
9395
)
9496
)
9597
def benchmark(size, provider):
96-
lhs = torch.randn(size, device="cuda", dtype=torch.float16)
97-
rhs = torch.randn(size, device="cuda", dtype=torch.float16)
98+
lhs = torch.randn(size, dtype=dtype, device=device)
99+
rhs = torch.randn(size, dtype=dtype, device=device)
98100

99101
ninetoothed_output = add(lhs, rhs)
100-
torch_output = lhs + rhs
102+
torch_output = torch.add(lhs, rhs)
101103
triton_output = triton_add(lhs, rhs)
104+
102105
assert torch.allclose(ninetoothed_output, torch_output)
103106
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
104107

105108
if provider == "ninetoothed":
106109
ms = triton.testing.do_bench(lambda: add(lhs, rhs))
107110
elif provider == "torch":
108-
ms = triton.testing.do_bench(lambda: lhs + rhs)
111+
ms = triton.testing.do_bench(lambda: torch.add(lhs, rhs))
109112
elif provider == "triton":
110113
ms = triton.testing.do_bench(lambda: triton_add(lhs, rhs))
111114

addmm.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import random
2+
3+
import ninetoothed
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
from ninetoothed import Tensor
8+
9+
import matmul
10+
11+
12+
def arrangement(input, mat1, mat2, beta, alpha, output):
13+
_, _, input_arranged = matmul.arrangement(mat1, mat2, input)
14+
15+
mat1_arranged, mat2_arranged, output_arranged = matmul.arrangement(
16+
mat1, mat2, output
17+
)
18+
19+
return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged
20+
21+
22+
def application(input, mat1, mat2, beta, alpha, output):
23+
matmul.application(mat1, mat2, output)
24+
output = beta * input + alpha * output
25+
26+
27+
tensors = (Tensor(2), Tensor(2), Tensor(2), Tensor(0), Tensor(0), Tensor(2))
28+
addmm_kernel = ninetoothed.make(arrangement, application, tensors)
29+
30+
31+
def addmm(input, mat1, mat2, beta=1, alpha=1):
32+
output_shape = (mat1.shape[0], mat2.shape[1])
33+
output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device)
34+
35+
addmm_kernel(input, mat1, mat2, beta, alpha, output)
36+
37+
return output
38+
39+
40+
@triton.autotune(
41+
configs=[
42+
triton.Config(
43+
{
44+
"BLOCK_SIZE_M": 128,
45+
"BLOCK_SIZE_N": 256,
46+
"BLOCK_SIZE_K": 64,
47+
"GROUP_SIZE_M": 8,
48+
},
49+
num_stages=3,
50+
num_warps=8,
51+
),
52+
triton.Config(
53+
{
54+
"BLOCK_SIZE_M": 64,
55+
"BLOCK_SIZE_N": 256,
56+
"BLOCK_SIZE_K": 32,
57+
"GROUP_SIZE_M": 8,
58+
},
59+
num_stages=4,
60+
num_warps=4,
61+
),
62+
triton.Config(
63+
{
64+
"BLOCK_SIZE_M": 128,
65+
"BLOCK_SIZE_N": 128,
66+
"BLOCK_SIZE_K": 32,
67+
"GROUP_SIZE_M": 8,
68+
},
69+
num_stages=4,
70+
num_warps=4,
71+
),
72+
triton.Config(
73+
{
74+
"BLOCK_SIZE_M": 128,
75+
"BLOCK_SIZE_N": 64,
76+
"BLOCK_SIZE_K": 32,
77+
"GROUP_SIZE_M": 8,
78+
},
79+
num_stages=4,
80+
num_warps=4,
81+
),
82+
triton.Config(
83+
{
84+
"BLOCK_SIZE_M": 64,
85+
"BLOCK_SIZE_N": 128,
86+
"BLOCK_SIZE_K": 32,
87+
"GROUP_SIZE_M": 8,
88+
},
89+
num_stages=4,
90+
num_warps=4,
91+
),
92+
triton.Config(
93+
{
94+
"BLOCK_SIZE_M": 128,
95+
"BLOCK_SIZE_N": 32,
96+
"BLOCK_SIZE_K": 32,
97+
"GROUP_SIZE_M": 8,
98+
},
99+
num_stages=4,
100+
num_warps=4,
101+
),
102+
triton.Config(
103+
{
104+
"BLOCK_SIZE_M": 64,
105+
"BLOCK_SIZE_N": 32,
106+
"BLOCK_SIZE_K": 32,
107+
"GROUP_SIZE_M": 8,
108+
},
109+
num_stages=5,
110+
num_warps=2,
111+
),
112+
triton.Config(
113+
{
114+
"BLOCK_SIZE_M": 32,
115+
"BLOCK_SIZE_N": 64,
116+
"BLOCK_SIZE_K": 32,
117+
"GROUP_SIZE_M": 8,
118+
},
119+
num_stages=5,
120+
num_warps=2,
121+
),
122+
],
123+
key=["m", "n", "k"],
124+
)
125+
@triton.jit
126+
def triton_addmm_kernel(
127+
input_ptr,
128+
mat1_ptr,
129+
mat2_ptr,
130+
output_ptr,
131+
m,
132+
n,
133+
k,
134+
input_stride_m,
135+
input_stride_n,
136+
mat1_stride_m,
137+
mat1_stride_k,
138+
mat2_stride_k,
139+
mat2_stride_n,
140+
output_stride_m,
141+
output_stride_n,
142+
beta,
143+
alpha,
144+
BLOCK_SIZE_M: tl.constexpr,
145+
BLOCK_SIZE_N: tl.constexpr,
146+
BLOCK_SIZE_K: tl.constexpr,
147+
GROUP_SIZE_M: tl.constexpr,
148+
):
149+
pid = tl.program_id(0)
150+
num_pid_m = tl.cdiv(m, BLOCK_SIZE_M)
151+
num_pid_n = tl.cdiv(n, BLOCK_SIZE_N)
152+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
153+
group_id = pid // num_pid_in_group
154+
first_pid_m = group_id * GROUP_SIZE_M
155+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
156+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
157+
pid_n = (pid % num_pid_in_group) // group_size_m
158+
159+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % m
160+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
161+
offs_k = tl.arange(0, BLOCK_SIZE_K)
162+
mat1_ptrs = mat1_ptr + (
163+
offs_am[:, None] * mat1_stride_m + offs_k[None, :] * mat1_stride_k
164+
)
165+
mat2_ptrs = mat2_ptr + (
166+
offs_k[:, None] * mat2_stride_k + offs_bn[None, :] * mat2_stride_n
167+
)
168+
169+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
170+
for i in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
171+
mat1 = tl.load(
172+
mat1_ptrs, mask=offs_k[None, :] < k - i * BLOCK_SIZE_K, other=0.0
173+
)
174+
mat2 = tl.load(
175+
mat2_ptrs, mask=offs_k[:, None] < k - i * BLOCK_SIZE_K, other=0.0
176+
)
177+
accumulator = tl.dot(mat1, mat2, accumulator)
178+
mat1_ptrs += BLOCK_SIZE_K * mat1_stride_k
179+
mat2_ptrs += BLOCK_SIZE_K * mat2_stride_k
180+
181+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
182+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
183+
184+
mask_c = (offs_cm[:, None] < m) & (offs_cn[None, :] < n)
185+
186+
input_ptrs = (
187+
input_ptr
188+
+ input_stride_m * offs_cm[:, None]
189+
+ input_stride_n * offs_cn[None, :]
190+
)
191+
input = tl.load(input_ptrs, mask=mask_c)
192+
193+
output = beta * input + alpha * accumulator.to(tl.float16)
194+
195+
output_ptrs = (
196+
output_ptr
197+
+ output_stride_m * offs_cm[:, None]
198+
+ output_stride_n * offs_cn[None, :]
199+
)
200+
tl.store(output_ptrs, output, mask=mask_c)
201+
202+
203+
def triton_addmm(input, mat1, mat2, beta=1, alpha=1):
204+
output_shape = (mat1.shape[0], mat2.shape[1])
205+
output = torch.empty(output_shape, dtype=mat1.dtype, device=mat1.device)
206+
207+
def grid(meta):
208+
return (
209+
triton.cdiv(mat1.shape[0], meta["BLOCK_SIZE_M"])
210+
* triton.cdiv(mat2.shape[1], meta["BLOCK_SIZE_N"]),
211+
)
212+
213+
triton_addmm_kernel[grid](
214+
input,
215+
mat1,
216+
mat2,
217+
output,
218+
mat1.shape[0],
219+
mat2.shape[1],
220+
mat1.shape[1],
221+
input.stride(0),
222+
input.stride(1),
223+
mat1.stride(0),
224+
mat1.stride(1),
225+
mat2.stride(0),
226+
mat2.stride(1),
227+
output.stride(0),
228+
output.stride(1),
229+
beta,
230+
alpha,
231+
)
232+
233+
return output
234+
235+
236+
if __name__ == "__main__":
237+
random.seed(0)
238+
torch.manual_seed(0)
239+
240+
shape = (512, 512)
241+
dtype = torch.float16
242+
device = "cuda"
243+
244+
input = torch.randn(shape, dtype=dtype, device=device)
245+
mat1 = torch.randn(shape, dtype=dtype, device=device)
246+
mat2 = torch.randn(shape, dtype=dtype, device=device)
247+
beta = random.uniform(0, 1)
248+
alpha = random.uniform(0, 1)
249+
250+
ninetoothed_output = addmm(input, mat1, mat2, beta=beta, alpha=alpha)
251+
torch_output = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
252+
triton_output = triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha)
253+
254+
print(ninetoothed_output)
255+
print(torch_output)
256+
print(triton_output)
257+
258+
if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01):
259+
print("✅ NineToothed and PyTorch match.")
260+
else:
261+
print("❌ NineToothed and PyTorch differ.")
262+
if torch.allclose(ninetoothed_output, triton_output):
263+
print("✅ NineToothed and Triton match.")
264+
else:
265+
print("❌ NineToothed and Triton differ.")
266+
267+
@triton.testing.perf_report(
268+
triton.testing.Benchmark(
269+
x_names=["m", "n", "k"],
270+
x_vals=[128 * i for i in range(2, 33)],
271+
line_arg="provider",
272+
line_vals=["ninetoothed", "torch", "triton"],
273+
line_names=["NineToothed", "PyTorch", "Triton"],
274+
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
275+
ylabel="ms",
276+
plot_name="addmm-performance",
277+
args={},
278+
)
279+
)
280+
def benchmark(m, n, k, provider):
281+
input = torch.randn((m, n), dtype=dtype, device=device)
282+
mat1 = torch.randn((m, k), dtype=dtype, device=device)
283+
mat2 = torch.randn((k, n), dtype=dtype, device=device)
284+
beta = random.uniform(0, 1)
285+
alpha = random.uniform(0, 1)
286+
287+
if provider == "ninetoothed":
288+
ms = triton.testing.do_bench(
289+
lambda: addmm(input, mat1, mat2, beta=beta, alpha=alpha)
290+
)
291+
elif provider == "torch":
292+
ms = triton.testing.do_bench(
293+
lambda: torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
294+
)
295+
elif provider == "triton":
296+
ms = triton.testing.do_bench(
297+
lambda: triton_addmm(input, mat1, mat2, beta=beta, alpha=alpha)
298+
)
299+
300+
return ms
301+
302+
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 commit comments

Comments
 (0)