Skip to content

Commit 037f9a9

Browse files
committed
fixing issues
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 5ecfaca commit 037f9a9

File tree

1 file changed

+19
-18
lines changed
  • modelopt/torch/quantization/plugins

1 file changed

+19
-18
lines changed

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,6 @@
8787

8888
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
8989

90-
_vllm_fused_moe_invoke_name_cache: str | None = None
91-
92-
93-
def _vllm_fused_moe_invoke_name() -> str:
94-
"""Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95-
global _vllm_fused_moe_invoke_name_cache
96-
if _vllm_fused_moe_invoke_name_cache is not None:
97-
return _vllm_fused_moe_invoke_name_cache
98-
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
99-
if hasattr(vllm_fused_moe_package, name):
100-
_vllm_fused_moe_invoke_name_cache = name
101-
return name
102-
raise ValueError("fused_moe_kernel is not found")
103-
10490

10591
@contextmanager
10692
def disable_compilation(model):
@@ -349,6 +335,15 @@ def _setup(self):
349335
)
350336
self.parallel_state = create_parallel_state()
351337

338+
if getattr(self, "invoke_fused_moe_kernel_func", None) is None: # pragma: no cover
339+
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
340+
if hasattr(vllm_fused_moe_package, name):
341+
self.invoke_fused_moe_kernel_func = name
342+
break
343+
assert ( # pragma: no cover
344+
getattr(self, "invoke_fused_moe_kernel_func", None) is not None
345+
), "fused_moe_kernel is not found"
346+
352347
def invoke_fused_moe_quantized(
353348
self,
354349
A: torch.Tensor, # noqa: N803
@@ -360,11 +355,14 @@ def invoke_fused_moe_quantized(
360355
if B is self.w13_weight:
361356
# First layer of expert
362357
A = self.w13_input_quantizer(A) # noqa: N806
363-
if self.w13_weight_quantizer.is_enabled:
358+
if self.w13_weight_quantizer.is_enabled: # pragma: no cover
364359
original_weight, self.w13_weight = (
365360
self.w13_weight,
366361
self.w13_weight_quantizer(self.w13_weight),
367362
)
363+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
364+
# quantized weight to the kernel.
365+
B = self.w13_weight # noqa: N806
368366
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
369367
self.w13_weight = original_weight
370368
else:
@@ -373,11 +371,14 @@ def invoke_fused_moe_quantized(
373371
C[:] = self.w13_output_quantizer(C)
374372
elif B is self.w2_weight:
375373
A = self.w2_input_quantizer(A) # noqa: N806
376-
if self.w2_weight_quantizer.is_enabled:
374+
if self.w2_weight_quantizer.is_enabled: # pragma: no cover
377375
original_weight, self.w2_weight = (
378376
self.w2_weight,
379377
self.w2_weight_quantizer(self.w2_weight),
380378
)
379+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
380+
# quantized weight to the kernel.
381+
B = self.w2_weight # noqa: N806
381382
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
382383
self.w2_weight = original_weight
383384
else:
@@ -388,9 +389,9 @@ def invoke_fused_moe_quantized(
388389
raise ValueError("Cannot determine first or second layer of expert")
389390

390391
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
391-
with replace_function(
392+
with replace_function( # pragma: no cover
392393
vllm_fused_moe_package,
393-
_vllm_fused_moe_invoke_name(),
394+
self.invoke_fused_moe_kernel_func,
394395
self.invoke_fused_moe_quantized,
395396
og_func_cache_name="_invoke_fused_moe_kernel",
396397
):

0 commit comments

Comments
 (0)