Skip to content

Commit cfb9396

Browse files
committed
Patching all functions which are present for fused_moe
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 687ceea commit cfb9396

File tree

1 file changed

+58
-30
lines changed
  • modelopt/torch/quantization/plugins

1 file changed

+58
-30
lines changed

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
"""Support quantization for VLLM layers."""
1717

18+
import contextvars
1819
import importlib
20+
from collections.abc import Callable
1921
from contextlib import contextmanager
22+
from functools import partial
2023
from itertools import chain
2124

2225
import torch
@@ -85,13 +88,20 @@
8588
)
8689

8790
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
88-
_FUSED_MOE_KERNEL_FUNC = next(
89-
(
90-
n
91-
for n in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel")
92-
if hasattr(vllm_fused_moe_package, n)
93-
),
94-
"",
91+
# vLLM may call one entry (e.g. ``dispatch_fused_moe_kernel``) which then calls another on the same
92+
# module (e.g. ``invoke_fused_moe_triton_kernel``). Patching every name would otherwise apply fakequant
93+
# twice; see ``_moe_fakequant_active`` in ``invoke_fused_moe_quantized``.
94+
_FUSED_MOE_KERNEL_CANDIDATES = (
95+
"invoke_fused_moe_kernel",
96+
"invoke_fused_moe_triton_kernel",
97+
"dispatch_fused_moe_kernel",
98+
)
99+
_FUSED_MOE_KERNEL_FUNCS = tuple(
100+
n for n in _FUSED_MOE_KERNEL_CANDIDATES if hasattr(vllm_fused_moe_package, n)
101+
)
102+
103+
_moe_fakequant_active: contextvars.ContextVar[bool] = contextvars.ContextVar(
104+
"moe_fakequant_active", default=False
95105
)
96106

97107

@@ -348,6 +358,27 @@ def invoke_fused_moe_quantized(
348358
B: torch.Tensor, # noqa: N803
349359
C: torch.Tensor, # noqa: N803
350360
*args,
361+
original_kernel: Callable,
362+
**kwargs,
363+
):
364+
# Nested module-level entry (e.g. dispatch -> triton): call the real kernel once, no second quant.
365+
if _moe_fakequant_active.get():
366+
return original_kernel(A, B, C, *args, **kwargs)
367+
token = _moe_fakequant_active.set(True)
368+
try:
369+
return self._invoke_fused_moe_quantized_function(
370+
A, B, C, *args, original_kernel=original_kernel, **kwargs
371+
)
372+
finally:
373+
_moe_fakequant_active.reset(token)
374+
375+
def _invoke_fused_moe_quantized_function(
376+
self,
377+
A: torch.Tensor, # noqa: N803
378+
B: torch.Tensor, # noqa: N803
379+
C: torch.Tensor, # noqa: N803
380+
*args,
381+
original_kernel: Callable,
351382
**kwargs,
352383
):
353384
if B is self.w13_weight:
@@ -361,10 +392,10 @@ def invoke_fused_moe_quantized(
361392
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
362393
# quantized weight to the kernel.
363394
B = self.w13_weight # noqa: N806
364-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
395+
original_kernel(A, B, C, *args, **kwargs)
365396
self.w13_weight = original_weight
366397
else:
367-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
398+
original_kernel(A, B, C, *args, **kwargs)
368399
if self.w13_output_quantizer.is_enabled:
369400
C[:] = self.w13_output_quantizer(C)
370401
elif B is self.w2_weight:
@@ -377,10 +408,10 @@ def invoke_fused_moe_quantized(
377408
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
378409
# quantized weight to the kernel.
379410
B = self.w2_weight # noqa: N806
380-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
411+
original_kernel(A, B, C, *args, **kwargs)
381412
self.w2_weight = original_weight
382413
else:
383-
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
414+
original_kernel(A, B, C, *args, **kwargs)
384415
if self.w2_output_quantizer.is_enabled:
385416
C[:] = self.w2_output_quantizer(C)
386417
else:
@@ -390,35 +421,31 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
390421
# This is again due to the bad coding of vLLM
391422
# fused_moe submodule is overwritten by the fused_moe function
392423
# so we need to import the fused_moe module explicitly
393-
assert (
394-
_FUSED_MOE_KERNEL_FUNC != ""
395-
and getattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, None) is not None
424+
assert _FUSED_MOE_KERNEL_FUNCS and all(
425+
getattr(vllm_fused_moe_package, n, None) is not None for n in _FUSED_MOE_KERNEL_FUNCS
396426
)
397427
# This context manager will conflict with torch.compile
398428
# with replace_function(
399429
# vllm_fused_moe_package,
400430
# "invoke_fused_moe_kernel",
401431
# self.invoke_fused_moe_quantized,
402432
# ):
433+
originals = {n: getattr(vllm_fused_moe_package, n) for n in _FUSED_MOE_KERNEL_FUNCS}
403434
try:
404-
original_invoke_fused_moe_kernel = getattr(
405-
vllm_fused_moe_package,
406-
_FUSED_MOE_KERNEL_FUNC,
407-
None,
408-
)
409-
setattr(
410-
vllm_fused_moe_package, "_invoke_fused_moe_kernel", original_invoke_fused_moe_kernel
411-
)
412-
setattr(vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, self.invoke_fused_moe_quantized)
435+
for n in _FUSED_MOE_KERNEL_FUNCS:
436+
setattr(
437+
vllm_fused_moe_package,
438+
n,
439+
partial(
440+
self.invoke_fused_moe_quantized,
441+
original_kernel=originals[n],
442+
),
443+
)
413444
output = super().forward(hidden_states, router_logits)
414-
setattr(
415-
vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel
416-
)
417445
return output
418446
finally:
419-
setattr(
420-
vllm_fused_moe_package, _FUSED_MOE_KERNEL_FUNC, original_invoke_fused_moe_kernel
421-
)
447+
for n in _FUSED_MOE_KERNEL_FUNCS:
448+
setattr(vllm_fused_moe_package, n, originals[n])
422449

423450
@torch.no_grad()
424451
def fold_weight(self, keep_attrs: bool = False):
@@ -438,7 +465,8 @@ def fold_weight(self, keep_attrs: bool = False):
438465
)
439466
self.w2_weight_quantizer.disable()
440467

441-
torch.cuda.empty_cache()
468+
if torch.cuda.is_available():
469+
torch.cuda.empty_cache()
442470

443471

444472
@QuantModuleRegistry.register({vllm_fused_moe_layer.FusedMoE: "vllm_FusedMoE"})

0 commit comments

Comments
 (0)