We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
ninetoothed_swiglu
swiglu
1 parent da33f58 commit 906a166Copy full SHA for 906a166
1 file changed
swiglu.py
@@ -20,7 +20,7 @@ def swiglu_kernel(
20
c = a * gate # noqa: F841
21
22
23
-def ninetoothed_swiglu(a, b):
+def swiglu(a, b):
24
a_1d = a.flatten()
25
b_1d = b.flatten()
26
@@ -80,7 +80,7 @@ def torch_swiglu(
80
b = torch.rand(shape, dtype=dtype, device=device)
81
c = torch.rand(shape, dtype=dtype, device=device)
82
83
- ninetoothed_output = ninetoothed_swiglu(a, b)
+ ninetoothed_output = swiglu(a, b)
84
torch_output = torch_swiglu(a, b)
85
triton_output = triton_swiglu(a, b)
86
print(ninetoothed_output)
@@ -121,7 +121,7 @@ def benchmark(m, n, provider):
121
122
if provider == "ninetoothed":
123
ms, min_ms, max_ms = triton.testing.do_bench(
124
- lambda: ninetoothed_swiglu(a, b), quantiles=quantiles
+ lambda: swiglu(a, b), quantiles=quantiles
125
)
126
elif provider == "torch":
127
0 commit comments