Skip to content

Commit b6b6d1e

Browse files
committed
using replace_function
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent cca192f commit b6b6d1e

1 file changed

Lines changed: 31 additions & 22 deletions

File tree

  • modelopt/torch/quantization/plugins

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from ...utils.distributed import ParallelState
3535
from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer
36+
from ..utils import replace_function
3637
from .custom import CUSTOM_MODEL_PLUGINS
3738

3839
# Try multiple import paths for vLLM compatibility across versions
@@ -86,6 +87,20 @@
8687

8788
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
8889

90+
_vllm_fused_moe_invoke_name_cache: str | None = None
91+
92+
93+
def _vllm_fused_moe_invoke_name() -> str:
94+
"""Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95+
global _vllm_fused_moe_invoke_name_cache
96+
if _vllm_fused_moe_invoke_name_cache is not None:
97+
return _vllm_fused_moe_invoke_name_cache
98+
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
99+
if hasattr(vllm_fused_moe_package, name):
100+
_vllm_fused_moe_invoke_name_cache = name
101+
return name
102+
raise ValueError("fused_moe_kernel is not found")
103+
89104

90105
@contextmanager
91106
def disable_compilation(model):
@@ -346,45 +361,39 @@ def invoke_fused_moe_quantized(
346361
# First layer of expert
347362
A = self.w13_input_quantizer(A) # noqa: N806
348363
if self.w13_weight_quantizer.is_enabled:
349-
orig, self.w13_weight = self.w13_weight, self.w13_weight_quantizer(self.w13_weight)
364+
original_weight, self.w13_weight = (
365+
self.w13_weight,
366+
self.w13_weight_quantizer(self.w13_weight),
367+
)
350368
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
351-
self.w13_weight = orig
369+
self.w13_weight = original_weight
352370
else:
353371
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
354372
if self.w13_output_quantizer.is_enabled:
355373
C[:] = self.w13_output_quantizer(C)
356374
elif B is self.w2_weight:
357375
A = self.w2_input_quantizer(A) # noqa: N806
358376
if self.w2_weight_quantizer.is_enabled:
359-
orig, self.w2_weight = self.w2_weight, self.w2_weight_quantizer(self.w2_weight)
377+
original_weight, self.w2_weight = (
378+
self.w2_weight,
379+
self.w2_weight_quantizer(self.w2_weight),
380+
)
360381
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
361-
self.w2_weight = orig
382+
self.w2_weight = original_weight
362383
else:
363384
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
364385
if self.w2_output_quantizer.is_enabled:
365386
C[:] = self.w2_output_quantizer(C)
366387
else:
367388
raise ValueError("Cannot determine first or second layer of expert")
368389

369-
@contextmanager
370-
def _patch_moe_kernel(self):
371-
"""Temporarily replace vLLM fused_moe kernel with quantized version."""
372-
# `invoke_fused_moe_kernel` was used through v0.14.0rc0; it was renamed
373-
# to `invoke_fused_moe_triton_kernel` starting from v0.14.0rc1.
374-
for attr in ["invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"]:
375-
if hasattr(vllm_fused_moe_package, attr):
376-
orig = getattr(vllm_fused_moe_package, attr)
377-
setattr(vllm_fused_moe_package, "_invoke_fused_moe_kernel", orig)
378-
setattr(vllm_fused_moe_package, attr, self.invoke_fused_moe_quantized)
379-
try:
380-
yield
381-
finally:
382-
setattr(vllm_fused_moe_package, attr, orig)
383-
return
384-
raise ValueError("fused_moe_kernel is not found")
385-
386390
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
387-
with self._patch_moe_kernel():
391+
with replace_function(
392+
vllm_fused_moe_package,
393+
_vllm_fused_moe_invoke_name(),
394+
self.invoke_fused_moe_quantized,
395+
og_func_cache_name="_invoke_fused_moe_kernel",
396+
):
388397
return super().forward(hidden_states, router_logits)
389398

390399
@torch.no_grad()

0 commit comments

Comments
 (0)