Skip to content

Commit 900cc32

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 5a031ac commit 900cc32

2 files changed

Lines changed: 35 additions & 31 deletions

File tree

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import torch
2424
from vllm.distributed.parallel_state import get_tp_group
2525

26-
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
26+
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
27+
is_weight_quantizer_state_key,
28+
merge_amax_tensors_for_group,
29+
)
2730
from modelopt.torch.opt.conversion import (
2831
ModelLikeModule,
2932
ModeloptStateManager,
@@ -160,33 +163,6 @@ def _group_keys_for_vllm(
160163
return vllm_state_dict, merge_groups
161164

162165

163-
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
164-
"""Combine `_amax` buffers from a merge group into a single tensor.
165-
166-
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
167-
168-
- If every tensor has the same shape, take the element-wise maximum over the group
169-
(conservative when each branch carried the same axis layout).
170-
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
171-
per-channel amax; otherwise fall back to a scalar max over all elements.
172-
"""
173-
if not tensors:
174-
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
175-
if len(tensors) == 1:
176-
return tensors[0]
177-
178-
first = tensors[0]
179-
if all(t.shape == first.shape for t in tensors):
180-
stacked = torch.stack([t.float() for t in tensors], dim=0)
181-
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
182-
183-
try:
184-
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
185-
except RuntimeError:
186-
flat = torch.cat([t.reshape(-1).float() for t in tensors])
187-
return torch.max(flat).to(dtype=first.dtype, device=first.device)
188-
189-
190166
def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any:
191167
"""
192168
Merge values by taking max for amax, concatenating for others.
@@ -202,7 +178,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
202178
for dict_key in values[0]:
203179
tensors = [v[dict_key] for v in values]
204180
if "_amax" in dict_key:
205-
merged_value[dict_key] = merge_amax_tensors_for_vllm_group(tensors)
181+
merged_value[dict_key] = merge_amax_tensors_for_group(tensors)
206182
elif "_pre_quant_scale" in dict_key:
207183
# _pre_quant_scale is per-input-channel: identical across q/k/v projections
208184
# since they share the same input. Do not concatenate; take the first value.
@@ -213,7 +189,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
213189
else:
214190
# Values are tensors directly
215191
if "_amax" in merged_key:
216-
merged_value = merge_amax_tensors_for_vllm_group(values)
192+
merged_value = merge_amax_tensors_for_group(values)
217193
else:
218194
merged_value = torch.cat(values, dim=0)
219195
return merged_value

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
__all__ = [
3737
"export_hf_vllm_fq_checkpoint",
3838
"is_weight_quantizer_state_key",
39+
"merge_amax_tensors_for_group",
3940
]
4041

4142
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
@@ -90,6 +91,33 @@ def requant_weights_for_export(
9091
return quantizer_copy(w.float()).to(w.dtype)
9192

9293

94+
def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
95+
"""Combine `_amax` buffers from a merge group into a single tensor.
96+
97+
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
98+
99+
- If every tensor has the same shape, take the element-wise maximum over the group
100+
(conservative when each branch carried the same axis layout).
101+
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
102+
per-channel amax; otherwise fall back to a scalar max over all elements.
103+
"""
104+
if not tensors:
105+
raise ValueError("merge_amax_tensors_for_group: expected at least one tensor")
106+
if len(tensors) == 1:
107+
return tensors[0]
108+
109+
first = tensors[0]
110+
if all(t.shape == first.shape for t in tensors):
111+
stacked = torch.stack([t.float() for t in tensors], dim=0)
112+
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
113+
114+
try:
115+
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
116+
except RuntimeError:
117+
flat = torch.cat([t.reshape(-1).float() for t in tensors])
118+
return torch.max(flat).to(dtype=first.dtype, device=first.device)
119+
120+
93121
def _resmooth_experts_for_export(
94122
model: nn.Module,
95123
state_dict: dict[str, Any],
@@ -147,7 +175,7 @@ def _resmooth_experts_for_export(
147175
if iq0.is_enabled:
148176
amaxes = [e.input_quantizer.amax for e in experts]
149177
if all(a is not None for a in amaxes):
150-
max_in_amax = torch.stack(amaxes).max()
178+
max_in_amax = merge_amax_tensors_for_group(amaxes)
151179

152180
avg_out = avg_pqs.detach().clone()
153181
for ex in experts:

0 commit comments

Comments
 (0)