Skip to content

Commit 15ae3c0

Browse files
authored
Merge pull request #5 from InfiniTensor/swiglu-2d-to-1d
Change the parameters of the SwiGLU kernels to vectors
2 parents 1f32628 + 906a166 commit 15ae3c0

1 file changed

Lines changed: 28 additions & 56 deletions

File tree

swiglu.py

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,89 +6,61 @@
66
import triton.language as tl
77
from ninetoothed import Symbol, Tensor
88

9-
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
10-
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
9+
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)
1110

1211

1312
@ninetoothed.jit
1413
def swiglu_kernel(
15-
a: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
16-
b: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
17-
c: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
14+
a: Tensor(1).tile((BLOCK_SIZE,)),
15+
b: Tensor(1).tile((BLOCK_SIZE,)),
16+
c: Tensor(1).tile((BLOCK_SIZE,)),
1817
):
1918
b_loaded = b
2019
gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32))
2120
c = a * gate # noqa: F841
2221

2322

24-
def ninetoothed_swiglu(a, b):
25-
c = torch.empty_like(a)
23+
def swiglu(a, b):
24+
a_1d = a.flatten()
25+
b_1d = b.flatten()
2626

27-
swiglu_kernel(a, b, c)
27+
c = torch.empty_like(a_1d)
2828

29-
return c
29+
swiglu_kernel(a_1d, b_1d, c)
30+
31+
return c.view_as(a)
3032

3133

3234
@triton.jit
3335
def triton_swiglu_kernel(
34-
a_ptr,
35-
b_ptr,
36-
c_ptr,
37-
m,
38-
n,
39-
a_stride_m,
40-
a_stride_n,
41-
b_stride_m,
42-
b_stride_n,
43-
c_stride_m,
44-
c_stride_n,
45-
BLOCK_SIZE: tl.constexpr,
36+
a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr
4637
):
4738
pid = tl.program_id(0)
48-
block_start = pid * BLOCK_SIZE
49-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
50-
51-
rows = offsets // n
52-
cols = offsets % n
53-
54-
mask = (rows < m) & (cols < n)
39+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
40+
mask = offsets < num_elements
5541

56-
a_offsets = rows * a_stride_m + cols * a_stride_n
57-
b_offsets = rows * b_stride_m + cols * b_stride_n
58-
c_offsets = rows * c_stride_m + cols * c_stride_n
59-
60-
a = tl.load(a_ptr + a_offsets, mask=mask, other=0.0)
61-
b = tl.load(b_ptr + b_offsets, mask=mask, other=0.0)
42+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
43+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
6244

6345
silu_b = b * tl.sigmoid(tl.cast(b, tl.float32))
6446
c = a * silu_b
6547

66-
tl.store(c_ptr + c_offsets, c, mask=mask)
48+
tl.store(c_ptr + offsets, c, mask=mask)
6749

6850

6951
def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
70-
m, n = a.shape
71-
c = torch.empty_like(a)
52+
# Flatten the inputs so that the kernel always works on 1D tensors
53+
a_flat = a.flatten()
54+
b_flat = b.flatten()
55+
c_flat = torch.empty_like(a_flat)
56+
num_elements = a_flat.numel()
7257

7358
def grid(meta):
74-
return (triton.cdiv(m * n, meta["BLOCK_SIZE"]),)
75-
76-
triton_swiglu_kernel[grid](
77-
a,
78-
b,
79-
c,
80-
m,
81-
n,
82-
a.stride(0),
83-
a.stride(1),
84-
b.stride(0),
85-
b.stride(1),
86-
c.stride(0),
87-
c.stride(1),
88-
BLOCK_SIZE=1024,
89-
)
59+
return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
60+
61+
triton_swiglu_kernel[grid](a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024)
9062

91-
return c
63+
return c_flat.view_as(a)
9264

9365

9466
def torch_swiglu(
@@ -108,7 +80,7 @@ def torch_swiglu(
10880
b = torch.rand(shape, dtype=dtype, device=device)
10981
c = torch.rand(shape, dtype=dtype, device=device)
11082

111-
ninetoothed_output = ninetoothed_swiglu(a, b)
83+
ninetoothed_output = swiglu(a, b)
11284
torch_output = torch_swiglu(a, b)
11385
triton_output = triton_swiglu(a, b)
11486
print(ninetoothed_output)
@@ -149,7 +121,7 @@ def benchmark(m, n, provider):
149121

150122
if provider == "ninetoothed":
151123
ms, min_ms, max_ms = triton.testing.do_bench(
152-
lambda: ninetoothed_swiglu(a, b), quantiles=quantiles
124+
lambda: swiglu(a, b), quantiles=quantiles
153125
)
154126
elif provider == "torch":
155127
ms, min_ms, max_ms = triton.testing.do_bench(

0 commit comments

Comments
 (0)