Skip to content

Commit 5a031ac

Browse files
committed
cleanup
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 96689fc commit 5a031ac

3 files changed

Lines changed: 33 additions & 34 deletions

File tree

examples/vllm_serve/fakequant_worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,11 @@ def _fakequant_run_prolog_worker(self) -> None:
144144

145145
mtq.fold_weight(model)
146146
for name, module in model.named_modules():
147-
if is_weight_quantizer_state_key(name):
148-
assert not module.is_enabled, f"quantizer {name} is still enabled"
147+
if is_weight_quantizer_state_key(name) and module.is_enabled:
148+
raise RuntimeError(
149+
f"Weight quantizer {name!r} is still enabled after fold_weight — "
150+
"double-quantization would corrupt activations."
151+
)
149152

150153

151154
class FakeQuantWorker(BaseWorker):

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
import torch
2424
from vllm.distributed.parallel_state import get_tp_group
2525

26-
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
27-
is_weight_quantizer_state_key,
28-
merge_amax_tensors_for_vllm_group,
29-
)
26+
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
3027
from modelopt.torch.opt.conversion import (
3128
ModelLikeModule,
3229
ModeloptStateManager,
@@ -163,6 +160,33 @@ def _group_keys_for_vllm(
163160
return vllm_state_dict, merge_groups
164161

165162

163+
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
164+
"""Combine `_amax` buffers from a merge group into a single tensor.
165+
166+
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
167+
168+
- If every tensor has the same shape, take the element-wise maximum over the group
169+
(conservative when each branch carried the same axis layout).
170+
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
171+
per-channel amax; otherwise fall back to a scalar max over all elements.
172+
"""
173+
if not tensors:
174+
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
175+
if len(tensors) == 1:
176+
return tensors[0]
177+
178+
first = tensors[0]
179+
if all(t.shape == first.shape for t in tensors):
180+
stacked = torch.stack([t.float() for t in tensors], dim=0)
181+
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
182+
183+
try:
184+
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
185+
except RuntimeError:
186+
flat = torch.cat([t.reshape(-1).float() for t in tensors])
187+
return torch.max(flat).to(dtype=first.dtype, device=first.device)
188+
189+
166190
def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any:
167191
"""
168192
Merge values by taking max for amax, concatenating for others.

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
__all__ = [
3737
"export_hf_vllm_fq_checkpoint",
3838
"is_weight_quantizer_state_key",
39-
"merge_amax_tensors_for_vllm_group",
4039
]
4140

4241
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
@@ -51,33 +50,6 @@ def is_weight_quantizer_state_key(key: str) -> bool:
5150
return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key))
5251

5352

54-
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
55-
"""Combine `_amax` buffers from a merge group into a single tensor.
56-
57-
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
58-
59-
- If every tensor has the same shape, take the element-wise maximum over the group
60-
(conservative when each branch carried the same axis layout).
61-
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
62-
per-channel amax; otherwise fall back to a scalar max over all elements.
63-
"""
64-
if not tensors:
65-
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
66-
if len(tensors) == 1:
67-
return tensors[0]
68-
69-
first = tensors[0]
70-
if all(t.shape == first.shape for t in tensors):
71-
stacked = torch.stack([t.float() for t in tensors], dim=0)
72-
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
73-
74-
try:
75-
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
76-
except RuntimeError:
77-
flat = torch.cat([t.reshape(-1).float() for t in tensors])
78-
return torch.max(flat).to(dtype=first.dtype, device=first.device)
79-
80-
8153
def disable_rotate(quantizer: TensorQuantizer):
8254
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
8355
if isinstance(quantizer._rotate, RotateConfig):

0 commit comments

Comments
 (0)