Skip to content

Commit 3036e8f

Browse files
committed
addressed comments
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 56bec34 commit 3036e8f

1 file changed

Lines changed: 40 additions & 17 deletions

File tree

  • modelopt/torch/quantization/plugins

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from ...utils.distributed import ParallelState
3535
from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer
36-
from ..utils import replace_function # pragma: no cover
3736
from .custom import CUSTOM_MODEL_PLUGINS
3837

3938
# Try multiple import paths for vLLM compatibility across versions
@@ -86,6 +85,14 @@
8685
)
8786

8887
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
88+
_FUSED_MOE_KERNEL_FUNC = next(
89+
(
90+
n
91+
for n in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel")
92+
if hasattr(vllm_fused_moe_package, n)
93+
),
94+
"",
95+
)
8996

9097

9198
@contextmanager
@@ -335,15 +342,6 @@ def _setup(self):
335342
)
336343
self.parallel_state = create_parallel_state()
337344

338-
if getattr(self, "invoke_fused_moe_kernel_func", None) is None: # pragma: no cover
339-
for name in ("invoke_fused_moe_kernel", "dispatch_fused_moe_kernel"):
340-
if hasattr(vllm_fused_moe_package, name):
341-
self.invoke_fused_moe_kernel_func = name
342-
break
343-
assert ( # pragma: no cover
344-
getattr(self, "invoke_fused_moe_kernel_func", None) is not None
345-
), "fused_moe_kernel is not found"
346-
347345
def invoke_fused_moe_quantized(
348346
self,
349347
A: torch.Tensor, # noqa: N803
@@ -389,13 +387,38 @@ def invoke_fused_moe_quantized(
389387
raise ValueError("Cannot determine first or second layer of expert")
390388

391389
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
392-
with replace_function( # pragma: no cover
393-
vllm_fused_moe_package,
394-
self.invoke_fused_moe_kernel_func,
395-
self.invoke_fused_moe_quantized,
396-
og_func_cache_name="_invoke_fused_moe_kernel",
397-
):
398-
return super().forward(hidden_states, router_logits)
390+
# This is again due to the bad coding of vLLM
391+
# fused_moe submodule is overwritten by the fused_moe function
392+
# so we need to import the fused_moe module explicitly
393+
assert (
394+
_FUSED_MOE_KERNEL_FUNC != ""
395+
and getattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, None) is not None
396+
)
397+
# This context manager will conflict with torch.compile
398+
# with replace_function(
399+
# vllm_fused_moe_package,
400+
# "invoke_fused_moe_kernel",
401+
# self.invoke_fused_moe_quantized,
402+
# ):
403+
try:
404+
original_invoke_fused_moe_kernel = getattr(
405+
vllm_fused_moe_package,
406+
_FUSED_MOE_KERNEL_FUNC,
407+
None,
408+
)
409+
setattr(
410+
vllm_fused_moe_package, "_invoke_fused_moe_kernel", original_invoke_fused_moe_kernel
411+
)
412+
setattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, self.invoke_fused_moe_quantized)
413+
output = super().forward(hidden_states, router_logits)
414+
setattr(
415+
vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel
416+
)
417+
return output
418+
finally:
419+
setattr(
420+
vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel
421+
)
399422

400423
@torch.no_grad()
401424
def fold_weight(self, keep_attrs: bool = False):

0 commit comments

Comments
 (0)