1414# limitations under the License.
1515"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616
17- import contextlib
1817import copy
18+ import logging
1919import re
20+ from collections .abc import Callable
2021from pathlib import Path
2122from typing import Any
2223
2728from modelopt .torch .quantization .config import RotateConfig
2829from modelopt .torch .quantization .conversion import quantizer_state
2930from modelopt .torch .quantization .model_calib import enable_stats_collection , finish_stats_collection
30- from modelopt .torch .quantization .nn import QuantModule , TensorQuantizer
31+ from modelopt .torch .quantization .nn import QuantModule , SequentialQuantizer , TensorQuantizer
3132from modelopt .torch .quantization .utils import get_quantizer_state_dict
3233from modelopt .torch .utils import get_unwrapped_name , safe_save
3334
3738
3839__all__ = [
3940 "export_hf_vllm_fq_checkpoint" ,
41+ "infer_quantizer_prefix_remap" ,
4042 "is_weight_quantizer_state_key" ,
4143 "merge_amax_tensors_for_group" ,
4244]
@@ -53,6 +55,79 @@ def is_weight_quantizer_state_key(key: str) -> bool:
5355 return bool (_WEIGHT_QUANTIZER_STATE_KEY .search (key ))
5456
5557
58+ def infer_quantizer_prefix_remap (
59+ quantizer_keys : dict [str , Any ],
60+ map_fun : Callable [[dict [str , Any ]], dict [str , Any ]],
61+ ) -> dict [str , str ]:
62+ """Infer HF root name → vLLM root (e.g. ``backbone`` → ``model``) for reload/export.
63+
64+ Map HF root → vLLM root (e.g. ``backbone`` → ``model``) by probing ``map_fun`` with
65+ synthetic ``<module>.weight`` keys and a 2-D placeholder (quantizer paths are not weight
66+ keys). Keys under the same HF root must agree on the target root or :exc:`ValueError` is
67+ raised; failed probes are skipped. Returns ``{hf_root: vllm_root}`` only where the root
68+ renames; not for arbitrary layer rewrites.
69+
70+ Args:
71+ quantizer_keys: HF quantizer state paths as keys (values unused).
72+ map_fun: HF→vLLM weight ``state_dict`` mapper, same as for ``convert_dict_to_vllm``.
73+
74+ Returns:
75+ ``{hf_root: vllm_root}`` for roots that rename; omits identity pairs.
76+ """
77+ logger = logging .getLogger (__name__ )
78+ probe_weight = torch .empty ((1 , 1 ))
79+ observed_vllm_root : dict [str , str ] = {}
80+
81+ for key in quantizer_keys :
82+ first_component = key .split ("." )[0 ]
83+ last_dot = key .rfind ("." )
84+ if last_dot == - 1 :
85+ continue
86+ probe_key = key [:last_dot ] + ".weight"
87+ try :
88+ result = map_fun ({probe_key : probe_weight })
89+ if not result :
90+ continue
91+ new_key = next (iter (result ))
92+ new_first = new_key .split ("." )[0 ]
93+ except Exception as e :
94+ logger .debug ("prefix-remap probe failed for %r: %s" , probe_key , e )
95+ continue
96+
97+ if first_component not in observed_vllm_root :
98+ observed_vllm_root [first_component ] = new_first
99+ elif observed_vllm_root [first_component ] != new_first :
100+ raise ValueError (
101+ "Inconsistent HF→vLLM prefix remap for "
102+ f"{ first_component !r} : probes implied "
103+ f"{ observed_vllm_root [first_component ]!r} and { new_first !r} . "
104+ "map_fun must apply one target root per HF root, or use explicit quantizer "
105+ "key remapping."
106+ )
107+
108+ return {
109+ hf_root : vllm_root
110+ for hf_root , vllm_root in observed_vllm_root .items ()
111+ if hf_root != vllm_root
112+ }
113+
114+
115+ def _check_all_weight_quantizers_disabled (model : nn .Module ) -> None :
116+ """Export invariant before writing metadata: every weight quantizer must be off."""
117+ for _ , module in model .named_modules ():
118+ if not isinstance (module , QuantModule ):
119+ continue
120+ for attr_name , quantizer in module .named_children ():
121+ if attr_name .endswith ("weight_quantizer" ) and isinstance (
122+ quantizer , (TensorQuantizer , SequentialQuantizer )
123+ ):
124+ assert not quantizer .is_enabled , (
125+ f"vLLM fakequant export: { attr_name !r} must be disabled before saving "
126+ f"quantizer_state (weights already folded). "
127+ f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils."
128+ )
129+
130+
56131def disable_rotate (quantizer : TensorQuantizer ):
57132 """Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
58133 if isinstance (quantizer ._rotate , RotateConfig ):
@@ -62,35 +137,61 @@ def disable_rotate(quantizer: TensorQuantizer):
62137 return False
63138
64139
65- def _collect_expert_pre_quant_scales (
140+ def _collect_group_pre_quant_scales (
66141 experts : list [nn .Module ],
67142) -> list [torch .Tensor ] | None :
68143 """Return per-expert ``pre_quant_scale`` tensors if every expert can be averaged; else None.
69144
70145 Skips groups where any expert has no input quantizer, no pqs (e.g. weight-only AWQ INT4),
71146 or a disabled input quantizer (pqs already folded / not used).
72147 """
73- pqs_list : list [torch .Tensor ] = []
74- for ex in experts :
75- iq = getattr (ex , "input_quantizer" , None )
76- if iq is None or not iq .is_enabled or iq .pre_quant_scale is None :
148+ pre_quant_scales : list [torch .Tensor ] = []
149+ for expert_module in experts :
150+ input_quantizer = getattr (expert_module , "input_quantizer" , None )
151+ if (
152+ input_quantizer is None
153+ or not input_quantizer .is_enabled
154+ or input_quantizer .pre_quant_scale is None
155+ ):
77156 return None
78- pqs_list .append (iq .pre_quant_scale )
79- return pqs_list
157+ pre_quant_scales .append (input_quantizer .pre_quant_scale )
158+ return pre_quant_scales
80159
81160
82161def requant_weights_for_export (
83- quantizer : TensorQuantizer ,
84- w : torch .Tensor ,
162+ quantizer : TensorQuantizer | SequentialQuantizer ,
163+ weight : torch .Tensor ,
85164) -> torch .Tensor :
86- """Requantize weights for export."""
87- quantizer_copy = copy .deepcopy (quantizer )
88- quantizer_copy .eval ()
89- quantizer_copy .reset_amax ()
90- enable_stats_collection (quantizer_copy )
91- quantizer_copy (w )
92- finish_stats_collection (quantizer_copy )
93- return quantizer_copy (w .float ()).to (w .dtype )
165+ """Requantize folded weights after resmooth (``TensorQuantizer`` or ``SequentialQuantizer``).
166+
167+ A single ``TensorQuantizer`` is treated as a one-stage chain so the same
168+ calibrate-then-apply steps cover W4A8-style sequential weights (e.g. INT4→FP8).
169+
170+ Deepcopy may leave buffers on the original device; ``.to(device=w.device)`` aligns with
171+ ``w`` (e.g. CPU offload).
172+ """
173+ copied = copy .deepcopy (quantizer ).to (device = weight .device )
174+ sequence_quantizers : list [TensorQuantizer ] = (
175+ list (copied ) if isinstance (copied , SequentialQuantizer ) else [copied ]
176+ )
177+
178+ for quantizer_copy in sequence_quantizers :
179+ quantizer_copy .eval ()
180+ quantizer_copy .reset_amax ()
181+ enable_stats_collection (quantizer_copy )
182+ # Match legacy single-quantizer path: first calib uses ``w`` as-is; chains use float.
183+ if len (sequence_quantizers ) == 1 :
184+ weight_quantized = sequence_quantizers [0 ](weight )
185+ else :
186+ weight_quantized = weight .float ()
187+ for quantizer_copy in sequence_quantizers :
188+ weight_quantized = quantizer_copy (weight_quantized )
189+ for quantizer_copy in sequence_quantizers :
190+ finish_stats_collection (quantizer_copy )
191+ weight_quantized = weight .float ()
192+ for quantizer_copy in sequence_quantizers :
193+ weight_quantized = quantizer_copy (weight_quantized )
194+ return weight_quantized .to (weight .dtype )
94195
95196
96197def merge_amax_tensors_for_group (tensors : list [torch .Tensor ]) -> torch .Tensor :
@@ -100,8 +201,10 @@ def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor:
100201
101202 - If every tensor has the same shape, take the element-wise maximum over the group
102203 (conservative when each branch carried the same axis layout).
103- - If shapes differ (e.g. GQA q vs k), try ``torch.cat(..., dim=0)`` when valid for
104- per-channel amax; otherwise fall back to a scalar max over all elements.
204+ - If shapes differ: ``torch.cat(..., dim=0)`` assumes **1D per-channel** amaxes in
205+ fused order (e.g. GQA q/k/v → ``[N_q]`` + ``[N_kv]`` + ``[N_kv]``), matching vLLM’s
206+ grouped quantizer. Not valid for 2D blockwise amax; on failure, **scalar**
207+ max (drops channel structure).
105208 """
106209 if not tensors :
107210 raise ValueError ("merge_amax_tensors_for_group: expected at least one tensor" )
@@ -151,7 +254,7 @@ def _resmooth_experts_for_export(
151254 requant_weights : set [str ] = set ()
152255
153256 def _process_group (modules : list [nn .Module ]) -> None :
154- pqs_list = _collect_expert_pre_quant_scales (modules )
257+ pqs_list = _collect_group_pre_quant_scales (modules )
155258 if pqs_list is None :
156259 return
157260
@@ -205,9 +308,15 @@ def _process_group(modules: list[nn.Module]) -> None:
205308 dev = next (model .parameters ()).device
206309
207310 def _dummy_forward () -> None :
208- # Partial forward is OK: hooks record layers reached before failure (e.g. VLMs) .
209- with contextlib . suppress ( Exception ) :
311+ # Partial forward is OK: hooks record layers reached before failure.
312+ try :
210313 model (torch .ones ([1 , 2 ], dtype = torch .long , device = dev ))
314+ except Exception as e :
315+ import logging
316+
317+ logging .getLogger (__name__ ).debug (
318+ "Dummy forward for shared-input detection failed (expected for VLMs): %s" , e
319+ )
211320
212321 input_to_linear , _ = collect_shared_input_modules (model , _dummy_forward )
213322 for modules in input_to_linear .values ():
@@ -263,7 +372,7 @@ def export_hf_vllm_fq_checkpoint(
263372 for attr_name , quantizer in module .named_children ():
264373 if not (
265374 attr_name .endswith ("weight_quantizer" )
266- and isinstance (quantizer , TensorQuantizer )
375+ and ( isinstance (quantizer , ( TensorQuantizer , SequentialQuantizer )) )
267376 and quantizer .fake_quant
268377 and quantizer .is_enabled
269378 ):
@@ -313,11 +422,16 @@ def export_hf_vllm_fq_checkpoint(
313422 for _ , module in model .named_modules ():
314423 if isinstance (module , QuantModule ):
315424 for attr_name , quantizer in module .named_children ():
316- if (
317- attr_name .endswith ("weight_quantizer" )
318- and isinstance (quantizer , TensorQuantizer )
319- and quantizer .is_enabled
320- ):
425+ if not (attr_name .endswith ("weight_quantizer" ) and quantizer .is_enabled ):
426+ continue
427+ if isinstance (quantizer , SequentialQuantizer ):
428+ quantizer .disable ()
429+ for sub in quantizer :
430+ orig_rotate = sub ._rotate
431+ if sub .rotate_is_enabled :
432+ sub ._rotate = disable_rotate (sub )
433+ wqs_to_restore .append ((sub , orig_rotate ))
434+ elif isinstance (quantizer , TensorQuantizer ):
321435 quantizer .disable ()
322436 orig_rotate = quantizer ._rotate
323437 if quantizer .rotate_is_enabled :
@@ -328,6 +442,8 @@ def export_hf_vllm_fq_checkpoint(
328442 for key in list (quantizer_state_dict ):
329443 if is_weight_quantizer_state_key (key ):
330444 # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors.
445+ # Reload must force-disable WQs missing from saved state (see
446+ # ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils).
331447 quantizer_state_dict .pop (key )
332448 elif key in input_quantizers_folded_pqs :
333449 # pre_quant_scale was folded into the weight; keep the buffer for strict load but
@@ -351,10 +467,12 @@ def export_hf_vllm_fq_checkpoint(
351467
352468 modelopt_state = mto .modelopt_state (model )
353469 # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
354- # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded).
470+ # ``quantizer_state`` and strip weight-quantizer entries (same policy as
471+ # ``modelopt_state_weights``). Reload synthesizes missing WQ rows with ``_disabled``.
472+ _check_all_weight_quantizers_disabled (model )
355473 qstate = quantizer_state (model )
356474 for key in list (qstate ):
357- if is_weight_quantizer_state_key (key ) and qstate [ key ]. get ( "_disabled" ) :
475+ if is_weight_quantizer_state_key (key ):
358476 qstate .pop (key )
359477
360478 for mode_str , m_state in modelopt_state .get ("modelopt_state_dict" , []):
0 commit comments