Skip to content

Commit 6806f16

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent e00e8a6 commit 6806f16

File tree

2 files changed

+35
-31
lines changed

2 files changed

+35
-31
lines changed

examples/vllm_serve/vllm_reload_utils.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import torch
2424
from vllm.distributed.parallel_state import get_tp_group
2525

26-
from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key
26+
from modelopt.torch.export.plugins.vllm_fakequant_hf import (
27+
is_weight_quantizer_state_key,
28+
merge_amax_tensors_for_group,
29+
)
2730
from modelopt.torch.opt.conversion import (
2831
ModelLikeModule,
2932
ModeloptStateManager,
@@ -137,33 +140,6 @@ def _group_keys_for_vllm(
137140
return vllm_state_dict, merge_groups
138141

139142

140-
def merge_amax_tensors_for_vllm_group(tensors: list[torch.Tensor]) -> torch.Tensor:
141-
"""Combine `_amax` buffers from a merge group into a single tensor.
142-
143-
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
144-
145-
- If every tensor has the same shape, take the element-wise maximum over the group
146-
(conservative when each branch carried the same axis layout).
147-
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
148-
per-channel amax; otherwise fall back to a scalar max over all elements.
149-
"""
150-
if not tensors:
151-
raise ValueError("merge_amax_tensors_for_vllm_group: expected at least one tensor")
152-
if len(tensors) == 1:
153-
return tensors[0]
154-
155-
first = tensors[0]
156-
if all(t.shape == first.shape for t in tensors):
157-
stacked = torch.stack([t.float() for t in tensors], dim=0)
158-
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
159-
160-
try:
161-
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
162-
except RuntimeError:
163-
flat = torch.cat([t.reshape(-1).float() for t in tensors])
164-
return torch.max(flat).to(dtype=first.dtype, device=first.device)
165-
166-
167143
def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any:
168144
"""
169145
Merge values by taking max for amax, concatenating for others.
@@ -179,7 +155,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
179155
for dict_key in values[0]:
180156
tensors = [v[dict_key] for v in values]
181157
if "_amax" in dict_key:
182-
merged_value[dict_key] = merge_amax_tensors_for_vllm_group(tensors)
158+
merged_value[dict_key] = merge_amax_tensors_for_group(tensors)
183159
elif "_pre_quant_scale" in dict_key:
184160
# _pre_quant_scale is per-input-channel: identical across q/k/v projections
185161
# since they share the same input. Do not concatenate; take the first value.
@@ -190,7 +166,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
190166
else:
191167
# Values are tensors directly
192168
if "_amax" in merged_key:
193-
merged_value = merge_amax_tensors_for_vllm_group(values)
169+
merged_value = merge_amax_tensors_for_group(values)
194170
else:
195171
merged_value = torch.cat(values, dim=0)
196172
return merged_value

modelopt/torch/export/plugins/vllm_fakequant_hf.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
__all__ = [
3737
"export_hf_vllm_fq_checkpoint",
3838
"is_weight_quantizer_state_key",
39+
"merge_amax_tensors_for_group",
3940
]
4041

4142
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
@@ -90,6 +91,33 @@ def requant_weights_for_export(
9091
return quantizer_copy(w.float()).to(w.dtype)
9192

9293

94+
def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
95+
"""Combine `_amax` buffers from a merge group into a single tensor.
96+
97+
Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
98+
99+
- If every tensor has the same shape, take the element-wise maximum over the group
100+
(conservative when each branch carried the same axis layout).
101+
- If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
102+
per-channel amax; otherwise fall back to a scalar max over all elements.
103+
"""
104+
if not tensors:
105+
raise ValueError("merge_amax_tensors_for_group: expected at least one tensor")
106+
if len(tensors) == 1:
107+
return tensors[0]
108+
109+
first = tensors[0]
110+
if all(t.shape == first.shape for t in tensors):
111+
stacked = torch.stack([t.float() for t in tensors], dim=0)
112+
return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device)
113+
114+
try:
115+
return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device)
116+
except RuntimeError:
117+
flat = torch.cat([t.reshape(-1).float() for t in tensors])
118+
return torch.max(flat).to(dtype=first.dtype, device=first.device)
119+
120+
93121
def _resmooth_experts_for_export(
94122
model: nn.Module,
95123
state_dict: dict[str, Any],
@@ -147,7 +175,7 @@ def _resmooth_experts_for_export(
147175
if iq0.is_enabled:
148176
amaxes = [e.input_quantizer.amax for e in experts]
149177
if all(a is not None for a in amaxes):
150-
max_in_amax = torch.stack(amaxes).max()
178+
max_in_amax = merge_amax_tensors_for_group(amaxes)
151179

152180
avg_out = avg_pqs.detach().clone()
153181
for ex in experts:

0 commit comments

Comments
 (0)