|
24 | 24 | from utils import run_memory_benchmark |
25 | 25 | from utils import run_speed_benchmark |
26 | 26 |
|
27 | | -from liger_kernel.ops.fused_moe import LigerFusedMoEFunction |
| 27 | +from liger_kernel.ops import LigerFusedMoEFunction |
28 | 28 | from liger_kernel.utils import get_total_gpu_memory |
29 | 29 | from liger_kernel.utils import infer_device |
30 | 30 |
|
@@ -106,11 +106,11 @@ def _setup_fused_moe(input: SingleBenchmarkRunInput): |
106 | 106 | if input.kernel_provider == "liger": |
107 | 107 |
|
108 | 108 | def fwd_fn(): |
109 | | - return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts) |
| 109 | + return LigerFusedMoEFunction.apply(x, gup, dn, idx, wts).to(device) |
110 | 110 | elif input.kernel_provider == "huggingface": |
111 | 111 |
|
112 | 112 | def fwd_fn(): |
113 | | - return _huggingface_moe_forward(x, gup, dn, idx, wts) |
| 113 | + return _huggingface_moe_forward(x, gup, dn, idx, wts).to(device) |
114 | 114 | else: |
115 | 115 | raise ValueError(f"Unknown provider: {input.kernel_provider}") |
116 | 116 |
|
@@ -157,7 +157,10 @@ def _warmup_liger(T, E, H, intermediate_dim, K, dtype, sweep_dim): |
157 | 157 | warmup_out = warmup_fn() |
158 | 158 | warmup_out.sum().backward() |
159 | 159 | del warmup_out |
160 | | - torch.cuda.synchronize() |
| 160 | + if device == "cuda" and torch.cuda.is_available(): |
| 161 | + torch.cuda.synchronize() |
| 162 | + elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available(): |
| 163 | + torch.npu.synchronize() |
161 | 164 |
|
162 | 165 |
|
163 | 166 | # --------------------------------------------------------------------------- |
@@ -231,7 +234,11 @@ def _probe(): |
231 | 234 | print(f" warmup E={e_val}...") |
232 | 235 | _warmup_liger(probe_T, e_val, H, intermediate_dim, K, dtype, sweep_dim="E") |
233 | 236 |
|
234 | | - torch.cuda.synchronize() |
| 237 | + if device == "cuda" and torch.cuda.is_available(): |
| 238 | + torch.cuda.synchronize() |
| 239 | + elif device == "npu" and hasattr(torch, "npu") and torch.npu.is_available(): |
| 240 | + torch.npu.synchronize() |
| 241 | + |
235 | 242 | print("Autotune warmup complete.\n") |
236 | 243 |
|
237 | 244 | if args.sweep_dim == "num_tokens": |
|
0 commit comments