88import torch
99import triton
1010
11+ from liger_kernel .ops .utils import ensure_contiguous
12+
1113from .fused_moe_kernels import _fused_down_proj_kernel
1214from .fused_moe_kernels import _fused_up_proj_swiglu_kernel
1315from .fused_moe_kernels import _moe_bwd_down_proj_kernel
1820from .fused_moe_kernels import _moe_router_prefix_sum_kernel
1921from .fused_moe_kernels import _moe_router_scatter_kernel
2022from .fused_moe_kernels import _token_gather_weighted_sum_kernel
21- from liger_kernel .ops .utils import ensure_contiguous
2223
23- # Token-dimension tile size for M.
24- # Not in the inner-loop autotune because tile_row_start/tile_expert and the
25- # grid dim-0 (num_m_tiles) must be recomputed for every candidate value.
26- # To tune: change this constant and re-run benchmarks.
27- # Smaller M-tiles compile/run reliably on Ascend (UB scratch is tight vs CUDA).
24+ # Token-dimension tile size used across fused MoE kernels.
2825BLOCK_M_TOKEN = 32
2926
3027
31- # ---------------------------------------------------------------------------
32- # Routing metadata
33- # ---------------------------------------------------------------------------
34-
35-
3628def compute_routing_metadata (topk_indices : torch .Tensor , E : int , block_m_token : int = BLOCK_M_TOKEN ):
3729 """Compute token→expert routing permutation metadata via 3 Triton kernels.
3830
@@ -61,7 +53,6 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
6153 TOKENS_PER_BLOCK = max (1 , 1024 // K_POW2 )
6254 n_tiles = triton .cdiv (T , TOKENS_PER_BLOCK )
6355
64- # Kernel 1: tiled histogram → tile_expert_counts (E, n_tiles)
6556 tile_expert_counts = torch .empty (E , n_tiles , dtype = torch .int32 , device = device )
6657 _moe_router_histogram_kernel [(n_tiles ,)](
6758 topk_indices ,
@@ -77,7 +68,6 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
7768
7869 expert_token_count = tile_expert_counts .sum (dim = 1 , dtype = torch .int32 ) # (E,)
7970
80- # Kernel 2: prefix sums + expert offsets + tile offsets (all in one pass)
8171 expert_start_idx = torch .empty (E + 1 , dtype = torch .int32 , device = device )
8272 expert_tile_offset = torch .empty (E + 1 , dtype = torch .int32 , device = device )
8373 _moe_router_prefix_sum_kernel [(E + 2 ,)](
@@ -93,19 +83,16 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
9383 BLOCK_M_TOKEN = block_m_token ,
9484 )
9585
96- # One sync to get num_m_tiles for buffer allocation and GEMM grid.
9786 num_m_tiles = int (expert_tile_offset [- 1 ].item ())
9887
9988 tile_row_start = torch .empty (num_m_tiles , dtype = torch .int32 , device = device )
10089 tile_expert = torch .empty (num_m_tiles , dtype = torch .int32 , device = device )
10190
102- # Kernel 3: sort by expert + scatter permutation arrays + tile metadata
10391 s_scatter_idx = torch .empty (TK , dtype = torch .int32 , device = device )
10492 s_reverse_scatter_idx = torch .empty (TK , dtype = torch .int32 , device = device )
10593 x_gather_idx = torch .empty (TK , dtype = torch .int32 , device = device )
10694
10795 if TK > 0 :
108- # Per-tile per-expert atomic counters (zeroed); used by Ascend-safe scatter.
10996 scatter_rank_scratch = torch .zeros ((n_tiles , E ), dtype = torch .int32 , device = device )
11097 _moe_router_scatter_kernel [(n_tiles ,)](
11198 s_scatter_idx ,
@@ -159,11 +146,6 @@ def _token_aggregation(Y, topk_weights_flat, s_reverse_scatter_idx, T, K, H):
159146 return out
160147
161148
162- # ---------------------------------------------------------------------------
163- # Autograd Function
164- # ---------------------------------------------------------------------------
165-
166-
167149class LigerFusedMoEFunction (torch .autograd .Function ):
168150 """Fused grouped GEMM MoE forward + memory-efficient backward.
169151
@@ -308,10 +290,9 @@ def backward(ctx, dO):
308290 TK = ctx .TK
309291 num_m_tiles = ctx .num_m_tiles
310292
311- # dA' = dO @ W2^T, SwiGLU backward, write d_pre_act and dS
312293 d_pre_act = torch .empty (TK , 2 * intermediate_dim , dtype = dO .dtype , device = dO .device )
313294 weighted_act = torch .empty (TK , intermediate_dim , dtype = dO .dtype , device = dO .device )
314- dS = torch .zeros (TK , dtype = dO .dtype , device = dO .device ) # zeros: atomic_add in kernel accumulates across N-tiles
295+ dS = torch .zeros (TK , dtype = dO .dtype , device = dO .device ) # atomic_add accumulates across N-tiles
315296
316297 if num_m_tiles > 0 :
317298 _moe_bwd_down_proj_kernel [lambda meta : (num_m_tiles , triton .cdiv (intermediate_dim , meta ["BLOCK_N" ]))](
@@ -343,7 +324,6 @@ def backward(ctx, dO):
343324 BLOCK_M = BLOCK_M_TOKEN ,
344325 )
345326
346- # dW2 = (s_k * y1)^T @ dO_gathered
347327 ddown_proj = torch .zeros_like (down_proj )
348328 _moe_bwd_dW2_kernel [
349329 lambda meta : (
@@ -367,7 +347,6 @@ def backward(ctx, dO):
367347 stride_dW2_I = ddown_proj .stride (2 ),
368348 )
369349
370- # dx_expanded = d_pre_act @ W1^T
371350 dx_expanded = torch .empty (TK , H , dtype = dO .dtype , device = dO .device )
372351
373352 if num_m_tiles > 0 :
@@ -390,12 +369,11 @@ def backward(ctx, dO):
390369 BLOCK_M = BLOCK_M_TOKEN ,
391370 )
392371
393- # dx = unweighted gather-sum of dx_expanded
394372 dx = torch .zeros (T , H , dtype = dO .dtype , device = dO .device )
395373 if TK > 0 :
396374 _token_gather_weighted_sum_kernel [(T ,)](
397375 dx_expanded ,
398- dS , # dummy w_ptr — never loaded when w_is_None=True
376+ dS , # dummy w_ptr; unused when w_is_None=True
399377 s_reverse_scatter_idx ,
400378 dx ,
401379 H_dim = H ,
@@ -407,7 +385,6 @@ def backward(ctx, dO):
407385 w_is_None = True ,
408386 )
409387
410- # dW1 = X_gathered^T @ d_pre_act
411388 dgate_up_proj = torch .zeros_like (gate_up_proj )
412389 _moe_bwd_dW1_kernel [
413390 lambda meta : (
0 commit comments