1515
1616"""Support quantization for VLLM layers."""
1717
18+ import contextvars
1819import importlib
20+ from collections .abc import Callable
1921from contextlib import contextmanager
22+ from functools import partial
2023from itertools import chain
2124
2225import torch
8588)
8689
8790vllm_fused_moe_package = importlib .import_module ("vllm.model_executor.layers.fused_moe.fused_moe" )
88- _FUSED_MOE_KERNEL_FUNC = next (
89- (
90- n
91- for n in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" )
92- if hasattr (vllm_fused_moe_package , n )
93- ),
94- "" ,
91+ # vLLM may call one entry (e.g. ``dispatch_fused_moe_kernel``) which then calls another on the same
92+ # module (e.g. ``invoke_fused_moe_triton_kernel``). Patching every name would otherwise apply fakequant
93+ # twice; see ``_moe_fakequant_active`` in ``invoke_fused_moe_quantized``.
94+ _FUSED_MOE_KERNEL_CANDIDATES = (
95+ "invoke_fused_moe_kernel" ,
96+ "invoke_fused_moe_triton_kernel" ,
97+ "dispatch_fused_moe_kernel" ,
98+ )
99+ _FUSED_MOE_KERNEL_FUNCS = tuple (
100+ n for n in _FUSED_MOE_KERNEL_CANDIDATES if hasattr (vllm_fused_moe_package , n )
101+ )
102+
103+ _moe_fakequant_active : contextvars .ContextVar [bool ] = contextvars .ContextVar (
104+ "moe_fakequant_active" , default = False
95105)
96106
97107
@@ -348,6 +358,27 @@ def invoke_fused_moe_quantized(
348358 B : torch .Tensor , # noqa: N803
349359 C : torch .Tensor , # noqa: N803
350360 * args ,
361+ original_kernel : Callable ,
362+ ** kwargs ,
363+ ):
364+ # Nested module-level entry (e.g. dispatch -> triton): call the real kernel once, no second quant.
365+ if _moe_fakequant_active .get ():
366+ return original_kernel (A , B , C , * args , ** kwargs )
367+ token = _moe_fakequant_active .set (True )
368+ try :
369+ return self ._invoke_fused_moe_quantized_function (
370+ A , B , C , * args , original_kernel = original_kernel , ** kwargs
371+ )
372+ finally :
373+ _moe_fakequant_active .reset (token )
374+
375+ def _invoke_fused_moe_quantized_function (
376+ self ,
377+ A : torch .Tensor , # noqa: N803
378+ B : torch .Tensor , # noqa: N803
379+ C : torch .Tensor , # noqa: N803
380+ * args ,
381+ original_kernel : Callable ,
351382 ** kwargs ,
352383 ):
353384 if B is self .w13_weight :
@@ -361,10 +392,10 @@ def invoke_fused_moe_quantized(
361392 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
362393 # quantized weight to the kernel.
363394 B = self .w13_weight # noqa: N806
364- vllm_fused_moe_package . _invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
395+ original_kernel (A , B , C , * args , ** kwargs )
365396 self .w13_weight = original_weight
366397 else :
367- vllm_fused_moe_package . _invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
398+ original_kernel (A , B , C , * args , ** kwargs )
368399 if self .w13_output_quantizer .is_enabled :
369400 C [:] = self .w13_output_quantizer (C )
370401 elif B is self .w2_weight :
@@ -377,10 +408,10 @@ def invoke_fused_moe_quantized(
377408 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
378409 # quantized weight to the kernel.
379410 B = self .w2_weight # noqa: N806
380- vllm_fused_moe_package . _invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
411+ original_kernel (A , B , C , * args , ** kwargs )
381412 self .w2_weight = original_weight
382413 else :
383- vllm_fused_moe_package . _invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
414+ original_kernel (A , B , C , * args , ** kwargs )
384415 if self .w2_output_quantizer .is_enabled :
385416 C [:] = self .w2_output_quantizer (C )
386417 else :
@@ -390,35 +421,31 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
390421 # This is again due to the bad coding of vLLM
391422 # fused_moe submodule is overwritten by the fused_moe function
392423 # so we need to import the fused_moe module explicitly
393- assert (
394- _FUSED_MOE_KERNEL_FUNC != ""
395- and getattr (vllm_fused_moe_package , _FUSED_MOE_KERNEL_FUNC , None ) is not None
424+ assert _FUSED_MOE_KERNEL_FUNCS and all (
425+ getattr (vllm_fused_moe_package , n , None ) is not None for n in _FUSED_MOE_KERNEL_FUNCS
396426 )
397427 # This context manager will conflict with torch.compile
398428 # with replace_function(
399429 # vllm_fused_moe_package,
400430 # "invoke_fused_moe_kernel",
401431 # self.invoke_fused_moe_quantized,
402432 # ):
433+ originals = {n : getattr (vllm_fused_moe_package , n ) for n in _FUSED_MOE_KERNEL_FUNCS }
403434 try :
404- original_invoke_fused_moe_kernel = getattr (
405- vllm_fused_moe_package ,
406- _FUSED_MOE_KERNEL_FUNC ,
407- None ,
408- )
409- setattr (
410- vllm_fused_moe_package , "_invoke_fused_moe_kernel" , original_invoke_fused_moe_kernel
411- )
412- setattr ( vllm_fused_moe_package , _FUSED_MOE_KERNEL_FUNC , self . invoke_fused_moe_quantized )
435+ for n in _FUSED_MOE_KERNEL_FUNCS :
436+ setattr (
437+ vllm_fused_moe_package ,
438+ n ,
439+ partial (
440+ self . invoke_fused_moe_quantized ,
441+ original_kernel = originals [ n ],
442+ ),
443+ )
413444 output = super ().forward (hidden_states , router_logits )
414- setattr (
415- vllm_fused_moe_package , _FUSED_MOE_KERNEL_FUNC , original_invoke_fused_moe_kernel
416- )
417445 return output
418446 finally :
419- setattr (
420- vllm_fused_moe_package , _FUSED_MOE_KERNEL_FUNC , original_invoke_fused_moe_kernel
421- )
447+ for n in _FUSED_MOE_KERNEL_FUNCS :
448+ setattr (vllm_fused_moe_package , n , originals [n ])
422449
423450 @torch .no_grad ()
424451 def fold_weight (self , keep_attrs : bool = False ):
@@ -438,7 +465,8 @@ def fold_weight(self, keep_attrs: bool = False):
438465 )
439466 self .w2_weight_quantizer .disable ()
440467
441- torch .cuda .empty_cache ()
468+ if torch .cuda .is_available ():
469+ torch .cuda .empty_cache ()
442470
443471
444472@QuantModuleRegistry .register ({vllm_fused_moe_layer .FusedMoE : "vllm_FusedMoE" })
0 commit comments