|
15 | 15 | """Export HuggingFace model to vLLM fakequant checkpoint.""" |
16 | 16 |
|
17 | 17 | import copy |
| 18 | +import re |
18 | 19 | from pathlib import Path |
19 | 20 | from typing import Any |
20 | 21 |
|
|
29 | 30 | from modelopt.torch.quantization.utils import get_quantizer_state_dict |
30 | 31 | from modelopt.torch.utils import get_unwrapped_name, safe_save |
31 | 32 |
|
32 | | -from ..hf_vllm_quantizer_merge import is_weight_quantizer_state_key |
33 | 33 | from ..layer_utils import get_experts_list, is_moe |
34 | 34 | from ..quant_utils import get_quantization_format |
35 | 35 |
|
36 | | -__all__ = ["export_hf_vllm_fq_checkpoint", "is_weight_quantizer_state_key"] |
| 36 | +__all__ = [ |
| 37 | + "export_hf_vllm_fq_checkpoint", |
| 38 | + "is_weight_quantizer_state_key", |
| 39 | + "merge_amax_tensors_for_vllm_group", |
| 40 | +] |
| 41 | + |
| 42 | +# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc. |
| 43 | +_WEIGHT_QUANTIZER_STATE_KEY = re.compile(r"(?:^|\.)(?:\w+_)?weight_quantizer(?:\.\d+)*$") |
| 44 | + |
| 45 | + |
| 46 | +def is_weight_quantizer_state_key(key: str) -> bool: |
| 47 | + """Return True for weight-quantizer state keys, including SequentialQuantizer entries. |
| 48 | +
|
| 49 | + Matches ``weight_quantizer``, ``w13_weight_quantizer``, ``weight_quantizer.0``, etc. |
| 50 | + """ |
| 51 | + return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key)) |
| 52 | + |
| 53 | + |
| 54 | +def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor: |
| 55 | + """Combine `_amax` buffers from a merge group into a single tensor. |
| 56 | +
|
| 57 | + Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj). |
| 58 | +
|
| 59 | + - If every tensor has the same shape, take the element-wise maximum over the group |
| 60 | + (conservative when each branch carried the same axis layout). |
| 61 | + - If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for |
| 62 | + per-channel amax; otherwise fall back to a scalar max over all elements. |
| 63 | + """ |
| 64 | + if not tensors: |
| 65 | + raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor") |
| 66 | + if len(tensors) == 1: |
| 67 | + return tensors[0] |
| 68 | + |
| 69 | + first = tensors[0] |
| 70 | + if all(t.shape == first.shape for t in tensors): |
| 71 | + stacked = torch.stack([t.float() for t in tensors], dim=0) |
| 72 | + return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device) |
| 73 | + |
| 74 | + try: |
| 75 | + return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device) |
| 76 | + except RuntimeError: |
| 77 | + flat = torch.cat([t.reshape(-1).float() for t in tensors]) |
| 78 | + return torch.max(flat).to(dtype=first.dtype, device=first.device) |
37 | 79 |
|
38 | 80 |
|
39 | 81 | def disable_rotate(quantizer: TensorQuantizer): |
@@ -217,7 +259,7 @@ def export_hf_vllm_fq_checkpoint( |
217 | 259 | if ( |
218 | 260 | hasattr(inp_q, "_pre_quant_scale") |
219 | 261 | and inp_q._pre_quant_scale is not None |
220 | | - and inp_q._disabled |
| 262 | + and not inp_q.is_enabled |
221 | 263 | ): |
222 | 264 | scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) |
223 | 265 | w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) |
|
0 commit comments