@@ -702,30 +702,74 @@ def moe_align_block_size(
702702# Autotune configs for batched GEMM1 (gate+up projection).
703703# BLOCK_M is fixed at _BATCHED_BLOCK_M; only N and K are tuned.
704704_BATCHED_GEMM1_CONFIGS = [
705- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 3 ),
706- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 3 ),
707- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 }, num_warps = 4 , num_stages = 3 ),
708- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 3 ),
709- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 16 }, num_warps = 4 , num_stages = 3 ),
710705 triton .Config (
711- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 2
706+ {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 },
707+ num_warps = 4 ,
708+ num_stages = 3 ,
712709 ),
713710 triton .Config (
714- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 }, num_warps = 4 , num_stages = 2
711+ {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 },
712+ num_warps = 4 ,
713+ num_stages = 3 ,
714+ ),
715+ triton .Config (
716+ {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 },
717+ num_warps = 4 ,
718+ num_stages = 3 ,
719+ ),
720+ triton .Config (
721+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 },
722+ num_warps = 4 ,
723+ num_stages = 3 ,
724+ ),
725+ triton .Config (
726+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 16 },
727+ num_warps = 4 ,
728+ num_stages = 3 ,
729+ ),
730+ triton .Config (
731+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 },
732+ num_warps = 4 ,
733+ num_stages = 2 ,
734+ ),
735+ triton .Config (
736+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 },
737+ num_warps = 4 ,
738+ num_stages = 2 ,
715739 ),
716740]
717741
718742# Autotune configs for batched GEMM2 (down projection + SiLU).
719743_BATCHED_GEMM2_CONFIGS = [
720- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 3 ),
721- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 3 ),
722- triton .Config ({"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 16 }, num_warps = 4 , num_stages = 3 ),
723- triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 2 ),
724744 triton .Config (
725- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 }, num_warps = 4 , num_stages = 2
745+ {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 },
746+ num_warps = 4 ,
747+ num_stages = 3 ,
748+ ),
749+ triton .Config (
750+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 8 },
751+ num_warps = 4 ,
752+ num_stages = 3 ,
753+ ),
754+ triton .Config (
755+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 64 , "GROUP_SIZE_M" : 16 },
756+ num_warps = 4 ,
757+ num_stages = 3 ,
758+ ),
759+ triton .Config (
760+ {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 },
761+ num_warps = 4 ,
762+ num_stages = 2 ,
763+ ),
764+ triton .Config (
765+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 8 },
766+ num_warps = 4 ,
767+ num_stages = 2 ,
726768 ),
727769 triton .Config (
728- {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 }, num_warps = 4 , num_stages = 2
770+ {"BLOCK_SIZE_N" : 128 , "BLOCK_SIZE_K" : 128 , "GROUP_SIZE_M" : 16 },
771+ num_warps = 4 ,
772+ num_stages = 2 ,
729773 ),
730774]
731775
@@ -831,7 +875,8 @@ def _fused_moe_batched_kernel(
831875 B_scale
832876 + expert_id * stride_bse
833877 + offs_n [None , :] * stride_bsn
834- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
878+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size )
879+ * stride_bsk
835880 )
836881 b_scale = tl .load (
837882 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
@@ -967,7 +1012,8 @@ def _fused_moe_batched_int8_kernel(
9671012 B_scale
9681013 + expert_id * stride_bse
9691014 + offs_n [None , :] * stride_bsn
970- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
1015+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size )
1016+ * stride_bsk
9711017 )
9721018 b_scale = tl .load (
9731019 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
@@ -1085,7 +1131,8 @@ def _fused_moe_silu_batched_kernel(
10851131 B_scale
10861132 + expert_id * stride_bse
10871133 + offs_n [None , :] * stride_bsn
1088- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
1134+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size )
1135+ * stride_bsk
10891136 )
10901137 b_scale = tl .load (
10911138 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
@@ -1227,7 +1274,8 @@ def _fused_moe_silu_batched_int8_kernel(
12271274 B_scale
12281275 + expert_id * stride_bse
12291276 + offs_n [None , :] * stride_bsn
1230- + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
1277+ + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size )
1278+ * stride_bsk
12311279 )
12321280 b_scale = tl .load (
12331281 scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
0 commit comments