2323import torch
2424from vllm .distributed .parallel_state import get_tp_group
2525
26- from modelopt .torch .export .plugins .vllm_fakequant_hf import is_weight_quantizer_state_key
26+ from modelopt .torch .export .plugins .vllm_fakequant_hf import (
27+ is_weight_quantizer_state_key ,
28+ merge_amax_tensors_for_group ,
29+ )
2730from modelopt .torch .opt .conversion import (
2831 ModelLikeModule ,
2932 ModeloptStateManager ,
@@ -160,33 +163,6 @@ def _group_keys_for_vllm(
160163 return vllm_state_dict , merge_groups
161164
162165
163- def merge_amax_tensors_for_vllm_group (tensors : list [torch .Tensor ]) -> torch .Tensor :
164- """Combine `_amax` buffers from a merge group into a single tensor.
165-
166- Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
167-
168- - If every tensor has the same shape, take the element-wise maximum over the group
169- (conservative when each branch carried the same axis layout).
170- - If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
171- per-channel amax; otherwise fall back to a scalar max over all elements.
172- """
173- if not tensors :
174- raise ValueError ("merge_amax_tensors_for_vllm_group: expected at least one tensor" )
175- if len (tensors ) == 1 :
176- return tensors [0 ]
177-
178- first = tensors [0 ]
179- if all (t .shape == first .shape for t in tensors ):
180- stacked = torch .stack ([t .float () for t in tensors ], dim = 0 )
181- return torch .amax (stacked , dim = 0 ).to (dtype = first .dtype , device = first .device )
182-
183- try :
184- return torch .cat (tensors , dim = 0 ).to (dtype = first .dtype , device = first .device )
185- except RuntimeError :
186- flat = torch .cat ([t .reshape (- 1 ).float () for t in tensors ])
187- return torch .max (flat ).to (dtype = first .dtype , device = first .device )
188-
189-
190166def _merge_values_by_max_or_concat (merged_key : str , key_value_pairs : list [tuple [str , Any ]]) -> Any :
191167 """
192168 Merge values by taking max for amax, concatenating for others.
@@ -202,7 +178,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
202178 for dict_key in values [0 ]:
203179 tensors = [v [dict_key ] for v in values ]
204180 if "_amax" in dict_key :
205- merged_value [dict_key ] = merge_amax_tensors_for_vllm_group (tensors )
181+ merged_value [dict_key ] = merge_amax_tensors_for_group (tensors )
206182 elif "_pre_quant_scale" in dict_key :
207183 # _pre_quant_scale is per-input-channel: identical across q/k/v projections
208184 # since they share the same input. Do not concatenate; take the first value.
@@ -213,7 +189,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
213189 else :
214190 # Values are tensors directly
215191 if "_amax" in merged_key :
216- merged_value = merge_amax_tensors_for_vllm_group (values )
192+ merged_value = merge_amax_tensors_for_group (values )
217193 else :
218194 merged_value = torch .cat (values , dim = 0 )
219195 return merged_value
0 commit comments