@@ -84,7 +84,7 @@ def _fused_moe_kernel(
8484 # Pointers
8585 A , # [M, K] bf16 activations
8686 B , # [E, N, K//2] int8 packed INT4 weights
87- C , # [M, N] fp32 output (atomic accumulation across experts)
87+ C , # [M * top_k , N] bf16 output
8888 B_scale , # [E, N, K//group_size] bf16 scales
8989 topk_ids , # [M * top_k] int64 expert indices
9090 topk_weights , # [M * top_k] float32 router weights
@@ -241,7 +241,6 @@ def _fused_moe_silu_kernel(
241241 group_size : tl .constexpr ,
242242 BLOCK_SIZE_N : tl .constexpr ,
243243 BLOCK_SIZE_K : tl .constexpr ,
244- top_k : tl .constexpr ,
245244 compute_type : tl .constexpr ,
246245):
247246 """GEMM2 with fused SiLU activation.
@@ -336,13 +335,12 @@ def _fused_moe_silu_kernel(
336335 a_up_ptrs += BLOCK_SIZE_K * stride_ak
337336 b_ptrs += (BLOCK_SIZE_K // 2 ) * stride_bk
338337
339- # Multiply by router weight and atomically accumulate into token row
338+ # Multiply by router weight
340339 weight = tl .load (topk_weights + pair_idx )
341340 acc = acc * weight
342341
343- token_idx = pair_idx // top_k
344- c_ptrs = C + token_idx * stride_cm + offs_n * stride_cn
345- tl .atomic_add (c_ptrs , acc , mask = n_mask )
342+ c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
343+ tl .store (c_ptrs , acc .to (compute_type ), mask = n_mask )
346344
347345
348346# ---------------------------------------------------------------------------
@@ -430,16 +428,17 @@ def grid1(meta):
430428 )
431429
432430 # ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
433- # Zero-init FP32 buffer — atomic_add in the kernel accumulates across top_k experts
434- output = torch .zeros (M , N2 , dtype = torch .float32 , device = hidden_states .device )
431+ cache3 = torch .empty (
432+ num_pairs , N2 , dtype = hidden_states .dtype , device = hidden_states .device
433+ )
435434
436435 def grid2 (meta ):
437436 return (num_pairs * triton .cdiv (N2 , meta ["BLOCK_SIZE_N" ]),)
438437
439438 wrap_triton (_fused_moe_silu_kernel )[grid2 ](
440439 cache1 ,
441440 w2 ,
442- output ,
441+ cache3 ,
443442 w2_scale ,
444443 topk_ids_flat ,
445444 topk_weights_flat ,
@@ -451,17 +450,17 @@ def grid2(meta):
451450 stride_be = w2 .stride (0 ),
452451 stride_bk = w2 .stride (2 ),
453452 stride_bn = w2 .stride (1 ),
454- stride_cm = output .stride (0 ),
455- stride_cn = output .stride (1 ),
453+ stride_cm = cache3 .stride (0 ),
454+ stride_cn = cache3 .stride (1 ),
456455 stride_bse = w2_scale .stride (0 ),
457456 stride_bsk = w2_scale .stride (2 ),
458457 stride_bsn = w2_scale .stride (1 ),
459458 group_size = group_size ,
460- top_k = top_k ,
461459 compute_type = tl .bfloat16 ,
462460 )
463461
464- return output .to (hidden_states .dtype )
462+ # ---- Sum across top-k experts ----
463+ return cache3 .view (M , top_k , N2 ).sum (dim = 1 )
465464
466465
467466@fused_moe .register_fake
0 commit comments