Skip to content

Commit 358a627

Browse files
committed
Fix the shape issue in the performance testing in swiglu.py
1 parent 80009e8 commit 358a627

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

swiglu.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def torch_swiglu(
126126

127127
@triton.testing.perf_report(
128128
triton.testing.Benchmark(
129-
x_names=["size"],
130-
x_vals=[2**i for i in range(12, 28, 1)],
129+
x_names=["m", "n"],
130+
x_vals=[128 * i for i in range(2, 50)],
131131
x_log=True,
132132
line_arg="provider",
133133
line_vals=["ninetoothed", "torch", "triton"],
@@ -138,11 +138,13 @@ def torch_swiglu(
138138
args={},
139139
)
140140
)
141-
def benchmark(size, provider):
141+
def benchmark(m, n, provider):
142+
shape = (m, n)
142143
dtype = torch.float16
143144
device = "cuda"
144-
a = torch.rand(size, dtype=dtype, device=device)
145-
b = torch.rand(size, dtype=dtype, device=device)
145+
146+
a = torch.rand(shape, dtype=dtype, device=device)
147+
b = torch.rand(shape, dtype=dtype, device=device)
146148
quantiles = [0.5, 0.2, 0.8]
147149

148150
if provider == "ninetoothed":

0 commit comments

Comments
 (0)