@@ -36,61 +36,41 @@ def triton_swiglu_kernel(
3636 a_ptr ,
3737 b_ptr ,
3838 c_ptr ,
39- m ,
40- n ,
41- a_stride_m ,
42- a_stride_n ,
43- b_stride_m ,
44- b_stride_n ,
45- c_stride_m ,
46- c_stride_n ,
39+ data_size : tl .constexpr ,
4740 BLOCK_SIZE : tl .constexpr ,
4841):
4942 pid = tl .program_id (0 )
50- block_start = pid * BLOCK_SIZE
51- offsets = block_start + tl . arange ( 0 , BLOCK_SIZE )
43+ offsets = pid * BLOCK_SIZE + tl . arange ( 0 , BLOCK_SIZE )
44+ mask = offsets < data_size
5245
53- rows = offsets // n
54- cols = offsets % n
55-
56- mask = (rows < m ) & (cols < n )
57-
58- a_offsets = rows * a_stride_m + cols * a_stride_n
59- b_offsets = rows * b_stride_m + cols * b_stride_n
60- c_offsets = rows * c_stride_m + cols * c_stride_n
61-
62- a = tl .load (a_ptr + a_offsets , mask = mask , other = 0.0 )
63- b = tl .load (b_ptr + b_offsets , mask = mask , other = 0.0 )
46+ a = tl .load (a_ptr + offsets , mask = mask , other = 0.0 )
47+ b = tl .load (b_ptr + offsets , mask = mask , other = 0.0 )
6448
6549 silu_b = b * tl .sigmoid (tl .cast (b , tl .float32 ))
6650 c = a * silu_b
6751
68- tl .store (c_ptr + c_offsets , c , mask = mask )
52+ tl .store (c_ptr + offsets , c , mask = mask )
6953
7054
7155def triton_swiglu (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
72- m , n = a .shape
73- c = torch .empty_like (a )
56+ # Flatten the inputs so that the kernel always works on 1D tensors
57+ a_flat = a .flatten ()
58+ b_flat = b .flatten ()
59+ c_flat = torch .empty_like (a_flat )
60+ data_size = a_flat .numel ()
7461
7562 def grid (meta ):
76- return (triton .cdiv (m * n , meta ["BLOCK_SIZE" ]),)
63+ return (triton .cdiv (data_size , meta ["BLOCK_SIZE" ]),)
7764
7865 triton_swiglu_kernel [grid ](
79- a ,
80- b ,
81- c ,
82- m ,
83- n ,
84- a .stride (0 ),
85- a .stride (1 ),
86- b .stride (0 ),
87- b .stride (1 ),
88- c .stride (0 ),
89- c .stride (1 ),
66+ a_flat ,
67+ b_flat ,
68+ c_flat ,
69+ data_size ,
9070 BLOCK_SIZE = 1024 ,
9171 )
9272
93- return c
73+ return c_flat . view_as ( a )
9474
9575
9676def torch_swiglu (
0 commit comments