Skip to content

Commit 7ee78a1

Browse files
committed
optimize fused_moe kernel performance
1 parent d70e6ef commit 7ee78a1

1 file changed

Lines changed: 15 additions & 14 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/fused_moe_kernels.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,9 @@ def _get_gemm_autotune_configs():
268268
# tests and first-run unbearable. One stable config is enough for correctness.
269269
# Keep BLOCK_N/BLOCK_K moderate — large tiles overflow UB in bwd kernels.
270270
return [
271-
# NOTE: Triton requires total grid programs < 65536. For large MoE shapes
272-
# (e.g. E=128, T=8192, K=8, BLOCK_M_TOKEN=32), num_m_tiles can be ~2048.
273-
# With BLOCK_N=64 and H=2048 => grid=(2048, 32) => 65536 (invalid).
274-
# Use a larger BLOCK_N to reduce grid-y and stay below the limit.
275-
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4, num_stages=2),
271+
# NOTE: Triton requires total grid programs < 65536. Keep a robust
272+
# compile-safe config for Ascend.
273+
triton.Config({"BLOCK_N": 192, "BLOCK_K": 64}, num_warps=4, num_stages=2),
276274
]
277275

278276

@@ -473,6 +471,8 @@ def _fused_down_proj_kernel(
473471
def _get_token_gather_autotune_configs():
474472
return [
475473
triton.Config({"BLOCK_H": 128, "BLOCK_K": 8}, num_warps=4, num_stages=4),
474+
triton.Config({"BLOCK_H": 256, "BLOCK_K": 8}, num_warps=4, num_stages=4),
475+
triton.Config({"BLOCK_H": 256, "BLOCK_K": 16}, num_warps=4, num_stages=4),
476476
]
477477

478478

@@ -651,8 +651,9 @@ def _moe_bwd_down_proj_kernel(
651651
@triton.autotune(
652652
configs=[
653653
# Keep total grid programs < 65536 for large E/H/I.
654-
# Grid: (E * ceil(I/BLOCK_M), ceil(H/BLOCK_N)). Larger BLOCK_N reduces grid-y.
655-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=2),
654+
# Larger BLOCK_M/BLOCK_K improve arithmetic intensity on long-sequence
655+
# backward workloads and cut K-loop overhead.
656+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 240, "BLOCK_K": 32}, num_warps=4, num_stages=2),
656657
],
657658
key=["H_dim", "I_dim"],
658659
reset_to_zero=["dW2_ptr"],
@@ -712,16 +713,16 @@ def _moe_bwd_dW2_kernel(
712713
k_mask = k_idx < M_e
713714
row_offs = expert_start + k_idx
714715

715-
wact_ptrs = weighted_act_ptr + row_offs[None, :] * stride_wact_TK + i_idx[:, None] * stride_wact_I
716-
wact_tile = tl.load(wact_ptrs, mask=k_mask[None, :] & i_mask[:, None], other=0.0)
716+
wact_ptrs = weighted_act_ptr + row_offs[:, None] * stride_wact_TK + i_idx[None, :] * stride_wact_I
717+
wact_tile = tl.load(wact_ptrs, mask=k_mask[:, None] & i_mask[None, :], other=0.0)
717718

718719
token_idx = tl.load(x_gather_idx_ptr + row_offs, mask=k_mask, other=0)
719720
dout_ptrs = dout_ptr + token_idx[:, None] * stride_dout_T + h_idx[None, :] * stride_dout_H
720721
dout_tile = tl.load(dout_ptrs, mask=k_mask[:, None] & h_mask[None, :], other=0.0)
721722

722-
# dW2[h, i] = sum_t dout_g[t, h] * wa[t, i] <=> (H,T)*(T,I) with wa (T,I), dout_g (T,H)
723-
# wact_tile is (I_blk, T_blk); dout_tile is (T_blk, H_blk) — previous tl.dot(wa, dout) gave (I,H), wrong layout.
724-
acc += tl.dot(tl.trans(dout_tile), tl.trans(wact_tile))
723+
# dW2[h, i] = sum_t dout_g[t, h] * wa[t, i] <=> (H,T) @ (T,I)
724+
# Keep weighted_act as (T_blk, I_blk) so only one transpose is needed.
725+
acc += tl.dot(tl.trans(dout_tile), wact_tile)
725726

726727
# acc layout is (H_blk, I_blk) — match dW2[e, h, i] with h on the first broadcast axis.
727728
dW2_ptrs = dW2_ptr + expert_idx * stride_dW2_E + h_idx[:, None] * stride_dW2_H + i_idx[None, :] * stride_dW2_I
@@ -823,8 +824,8 @@ def _moe_bwd_dX_expanded_kernel(
823824
@triton.autotune(
824825
configs=[
825826
# Keep total grid programs < 65536 for large E/H/I.
826-
# Grid: (E * ceil(H/BLOCK_M), ceil(2I/BLOCK_N)). Larger BLOCK_N reduces grid-y.
827-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 16}, num_warps=4, num_stages=2),
827+
# Match dW2 tiling for better large-shape throughput.
828+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 240, "BLOCK_K": 32}, num_warps=4, num_stages=2),
828829
],
829830
key=["H_dim", "I_dim"],
830831
reset_to_zero=["dW1_ptr"],

0 commit comments

Comments
 (0)