4242
4343
4444# Autotune configs for GEMM1 (_fused_moe_kernel).
45- # Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
46- # (M=1, N=1024, K=2048, 8 experts, group_size=128) .
45+ # Qwen3.5 MoE dimensions (M=1, N=1024, K=2048, 8 experts, group_size=32).
46+ # BLOCK_K ≤ 32 ensures the efficient one-scale-per-tile path .
4747_GEMM1_CONFIGS = [
48+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 2 ),
49+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
50+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 5 ),
51+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 3 ),
52+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
53+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 3 ),
54+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 3 ),
55+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
4856 triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 2 ),
49- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 2 ),
50- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 4 ),
51- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 5 ),
52- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 3 ),
53- triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 5 ),
54- triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 4 ),
55- triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 3 ),
57+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 3 ),
58+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 16 }, num_warps = 2 , num_stages = 4 ),
59+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 16 }, num_warps = 2 , num_stages = 5 ),
5660]
5761
5862# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
59- # Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
60- # (M=1, N=2048, K=512, 8 experts, group_size=128) .
63+ # Qwen3.5 MoE dimensions (M=1, N=2048, K=512, 8 experts, group_size=32).
64+ # BLOCK_K ≤ 32 ensures the efficient one-scale-per-tile path .
6165_GEMM2_CONFIGS = [
66+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 2 ),
67+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
68+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 5 ),
69+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 3 ),
70+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
71+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 3 ),
72+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 3 ),
73+ triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 2 , num_stages = 4 ),
6274 triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 2 ),
63- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 128 }, num_warps = 2 , num_stages = 4 ),
64- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 4 ),
65- triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 256 }, num_warps = 4 , num_stages = 4 ),
66- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 2 , num_stages = 3 ),
67- triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 256 }, num_warps = 4 , num_stages = 3 ),
68- triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 128 }, num_warps = 2 , num_stages = 3 ),
69- triton .Config ({"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 128 }, num_warps = 4 , num_stages = 4 ),
75+ triton .Config ({"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }, num_warps = 4 , num_stages = 3 ),
76+ triton .Config ({"BLOCK_SIZE_N" : 8 , "BLOCK_SIZE_K" : 16 }, num_warps = 2 , num_stages = 4 ),
77+ triton .Config ({"BLOCK_SIZE_N" : 16 , "BLOCK_SIZE_K" : 16 }, num_warps = 2 , num_stages = 5 ),
7078]
7179
7280
@@ -76,7 +84,7 @@ def _fused_moe_kernel(
7684 # Pointers
7785 A , # [M, K] bf16 activations
7886 B , # [E, N, K//2] int8 packed INT4 weights
79- C , # [M * top_k , N] bf16 output
87+ C , # [M, N] fp32 output (atomic accumulation across experts)
8088 B_scale , # [E, N, K//group_size] bf16 scales
8189 topk_ids , # [M * top_k] int64 expert indices
8290 topk_weights , # [M * top_k] float32 router weights
@@ -144,11 +152,16 @@ def _fused_moe_kernel(
144152 k_remaining = K - k_step * BLOCK_SIZE_K
145153 k_mask = offs_k < k_remaining
146154
147- # Load A tile [BLOCK_SIZE_K]
148- a = tl .load (a_ptrs , mask = k_mask , other = 0.0 )
155+ # Load A tile [BLOCK_SIZE_K] — reused across N-blocks, keep in L2
156+ a = tl .load (a_ptrs , mask = k_mask , other = 0.0 , eviction_policy = "evict_last" )
149157
150- # Load B tile [BLOCK_SIZE_K, BLOCK_SIZE_N] and unpack INT4
151- b = tl .load (b_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0 )
158+ # Load B tile [BLOCK_SIZE_K, BLOCK_SIZE_N] and unpack INT4 — streaming
159+ b = tl .load (
160+ b_ptrs ,
161+ mask = k_mask [:, None ] & n_mask [None , :],
162+ other = 0 ,
163+ eviction_policy = "evict_first" ,
164+ )
152165 b = (b >> b_shifter ) & 0xF
153166
154167 # Load per-group scales and dequantize
@@ -161,9 +174,12 @@ def _fused_moe_kernel(
161174 + offs_n [None , :] * stride_bsn
162175 + group_idx * stride_bsk
163176 )
164- b_scale = tl .load (scale_ptrs , mask = n_mask [None , :], other = 0.0 ).to (
165- tl .float32
166- )
177+ b_scale = tl .load (
178+ scale_ptrs ,
179+ mask = n_mask [None , :],
180+ other = 0.0 ,
181+ eviction_policy = "evict_first" ,
182+ ).to (tl .float32 )
167183 else :
168184 scale_ptrs = (
169185 B_scale
@@ -172,7 +188,10 @@ def _fused_moe_kernel(
172188 + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
173189 )
174190 b_scale = tl .load (
175- scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
191+ scale_ptrs ,
192+ mask = k_mask [:, None ] & n_mask [None , :],
193+ other = 0.0 ,
194+ eviction_policy = "evict_first" ,
176195 ).to (tl .float32 )
177196
178197 # Dequantize and accumulate in float32: vector-matrix multiply
@@ -222,6 +241,7 @@ def _fused_moe_silu_kernel(
222241 group_size : tl .constexpr ,
223242 BLOCK_SIZE_N : tl .constexpr ,
224243 BLOCK_SIZE_K : tl .constexpr ,
244+ top_k : tl .constexpr ,
225245 compute_type : tl .constexpr ,
226246):
227247 """GEMM2 with fused SiLU activation.
@@ -263,13 +283,22 @@ def _fused_moe_silu_kernel(
263283 k_remaining = K - k_step * BLOCK_SIZE_K
264284 k_mask = offs_k < k_remaining
265285
266- # Load gate and up in float32, apply SiLU(gate) * up
267- gate = tl .load (a_gate_ptrs , mask = k_mask , other = 0.0 ).to (tl .float32 )
268- up = tl .load (a_up_ptrs , mask = k_mask , other = 0.0 ).to (tl .float32 )
286+ # Load gate and up in float32, apply SiLU(gate) * up — reused across N-blocks
287+ gate = tl .load (
288+ a_gate_ptrs , mask = k_mask , other = 0.0 , eviction_policy = "evict_last"
289+ ).to (tl .float32 )
290+ up = tl .load (
291+ a_up_ptrs , mask = k_mask , other = 0.0 , eviction_policy = "evict_last"
292+ ).to (tl .float32 )
269293 a = gate * tl .sigmoid (gate ) * up
270294
271- # Load and dequantize INT4 weights
272- b = tl .load (b_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0 )
295+ # Load and dequantize INT4 weights — streaming
296+ b = tl .load (
297+ b_ptrs ,
298+ mask = k_mask [:, None ] & n_mask [None , :],
299+ other = 0 ,
300+ eviction_policy = "evict_first" ,
301+ )
273302 b = (b >> b_shifter ) & 0xF
274303
275304 if BLOCK_SIZE_K <= group_size :
@@ -280,9 +309,12 @@ def _fused_moe_silu_kernel(
280309 + offs_n [None , :] * stride_bsn
281310 + group_idx * stride_bsk
282311 )
283- b_scale = tl .load (scale_ptrs , mask = n_mask [None , :], other = 0.0 ).to (
284- tl .float32
285- )
312+ b_scale = tl .load (
313+ scale_ptrs ,
314+ mask = n_mask [None , :],
315+ other = 0.0 ,
316+ eviction_policy = "evict_first" ,
317+ ).to (tl .float32 )
286318 else :
287319 scale_ptrs = (
288320 B_scale
@@ -291,7 +323,10 @@ def _fused_moe_silu_kernel(
291323 + ((offs_k [:, None ] + BLOCK_SIZE_K * k_step ) // group_size ) * stride_bsk
292324 )
293325 b_scale = tl .load (
294- scale_ptrs , mask = k_mask [:, None ] & n_mask [None , :], other = 0.0
326+ scale_ptrs ,
327+ mask = k_mask [:, None ] & n_mask [None , :],
328+ other = 0.0 ,
329+ eviction_policy = "evict_first" ,
295330 ).to (tl .float32 )
296331
297332 b_dequant = (b .to (tl .float32 ) - 8.0 ) * b_scale
@@ -301,12 +336,13 @@ def _fused_moe_silu_kernel(
301336 a_up_ptrs += BLOCK_SIZE_K * stride_ak
302337 b_ptrs += (BLOCK_SIZE_K // 2 ) * stride_bk
303338
304- # Multiply by router weight
339+ # Multiply by router weight and atomically accumulate into token row
305340 weight = tl .load (topk_weights + pair_idx )
306341 acc = acc * weight
307342
308- c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
309- tl .store (c_ptrs , acc .to (compute_type ), mask = n_mask )
343+ token_idx = pair_idx // top_k
344+ c_ptrs = C + token_idx * stride_cm + offs_n * stride_cn
345+ tl .atomic_add (c_ptrs , acc , mask = n_mask )
310346
311347
312348# ---------------------------------------------------------------------------
@@ -394,17 +430,16 @@ def grid1(meta):
394430 )
395431
396432 # ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
397- cache3 = torch .empty (
398- num_pairs , N2 , dtype = hidden_states .dtype , device = hidden_states .device
399- )
433+ # Zero-init FP32 buffer — atomic_add in the kernel accumulates across top_k experts
434+ output = torch .zeros (M , N2 , dtype = torch .float32 , device = hidden_states .device )
400435
401436 def grid2 (meta ):
402437 return (num_pairs * triton .cdiv (N2 , meta ["BLOCK_SIZE_N" ]),)
403438
404439 wrap_triton (_fused_moe_silu_kernel )[grid2 ](
405440 cache1 ,
406441 w2 ,
407- cache3 ,
442+ output ,
408443 w2_scale ,
409444 topk_ids_flat ,
410445 topk_weights_flat ,
@@ -416,17 +451,17 @@ def grid2(meta):
416451 stride_be = w2 .stride (0 ),
417452 stride_bk = w2 .stride (2 ),
418453 stride_bn = w2 .stride (1 ),
419- stride_cm = cache3 .stride (0 ),
420- stride_cn = cache3 .stride (1 ),
454+ stride_cm = output .stride (0 ),
455+ stride_cn = output .stride (1 ),
421456 stride_bse = w2_scale .stride (0 ),
422457 stride_bsk = w2_scale .stride (2 ),
423458 stride_bsn = w2_scale .stride (1 ),
424459 group_size = group_size ,
460+ top_k = top_k ,
425461 compute_type = tl .bfloat16 ,
426462 )
427463
428- # ---- Sum across top-k experts ----
429- return cache3 .view (M , top_k , N2 ).sum (dim = 1 )
464+ return output .to (hidden_states .dtype )
430465
431466
432467@fused_moe .register_fake
0 commit comments