Skip to content

Commit e4831e4

Browse files
authored
[NPU] Add NPU Fused MoE kernel (#1183)
## Motivation This pr ports `fused_moe.py` and `fused_moe_kernels.py` to an NPU-affine implementation while preserving the original math. The computational definition is unchanged: forward remains `W1 (gate/up) -> SwiGLU -> W2 -> token-weighted gather`, and backward still follows `dA' = dO @ W2^T` to produce `d_pre_act / dS / dW2 / dX / dW1`. The main changes are execution-strategy optimizations for NPU. ## Note: Use the Skill For this fused_moe kernel migration, we followed the skill document from #1197. ## Testing Done - Hardware Type: Ascend 910B2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence 🤖 Generated with: [cursor](https://cursor.com/).
1 parent dcd404b commit e4831e4

5 files changed

Lines changed: 1166 additions & 5 deletions

File tree

benchmark/scripts/benchmark_fused_moe.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from utils import run_memory_benchmark
2525
from utils import run_speed_benchmark
2626

27-
from liger_kernel.ops.fused_moe import LigerFusedMoEFunction
27+
from liger_kernel.ops import LigerFusedMoEFunction
2828
from liger_kernel.utils import get_total_gpu_memory
2929
from liger_kernel.utils import infer_device
3030

@@ -157,7 +157,14 @@ def _warmup_liger(T, E, H, intermediate_dim, K, dtype, sweep_dim):
157157
warmup_out = warmup_fn()
158158
warmup_out.sum().backward()
159159
del warmup_out
160-
torch.cuda.synchronize()
160+
if device == "cuda":
161+
torch.cuda.synchronize()
162+
elif device == "npu":
163+
torch.npu.synchronize()
164+
elif device == "xpu":
165+
torch.xpu.synchronize()
166+
else:
167+
torch.cpu.synchronize()
161168

162169

163170
# ---------------------------------------------------------------------------
@@ -231,7 +238,15 @@ def _probe():
231238
print(f" warmup E={e_val}...")
232239
_warmup_liger(probe_T, e_val, H, intermediate_dim, K, dtype, sweep_dim="E")
233240

234-
torch.cuda.synchronize()
241+
if device == "cuda":
242+
torch.cuda.synchronize()
243+
elif device == "npu":
244+
torch.npu.synchronize()
245+
elif device == "xpu":
246+
torch.xpu.synchronize()
247+
else:
248+
torch.cpu.synchronize()
249+
235250
print("Autotune warmup complete.\n")
236251

237252
if args.sweep_dim == "num_tokens":

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
3333
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward
3434
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward
35+
from liger_kernel.ops.backends._ascend.ops.fused_moe import LigerFusedMoEFunction
36+
from liger_kernel.ops.backends._ascend.ops.fused_moe import compute_routing_metadata
3537
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
3638
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import fused_neighborhood_attention_forward
3739
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
@@ -149,6 +151,8 @@
149151
"LigerFusedLinearCrossEntropyFunction",
150152
"fused_linear_cross_entropy_forward",
151153
"fused_linear_cross_entropy_backward",
154+
"LigerFusedMoEFunction",
155+
"compute_routing_metadata",
152156
"LigerMHCCoeffsFunction",
153157
"LigerMHCPreFunction",
154158
"LigerMHCPostResFunction",

0 commit comments

Comments
 (0)