|
33 | 33 |
|
34 | 34 | from ...utils.distributed import ParallelState |
35 | 35 | from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer |
| 36 | +from ..utils import replace_function |
36 | 37 | from .custom import CUSTOM_MODEL_PLUGINS |
37 | 38 |
|
38 | 39 | # Try multiple import paths for vLLM compatibility across versions |
|
86 | 87 |
|
87 | 88 | vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") |
88 | 89 |
|
| 90 | +_vllm_fused_moe_invoke_name_cache: str | None = None |
| 91 | + |
| 92 | + |
| 93 | +def _vllm_fused_moe_invoke_name() -> str: |
| 94 | + """Return the vLLM public fused_moe entrypoint (renamed across versions).""" |
| 95 | + global _vllm_fused_moe_invoke_name_cache |
| 96 | + if _vllm_fused_moe_invoke_name_cache is not None: |
| 97 | + return _vllm_fused_moe_invoke_name_cache |
| 98 | + for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"): |
| 99 | + if hasattr(vllm_fused_moe_package, name): |
| 100 | + _vllm_fused_moe_invoke_name_cache = name |
| 101 | + return name |
| 102 | + raise ValueError("fused_moe_kernel is not found") |
| 103 | + |
89 | 104 |
|
90 | 105 | @contextmanager |
91 | 106 | def disable_compilation(model): |
@@ -346,45 +361,39 @@ def invoke_fused_moe_quantized( |
346 | 361 | # First layer of expert |
347 | 362 | A = self.w13_input_quantizer(A) # noqa: N806 |
348 | 363 | if self.w13_weight_quantizer.is_enabled: |
349 | | - orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight) |
| 364 | + original_weight, self.w13_weight = ( |
| 365 | + self.w13_weight, |
| 366 | + self.w13_weight_quantizer(self.w13_weight), |
| 367 | + ) |
350 | 368 | vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) |
351 | | - self.w13_weight = orig |
| 369 | + self.w13_weight = original_weight |
352 | 370 | else: |
353 | 371 | vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) |
354 | 372 | if self.w13_output_quantizer.is_enabled: |
355 | 373 | C[:] = self.w13_output_quantizer(C) |
356 | 374 | elif B is self.w2_weight: |
357 | 375 | A = self.w2_input_quantizer(A) # noqa: N806 |
358 | 376 | if self.w2_weight_quantizer.is_enabled: |
359 | | - orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight) |
| 377 | + original_weight, self.w2_weight = ( |
| 378 | + self.w2_weight, |
| 379 | + self.w2_weight_quantizer(self.w2_weight), |
| 380 | + ) |
360 | 381 | vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) |
361 | | - self.w2_weight = orig |
| 382 | + self.w2_weight = original_weight |
362 | 383 | else: |
363 | 384 | vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs) |
364 | 385 | if self.w2_output_quantizer.is_enabled: |
365 | 386 | C[:] = self.w2_output_quantizer(C) |
366 | 387 | else: |
367 | 388 | raise ValueError("Cannot determine first or second layer of expert") |
368 | 389 |
|
369 | | - @contextmanager |
370 | | - def _patch_moe_kernel(self): |
371 | | - """Temporarily replace vLLM fused_moe kernel with quantized version.""" |
372 | | - # `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed |
373 | | - # to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1. |
374 | | - for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]: |
375 | | - if hasattr(vllm_fused_moe_package, attr): |
376 | | - orig = getattr(vllm_fused_moe_package, attr) |
377 | | - setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig) |
378 | | - setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized) |
379 | | - try: |
380 | | - yield |
381 | | - finally: |
382 | | - setattr(vllm_fused_moe_package, attr, orig) |
383 | | - return |
384 | | - raise ValueError("fused_moe_kernel is not found") |
385 | | - |
386 | 390 | def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): |
387 | | - with self._patch_moe_kernel(): |
| 391 | + with replace_function( |
| 392 | + vllm_fused_moe_package, |
| 393 | + _vllm_fused_moe_invoke_name(), |
| 394 | + self.invoke_fused_moe_quantized, |
| 395 | + og_func_cache_name="_invoke_fused_moe_kernel", |
| 396 | + ): |
388 | 397 | return super().forward(hidden_states, router_logits) |
389 | 398 |
|
390 | 399 | @torch.no_grad() |
|
0 commit comments