5151 triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 5 ),
5252 triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 3 ),
5353 triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 5 ),
54+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 4 ),
55+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 3 ),
5456]
5557
5658# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
6365 triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 256 }, num_warps = 4 , num_stages = 4 ),
6466 triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 3 ),
6567 triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 4 , num_stages = 3 ),
68+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 128 }, num_warps = 2 , num_stages = 3 ),
69+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 4 ),
6670]
6771
6872
@@ -171,9 +175,9 @@ def _fused_moe_kernel(
171175 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
172176 ).to (tl .float32 )
173177
174- # Dequantize and accumulate: vector-matrix multiply
175- b_dequant = (( b .to (tl .float32 ) - 8.0 ) * b_scale ). to ( compute_type )
176- acc += tl .sum (a [:, None ].to (compute_type ) * b_dequant , axis = 0 )
178+ # Dequantize and accumulate in float32 : vector-matrix multiply
179+ b_dequant = (b .to (tl .float32 ) - 8.0 ) * b_scale
180+ acc += tl .sum (a [:, None ].to (tl . float32 ) * b_dequant , axis = 0 )
177181
178182 # Advance K pointers
179183 a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -259,10 +263,10 @@ def _fused_moe_silu_kernel(
259263 k_remaining = K - k_step * BLOCK_SIZE_K
260264 k_mask = offs_k < k_remaining
261265
262- # Load gate and up, apply SiLU(gate) * up
266+ # Load gate and up in float32 , apply SiLU(gate) * up
263267 gate = tl .load (a_gate_ptrs , mask = k_mask , other = 0.0 ).to (tl .float32 )
264- up = tl .load (a_up_ptrs , mask = k_mask , other = 0.0 )
265- a = ( gate * tl .sigmoid (gate ) * up ). to ( compute_type )
268+ up = tl .load (a_up_ptrs , mask = k_mask , other = 0.0 ). to ( tl . float32 )
269+ a = gate * tl .sigmoid (gate ) * up
266270
267271 # Load and dequantize INT4 weights
268272 b = tl .load (b_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0 )
@@ -290,8 +294,8 @@ def _fused_moe_silu_kernel(
290294 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
291295 ).to (tl .float32 )
292296
293- b_dequant = (( b .to (tl .float32 ) - 8.0 ) * b_scale ). to ( compute_type )
294- acc += tl .sum (a [:, None ]. to ( compute_type ) * b_dequant , axis = 0 )
297+ b_dequant = (b .to (tl .float32 ) - 8.0 ) * b_scale
298+ acc += tl .sum (a [:, None ] * b_dequant , axis = 0 )
295299
296300 a_gate_ptrs += BLOCK_SIZE_K * stride_ak
297301 a_up_ptrs += BLOCK_SIZE_K * stride_ak
@@ -571,6 +575,8 @@ def moe_align_block_size(
571575 triton .Config (
572576 {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2
573577 ),
578+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 }, num_warps = 8 , num_stages = 4 ),
579+ triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 8 , num_stages = 4 ),
574580]
575581
576582# Autotune configs for batched GEMM2 (down projection + SiLU).
@@ -581,6 +587,8 @@ def moe_align_block_size(
581587 triton .Config (
582588 {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2
583589 ),
590+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 }, num_warps = 8 , num_stages = 3 ),
591+ triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 8 , num_stages = 4 ),
584592]
585593
586594
@@ -778,9 +786,9 @@ def _fused_moe_silu_batched_kernel(
778786 k_remaining = K - k_step * BLOCK_SIZE_K
779787 k_mask = offs_k < k_remaining
780788
781- # Load gate and up tiles [BLOCK_M, BLOCK_K], apply SiLU
789+ # Load gate and up in float32 [BLOCK_M, BLOCK_K], apply SiLU
782790 gate = tl .load (a_gate_ptrs , mask = k_mask [None , :], other = 0.0 ).to (tl .float32 )
783- up = tl .load (a_up_ptrs , mask = k_mask [None , :], other = 0.0 )
791+ up = tl .load (a_up_ptrs , mask = k_mask [None , :], other = 0.0 ). to ( tl . float32 )
784792 a = (gate * tl .sigmoid (gate ) * up ).to (compute_type )
785793
786794 # Load and dequantize INT4 weights [BLOCK_K, BLOCK_N]
0 commit comments