|
33 | 33 |
|
34 | 34 | from ...utils.distributed import ParallelState |
35 | 35 | from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer |
36 | | -from ..utils import replace_function # pragma: no cover |
37 | 36 | from .custom import CUSTOM_MODEL_PLUGINS |
38 | 37 |
|
39 | 38 | # Try multiple import paths for vLLM compatibility across versions |
|
86 | 85 | ) |
87 | 86 |
|
88 | 87 | vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") |
| 88 | +_FUSED_MOE_KERNEL_FUNC = next( |
| 89 | + ( |
| 90 | + n |
| 91 | + for n in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel") |
| 92 | + if hasattr(vllm_fused_moe_package, n) |
| 93 | + ), |
| 94 | + "", |
| 95 | +) |
89 | 96 |
|
90 | 97 |
|
91 | 98 | @contextmanager |
@@ -335,15 +342,6 @@ def _setup(self): |
335 | 342 | ) |
336 | 343 | self.parallel_state = create_parallel_state() |
337 | 344 |
|
338 | | - if getattr(self, "invoke_fused_moe_kernel_func", None) is None: # pragma: no cover |
339 | | - for name in ("invoke_fused_moe_kernel", "dispatch_fused_moe_kernel"): |
340 | | - if hasattr(vllm_fused_moe_package, name): |
341 | | - self.invoke_fused_moe_kernel_func = name |
342 | | - break |
343 | | - assert ( # pragma: no cover |
344 | | - getattr(self, "invoke_fused_moe_kernel_func", None) is not None |
345 | | - ), "fused_moe_kernel is not found" |
346 | | - |
347 | 345 | def invoke_fused_moe_quantized( |
348 | 346 | self, |
349 | 347 | A: torch.Tensor, # noqa: N803 |
@@ -389,13 +387,38 @@ def invoke_fused_moe_quantized( |
389 | 387 | raise ValueError("Cannot determine first or second layer of expert") |
390 | 388 |
|
391 | 389 | def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): |
392 | | - with replace_function( # pragma: no cover |
393 | | - vllm_fused_moe_package, |
394 | | - self.invoke_fused_moe_kernel_func, |
395 | | - self.invoke_fused_moe_quantized, |
396 | | - og_func_cache_name="_invoke_fused_moe_kernel", |
397 | | - ): |
398 | | - return super().forward(hidden_states, router_logits) |
| 390 | + # This is again due to the bad coding of vLLM |
| 391 | + # fused_moe submodule is overwritten by the fused_moe function |
| 392 | + # so we need to import the fused_moe module explicitly |
| 393 | + assert ( |
| 394 | + _FUSED_MOE_KERNEL_FUNC != "" |
| 395 | + and getattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, None) is not None |
| 396 | + ) |
| 397 | + # This context manager will conflict with torch.compile |
| 398 | + # with replace_function( |
| 399 | + # vllm_fused_moe_package, |
| 400 | + # "invoke_fused_moe_kernel", |
| 401 | + # self.invoke_fused_moe_quantized, |
| 402 | + # ): |
| 403 | + try: |
| 404 | + original_invoke_fused_moe_kernel = getattr( |
| 405 | + vllm_fused_moe_package, |
| 406 | + _FUSED_MOE_KERNEL_FUNC, |
| 407 | + None, |
| 408 | + ) |
| 409 | + setattr( |
| 410 | + vllm_fused_moe_package, "_invoke_fused_moe_kernel", original_invoke_fused_moe_kernel |
| 411 | + ) |
| 412 | + setattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, self.invoke_fused_moe_quantized) |
| 413 | + output = super().forward(hidden_states, router_logits) |
| 414 | + setattr( |
| 415 | + vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel |
| 416 | + ) |
| 417 | + return output |
| 418 | + finally: |
| 419 | + setattr( |
| 420 | + vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel |
| 421 | + ) |
399 | 422 |
|
400 | 423 | @torch.no_grad() |
401 | 424 | def fold_weight(self, keep_attrs: bool = False): |
|
0 commit comments