Skip to content

Commit 80009e8

Browse files
committed
Make a few naming and formatting changes
1 parent b07beb8 commit 80009e8

1 file changed

Lines changed: 27 additions & 22 deletions

File tree

swiglu.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def swiglu_kernel(
1616
b: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
1717
c: Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)),
1818
):
19-
magic = b
20-
gate = magic * ntl.sigmoid(ntl.cast(magic, ntl.float32))
19+
b_loaded = b
20+
gate = b_loaded * ntl.sigmoid(ntl.cast(b_loaded, ntl.float32))
2121
c = a * gate # noqa: F841
2222

2323

@@ -34,28 +34,28 @@ def triton_swiglu_kernel(
3434
a_ptr,
3535
b_ptr,
3636
c_ptr,
37-
M,
38-
N,
39-
stride_am,
40-
stride_an,
41-
stride_bm,
42-
stride_bn,
43-
stride_cm,
44-
stride_cn,
37+
m,
38+
n,
39+
a_stride_m,
40+
a_stride_n,
41+
b_stride_m,
42+
b_stride_n,
43+
c_stride_m,
44+
c_stride_n,
4545
BLOCK_SIZE: tl.constexpr,
4646
):
4747
pid = tl.program_id(0)
4848
block_start = pid * BLOCK_SIZE
4949
offsets = block_start + tl.arange(0, BLOCK_SIZE)
5050

51-
rows = offsets // N
52-
cols = offsets % N
51+
rows = offsets // n
52+
cols = offsets % n
5353

54-
mask = (rows < M) & (cols < N)
54+
mask = (rows < m) & (cols < n)
5555

56-
a_offsets = rows * stride_am + cols * stride_an
57-
b_offsets = rows * stride_bm + cols * stride_bn
58-
c_offsets = rows * stride_cm + cols * stride_cn
56+
a_offsets = rows * a_stride_m + cols * a_stride_n
57+
b_offsets = rows * b_stride_m + cols * b_stride_n
58+
c_offsets = rows * c_stride_m + cols * c_stride_n
5959

6060
a = tl.load(a_ptr + a_offsets, mask=mask, other=0.0)
6161
b = tl.load(b_ptr + b_offsets, mask=mask, other=0.0)
@@ -87,6 +87,7 @@ def grid(meta):
8787
c.stride(1),
8888
BLOCK_SIZE=1024,
8989
)
90+
9091
return c
9192

9293

@@ -101,9 +102,11 @@ def torch_swiglu(
101102
torch.manual_seed(0)
102103

103104
shape = (13, 3)
104-
a = torch.rand(shape, device="cuda", dtype=torch.float16)
105-
b = torch.rand(shape, device="cuda", dtype=torch.float16)
106-
c = torch.rand(shape, device="cuda", dtype=torch.float16)
105+
dtype = torch.float16
106+
device = "cuda"
107+
a = torch.rand(shape, dtype=dtype, device=device)
108+
b = torch.rand(shape, dtype=dtype, device=device)
109+
c = torch.rand(shape, dtype=dtype, device=device)
107110

108111
ninetoothed_output = ninetoothed_swiglu(a, b)
109112
torch_output = torch_swiglu(a, b)
@@ -131,13 +134,15 @@ def torch_swiglu(
131134
line_names=["NineToothed", "PyTorch", "Triton"],
132135
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
133136
ylabel="GB/s",
134-
plot_name="vector-addition-performance",
137+
plot_name="swiglu-performance",
135138
args={},
136139
)
137140
)
138141
def benchmark(size, provider):
139-
a = torch.rand(size, device="cuda", dtype=torch.float16)
140-
b = torch.rand(size, device="cuda", dtype=torch.float16)
142+
dtype = torch.float16
143+
device = "cuda"
144+
a = torch.rand(size, dtype=dtype, device=device)
145+
b = torch.rand(size, dtype=dtype, device=device)
141146
quantiles = [0.5, 0.2, 0.8]
142147

143148
if provider == "ninetoothed":

0 commit comments

Comments
 (0)