2323import torch
2424from vllm .distributed .parallel_state import get_tp_group
2525
26- from modelopt .torch .export .plugins .vllm_fakequant_hf import is_weight_quantizer_state_key
26+ from modelopt .torch .export .plugins .vllm_fakequant_hf import (
27+ is_weight_quantizer_state_key ,
28+ merge_amax_tensors_for_group ,
29+ )
2730from modelopt .torch .opt .conversion import (
2831 ModelLikeModule ,
2932 ModeloptStateManager ,
@@ -137,33 +140,6 @@ def _group_keys_for_vllm(
137140 return vllm_state_dict , merge_groups
138141
139142
140- def merge_amax_tensors_for_vllm_group (tensors : list [torch .Tensor ]) -> torch .Tensor :
141- """Combine `_amax` buffers from a merge group into a single tensor.
142-
143- Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj).
144-
145- - If every tensor has the same shape, take the element-wise maximum over the group
146- (conservative when each branch carried the same axis layout).
147- - If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
148- per-channel amax; otherwise fall back to a scalar max over all elements.
149- """
150- if not tensors :
151- raise ValueError ("merge_amax_tensors_for_vllm_group: expected at least one tensor" )
152- if len (tensors ) == 1 :
153- return tensors [0 ]
154-
155- first = tensors [0 ]
156- if all (t .shape == first .shape for t in tensors ):
157- stacked = torch .stack ([t .float () for t in tensors ], dim = 0 )
158- return torch .amax (stacked , dim = 0 ).to (dtype = first .dtype , device = first .device )
159-
160- try :
161- return torch .cat (tensors , dim = 0 ).to (dtype = first .dtype , device = first .device )
162- except RuntimeError :
163- flat = torch .cat ([t .reshape (- 1 ).float () for t in tensors ])
164- return torch .max (flat ).to (dtype = first .dtype , device = first .device )
165-
166-
167143def _merge_values_by_max_or_concat (merged_key : str , key_value_pairs : list [tuple [str , Any ]]) -> Any :
168144 """
169145 Merge values by taking max for amax, concatenating for others.
@@ -179,7 +155,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
179155 for dict_key in values [0 ]:
180156 tensors = [v [dict_key ] for v in values ]
181157 if "_amax" in dict_key :
182- merged_value [dict_key ] = merge_amax_tensors_for_vllm_group (tensors )
158+ merged_value [dict_key ] = merge_amax_tensors_for_group (tensors )
183159 elif "_pre_quant_scale" in dict_key :
184160 # _pre_quant_scale is per-input-channel: identical across q/k/v projections
185161 # since they share the same input. Do not concatenate; take the first value.
@@ -190,7 +166,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[
190166 else :
191167 # Values are tensors directly
192168 if "_amax" in merged_key :
193- merged_value = merge_amax_tensors_for_vllm_group (values )
169+ merged_value = merge_amax_tensors_for_group (values )
194170 else :
195171 merged_value = torch .cat (values , dim = 0 )
196172 return merged_value
0 commit comments