@@ -16,8 +16,8 @@ def swiglu_kernel(
1616 b : Tensor (2 ).tile ((BLOCK_SIZE_M , BLOCK_SIZE_N )),
1717 c : Tensor (2 ).tile ((BLOCK_SIZE_M , BLOCK_SIZE_N )),
1818):
19- magic = b
20- gate = magic * ntl .sigmoid (ntl .cast (magic , ntl .float32 ))
19+ b_loaded = b
20+ gate = b_loaded * ntl .sigmoid (ntl .cast (b_loaded , ntl .float32 ))
2121 c = a * gate # noqa: F841
2222
2323
@@ -34,28 +34,28 @@ def triton_swiglu_kernel(
3434 a_ptr ,
3535 b_ptr ,
3636 c_ptr ,
37- M ,
38- N ,
39- stride_am ,
40- stride_an ,
41- stride_bm ,
42- stride_bn ,
43- stride_cm ,
44- stride_cn ,
37+ m ,
38+ n ,
39+ a_stride_m ,
40+ a_stride_n ,
41+ b_stride_m ,
42+ b_stride_n ,
43+ c_stride_m ,
44+ c_stride_n ,
4545 BLOCK_SIZE : tl .constexpr ,
4646):
4747 pid = tl .program_id (0 )
4848 block_start = pid * BLOCK_SIZE
4949 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
5050
51- rows = offsets // N
52- cols = offsets % N
51+ rows = offsets // n
52+ cols = offsets % n
5353
54- mask = (rows < M ) & (cols < N )
54+ mask = (rows < m ) & (cols < n )
5555
56- a_offsets = rows * stride_am + cols * stride_an
57- b_offsets = rows * stride_bm + cols * stride_bn
58- c_offsets = rows * stride_cm + cols * stride_cn
56+ a_offsets = rows * a_stride_m + cols * a_stride_n
57+ b_offsets = rows * b_stride_m + cols * b_stride_n
58+ c_offsets = rows * c_stride_m + cols * c_stride_n
5959
6060 a = tl .load (a_ptr + a_offsets , mask = mask , other = 0.0 )
6161 b = tl .load (b_ptr + b_offsets , mask = mask , other = 0.0 )
@@ -87,6 +87,7 @@ def grid(meta):
8787 c .stride (1 ),
8888 BLOCK_SIZE = 1024 ,
8989 )
90+
9091 return c
9192
9293
@@ -101,9 +102,11 @@ def torch_swiglu(
101102 torch .manual_seed (0 )
102103
103104 shape = (13 , 3 )
104- a = torch .rand (shape , device = "cuda" , dtype = torch .float16 )
105- b = torch .rand (shape , device = "cuda" , dtype = torch .float16 )
106- c = torch .rand (shape , device = "cuda" , dtype = torch .float16 )
105+ dtype = torch .float16
106+ device = "cuda"
107+ a = torch .rand (shape , dtype = dtype , device = device )
108+ b = torch .rand (shape , dtype = dtype , device = device )
109+ c = torch .rand (shape , dtype = dtype , device = device )
107110
108111 ninetoothed_output = ninetoothed_swiglu (a , b )
109112 torch_output = torch_swiglu (a , b )
@@ -131,13 +134,15 @@ def torch_swiglu(
131134 line_names = ["NineToothed" , "PyTorch" , "Triton" ],
132135 styles = [("blue" , "-" ), ("green" , "-" ), ("orange" , "-" )],
133136 ylabel = "GB/s" ,
134- plot_name = "vector-addition -performance" ,
137+ plot_name = "swiglu -performance" ,
135138 args = {},
136139 )
137140 )
138141 def benchmark (size , provider ):
139- a = torch .rand (size , device = "cuda" , dtype = torch .float16 )
140- b = torch .rand (size , device = "cuda" , dtype = torch .float16 )
142+ dtype = torch .float16
143+ device = "cuda"
144+ a = torch .rand (size , dtype = dtype , device = device )
145+ b = torch .rand (size , dtype = dtype , device = device )
141146 quantiles = [0.5 , 0.2 , 0.8 ]
142147
143148 if provider == "ninetoothed" :
0 commit comments