44import triton .language as tl
55from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
66
7-
87@triton .jit
9- def _silu_and_mul_kernel (
8+ def _silu_and_mul_kernel_fast (
109 input_ptr ,
1110 output_ptr ,
1211 stride_input_m ,
@@ -17,89 +16,48 @@ def _silu_and_mul_kernel(
1716 size_n ,
1817 BLOCK_M : tl .constexpr ,
1918 BLOCK_N : tl .constexpr ,
20- ):
21- stride_input_m = tl .cast (stride_input_m , dtype = tl .int64 )
22- stride_output_m = tl .cast (stride_output_m , dtype = tl .int64 )
23-
24- tid = tl .program_id (0 )
25- input_m_offsets = tid * BLOCK_M + tl .arange (0 , BLOCK_M )
26- output_m_offsets = tid * BLOCK_M + tl .arange (0 , BLOCK_M )
27-
28- pid = tl .program_id (1 )
29- input_n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
30- output_n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
31-
32- up_offsets = input_m_offsets [:, None ] * stride_input_m + (input_n_offsets [None , :] + size_n )
33- gate_offsets = input_m_offsets [:, None ] * stride_input_m + input_n_offsets [None , :]
34- res_offsets = output_m_offsets [:, None ] * stride_output_m + output_n_offsets [None , :]
35-
36- up = tl .load (
37- input_ptr + up_offsets ,
38- mask = (input_n_offsets < size_n )[None , :] * (input_m_offsets < size_m )[:, None ],
39- other = 0.0 ,
40- )
41- gate = tl .load (
42- input_ptr + gate_offsets ,
43- mask = (input_n_offsets < size_n )[None , :] * (input_m_offsets < size_m )[:, None ],
44- other = 0.0 ,
45- ).to (tl .float32 )
46-
47- gate = gate / (1 + tl .exp (- gate ))
48- gate = gate .to (input_ptr .dtype .element_ty )
49-
50- tl .store (
51- output_ptr + res_offsets ,
52- up * gate ,
53- mask = (output_n_offsets < size_n )[None , :] * (output_m_offsets < size_m )[:, None ],
54- )
55-
56-
57- @triton .jit
58- def _silu_and_mul_kernel_fast (
59- input_ptr ,
60- output_ptr ,
61- stride_input_m ,
62- stride_input_n ,
63- stride_output_m ,
64- stride_output_n ,
65- size_n ,
66- BLOCK_N : tl .constexpr ,
6719 NEED_MASK : tl .constexpr ,
6820):
6921 stride_input_m = tl .cast (stride_input_m , dtype = tl .int64 )
7022 stride_output_m = tl .cast (stride_output_m , dtype = tl .int64 )
7123
72- cur_batch = tl .program_id (0 )
73- pid = tl .program_id (1 )
74- n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
75-
76- up_offsets = cur_batch * stride_input_m + (n_offsets [None , :] + size_n )
77- gate_offsets = cur_batch * stride_input_m + n_offsets [None , :]
78- res_offsets = cur_batch * stride_output_m + n_offsets [None , :]
24+ m_block_index = tl .program_id (0 )
25+ n_block_index = tl .program_id (1 )
26+ n_offsets = n_block_index * BLOCK_N + tl .arange (0 , BLOCK_N )
27+ m_start_index = m_block_index * BLOCK_M
28+ m_end_index = (m_block_index + 1 ) * BLOCK_M
29+ m_end_index = tl .where (m_end_index < size_m , m_end_index , size_m )
7930 if NEED_MASK :
8031 mask = n_offsets [None , :] < size_n
32+ other = 0.0
8133 else :
82- mask = True
83-
84- up = tl .load (
85- input_ptr + up_offsets ,
86- mask = mask ,
87- other = 0.0 ,
88- )
89- gate = tl .load (
90- input_ptr + gate_offsets ,
91- mask = mask ,
92- other = 0.0 ,
93- ).to (tl .float32 )
94-
95- gate = gate / (1 + tl .exp (- gate ))
96- gate = gate .to (input_ptr .dtype .element_ty )
97-
98- tl .store (
99- output_ptr + res_offsets ,
100- up * gate ,
101- mask = mask ,
102- )
34+ mask = None
35+ other = None
36+
37+ for m_index in range (m_start_index , m_end_index ):
38+ gate_offsets = m_index * stride_input_m + n_offsets [None , :]
39+ up_offsets = m_index * stride_input_m + (n_offsets [None , :] + size_n )
40+ out_offsets = m_index * stride_output_m + n_offsets [None , :]
41+
42+ up = tl .load (
43+ input_ptr + up_offsets ,
44+ mask = mask ,
45+ other = other ,
46+ )
47+ gate = tl .load (
48+ input_ptr + gate_offsets ,
49+ mask = mask ,
50+ other = other ,
51+ ).to (tl .float32 )
52+
53+ gate = gate / (1 + tl .exp (- gate ))
54+ gate = gate .to (input_ptr .dtype .element_ty )
55+
56+ tl .store (
57+ output_ptr + out_offsets ,
58+ up * gate ,
59+ mask = mask ,
60+ )
10361
10462
10563def silu_and_mul_fwd (input : torch .Tensor , output : torch .Tensor , ** run_config ):
@@ -116,26 +74,6 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
11674 if not run_config :
11775 run_config = MoeSiluAndMulKernelConfig .try_to_get_best_config (M = size_m , N = size_n , out_dtype = str (output .dtype ))
11876
119- if size_m <= 4096 :
120- BLOCK_N = run_config ["BLOCK_N" ]
121- grid = (
122- size_m ,
123- triton .cdiv (size_n , BLOCK_N ),
124- )
125- NEED_MASK = size_n % BLOCK_N != 0
126- _silu_and_mul_kernel_fast [grid ](
127- input ,
128- output ,
129- stride_input_m ,
130- stride_input_n ,
131- stride_output_m ,
132- stride_output_n ,
133- size_n ,
134- BLOCK_N = BLOCK_N ,
135- NEED_MASK = NEED_MASK ,
136- )
137- return
138-
13977 BLOCK_M = run_config ["BLOCK_M" ]
14078 BLOCK_N = run_config ["BLOCK_N" ]
14179 num_warps = run_config ["num_warps" ]
@@ -144,17 +82,19 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
14482 triton .cdiv (size_m , BLOCK_M ),
14583 triton .cdiv (size_n , BLOCK_N ),
14684 )
147- _silu_and_mul_kernel [grid ](
148- input ,
149- output ,
150- stride_input_m ,
151- stride_input_n ,
152- stride_output_m ,
153- stride_output_n ,
154- size_m ,
155- size_n ,
85+ NEED_MASK = (size_n % BLOCK_N ) != 0
86+ _silu_and_mul_kernel_fast [grid ](
87+ input_ptr = input ,
88+ output_ptr = output ,
89+ stride_input_m = stride_input_m ,
90+ stride_input_n = stride_input_n ,
91+ stride_output_m = stride_output_m ,
92+ stride_output_n = stride_output_n ,
93+ size_m = size_m ,
94+ size_n = size_n ,
15695 BLOCK_M = BLOCK_M ,
15796 BLOCK_N = BLOCK_N ,
97+ NEED_MASK = NEED_MASK ,
15898 num_warps = num_warps ,
15999 )
160100 return
0 commit comments