Skip to content

Commit 92900c9

Browse files
committed
Update the Triton implementation to also accept any dimensional inputs
1 parent 50e3c23 commit 92900c9

1 file changed

Lines changed: 17 additions & 37 deletions

File tree

swiglu.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,61 +36,41 @@ def triton_swiglu_kernel(
3636
a_ptr,
3737
b_ptr,
3838
c_ptr,
39-
m,
40-
n,
41-
a_stride_m,
42-
a_stride_n,
43-
b_stride_m,
44-
b_stride_n,
45-
c_stride_m,
46-
c_stride_n,
39+
data_size: tl.constexpr,
4740
BLOCK_SIZE: tl.constexpr,
4841
):
4942
pid = tl.program_id(0)
50-
block_start = pid * BLOCK_SIZE
51-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
43+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
44+
mask = offsets < data_size
5245

53-
rows = offsets // n
54-
cols = offsets % n
55-
56-
mask = (rows < m) & (cols < n)
57-
58-
a_offsets = rows * a_stride_m + cols * a_stride_n
59-
b_offsets = rows * b_stride_m + cols * b_stride_n
60-
c_offsets = rows * c_stride_m + cols * c_stride_n
61-
62-
a = tl.load(a_ptr + a_offsets, mask=mask, other=0.0)
63-
b = tl.load(b_ptr + b_offsets, mask=mask, other=0.0)
46+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
47+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
6448

6549
silu_b = b * tl.sigmoid(tl.cast(b, tl.float32))
6650
c = a * silu_b
6751

68-
tl.store(c_ptr + c_offsets, c, mask=mask)
52+
tl.store(c_ptr + offsets, c, mask=mask)
6953

7054

7155
def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
72-
m, n = a.shape
73-
c = torch.empty_like(a)
56+
# Flatten the inputs so that the kernel always works on 1D tensors
57+
a_flat = a.flatten()
58+
b_flat = b.flatten()
59+
c_flat = torch.empty_like(a_flat)
60+
data_size = a_flat.numel()
7461

7562
def grid(meta):
76-
return (triton.cdiv(m * n, meta["BLOCK_SIZE"]),)
63+
return (triton.cdiv(data_size, meta["BLOCK_SIZE"]),)
7764

7865
triton_swiglu_kernel[grid](
79-
a,
80-
b,
81-
c,
82-
m,
83-
n,
84-
a.stride(0),
85-
a.stride(1),
86-
b.stride(0),
87-
b.stride(1),
88-
c.stride(0),
89-
c.stride(1),
66+
a_flat,
67+
b_flat,
68+
c_flat,
69+
data_size,
9070
BLOCK_SIZE=1024,
9171
)
9272

93-
return c
73+
return c_flat.view_as(a)
9474

9575

9676
def torch_swiglu(

0 commit comments

Comments
 (0)