@@ -707,11 +707,11 @@ def _fused_moe_batched_kernel(
707707# Autotune configs for batched INT8 GEMM1 (gate+up projection, W4A8).
708708_BATCHED_GEMM1_INT8_CONFIGS = [
709709 triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 3 ),
710- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2 ),
711- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
712710 triton .Config (
713- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3
711+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2
714712 ),
713+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
714+ triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
715715 triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 4 ),
716716 triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 4 ),
717717]
@@ -833,7 +833,10 @@ def _fused_moe_batched_int8_kernel(
833833 else :
834834 # Multi-group tile: dequantize weights per group, use float matmul
835835 b_dequant = (b_int8 .to (tl .float32 ) * b_scale ).to (compute_type )
836- acc += tl .dot (a_int8 .to (compute_type ), b_dequant ).to (tl .float32 ) * a_scale [:, None ]
836+ acc += (
837+ tl .dot (a_int8 .to (compute_type ), b_dequant ).to (tl .float32 )
838+ * a_scale [:, None ]
839+ )
837840
838841 a_ptrs += BLOCK_SIZE_K * stride_ak
839842 b_ptrs += (BLOCK_SIZE_K // 2 ) * stride_bk
@@ -977,11 +980,11 @@ def _fused_moe_silu_batched_kernel(
977980# Autotune configs for batched INT8 GEMM2 (down projection + SiLU, W4A8).
978981_BATCHED_GEMM2_INT8_CONFIGS = [
979982 triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2 ),
980- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2 ),
981- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
982983 triton .Config (
983- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3
984+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 2
984985 ),
986+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
987+ triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 }, num_warps = 4 , num_stages = 3 ),
985988 triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 4 ),
986989 triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 4 ),
987990]
@@ -1105,7 +1108,10 @@ def _fused_moe_silu_batched_int8_kernel(
11051108 else :
11061109 # Multi-group tile: dequantize weights per group, use float matmul
11071110 b_dequant = (b_int8 .to (tl .float32 ) * b_scale ).to (compute_type )
1108- acc += tl .dot (a_int8 .to (compute_type ), b_dequant ).to (tl .float32 ) * a_scale [:, None ]
1111+ acc += (
1112+ tl .dot (a_int8 .to (compute_type ), b_dequant ).to (tl .float32 )
1113+ * a_scale [:, None ]
1114+ )
11091115
11101116 a_gate_ptrs += BLOCK_SIZE_K * stride_ak
11111117 a_up_ptrs += BLOCK_SIZE_K * stride_ak
0 commit comments