|
58 | 58 |
|
59 | 59 | import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer |
60 | 60 | import vllm.model_executor.layers.linear as vllm_linear |
61 | | -from _test_utils.torch.transformers_models import ( |
62 | | - create_tiny_llama_dir, |
63 | | - create_tiny_qwen3_moe_dir, |
64 | | -) |
| 61 | +from _test_utils.torch.transformers_models import create_tiny_llama_dir, create_tiny_qwen3_moe_dir |
65 | 62 | from vllm import LLM |
66 | 63 | from vllm.distributed import cleanup_dist_env_and_memory |
67 | 64 |
|
68 | 65 | from modelopt.torch.quantization.plugins.vllm import VllmMLAAttention |
69 | 66 |
|
70 | | - |
71 | 67 | # Sizes picked so vLLM accepts the head_size (must be supported by the chosen |
72 | 68 | # attention backend). head_size=64 with num_heads=2 is broadly supported. |
73 | 69 | _LLAMA_OVERRIDES = { |
@@ -135,33 +131,42 @@ def _forward_loop(_model): |
135 | 131 | quantizers_without_amax: list[str] = [] |
136 | 132 | enabled_quantizer_count = 0 |
137 | 133 |
|
| 134 | + def _missing(mod, name, slots): |
| 135 | + return ( |
| 136 | + f"{name}.{q}" for q in slots if not isinstance(getattr(mod, q, None), TensorQuantizer) |
| 137 | + ) |
| 138 | + |
138 | 139 | for name, mod in model.named_modules(): |
139 | 140 | if isinstance(mod, vp._VLLMParallelLinear): |
140 | 141 | kind = type(mod).__name__ |
141 | 142 | parallel_linear_counts[kind] = parallel_linear_counts.get(kind, 0) + 1 |
142 | | - for q in ("input_quantizer", "weight_quantizer", "output_quantizer"): |
143 | | - if not isinstance(getattr(mod, q, None), TensorQuantizer): |
144 | | - missing_quantizers.append(f"{name}.{q}") |
| 143 | + missing_quantizers.extend( |
| 144 | + _missing(mod, name, ("input_quantizer", "weight_quantizer", "output_quantizer")) |
| 145 | + ) |
145 | 146 | elif isinstance(mod, vp._QuantFusedMoEBase): |
146 | 147 | moe_count += 1 |
147 | | - for q in ( |
148 | | - "w13_input_quantizer", |
149 | | - "w2_input_quantizer", |
150 | | - "w13_weight_quantizer", |
151 | | - "w2_weight_quantizer", |
152 | | - ): |
153 | | - if not isinstance(getattr(mod, q, None), TensorQuantizer): |
154 | | - missing_quantizers.append(f"{name}.{q}") |
| 148 | + missing_quantizers.extend( |
| 149 | + _missing( |
| 150 | + mod, |
| 151 | + name, |
| 152 | + ( |
| 153 | + "w13_input_quantizer", |
| 154 | + "w2_input_quantizer", |
| 155 | + "w13_weight_quantizer", |
| 156 | + "w2_weight_quantizer", |
| 157 | + ), |
| 158 | + ) |
| 159 | + ) |
155 | 160 | elif vp.VllmMLAAttention is not None and isinstance(mod, vp.VllmMLAAttention): |
156 | 161 | mla_count += 1 |
157 | | - for q in ("q_bmm_quantizer", "kv_c_bmm_quantizer", "k_pe_bmm_quantizer"): |
158 | | - if not isinstance(getattr(mod, q, None), TensorQuantizer): |
159 | | - missing_quantizers.append(f"{name}.{q}") |
| 162 | + missing_quantizers.extend( |
| 163 | + _missing(mod, name, ("q_bmm_quantizer", "kv_c_bmm_quantizer", "k_pe_bmm_quantizer")) |
| 164 | + ) |
160 | 165 | elif isinstance(mod, vp._ATTENTION_TYPES): |
161 | 166 | attention_count += 1 |
162 | | - for q in ("q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"): |
163 | | - if not isinstance(getattr(mod, q, None), TensorQuantizer): |
164 | | - missing_quantizers.append(f"{name}.{q}") |
| 167 | + missing_quantizers.extend( |
| 168 | + _missing(mod, name, ("q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer")) |
| 169 | + ) |
165 | 170 |
|
166 | 171 | # Static-amax invariant: after calibration, every enabled quantizer |
167 | 172 | # must own an ``_amax`` buffer. Missing ``_amax`` means the quantizer |
@@ -386,7 +391,9 @@ def test_registry_registration(vllm_cls): |
386 | 391 | assert vllm_cls in QuantModuleRegistry |
387 | 392 |
|
388 | 393 |
|
389 | | -@pytest.mark.skipif(VllmMLAAttention is None, reason="MLAAttention not present in this vLLM version") |
| 394 | +@pytest.mark.skipif( |
| 395 | + VllmMLAAttention is None, reason="MLAAttention not present in this vLLM version" |
| 396 | +) |
390 | 397 | def test_registry_has_mla_attention(): |
391 | 398 | from modelopt.torch.quantization.nn import QuantModuleRegistry |
392 | 399 |
|
|
0 commit comments