Skip to content

Commit e00e8a6

Browse files
committed
cleanup
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent cb459b1 commit e00e8a6

3 files changed

Lines changed: 33 additions & 34 deletions

File tree

examples/vllm_serve/fakequant_worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ def _fakequant_run_prolog_worker(self) -> None:
143143

144144
mtq.fold_weight(model)
145145
for name, module in model.named_modules():
146-
if is_weight_quantizer_state_key(name):
147-
assert not module.is_enabled, f"quantizer {name} is still enabled"
146+
if is_weight_quantizer_state_key(name) and module.is_enabled:
147+
raise RuntimeError(
148+
f"Weight quantizer {name!r} is still enabled after fold_weight — "
149+
"double-quantization would corrupt activations."
150+
)
148151

149152

150153
class FakeQuantWorker(BaseWorker):

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
import torch
2424
from vllm.distributed.parallel_state import get_tp_group
2525

26-
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
27-
is_weight_quantizer_state_key,
28-
merge_amax_tensors_for_vllm_group,
29-
)
26+
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
3027
from modelopt.torch.opt.conversion import (
3128
ModelLikeModule,
3229
ModeloptStateManager,
@@ -140,6 +137,33 @@ def _group_keys_for_vllm(
140137
return vllm_state_dict, merge_groups
141138

142139

140+
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
141+
"""Combine `_amax` buffers from a merge group into a single tensor.
142+
143+
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
144+
145+
- If every tensor has the same shape, take the element-wise maximum over the group
146+
(conservative when each branch carried the same axis layout).
147+
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
148+
per-channel amax; otherwise fall back to a scalar max over all elements.
149+
"""
150+
if not tensors:
151+
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
152+
if len(tensors) == 1:
153+
return tensors[0]
154+
155+
first = tensors[0]
156+
if all(t.shape == first.shape for t in tensors):
157+
stacked = torch.stack([t.float() for t in tensors], dim=0)
158+
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
159+
160+
try:
161+
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
162+
except RuntimeError:
163+
flat = torch.cat([t.reshape(-1).float() for t in tensors])
164+
return torch.max(flat).to(dtype=first.dtype, device=first.device)
165+
166+
143167
def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any:
144168
"""
145169
Merge values by taking max for amax, concatenating for others.

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
__all__ = [
3737
"export_hf_vllm_fq_checkpoint",
3838
"is_weight_quantizer_state_key",
39-
"merge_amax_tensors_for_vllm_group",
4039
]
4140

4241
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
@@ -51,33 +50,6 @@ def is_weight_quantizer_state_key(key: str) -> bool:
5150
return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key))
5251

5352

54-
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
55-
"""Combine `_amax` buffers from a merge group into a single tensor.
56-
57-
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
58-
59-
- If every tensor has the same shape, take the element-wise maximum over the group
60-
(conservative when each branch carried the same axis layout).
61-
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
62-
per-channel amax; otherwise fall back to a scalar max over all elements.
63-
"""
64-
if not tensors:
65-
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
66-
if len(tensors) == 1:
67-
return tensors[0]
68-
69-
first = tensors[0]
70-
if all(t.shape == first.shape for t in tensors):
71-
stacked = torch.stack([t.float() for t in tensors], dim=0)
72-
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
73-
74-
try:
75-
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
76-
except RuntimeError:
77-
flat = torch.cat([t.reshape(-1).float() for t in tensors])
78-
return torch.max(flat).to(dtype=first.dtype, device=first.device)
79-
80-
8153
def disable_rotate(quantizer: TensorQuantizer):
8254
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
8355
if isinstance(quantizer._rotate, RotateConfig):

0 commit comments

Comments
 (0)