Skip to content

Commit dabd601

Browse files
committed
lint code
1 parent 7ee78a1 commit dabd601

4 files changed

Lines changed: 23 additions & 209 deletions

File tree

benchmark/scripts/benchmark_fused_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def _setup_fused_moe(input: SingleBenchmarkRunInput):
106106
if input.kernel_provider == "liger":
107107

108108
def fwd_fn():
109-
return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts).to(device)
109+
return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts)
110110
elif input.kernel_provider == "huggingface":
111111

112112
def fwd_fn():
113-
return _huggingface_moe_forward(x, gup, dn, idx, wts).to(device)
113+
return _huggingface_moe_forward(x, gup, dn, idx, wts)
114114
else:
115115
raise ValueError(f"Unknown provider: {input.kernel_provider}")
116116

@@ -157,9 +157,9 @@ def _warmup_liger(T, E, H, intermediate_dim, K, dtype, sweep_dim):
157157
warmup_out = warmup_fn()
158158
warmup_out.sum().backward()
159159
del warmup_out
160-
if device == "cuda" and torch.cuda.is_available():
160+
if device == "cuda":
161161
torch.cuda.synchronize()
162-
elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available():
162+
elif device == "npu":
163163
torch.npu.synchronize()
164164

165165

@@ -234,9 +234,9 @@ def _probe():
234234
print(f" warmup E={e_val}...")
235235
_warmup_liger(probe_T, e_val, H, intermediate_dim, K, dtype, sweep_dim="E")
236236

237-
if device == "cuda" and torch.cuda.is_available():
237+
if device == "cuda":
238238
torch.cuda.synchronize()
239-
elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available():
239+
elif device == "npu":
240240
torch.npu.synchronize()
241241

242242
print("Autotune warmup complete.\n")

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,7 @@
150150
"fused_linear_cross_entropy_backward",
151151
"LigerFusedMoEFunction",
152152
"compute_routing_metadata",
153+
"LigerMHCCoeffsFunction",
154+
"LigerMHCPreFunction",
155+
"LigerMHCPostResFunction",
153156
]

src/liger_kernel/ops/backends/_ascend/ops/fused_moe.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch
99
import triton
1010

11+
from liger_kernel.ops.utils import ensure_contiguous
12+
1113
from .fused_moe_kernels import _fused_down_proj_kernel
1214
from .fused_moe_kernels import _fused_up_proj_swiglu_kernel
1315
from .fused_moe_kernels import _moe_bwd_down_proj_kernel
@@ -18,21 +20,11 @@
1820
from .fused_moe_kernels import _moe_router_prefix_sum_kernel
1921
from .fused_moe_kernels import _moe_router_scatter_kernel
2022
from .fused_moe_kernels import _token_gather_weighted_sum_kernel
21-
from liger_kernel.ops.utils import ensure_contiguous
2223

23-
# Token-dimension tile size for M.
24-
# Not in the inner-loop autotune because tile_row_start/tile_expert and the
25-
# grid dim-0 (num_m_tiles) must be recomputed for every candidate value.
26-
# To tune: change this constant and re-run benchmarks.
27-
# Smaller M-tiles compile/run reliably on Ascend (UB scratch is tight vs CUDA).
24+
# Token-dimension tile size used across fused MoE kernels.
2825
BLOCK_M_TOKEN = 32
2926

3027

