Skip to content

Commit 0ebeee6

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 1cae4d7 commit 0ebeee6

1 file changed

Lines changed: 30 additions & 23 deletions

File tree

tests/gpu_vllm_fakequant/torch/quantization/test_vllm_dynamic_modules.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,12 @@
5858

5959
import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer
6060
import vllm.model_executor.layers.linear as vllm_linear
61-
from _test_utils.torch.transformers_models import (
62-
create_tiny_llama_dir,
63-
create_tiny_qwen3_moe_dir,
64-
)
61+
from _test_utils.torch.transformers_models import create_tiny_llama_dir, create_tiny_qwen3_moe_dir
6562
from vllm import LLM
6663
from vllm.distributed import cleanup_dist_env_and_memory
6764

6865
from modelopt.torch.quantization.plugins.vllm import VllmMLAAttention
6966

70-
7167
# Sizes picked so vLLM accepts the head_size (must be supported by the chosen
7268
# attention backend). head_size=64 with num_heads=2 is broadly supported.
7369
_LLAMA_OVERRIDES = {
@@ -135,33 +131,42 @@ def _forward_loop(_model):
135131
quantizers_without_amax: list[str] = []
136132
enabled_quantizer_count = 0
137133

134+
def _missing(mod, name, slots):
135+
return (
136+
f"{name}.{q}" for q in slots if not isinstance(getattr(mod, q, None), TensorQuantizer)
137+
)
138+
138139
for name, mod in model.named_modules():
139140
if isinstance(mod, vp._VLLMParallelLinear):
140141
kind = type(mod).__name__
141142
parallel_linear_counts[kind] = parallel_linear_counts.get(kind, 0) + 1
142-
for q in ("input_quantizer", "weight_quantizer", "output_quantizer"):
143-
if not isinstance(getattr(mod, q, None), TensorQuantizer):
144-
missing_quantizers.append(f"{name}.{q}")
143+
missing_quantizers.extend(
144+
_missing(mod, name, ("input_quantizer", "weight_quantizer", "output_quantizer"))
145+
)
145146
elif isinstance(mod, vp._QuantFusedMoEBase):
146147
moe_count += 1
147-
for q in (
148-
"w13_input_quantizer",
149-
"w2_input_quantizer",
150-
"w13_weight_quantizer",
151-
"w2_weight_quantizer",
152-
):
153-
if not isinstance(getattr(mod, q, None), TensorQuantizer):
154-
missing_quantizers.append(f"{name}.{q}")
148+
missing_quantizers.extend(
149+
_missing(
150+
mod,
151+
name,
152+
(
153+
"w13_input_quantizer",
154+
"w2_input_quantizer",
155+
"w13_weight_quantizer",
156+
"w2_weight_quantizer",
157+
),
158+
)
159+
)
155160
elif vp.VllmMLAAttention is not None and isinstance(mod, vp.VllmMLAAttention):
156161
mla_count += 1
157-
for q in ("q_bmm_quantizer", "kv_c_bmm_quantizer", "k_pe_bmm_quantizer"):
158-
if not isinstance(getattr(mod, q, None), TensorQuantizer):
159-
missing_quantizers.append(f"{name}.{q}")
162+
missing_quantizers.extend(
163+
_missing(mod, name, ("q_bmm_quantizer", "kv_c_bmm_quantizer", "k_pe_bmm_quantizer"))
164+
)
160165
elif isinstance(mod, vp._ATTENTION_TYPES):
161166
attention_count += 1
162-
for q in ("q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"):
163-
if not isinstance(getattr(mod, q, None), TensorQuantizer):
164-
missing_quantizers.append(f"{name}.{q}")
167+
missing_quantizers.extend(
168+
_missing(mod, name, ("q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"))
169+
)
165170

166171
# Static-amax invariant: after calibration, every enabled quantizer
167172
# must own an ``_amax`` buffer. Missing ``_amax`` means the quantizer
@@ -386,7 +391,9 @@ def test_registry_registration(vllm_cls):
386391
assert vllm_cls in QuantModuleRegistry
387392

388393

389-
@pytest.mark.skipif(VllmMLAAttention is None, reason="MLAAttention not present in this vLLM version")
394+
@pytest.mark.skipif(
395+
VllmMLAAttention is None, reason="MLAAttention not present in this vLLM version"
396+
)
390397
def test_registry_has_mla_attention():
391398
from modelopt.torch.quantization.nn import QuantModuleRegistry
392399

0 commit comments

Comments
 (0)