8787
8888vllm_fused_moe_package = importlib .import_module ("vllm.model_executor.layers.fused_moe.fused_moe" )
8989
90- _vllm_fused_moe_invoke_name_cache : str | None = None
91-
92-
93- def _vllm_fused_moe_invoke_name () -> str :
94- """Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95- global _vllm_fused_moe_invoke_name_cache
96- if _vllm_fused_moe_invoke_name_cache is not None :
97- return _vllm_fused_moe_invoke_name_cache
98- for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
99- if hasattr (vllm_fused_moe_package , name ):
100- _vllm_fused_moe_invoke_name_cache = name
101- return name
102- raise ValueError ("fused_moe_kernel is not found" )
103-
10490
10591@contextmanager
10692def disable_compilation (model ):
@@ -349,6 +335,15 @@ def _setup(self):
349335 )
350336 self .parallel_state = create_parallel_state ()
351337
338+ if getattr (self , "invoke_fused_moe_kernel_func" , None ) is None : # pragma: no cover
339+ for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
340+ if hasattr (vllm_fused_moe_package , name ):
341+ self .invoke_fused_moe_kernel_func = name
342+ break
343+ assert ( # pragma: no cover
344+ getattr (self , "invoke_fused_moe_kernel_func" , None ) is not None
345+ ), "fused_moe_kernel is not found"
346+
352347 def invoke_fused_moe_quantized (
353348 self ,
354349 A : torch .Tensor , # noqa: N803
@@ -360,11 +355,14 @@ def invoke_fused_moe_quantized(
360355 if B is self .w13_weight :
361356 # First layer of expert
362357 A = self .w13_input_quantizer (A ) # noqa: N806
363- if self .w13_weight_quantizer .is_enabled :
358+ if self .w13_weight_quantizer .is_enabled : # pragma: no cover
364359 original_weight , self .w13_weight = (
365360 self .w13_weight ,
366361 self .w13_weight_quantizer (self .w13_weight ),
367362 )
363+ # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
364+ # quantized weight to the kernel.
365+ B = self .w13_weight # noqa: N806
368366 vllm_fused_moe_package ._invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
369367 self .w13_weight = original_weight
370368 else :
@@ -373,11 +371,14 @@ def invoke_fused_moe_quantized(
373371 C [:] = self .w13_output_quantizer (C )
374372 elif B is self .w2_weight :
375373 A = self .w2_input_quantizer (A ) # noqa: N806
376- if self .w2_weight_quantizer .is_enabled :
374+ if self .w2_weight_quantizer .is_enabled : # pragma: no cover
377375 original_weight , self .w2_weight = (
378376 self .w2_weight ,
379377 self .w2_weight_quantizer (self .w2_weight ),
380378 )
379+ # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
380+ # quantized weight to the kernel.
381+ B = self .w2_weight # noqa: N806
381382 vllm_fused_moe_package ._invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
382383 self .w2_weight = original_weight
383384 else :
@@ -388,9 +389,9 @@ def invoke_fused_moe_quantized(
388389 raise ValueError ("Cannot determine first or second layer of expert" )
389390
390391 def forward (self , hidden_states : torch .Tensor , router_logits : torch .Tensor ):
391- with replace_function (
392+ with replace_function ( # pragma: no cover
392393 vllm_fused_moe_package ,
393- _vllm_fused_moe_invoke_name () ,
394+ self . invoke_fused_moe_kernel_func ,
394395 self .invoke_fused_moe_quantized ,
395396 og_func_cache_name = "_invoke_fused_moe_kernel" ,
396397 ):
0 commit comments