31-
# ---------------------------------------------------------------------------
32-
# Routing metadata
33-
# ---------------------------------------------------------------------------
34-
35-
3628
def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token: int = BLOCK_M_TOKEN):
3729
"""Compute token→expert routing permutation metadata via 3 Triton kernels.
3830
@@ -61,7 +53,6 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
6153
TOKENS_PER_BLOCK = max(1, 1024 // K_POW2)
6254
n_tiles = triton.cdiv(T, TOKENS_PER_BLOCK)
6355

64-
# Kernel 1: tiled histogram → tile_expert_counts (E, n_tiles)
6556
tile_expert_counts = torch.empty(E, n_tiles, dtype=torch.int32, device=device)
6657
_moe_router_histogram_kernel[(n_tiles,)](
6758
topk_indices,
@@ -77,7 +68,6 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
7768

7869
expert_token_count = tile_expert_counts.sum(dim=1, dtype=torch.int32) # (E,)
7970

80-
# Kernel 2: prefix sums + expert offsets + tile offsets (all in one pass)
8171
expert_start_idx = torch.empty(E + 1, dtype=torch.int32, device=device)
8272
expert_tile_offset = torch.empty(E + 1, dtype=torch.int32, device=device)
8373
_moe_router_prefix_sum_kernel[(E + 2,)](
@@ -93,19 +83,16 @@ def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token:
9383
BLOCK_M_TOKEN=block_m_token,
9484
)
9585

96-
# One sync to get num_m_tiles for buffer allocation and GEMM grid.
9786
num_m_tiles = int(expert_tile_offset[-1].item())
9887

9988
tile_row_start = torch.empty(num_m_tiles, dtype=torch.int32, device=device)
10089
tile_expert = torch.empty(num_m_tiles, dtype=torch.int32, device=device)
10190

102-
# Kernel 3: sort by expert + scatter permutation arrays + tile metadata
10391
s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
10492
s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
10593
x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device)
10694

10795
if TK > 0:
108-
# Per-tile per-expert atomic counters (zeroed); used by Ascend-safe scatter.
10996
scatter_rank_scratch = torch.zeros((n_tiles, E), dtype=torch.int32, device=device)
11097
_moe_router_scatter_kernel[(n_tiles,)](
11198
s_scatter_idx,
@@ -159,11 +146,6 @@ def _token_aggregation(Y, topk_weights_flat, s_reverse_scatter_idx, T, K, H):
159146
return out
160147

161148

162-
# ---------------------------------------------------------------------------
163-
# Autograd Function
164-
# ---------------------------------------------------------------------------
165-
166-
167149
class LigerFusedMoEFunction(torch.autograd.Function):
168150
"""Fused grouped GEMM MoE forward + memory-efficient backward.
169151
@@ -308,10 +290,9 @@ def backward(ctx, dO):
308290
TK = ctx.TK
309291
num_m_tiles = ctx.num_m_tiles
310292

311-
# dA' = dO @ W2^T, SwiGLU backward, write d_pre_act and dS
312293
d_pre_act = torch.empty(TK, 2 * intermediate_dim, dtype=dO.dtype, device=dO.device)
313294
weighted_act = torch.empty(TK, intermediate_dim, dtype=dO.dtype, device=dO.device)
314-
dS = torch.zeros(TK, dtype=dO.dtype, device=dO.device) # zeros: atomic_add in kernel accumulates across N-tiles
295+
dS = torch.zeros(TK, dtype=dO.dtype, device=dO.device) # atomic_add accumulates across N-tiles
315296

316297
if num_m_tiles > 0:
317298
_moe_bwd_down_proj_kernel[lambda meta: (num_m_tiles, triton.cdiv(intermediate_dim, meta["BLOCK_N"]))](
@@ -343,7 +324,6 @@ def backward(ctx, dO):
343324
BLOCK_M=BLOCK_M_TOKEN,
344325
)
345326

346-
# dW2 = (s_k * y1)^T @ dO_gathered
347327
ddown_proj = torch.zeros_like(down_proj)
348328
_moe_bwd_dW2_kernel[
349329
lambda meta: (
@@ -367,7 +347,6 @@ def backward(ctx, dO):
367347
stride_dW2_I=ddown_proj.stride(2),
368348
)
369349

370-
# dx_expanded = d_pre_act @ W1^T
371350
dx_expanded = torch.empty(TK, H, dtype=dO.dtype, device=dO.device)
372351

373352
if num_m_tiles > 0:
@@ -390,12 +369,11 @@ def backward(ctx, dO):
390369
BLOCK_M=BLOCK_M_TOKEN,
391370
)
392371

393-
# dx = unweighted gather-sum of dx_expanded
394372
dx = torch.zeros(T, H, dtype=dO.dtype, device=dO.device)
395373
if TK > 0:
396374
_token_gather_weighted_sum_kernel[(T,)](
397375
dx_expanded,
398-
dS, # dummy w_ptr — never loaded when w_is_None=True
376+
dS, # dummy w_ptr; unused when w_is_None=True
399377
s_reverse_scatter_idx,
400378
dx,
401379
H_dim=H,
@@ -407,7 +385,6 @@ def backward(ctx, dO):
407385
w_is_None=True,
408386
)
409387

410-
# dW1 = X_gathered^T @ d_pre_act
411388
dgate_up_proj = torch.zeros_like(gate_up_proj)
412389
_moe_bwd_dW1_kernel[
413390
lambda meta: (

0 commit comments

Comments
 (0)