@@ -268,11 +268,9 @@ def _get_gemm_autotune_configs():
268268 # tests and first-run unbearable. One stable config is enough for correctness.
269269 # Keep BLOCK_N/BLOCK_K moderate — large tiles overflow UB in bwd kernels.
270270 return [
271- # NOTE: Triton requires total grid programs < 65536. For large MoE shapes
272- # (e.g. E=128, T=8192, K=8, BLOCK_M_TOKEN=32), num_m_tiles can be ~2048.
273- # With BLOCK_N=64 and H=2048 => grid=(2048, 32) => 65536 (invalid).
274- # Use a larger BLOCK_N to reduce grid-y and stay below the limit.
275- triton .Config ({"BLOCK_N" : 128 , "BLOCK_K" : 32 }, num_warps = 4 , num_stages = 2 ),
271+ # NOTE: Triton requires total grid programs < 65536. Keep a robust
272+ # compile-safe config for Ascend.
273+ triton .Config ({"BLOCK_N" : 192 , "BLOCK_K" : 64 }, num_warps = 4 , num_stages = 2 ),
276274 ]
277275
278276
@@ -473,6 +471,8 @@ def _fused_down_proj_kernel(
473471def _get_token_gather_autotune_configs ():
474472 return [
475473 triton .Config ({"BLOCK_H" : 128 , "BLOCK_K" : 8 }, num_warps = 4 , num_stages = 4 ),
474+ triton .Config ({"BLOCK_H" : 256 , "BLOCK_K" : 8 }, num_warps = 4 , num_stages = 4 ),
475+ triton .Config ({"BLOCK_H" : 256 , "BLOCK_K" : 16 }, num_warps = 4 , num_stages = 4 ),
476476 ]
477477
478478
@@ -651,8 +651,9 @@ def _moe_bwd_down_proj_kernel(
651651@triton .autotune (
652652 configs = [
653653 # Keep total grid programs < 65536 for large E/H/I.
654- # Grid: (E * ceil(I/BLOCK_M), ceil(H/BLOCK_N)). Larger BLOCK_N reduces grid-y.
655- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 128 , "BLOCK_K" : 16 }, num_warps = 4 , num_stages = 2 ),
654+ # Larger BLOCK_M/BLOCK_K improve arithmetic intensity on long-sequence
655+ # backward workloads and cut K-loop overhead.
656+ triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 240 , "BLOCK_K" : 32 }, num_warps = 4 , num_stages = 2 ),
656657 ],
657658 key = ["H_dim" , "I_dim" ],
658659 reset_to_zero = ["dW2_ptr" ],
@@ -712,16 +713,16 @@ def _moe_bwd_dW2_kernel(
712713 k_mask = k_idx < M_e
713714 row_offs = expert_start + k_idx
714715
715- wact_ptrs = weighted_act_ptr + row_offs [None , : ] * stride_wact_TK + i_idx [:, None ] * stride_wact_I
716- wact_tile = tl .load (wact_ptrs , mask = k_mask [None , : ] & i_mask [:, None ], other = 0.0 )
716+ wact_ptrs = weighted_act_ptr + row_offs [:, None ] * stride_wact_TK + i_idx [None , : ] * stride_wact_I
717+ wact_tile = tl .load (wact_ptrs , mask = k_mask [:, None ] & i_mask [None , : ], other = 0.0 )
717718
718719 token_idx = tl .load (x_gather_idx_ptr + row_offs , mask = k_mask , other = 0 )
719720 dout_ptrs = dout_ptr + token_idx [:, None ] * stride_dout_T + h_idx [None , :] * stride_dout_H
720721 dout_tile = tl .load (dout_ptrs , mask = k_mask [:, None ] & h_mask [None , :], other = 0.0 )
721722
722- # dW2[h, i] = sum_t dout_g[t, h] * wa[t, i] <=> (H,T)*(T,I) with wa (T,I), dout_g (T,H )
723- # wact_tile is (I_blk, T_blk); dout_tile is (T_blk, H_blk) — previous tl.dot(wa, dout) gave (I,H), wrong layout .
724- acc += tl .dot (tl .trans (dout_tile ), tl . trans ( wact_tile ) )
723+ # dW2[h, i] = sum_t dout_g[t, h] * wa[t, i] <=> (H,T) @ (T,I)
724+ # Keep weighted_act as (T_blk, I_blk) so only one transpose is needed .
725+ acc += tl .dot (tl .trans (dout_tile ), wact_tile )
725726
726727 # acc layout is (H_blk, I_blk) — match dW2[e, h, i] with h on the first broadcast axis.
727728 dW2_ptrs = dW2_ptr + expert_idx * stride_dW2_E + h_idx [:, None ] * stride_dW2_H + i_idx [None , :] * stride_dW2_I
@@ -823,8 +824,8 @@ def _moe_bwd_dX_expanded_kernel(
823824@triton .autotune (
824825 configs = [
825826 # Keep total grid programs < 65536 for large E/H/I.
826- # Grid: (E * ceil(H/BLOCK_M), ceil(2I/BLOCK_N)). Larger BLOCK_N reduces grid-y .
827- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 128 , "BLOCK_K" : 16 }, num_warps = 4 , num_stages = 2 ),
827+ # Match dW2 tiling for better large-shape throughput .
828+ triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 240 , "BLOCK_K" : 32 }, num_warps = 4 , num_stages = 2 ),
828829 ],
829830 key = ["H_dim" , "I_dim" ],
830831 reset_to_zero = ["dW1_ptr" ],
0 commit comments