1414# limitations under the License.
1515"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616
17+ import contextlib
1718import copy
1819import re
1920from pathlib import Path
3233
3334from ..layer_utils import get_experts_list , is_moe
3435from ..quant_utils import get_quantization_format
36+ from ..unified_export_hf import collect_shared_input_modules
3537
3638__all__ = [
3739 "export_hf_vllm_fq_checkpoint" ,
@@ -122,12 +124,22 @@ def _resmooth_experts_for_export(
122124 model : nn .Module ,
123125 state_dict : dict [str , Any ],
124126) -> tuple [dict [str , tuple [torch .Tensor , torch .Tensor | None ]], set [str ]]:
125- """Average pqs and unify input amax across MoE experts when AWQ smoothing applies (no-op otherwise).
126-
127- Adjusts expert weights in ``state_dict`` as ``W' = W * old_pqs / avg_pqs`` and returns
128- input-quantizer overrides for ``modelopt_state_weights``. **Does nothing** for weight-only
129- MoE (no ``pre_quant_scale`` on experts) or unsupported MoE layouts — same as skipping the
130- MoE branch in :func:`requantize_resmooth_fused_llm_layers`.
127+ """Average pqs and unify input amax for all groups vLLM collapses to one input quantizer.
128+
129+ Covers two cases that are structurally identical — a set of linears sharing the
130+ same input that vLLM fuses behind a single input quantizer:
131+
132+ * **MoE experts** — vLLM uses one input quantizer per expert group.
133+ * **Dense GQA attention** — vLLM fuses q/k/v into ``qkv_proj`` with one input
134+ quantizer. AWQ calibration gives each projection a different pqs because they
135+ share the same input but have different weight magnitudes; using only q's pqs
136+ for k and v at inference corrupts those activations.
137+
138+ For each group, adjusts weights in ``state_dict`` as ``W' = W * old_pqs / avg_pqs``
139+ and returns input-quantizer overrides so every member exports the same ``avg_pqs``.
140+ Only runs when input quantizers are **enabled** (``nvfp4_awq_wa`` style); for
141+ weight-only AWQ the input quantizer is disabled and pqs is folded into each
142+ projection's own weight independently.
131143 """
132144 qfmt = get_quantization_format (model )
133145 if qfmt is None or "awq" not in qfmt .lower ():
@@ -137,61 +149,70 @@ def _resmooth_experts_for_export(
137149 id_to_name : dict [int , str ] = {id (m ): n for n , m in model .named_modules ()}
138150 out : dict [str , tuple [torch .Tensor , torch .Tensor | None ]] = {}
139151 requant_weights : set [str ] = set ()
140- for _ , module in model .named_modules ():
141- if not is_moe (module ):
142- continue
143- try :
144- expert_groups = get_experts_list (module , model_type )
145- except NotImplementedError :
146- continue
147-
148- for experts in expert_groups :
149- if not experts :
152+
153+ def _process_group (modules : list [nn .Module ]) -> None :
154+ pqs_list = _collect_expert_pre_quant_scales (modules )
155+ if pqs_list is None :
156+ return
157+
158+ avg_pqs = torch .stack (pqs_list ).mean (0 )
159+ avg_pqs = avg_pqs .clamp (min = torch .finfo (torch .float32 ).tiny )
160+
161+ for m in modules :
162+ nm = id_to_name .get (id (m ))
163+ if nm is None or f"{ nm } .weight" not in state_dict :
150164 continue
151- pre_quant_scales_list = _collect_expert_pre_quant_scales (experts )
152- if pre_quant_scales_list is None :
165+ old_pqs = m .input_quantizer ._pre_quant_scale
166+ avg_pqs_dev = avg_pqs .to (device = old_pqs .device , dtype = old_pqs .dtype )
167+ if torch .equal (old_pqs , avg_pqs_dev ):
153168 continue
169+ weight = state_dict [f"{ nm } .weight" ]
170+ ratio = old_pqs .to (dtype = torch .float32 , device = weight .device ) / avg_pqs_dev .to (
171+ dtype = torch .float32 , device = weight .device
172+ )
173+ state_dict [f"{ nm } .weight" ] = (weight .to (torch .float32 ) * ratio ).to (weight .dtype )
174+ requant_weights .add (f"{ nm } .weight" )
175+
176+ synced_amax : torch .Tensor | None = None
177+ amaxes = [m .input_quantizer .amax for m in modules ]
178+ if all (a is not None for a in amaxes ):
179+ synced_amax = merge_amax_tensors_for_group (amaxes )
180+
181+ avg_pqs_out = avg_pqs .detach ().clone ()
182+ for m in modules :
183+ nm = id_to_name .get (id (m ))
184+ if nm is None :
185+ continue
186+ out [get_unwrapped_name (f"{ nm } .input_quantizer" , model )] = (avg_pqs_out , synced_amax )
154187
155- avg_pre_quant_scale = torch .stack (pre_quant_scales_list ).mean (0 )
156- # Guard against degenerate calibration where a channel's scale is zero:
157- # zero avg_pqs would produce inf ratio and corrupt the exported weight.
158- avg_pre_quant_scale = avg_pre_quant_scale .clamp (min = torch .finfo (torch .float32 ).tiny )
159-
160- for ex in experts :
161- nm = id_to_name .get (id (ex ))
162- if nm is None or f"{ nm } .weight" not in state_dict :
163- continue
164- old_pre_quant_scale = ex .input_quantizer ._pre_quant_scale
165- avg_pre_quant_scale = avg_pre_quant_scale .to (
166- device = old_pre_quant_scale .device , dtype = old_pre_quant_scale .dtype
167- )
168- if torch .equal (old_pre_quant_scale , avg_pre_quant_scale ):
169- continue
170- weight = state_dict [f"{ nm } .weight" ]
171- updated_weight = (
172- weight .to (torch .float32 )
173- * old_pre_quant_scale .to (dtype = torch .float32 , device = weight .device )
174- / avg_pre_quant_scale .to (dtype = torch .float32 , device = weight .device )
175- ).to (weight .dtype )
176- state_dict [f"{ nm } .weight" ] = updated_weight
177- requant_weights .add (f"{ nm } .weight" )
178-
179- iq0 = experts [0 ].input_quantizer
180- synced_amax : torch .Tensor | None = None
181- if iq0 .is_enabled :
182- amaxes = [e .input_quantizer .amax for e in experts ]
183- if all (a is not None for a in amaxes ):
184- synced_amax = merge_amax_tensors_for_group (amaxes )
185-
186- avg_pre_quant_scale_output = avg_pre_quant_scale .detach ().clone ()
187- for ex in experts :
188- nm = id_to_name .get (id (ex ))
189- if nm is None :
190- continue
191- out [get_unwrapped_name (f"{ nm } .input_quantizer" , model )] = (
192- avg_pre_quant_scale_output ,
193- synced_amax ,
194- )
188+ # MoE expert groups — must be enumerated by name because MoE routing sends
189+ # different tokens to each expert, so forward hooks cannot detect them as
190+ # sharing the same input tensor.
191+ for _ , module in model .named_modules ():
192+ if is_moe (module ):
193+ try :
194+ expert_groups = get_experts_list (module , model_type )
195+ except NotImplementedError :
196+ pass
197+ else :
198+ for experts in expert_groups :
199+ if experts :
200+ _process_group (experts )
201+
202+ # Dense shared-input groups (e.g. q/k/v in GQA attention) — detected via forward
203+ # hooks so any architecture is covered regardless of projection attribute names.
204+
205+ dev = next (model .parameters ()).device
206+
207+ def _dummy_forward () -> None :
208+ # Partial forward is OK: hooks record layers reached before failure (e.g. VLMs).
209+ with contextlib .suppress (Exception ):
210+ model (torch .ones ([1 , 2 ], dtype = torch .long , device = dev ))
211+
212+ input_to_linear , _ = collect_shared_input_modules (model , _dummy_forward )
213+ for modules in input_to_linear .values ():
214+ if len (modules ) > 1 :
215+ _process_group (modules )
195216
196217 return out , requant_weights
197218
@@ -227,10 +248,10 @@ def export_hf_vllm_fq_checkpoint(
227248 # to the corresponding weight tensor in the copy.
228249 state_dict = model .state_dict ()
229250
230- # Non-mutating MoE expert resmooth: average pqs and adjust state_dict weights.
231- # Must run before the fakequant loop so that the adjusted weights are fakequanted
232- # with the correct per-block scales.
233- expert_pqs_overrides , requant_weights = _resmooth_experts_for_export (model , state_dict )
251+ # Non-mutating resmooth: average pqs across all shared-input groups (dense GQA q/k/v
252+ # and MoE experts) and adjust state_dict weights. Must run before the fakequant loop
253+ # so adjusted weights are fakequanted with the correct per-block scales.
254+ pqs_overrides , requant_weights = _resmooth_experts_for_export (model , state_dict )
234255
235256 fakequant_weights : set [str ] = set ()
236257 # Input quantizer keys whose _pre_quant_scale was folded into the weight above.
@@ -317,9 +338,9 @@ def export_hf_vllm_fq_checkpoint(
317338 qstate_val ["_pre_quant_scale" ]
318339 )
319340
320- # Patch expert input quantizers with averaged pqs and unified amax so that
321- # vLLM's single per-group input quantizer sees consistent values across experts.
322- for iq_key , (avg_pqs , max_input_amax ) in expert_pqs_overrides .items ():
341+ # Patch input quantizers with averaged pqs and unified amax so that vLLM's single
342+ # per-group input quantizer sees consistent values (covers both dense qkv and MoE experts) .
343+ for iq_key , (avg_pqs , max_input_amax ) in pqs_overrides .items ():
323344 if iq_key in quantizer_state_dict :
324345 qstate_val = quantizer_state_dict [iq_key ]
325346 if isinstance (qstate_val , dict ):
0 commit comments