Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 28 additions & 56 deletions swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,61 @@
import triton.language as tl
from ninetoothed import Symbol, Tensor

BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True)


@ninetoothed.jit
def swiglu_kernel(
a: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
b: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
c: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
a: Tensor(1).tile((BLOCK_SIZE,)),
b: Tensor(1).tile((BLOCK_SIZE,)),
c: Tensor(1).tile((BLOCK_SIZE,)),
):
b_loaded = b
gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32))
c = a * gate # noqa: F841


def ninetoothed_swiglu(a, b):
c = torch.empty_like(a)
def swiglu(a, b):
a_1d = a.flatten()
b_1d = b.flatten()

swiglu_kernel(a, b, c)
c = torch.empty_like(a_1d)

return c
swiglu_kernel(a_1d, b_1d, c)

return c.view_as(a)


@triton.jit
def triton_swiglu_kernel(
a_ptr,
b_ptr,
c_ptr,
m,
n,
a_stride_m,
a_stride_n,
b_stride_m,
b_stride_n,
c_stride_m,
c_stride_n,
BLOCK_SIZE: tl.constexpr,
a_ptr, b_ptr, c_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

rows = offsets // n
cols = offsets % n

mask = (rows < m) & (cols < n)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements

a_offsets = rows * a_stride_m + cols * a_stride_n
b_offsets = rows * b_stride_m + cols * b_stride_n
c_offsets = rows * c_stride_m + cols * c_stride_n

a = tl.load(a_ptr + a_offsets, mask=mask, other=0.0)
b = tl.load(b_ptr + b_offsets, mask=mask, other=0.0)
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)

silu_b = b * tl.sigmoid(tl.cast(b, tl.float32))
c = a * silu_b

tl.store(c_ptr + c_offsets, c, mask=mask)
tl.store(c_ptr + offsets, c, mask=mask)


def triton_swiglu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
m, n = a.shape
c = torch.empty_like(a)
# Flatten the inputs so that the kernel always works on 1D tensors
a_flat = a.flatten()
b_flat = b.flatten()
c_flat = torch.empty_like(a_flat)
num_elements = a_flat.numel()

def grid(meta):
return (triton.cdiv(m * n, meta["BLOCK_SIZE"]),)

triton_swiglu_kernel[grid](
a,
b,
c,
m,
n,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_SIZE=1024,
)
return (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)

triton_swiglu_kernel[grid](a_flat, b_flat, c_flat, num_elements, BLOCK_SIZE=1024)

return c
return c_flat.view_as(a)


def torch_swiglu(
Expand All @@ -108,7 +80,7 @@ def torch_swiglu(
b = torch.rand(shape, dtype=dtype, device=device)
c = torch.rand(shape, dtype=dtype, device=device)

ninetoothed_output = ninetoothed_swiglu(a, b)
ninetoothed_output = swiglu(a, b)
torch_output = torch_swiglu(a, b)
triton_output = triton_swiglu(a, b)
print(ninetoothed_output)
Expand Down Expand Up @@ -149,7 +121,7 @@ def benchmark(m, n, provider):

if provider == "ninetoothed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ninetoothed_swiglu(a, b), quantiles=quantiles
lambda: swiglu(a, b), quantiles=quantiles
)
elif provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(
Expand Down