Skip to content

Commit 0acc835

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent d3bb642 commit 0acc835

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

examples/vllm_serve/fakequant_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232

3333
import modelopt.torch.quantization as mtq
34-
from modelopt.torch.export.hf_vllm_quantizer_merge import is_weight_quantizer_state_key
34+
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
3535
from modelopt.torch.quantization.plugins.vllm import (
3636
disable_compilation,
3737
post_restore_vllm_parallel_linears,

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from vllm.distributed.parallel_state import get_tp_group
2525

26-
from modelopt.torch.export.hf_vllm_quantizer_merge import (
26+
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
2727
is_weight_quantizer_state_key,
2828
merge_amax_tensors_for_vllm_group,
2929
)

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616

1717
import copy
18+
import re
1819
from pathlib import Path
1920
from typing import Any
2021

@@ -29,11 +30,52 @@
2930
from modelopt.torch.quantization.utils import get_quantizer_state_dict
3031
from modelopt.torch.utils import get_unwrapped_name, safe_save
3132

32-
from ..hf_vllm_quantizer_merge import is_weight_quantizer_state_key
3333
from ..layer_utils import get_experts_list, is_moe
3434
from ..quant_utils import get_quantization_format
3535

36-
__all__ = ["export_hf_vllm_fq_checkpoint", "is_weight_quantizer_state_key"]
36+
__all__ = [
37+
"export_hf_vllm_fq_checkpoint",
38+
"is_weight_quantizer_state_key",
39+
"merge_amax_tensors_for_vllm_group",
40+
]
41+
42+
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
43+
_WEIGHT_QUANTIZER_STATE_KEY = re.compile(r"(?:^|\.)(?:\w+_)?weight_quantizer(?:\.\d+)*$")
44+
45+
46+
def is_weight_quantizer_state_key(key: str) -> bool:
47+
"""Return True for weight-quantizer state keys, including SequentialQuantizer entries.
48+
49+
Matches ``weight_quantizer``, ``w13_weight_quantizer``, ``weight_quantizer.0``, etc.
50+
"""
51+
return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key))
52+
53+
54+
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
55+
"""Combine `_amax` buffers from a merge group into a single tensor.
56+
57+
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
58+
59+
- If every tensor has the same shape, take the element-wise maximum over the group
60+
(conservative when each branch carried the same axis layout).
61+
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
62+
per-channel amax; otherwise fall back to a scalar max over all elements.
63+
"""
64+
if not tensors:
65+
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
66+
if len(tensors) == 1:
67+
return tensors[0]
68+
69+
first = tensors[0]
70+
if all(t.shape == first.shape for t in tensors):
71+
stacked = torch.stack([t.float() for t in tensors], dim=0)
72+
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
73+
74+
try:
75+
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
76+
except RuntimeError:
77+
flat = torch.cat([t.reshape(-1).float() for t in tensors])
78+
return torch.max(flat).to(dtype=first.dtype, device=first.device)
3779

3880

3981
def disable_rotate(quantizer: TensorQuantizer):
@@ -217,7 +259,7 @@ def export_hf_vllm_fq_checkpoint(
217259
if (
218260
hasattr(inp_q, "_pre_quant_scale")
219261
and inp_q._pre_quant_scale is not None
220-
and inp_q._disabled
262+
and not inp_q.is_enabled
221263
):
222264
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
223265
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)

0 commit comments

Comments
 (0)