Skip to content

Commit d70e6ef

Browse files
committed
support npu fused moe kernel
1 parent 1faa013 commit d70e6ef

5 files changed

Lines changed: 1348 additions & 7 deletions

File tree

benchmark/scripts/benchmark_fused_moe.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from utils import run_memory_benchmark
2525
from utils import run_speed_benchmark
2626

27-
from liger_kernel.ops.fused_moe import LigerFusedMoEFunction
27+
from liger_kernel.ops import LigerFusedMoEFunction
2828
from liger_kernel.utils import get_total_gpu_memory
2929
from liger_kernel.utils import infer_device
3030

@@ -106,11 +106,11 @@ def _setup_fused_moe(input: SingleBenchmarkRunInput):
106106
if input.kernel_provider == "liger":
107107

108108
def fwd_fn():
109-
return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts)
109+
return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts).to(device)
110110
elif input.kernel_provider == "huggingface":
111111

112112
def fwd_fn():
113-
return _huggingface_moe_forward(x, gup, dn, idx, wts)
113+
return _huggingface_moe_forward(x, gup, dn, idx, wts).to(device)
114114
else:
115115
raise ValueError(f"Unknown provider: {input.kernel_provider}")
116116

@@ -157,7 +157,10 @@ def _warmup_liger(T, E, H, intermediate_dim, K, dtype, sweep_dim):
157157
warmup_out = warmup_fn()
158158
warmup_out.sum().backward()
159159
del warmup_out
160-
torch.cuda.synchronize()
160+
if device == "cuda" and torch.cuda.is_available():
161+
torch.cuda.synchronize()
162+
elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available():
163+
torch.npu.synchronize()
161164

162165

163166
# ---------------------------------------------------------------------------
@@ -231,7 +234,11 @@ def _probe():
231234
print(f" warmup E={e_val}...")
232235
_warmup_liger(probe_T, e_val, H, intermediate_dim, K, dtype, sweep_dim="E")
233236

234-
torch.cuda.synchronize()
237+
if device == "cuda" and torch.cuda.is_available():
238+
torch.cuda.synchronize()
239+
elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available():
240+
torch.npu.synchronize()
241+
235242
print("Autotune warmup complete.\n")
236243

237244
if args.sweep_dim == "num_tokens":

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
3333
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward
3434
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward
35+
from liger_kernel.ops.backends._ascend.ops.fused_moe import LigerFusedMoEFunction
36+
from liger_kernel.ops.backends._ascend.ops.fused_moe import compute_routing_metadata
3537
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
3638
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import fused_neighborhood_attention_forward
3739
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
@@ -146,4 +148,6 @@
146148
"LigerFusedLinearCrossEntropyFunction",
147149
"fused_linear_cross_entropy_forward",
148150
"fused_linear_cross_entropy_backward",
151+
"LigerFusedMoEFunction",
152+
"compute_routing_metadata",
149153
]

0 commit comments

Comments
 (0)