Skip to content

Commit c2574df

Browse files
committed
revert atomic changes
1 parent 6c3bab1 commit c2574df

1 file changed

Lines changed: 12 additions & 13 deletions

File tree

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _fused_moe_kernel(
8484
# Pointers
8585
A, # [M, K] bf16 activations
8686
B, # [E, N, K//2] int8 packed INT4 weights
87-
C, # [M, N] fp32 output (atomic accumulation across experts)
87+
C, # [M * top_k, N] bf16 output
8888
B_scale, # [E, N, K//group_size] bf16 scales
8989
topk_ids, # [M * top_k] int64 expert indices
9090
topk_weights, # [M * top_k] float32 router weights
@@ -241,7 +241,6 @@ def _fused_moe_silu_kernel(
241241
group_size: tl.constexpr,
242242
BLOCK_SIZE_N: tl.constexpr,
243243
BLOCK_SIZE_K: tl.constexpr,
244-
top_k: tl.constexpr,
245244
compute_type: tl.constexpr,
246245
):
247246
"""GEMM2 with fused SiLU activation.
@@ -336,13 +335,12 @@ def _fused_moe_silu_kernel(
336335
a_up_ptrs += BLOCK_SIZE_K * stride_ak
337336
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
338337

339-
# Multiply by router weight and atomically accumulate into token row
338+
# Multiply by router weight
340339
weight = tl.load(topk_weights + pair_idx)
341340
acc = acc * weight
342341

343-
token_idx = pair_idx // top_k
344-
c_ptrs = C + token_idx * stride_cm + offs_n * stride_cn
345-
tl.atomic_add(c_ptrs, acc, mask=n_mask)
342+
c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
343+
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)
346344

347345

348346
# ---------------------------------------------------------------------------
@@ -430,16 +428,17 @@ def grid1(meta):
430428
)
431429

432430
# ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
433-
# Zero-init FP32 buffer — atomic_add in the kernel accumulates across top_k experts
434-
output = torch.zeros(M, N2, dtype=torch.float32, device=hidden_states.device)
431+
cache3 = torch.empty(
432+
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
433+
)
435434

436435
def grid2(meta):
437436
return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),)
438437

439438
wrap_triton(_fused_moe_silu_kernel)[grid2](
440439
cache1,
441440
w2,
442-
output,
441+
cache3,
443442
w2_scale,
444443
topk_ids_flat,
445444
topk_weights_flat,
@@ -451,17 +450,17 @@ def grid2(meta):
451450
stride_be=w2.stride(0),
452451
stride_bk=w2.stride(2),
453452
stride_bn=w2.stride(1),
454-
stride_cm=output.stride(0),
455-
stride_cn=output.stride(1),
453+
stride_cm=cache3.stride(0),
454+
stride_cn=cache3.stride(1),
456455
stride_bse=w2_scale.stride(0),
457456
stride_bsk=w2_scale.stride(2),
458457
stride_bsn=w2_scale.stride(1),
459458
group_size=group_size,
460-
top_k=top_k,
461459
compute_type=tl.bfloat16,
462460
)
463461

464-
return output.to(hidden_states.dtype)
462+
# ---- Sum across top-k experts ----
463+
return cache3.view(M, top_k, N2).sum(dim=1)
465464

466465

467466
@fused_moe.register_fake

0 commit comments

Comments
 (0)