Skip to content

Commit 96d081a

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 435caab commit 96d081a

1 file changed

Lines changed: 13 additions & 18 deletions

File tree

  • modelopt/torch/quantization/plugins

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,6 @@
8787

8888
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
8989

90-
_vllm_fused_moe_invoke_name_cache: str | None = None
91-
92-
93-
def _vllm_fused_moe_invoke_name() -> str:
94-
"""Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95-
global _vllm_fused_moe_invoke_name_cache
96-
if _vllm_fused_moe_invoke_name_cache is not None:
97-
return _vllm_fused_moe_invoke_name_cache
98-
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
99-
if hasattr(vllm_fused_moe_package, name):
100-
_vllm_fused_moe_invoke_name_cache = name
101-
return name
102-
raise ValueError("fused_moe_kernel is not found")
103-
10490

10591
@contextmanager
10692
def disable_compilation(model):
@@ -349,6 +335,15 @@ def _setup(self):
349335
)
350336
self.parallel_state = create_parallel_state()
351337

338+
if getattr(self, "invoke_fused_moe_kernel_func", None) is None: # pragma: no cover
339+
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
340+
if hasattr(vllm_fused_moe_package, name):
341+
self.invoke_fused_moe_kernel_func = name
342+
break
343+
assert getattr(self, "invoke_fused_moe_kernel_func", None) is not None, (
344+
"fused_moe_kernel is not found"
345+
) # pragma: no cover
346+
352347
def invoke_fused_moe_quantized(
353348
self,
354349
A: torch.Tensor, # noqa: N803
@@ -360,7 +355,7 @@ def invoke_fused_moe_quantized(
360355
if B is self.w13_weight:
361356
# First layer of expert
362357
A = self.w13_input_quantizer(A) # noqa: N806
363-
if self.w13_weight_quantizer.is_enabled:
358+
if self.w13_weight_quantizer.is_enabled: # pragma: no cover
364359
original_weight, self.w13_weight = (
365360
self.w13_weight,
366361
self.w13_weight_quantizer(self.w13_weight),
@@ -376,7 +371,7 @@ def invoke_fused_moe_quantized(
376371
C[:] = self.w13_output_quantizer(C)
377372
elif B is self.w2_weight:
378373
A = self.w2_input_quantizer(A) # noqa: N806
379-
if self.w2_weight_quantizer.is_enabled:
374+
if self.w2_weight_quantizer.is_enabled: # pragma: no cover
380375
original_weight, self.w2_weight = (
381376
self.w2_weight,
382377
self.w2_weight_quantizer(self.w2_weight),
@@ -394,9 +389,9 @@ def invoke_fused_moe_quantized(
394389
raise ValueError("Cannot determine first or second layer of expert")
395390

396391
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
397-
with replace_function(
392+
with replace_function( # pragma: no cover
398393
vllm_fused_moe_package,
399-
_vllm_fused_moe_invoke_name(),
394+
self.invoke_fused_moe_kernel_func,
400395
self.invoke_fused_moe_quantized,
401396
og_func_cache_name="_invoke_fused_moe_kernel",
402397
):

0 commit comments

Comments
 (0)