8787
8888vllm_fused_moe_package = importlib .import_module ("vllm.model_executor.layers.fused_moe.fused_moe" )
8989
90- _vllm_fused_moe_invoke_name_cache : str | None = None
91-
92-
93- def _vllm_fused_moe_invoke_name () -> str :
94- """Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95- global _vllm_fused_moe_invoke_name_cache
96- if _vllm_fused_moe_invoke_name_cache is not None :
97- return _vllm_fused_moe_invoke_name_cache
98- for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
99- if hasattr (vllm_fused_moe_package , name ):
100- _vllm_fused_moe_invoke_name_cache = name
101- return name
102- raise ValueError ("fused_moe_kernel is not found" )
103-
10490
10591@contextmanager
10692def disable_compilation (model ):
@@ -349,6 +335,15 @@ def _setup(self):
349335 )
350336 self .parallel_state = create_parallel_state ()
351337
338+ if getattr (self , "invoke_fused_moe_kernel_func" , None ) is None : # pragma: no cover
339+ for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
340+ if hasattr (vllm_fused_moe_package , name ):
341+ self .invoke_fused_moe_kernel_func = name
342+ break
343+ assert getattr (self , "invoke_fused_moe_kernel_func" , None ) is not None , (
344+ "fused_moe_kernel is not found"
345+ ) # pragma: no cover
346+
352347 def invoke_fused_moe_quantized (
353348 self ,
354349 A : torch .Tensor , # noqa: N803
@@ -360,7 +355,7 @@ def invoke_fused_moe_quantized(
360355 if B is self .w13_weight :
361356 # First layer of expert
362357 A = self .w13_input_quantizer (A ) # noqa: N806
363- if self .w13_weight_quantizer .is_enabled :
358+ if self .w13_weight_quantizer .is_enabled : # pragma: no cover
364359 original_weight , self .w13_weight = (
365360 self .w13_weight ,
366361 self .w13_weight_quantizer (self .w13_weight ),
@@ -376,7 +371,7 @@ def invoke_fused_moe_quantized(
376371 C [:] = self .w13_output_quantizer (C )
377372 elif B is self .w2_weight :
378373 A = self .w2_input_quantizer (A ) # noqa: N806
379- if self .w2_weight_quantizer .is_enabled :
374+ if self .w2_weight_quantizer .is_enabled : # pragma: no cover
380375 original_weight , self .w2_weight = (
381376 self .w2_weight ,
382377 self .w2_weight_quantizer (self .w2_weight ),
@@ -394,9 +389,9 @@ def invoke_fused_moe_quantized(
394389 raise ValueError ("Cannot determine first or second layer of expert" )
395390
396391 def forward (self , hidden_states : torch .Tensor , router_logits : torch .Tensor ):
397- with replace_function (
392+ with replace_function ( # pragma: no cover
398393 vllm_fused_moe_package ,
399- _vllm_fused_moe_invoke_name () ,
394+ self . invoke_fused_moe_kernel_func ,
400395 self .invoke_fused_moe_quantized ,
401396 og_func_cache_name = "_invoke_fused_moe_kernel" ,
402397 ):
0 commit comments