66import triton .language as tl
77from ninetoothed import Symbol , Tensor
88
9- BLOCK_SIZE_M = Symbol ("BLOCK_SIZE_M" , meta = True )
10- BLOCK_SIZE_N = Symbol ("BLOCK_SIZE_N" , meta = True )
9+ BLOCK_SIZE = Symbol ("BLOCK_SIZE" , meta = True )
1110
1211
1312@ninetoothed .jit
1413def swiglu_kernel (
15- a : Tensor (2 ).tile ((BLOCK_SIZE_M , BLOCK_SIZE_N )),
16- b : Tensor (2 ).tile ((BLOCK_SIZE_M , BLOCK_SIZE_N )),
17- c : Tensor (2 ).tile ((BLOCK_SIZE_M , BLOCK_SIZE_N )),
14+ a : Tensor (1 ).tile ((BLOCK_SIZE , )),
15+ b : Tensor (1 ).tile ((BLOCK_SIZE , )),
16+ c : Tensor (1 ).tile ((BLOCK_SIZE , )),
1817):
1918 b_loaded = b
2019 gate = b_loaded * ntl .sigmoid (ntl .cast (b_loaded , ntl .float32 ))
2120 c = a * gate # noqa: F841
2221
2322
24- def ninetoothed_swiglu (a , b ):
25- c = torch .empty_like (a )
23+ def swiglu (a , b ):
24+ a_1d = a .flatten ()
25+ b_1d = b .flatten ()
2626
27- swiglu_kernel ( a , b , c )
27+ c = torch . empty_like ( a_1d )
2828
29- return c
29+ swiglu_kernel (a_1d , b_1d , c )
30+
31+ return c .view_as (a )
3032
3133
3234@triton .jit
3335def triton_swiglu_kernel (
34- a_ptr ,
35- b_ptr ,
36- c_ptr ,
37- m ,
38- n ,
39- a_stride_m ,
40- a_stride_n ,
41- b_stride_m ,
42- b_stride_n ,
43- c_stride_m ,
44- c_stride_n ,
45- BLOCK_SIZE : tl .constexpr ,
36+ a_ptr , b_ptr , c_ptr , num_elements : tl .constexpr , BLOCK_SIZE : tl .constexpr
4637):
4738 pid = tl .program_id (0 )
48- block_start = pid * BLOCK_SIZE
49- offsets = block_start + tl .arange (0 , BLOCK_SIZE )
50-
51- rows = offsets // n
52- cols = offsets % n
53-
54- mask = (rows < m ) & (cols < n )
39+ offsets = pid * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
40+ mask = offsets < num_elements
5541
56- a_offsets = rows * a_stride_m + cols * a_stride_n
57- b_offsets = rows * b_stride_m + cols * b_stride_n
58- c_offsets = rows * c_stride_m + cols * c_stride_n
59-
60- a = tl .load (a_ptr + a_offsets , mask = mask , other = 0.0 )
61- b = tl .load (b_ptr + b_offsets , mask = mask , other = 0.0 )
42+ a = tl .load (a_ptr + offsets , mask = mask , other = 0.0 )
43+ b = tl .load (b_ptr + offsets , mask = mask , other = 0.0 )
6244
6345 silu_b = b * tl .sigmoid (tl .cast (b , tl .float32 ))
6446 c = a * silu_b
6547
66- tl .store (c_ptr + c_offsets , c , mask = mask )
48+ tl .store (c_ptr + offsets , c , mask = mask )
6749
6850
6951def triton_swiglu (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
70- m , n = a .shape
71- c = torch .empty_like (a )
52+ # Flatten the inputs so that the kernel always works on 1D tensors
53+ a_flat = a .flatten ()
54+ b_flat = b .flatten ()
55+ c_flat = torch .empty_like (a_flat )
56+ num_elements = a_flat .numel ()
7257
7358 def grid (meta ):
74- return (triton .cdiv (m * n , meta ["BLOCK_SIZE" ]),)
75-
76- triton_swiglu_kernel [grid ](
77- a ,
78- b ,
79- c ,
80- m ,
81- n ,
82- a .stride (0 ),
83- a .stride (1 ),
84- b .stride (0 ),
85- b .stride (1 ),
86- c .stride (0 ),
87- c .stride (1 ),
88- BLOCK_SIZE = 1024 ,
89- )
59+ return (triton .cdiv (num_elements , meta ["BLOCK_SIZE" ]),)
60+
61+ triton_swiglu_kernel [grid ](a_flat , b_flat , c_flat , num_elements , BLOCK_SIZE = 1024 )
9062
91- return c
63+ return c_flat . view_as ( a )
9264
9365
9466def torch_swiglu (
@@ -108,7 +80,7 @@ def torch_swiglu(
10880 b = torch .rand (shape , dtype = dtype , device = device )
10981 c = torch .rand (shape , dtype = dtype , device = device )
11082
111- ninetoothed_output = ninetoothed_swiglu (a , b )
83+ ninetoothed_output = swiglu (a , b )
11284 torch_output = torch_swiglu (a , b )
11385 triton_output = triton_swiglu (a , b )
11486 print (ninetoothed_output )
@@ -149,7 +121,7 @@ def benchmark(m, n, provider):
149121
150122 if provider == "ninetoothed" :
151123 ms , min_ms , max_ms = triton .testing .do_bench (
152- lambda : ninetoothed_swiglu (a , b ), quantiles = quantiles
124+ lambda : swiglu (a , b ), quantiles = quantiles
153125 )
154126 elif provider == "torch" :
155127 ms , min_ms , max_ms = triton .testing .do_bench (
0 commit comments