Skip to content

Commit 19dbae5

Browse files
committed
Merge branch 'master' of github.com:InfiniTensor/ninetoothed-examples into end-to-end-model-inference
2 parents 93d6b4c + 15ae3c0 commit 19dbae5

4 files changed

Lines changed: 41 additions & 63 deletions

File tree

attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def application(q, k, v, scale, o):
6363
o = acc # noqa: F841
6464

6565

66-
q, k, v, o = (Tensor(4, constexpr_shape=True) for _ in range(4))
66+
q, k, v, o = (
67+
Tensor(4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}))
68+
for _ in range(4)
69+
)
6770
attention_kernel = ninetoothed.make(arrangement, application, (q, k, v, Tensor(0), o))
6871

6972

conv2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ def arrangement(input, filter, output):
2323
return matmul.arrangement(input_flattened, filter_permuted, output_flattened)
2424

2525

26-
conv2d_kernel = ninetoothed.make(
27-
arrangement,
28-
matmul.application,
29-
(Tensor(4), Tensor(4, constexpr_shape=True), Tensor(4)),
26+
filter_shape_options = (
27+
None,
28+
None,
29+
{"constexpr": True, "upper_bound": 16},
30+
{"constexpr": True, "upper_bound": 16},
3031
)
32+
tensors = (Tensor(4), Tensor(4, shape_options=filter_shape_options), Tensor(4))
33+
conv2d_kernel = ninetoothed.make(arrangement, matmul.application, tensors)
3134

3235

3336
def conv2d(input, filter):

max_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
def arrangement(input, output):
1212
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
1313

14-
WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True)
15-
WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True)
14+
WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True, upper_bound=16)
15+
WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True, upper_bound=16)
1616

1717
input_arranged = input.tile((1, 1, WINDOW_HEIGHT, WINDOW_WIDTH))
1818
input_arranged = input_arranged.ravel()

swiglu.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,89 +6,61 @@
66
import triton.language as tl
77
from ninetoothed import Symbol, Tensor
88

9-
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
10-
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
9+
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
1110

1211

1312
@ninetoothed.jit
1413
def swiglu_kernel(
15-
a: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
16-
b: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
17-
c: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
14+
a: Tensor(1).tile((BLOCK_SIZE,)),
15+
b: Tensor(1).tile((BLOCK_SIZE,)),
16+
c: Tensor(1).tile((BLOCK_SIZE,)),
1817
):
1918
b_loaded = b
2019
gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32))
2120
c = a * gate # noqa: F841
2221

2322

24-
def ninetoothed_swiglu(a, b):
25-
c = torch.empty_like(a)
23+
def swiglu(a, b):
24+
a_1d = a.flatten()
25+
b_1d = b.flatten()
2626

27-
swiglu_kernel(a, b, c)
27+
c = torch.empty_like(a_1d)
2828

29-
return c
29+
swiglu_kernel(a_1d, b_1d, c)
30+
31+
return c.view_as(a)
3032

3133

3234
@triton.jit
3335
def triton_swiglu_kernel(
34-
a_ptr,
35-
b_ptr,
36-
c_ptr,
37-
m,
38-
n,
39-
a_stride_m,
40-
a_stride_n,
41-
b_stride_m,
42-
b_stride_n,
43-
c_stride_m,
44-
c_stride_n,
45-
BLOCK_SIZE: tl.constexpr,
36+
a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr
4637
):
4738
pid = tl.program_id(0)
48-
block_start = pid * BLOCK_SIZE
49-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
50-
51-
rows = offsets // n
52-
cols = offsets % n
53-
54-
mask = (rows < m) & (cols < n)
39+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
40+
mask = offsets < num_elements
5541

56-
a_offsets = rows * a_stride_m + cols * a_stride_n
57-
b_offsets = rows * b_stride_m + cols * b_stride_n
58-
c_offsets = rows * c_stride_m + cols * c_stride_n
59-
60-
a = tl.load(a_ptr + a_offsets, mask=mask, other=0.0)
61-
b = tl.load(b_ptr + b_offsets, mask=mask, other=0.0)
42+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
43+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
6244

6345
silu_b = b * tl.sigmoid(tl.cast(b, tl.float32))
6446
c = a * silu_b
6547

66-
tl.store(c_ptr + c_offsets, c, mask=mask)
48+
tl.store(c_ptr + offsets, c, mask=mask)
6749

6850

6951
def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
70-
m, n = a.shape
71-
c = torch.empty_like(a)
52+
# Flatten the inputs so that the kernel always works on 1D tensors
53+
a_flat = a.flatten()
54+
b_flat = b.flatten()
55+
c_flat = torch.empty_like(a_flat)
56+
num_elements = a_flat.numel()
7257

7358
def grid(meta):
74-
return (triton.cdiv(m * n, meta["BLOCK_SIZE"]),)
75-
76-
triton_swiglu_kernel[grid](
77-
a,
78-
b,
79-
c,
80-
m,
81-
n,
82-
a.stride(0),
83-
a.stride(1),
84-
b.stride(0),
85-
b.stride(1),
86-
c.stride(0),
87-
c.stride(1),
88-
BLOCK_SIZE=1024,
89-
)
59+
return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
60+
61+
triton_swiglu_kernel[grid](a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024)
9062

91-
return c
63+
return c_flat.view_as(a)
9264

9365

9466
def torch_swiglu(
@@ -108,7 +80,7 @@ def torch_swiglu(
10880
b = torch.rand(shape, dtype=dtype, device=device)
10981
c = torch.rand(shape, dtype=dtype, device=device)
11082

111-
ninetoothed_output = ninetoothed_swiglu(a, b)
83+
ninetoothed_output = swiglu(a, b)
11284
torch_output = torch_swiglu(a, b)
11385
triton_output = triton_swiglu(a, b)
11486
print(ninetoothed_output)
@@ -149,7 +121,7 @@ def benchmark(m, n, provider):
149121

150122
if provider == "ninetoothed":
151123
ms, min_ms, max_ms = triton.testing.do_bench(
152-
lambda: ninetoothed_swiglu(a, b), quantiles=quantiles
124+
lambda: swiglu(a, b), quantiles=quantiles
153125
)
154126
elif provider == "torch":
155127
ms, min_ms, max_ms = triton.testing.do_bench(

0 commit comments

Comments
 (0)