1717import copy
1818import logging
1919import re
20+ import warnings
2021from collections .abc import Callable
2122from pathlib import Path
2223from typing import Any
@@ -121,11 +122,12 @@ def _check_all_weight_quantizers_disabled(model: nn.Module) -> None:
121122 if attr_name .endswith ("weight_quantizer" ) and isinstance (
122123 quantizer , (TensorQuantizer , SequentialQuantizer )
123124 ):
124- assert not quantizer .is_enabled , (
125- f"vLLM fakequant export: { attr_name !r} must be disabled before saving "
126- f"quantizer_state (weights already folded). "
127- f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
128- )
125+ if quantizer .is_enabled :
126+ raise RuntimeError (
127+ f"vLLM fakequant export: { attr_name !r} must be disabled before saving "
128+ f"quantizer_state (weights already folded). "
129+ f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
130+ )
129131
130132
131133def disable_rotate (quantizer : TensorQuantizer ):
@@ -171,25 +173,27 @@ def requant_weights_for_export(
171173 ``w`` (e.g. CPU offload).
172174 """
173175 copied = copy .deepcopy (quantizer ).to (device = weight .device )
174- sequence_quantizers : list [TensorQuantizer ] = (
176+ quantizers : list [TensorQuantizer ] = (
175177 list (copied ) if isinstance (copied , SequentialQuantizer ) else [copied ]
176178 )
177179
178- for quantizer_copy in sequence_quantizers :
180+ for quantizer_copy in quantizers :
179181 quantizer_copy .eval ()
180182 quantizer_copy .reset_amax ()
181183 enable_stats_collection (quantizer_copy )
182184 # Match legacy single-quantizer path: first calib uses ``w`` as-is; chains use float.
183- if len (sequence_quantizers ) == 1 :
184- weight_quantized = sequence_quantizers [0 ](weight )
185+ if len (quantizers ) == 1 :
186+ weight_quantized = quantizers [0 ](weight )
185187 else :
186- weight_quantized = weight . float ()
187- for quantizer_copy in sequence_quantizers :
188+ weight_quantized = weight
189+ for quantizer_copy in quantizers :
188190 weight_quantized = quantizer_copy (weight_quantized )
189- for quantizer_copy in sequence_quantizers :
191+ for quantizer_copy in quantizers :
190192 finish_stats_collection (quantizer_copy )
191- weight_quantized = weight .float ()
192- for quantizer_copy in sequence_quantizers :
193+ # Re-run application pass to get the quantized output with the freshly collected amax.
194+ # The calibration forward above only collected stats; its output is intentionally discarded.
195+ weight_quantized = weight
196+ for quantizer_copy in quantizers :
193197 weight_quantized = quantizer_copy (weight_quantized )
194198 return weight_quantized .to (weight .dtype )
195199
@@ -219,6 +223,12 @@ def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
219223 try :
220224 return torch .cat (tensors , dim = 0 ).to (dtype = first .dtype , device = first .device )
221225 except RuntimeError :
226+ shapes = [tuple (t .shape ) for t in tensors ]
227+ warnings .warn (
228+ f"merge_amax_tensors_for_group: torch.cat failed for shapes { shapes } ; "
229+ "falling back to scalar max which loses per-channel amax structure." ,
230+ stacklevel = 2 ,
231+ )
222232 flat = torch .cat ([t .reshape (- 1 ).float () for t in tensors ])
223233 return torch .max (flat ).to (dtype = first .dtype , device = first .device )
224234
@@ -258,7 +268,9 @@ def _process_group(modules: list[nn.Module]) -> None:
258268 if pqs_list is None :
259269 return
260270
261- avg_pqs = torch .stack (pqs_list ).mean (0 )
271+ # Mean and clamp in float32: fp16/bf16 would underflow float32.tiny to 0 and divide by zero.
272+ pqs_dtype = pqs_list [0 ].dtype
273+ avg_pqs = torch .stack ([p .float () for p in pqs_list ]).mean (0 )
262274 avg_pqs = avg_pqs .clamp (min = torch .finfo (torch .float32 ).tiny )
263275
264276 for m in modules :
@@ -270,8 +282,8 @@ def _process_group(modules: list[nn.Module]) -> None:
270282 if torch .equal (old_pqs , avg_pqs_dev ):
271283 continue
272284 weight = state_dict [f"{ nm } .weight" ]
273- ratio = old_pqs .to (dtype = torch .float32 , device = weight .device ) / avg_pqs_dev .to (
274- dtype = torch . float32 , device = weight .device
285+ ratio = old_pqs .to (dtype = torch .float32 , device = weight .device ) / avg_pqs .to (
286+ device = weight .device
275287 )
276288 state_dict [f"{ nm } .weight" ] = (weight .to (torch .float32 ) * ratio ).to (weight .dtype )
277289 requant_weights .add (f"{ nm } .weight" )
@@ -281,7 +293,7 @@ def _process_group(modules: list[nn.Module]) -> None:
281293 if all (a is not None for a in amaxes ):
282294 synced_amax = merge_amax_tensors_for_group (amaxes )
283295
284- avg_pqs_out = avg_pqs .detach ().clone ()
296+ avg_pqs_out = avg_pqs .detach ().to ( pqs_dtype ). clone ()
285297 for m in modules :
286298 nm = id_to_name .get (id (m ))
287299 if nm is None :
@@ -309,14 +321,15 @@ def _process_group(modules: list[nn.Module]) -> None:
309321
310322 def _dummy_forward () -> None :
311323 # Partial forward is OK: hooks record layers reached before failure.
312- try :
313- model (torch .ones ([1 , 2 ], dtype = torch .long , device = dev ))
314- except Exception as e :
315- import logging
324+ with torch .inference_mode ():
325+ try :
326+ model (torch .ones ([1 , 2 ], dtype = torch .long , device = dev ))
327+ except Exception as e :
328+ import logging
316329
317- logging .getLogger (__name__ ).debug (
318- "Dummy forward for shared-input detection failed (expected for VLMs): %s" , e
319- )
330+ logging .getLogger (__name__ ).debug (
331+ "Dummy forward for shared-input detection failed (expected for VLMs): %s" , e
332+ )
320333
321334 input_to_linear , _ = collect_shared_input_modules (model , _dummy_forward )
322335 for modules in input_to_linear .values ():
@@ -380,9 +393,8 @@ def export_hf_vllm_fq_checkpoint(
380393 weight_name = attr_name .removesuffix ("_quantizer" )
381394 prefix = f"{ module_name } ." if module_name else ""
382395 sd_key = f"{ prefix } { weight_name } "
383- assert sd_key not in fakequant_weights , (
384- f"Weight { sd_key } has already been fakequantized"
385- )
396+ if sd_key in fakequant_weights :
397+ raise RuntimeError (f"Weight { sd_key } has already been fakequantized" )
386398 if sd_key in state_dict :
387399 w = state_dict [sd_key ]
388400 if sd_key in requant_weights :
@@ -419,74 +431,75 @@ def export_hf_vllm_fq_checkpoint(
419431 # Rotation is also cleared: the weight was already folded with rotation applied,
420432 # so if fold_weight is called on reload it must not re-rotate the exported weight.
421433 wqs_to_restore : list [tuple [TensorQuantizer , Any ]] = []
422- for _ , module in model .named_modules ():
423- if isinstance (module , QuantModule ):
424- for attr_name , quantizer in module .named_children ():
425- if not (attr_name .endswith ("weight_quantizer" ) and quantizer .is_enabled ):
426- continue
427- if isinstance (quantizer , SequentialQuantizer ):
428- quantizer .disable ()
429- for sub in quantizer :
430- orig_rotate = sub ._rotate
431- if sub .rotate_is_enabled :
432- sub ._rotate = disable_rotate (sub )
433- wqs_to_restore .append ((sub , orig_rotate ))
434- elif isinstance (quantizer , TensorQuantizer ):
435- quantizer .disable ()
436- orig_rotate = quantizer ._rotate
437- if quantizer .rotate_is_enabled :
438- quantizer ._rotate = disable_rotate (quantizer )
439- wqs_to_restore .append ((quantizer , orig_rotate ))
440-
441- quantizer_state_dict = get_quantizer_state_dict (model )
442- for key in list (quantizer_state_dict ):
443- if is_weight_quantizer_state_key (key ):
444- # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
445- # Reload must force-disable WQs missing from saved state (see
446- # ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
447- quantizer_state_dict .pop (key )
448- elif key in input_quantizers_folded_pqs :
449- # pre_quant_scale was folded into the weight; keep the buffer for strict load but
450- # save identity so activations are not scaled twice.
451- qstate_val = quantizer_state_dict [key ]
452- if isinstance (qstate_val , dict ) and "_pre_quant_scale" in qstate_val :
453- quantizer_state_dict [key ]["_pre_quant_scale" ] = torch .ones_like (
454- qstate_val ["_pre_quant_scale" ]
455- )
456-
457- # Patch input quantizers with averaged pqs and unified amax so that vLLM's single
458- # per-group input quantizer sees consistent values (covers both dense qkv and MoE experts).
459- for iq_key , (avg_pqs , max_input_amax ) in pqs_overrides .items ():
460- if iq_key in quantizer_state_dict :
461- qstate_val = quantizer_state_dict [iq_key ]
462- if isinstance (qstate_val , dict ):
463- if "_pre_quant_scale" in qstate_val :
464- qstate_val ["_pre_quant_scale" ] = avg_pqs
465- if max_input_amax is not None and "_amax" in qstate_val :
466- qstate_val ["_amax" ] = max_input_amax
467-
468- modelopt_state = mto .modelopt_state (model )
469- # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
470- # ``quantizer_state`` and strip weight-quantizer entries (same policy as
471- # ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
472- _check_all_weight_quantizers_disabled (model )
473- qstate = quantizer_state (model )
474- for key in list (qstate ):
475- if is_weight_quantizer_state_key (key ):
476- qstate .pop (key )
477-
478- for mode_str , m_state in modelopt_state .get ("modelopt_state_dict" , []):
479- if mode_str == "quantize" and "metadata" in m_state :
480- m_state ["metadata" ]["quantizer_state" ] = qstate
481- break
482-
483- # Per-quantizer tensor dict loaded alongside metadata on reload.
484- modelopt_state ["modelopt_state_weights" ] = quantizer_state_dict
485- safe_save (modelopt_state , export_dir / "vllm_fq_modelopt_state.pth" )
486-
487- # Step 3: Save HF weights using the pre-built folded state dict.
488- model .save_pretrained (export_dir , state_dict = clean_sd , save_modelopt_state = False )
489-
490- for wq , orig_rotate in wqs_to_restore :
491- wq .enable ()
492- wq ._rotate = orig_rotate
434+ try :
435+ for _ , module in model .named_modules ():
436+ if isinstance (module , QuantModule ):
437+ for attr_name , quantizer in module .named_children ():
438+ if not (attr_name .endswith ("weight_quantizer" ) and quantizer .is_enabled ):
439+ continue
440+ if isinstance (quantizer , SequentialQuantizer ):
441+ quantizer .disable ()
442+ for sub in quantizer :
443+ orig_rotate = sub ._rotate
444+ if sub .rotate_is_enabled :
445+ sub ._rotate = disable_rotate (sub )
446+ wqs_to_restore .append ((sub , orig_rotate ))
447+ elif isinstance (quantizer , TensorQuantizer ):
448+ quantizer .disable ()
449+ orig_rotate = quantizer ._rotate
450+ if quantizer .rotate_is_enabled :
451+ quantizer ._rotate = disable_rotate (quantizer )
452+ wqs_to_restore .append ((quantizer , orig_rotate ))
453+
454+ quantizer_state_dict = get_quantizer_state_dict (model )
455+ for key in list (quantizer_state_dict ):
456+ if is_weight_quantizer_state_key (key ):
457+ # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
458+ # Reload must force-disable WQs missing from saved state (see
459+ # ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
460+ quantizer_state_dict .pop (key )
461+ elif key in input_quantizers_folded_pqs :
462+ # pre_quant_scale was folded into the weight; keep the buffer for strict load but
463+ # save identity so activations are not scaled twice.
464+ qstate_val = quantizer_state_dict [key ]
465+ if isinstance (qstate_val , dict ) and "_pre_quant_scale" in qstate_val :
466+ quantizer_state_dict [key ]["_pre_quant_scale" ] = torch .ones_like (
467+ qstate_val ["_pre_quant_scale" ]
468+ )
469+
470+ # Patch input quantizers with averaged pqs and unified amax so that vLLM's single
471+ # per-group input quantizer sees consistent values (covers both dense qkv and MoE experts).
472+ for iq_key , (avg_pqs , max_input_amax ) in pqs_overrides .items ():
473+ if iq_key in quantizer_state_dict :
474+ qstate_val = quantizer_state_dict [iq_key ]
475+ if isinstance (qstate_val , dict ):
476+ if "_pre_quant_scale" in qstate_val :
477+ qstate_val ["_pre_quant_scale" ] = avg_pqs
478+ if max_input_amax is not None and "_amax" in qstate_val :
479+ qstate_val ["_amax" ] = max_input_amax
480+
481+ modelopt_state = mto .modelopt_state (model )
482+ # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
483+ # ``quantizer_state`` and strip weight-quantizer entries (same policy as
484+ # ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
485+ _check_all_weight_quantizers_disabled (model )
486+ qstate = quantizer_state (model )
487+ for key in list (qstate ):
488+ if is_weight_quantizer_state_key (key ):
489+ qstate .pop (key )
490+
491+ for mode_str , m_state in modelopt_state .get ("modelopt_state_dict" , []):
492+ if mode_str == "quantize" and "metadata" in m_state :
493+ m_state ["metadata" ]["quantizer_state" ] = qstate
494+ break
495+
496+ # Per-quantizer tensor dict loaded alongside metadata on reload.
497+ modelopt_state ["modelopt_state_weights" ] = quantizer_state_dict
498+ safe_save (modelopt_state , export_dir / "vllm_fq_modelopt_state.pth" )
499+
500+ # Step 3: Save HF weights using the pre-built folded state dict.
501+ model .save_pretrained (export_dir , state_dict = clean_sd , save_modelopt_state = False )
502+ finally :
503+ for wq , orig_rotate in wqs_to_restore :
504+ wq .enable ()
505+ wq ._rotate = orig_rotate
0 commit comments