2424from modelopt .torch .quantization .conversion import quantizer_state
2525from modelopt .torch .quantization .nn import QuantModule , TensorQuantizer
2626from modelopt .torch .quantization .utils import get_quantizer_state_dict
27+ from modelopt .torch .quantization .utils .core_utils import enable_weight_access_and_writeback
28+ from modelopt .torch .quantization .utils .layerwise_calib import LayerActivationCollector
2729from modelopt .torch .utils import get_unwrapped_name
2830
2931__all__ = ["export_hf_vllm_fq_checkpoint" ]
@@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer):
3840 return False
3941
4042
43+ def _fakequant_module_weights (
44+ module : nn .Module ,
45+ module_name : str ,
46+ model : nn .Module ,
47+ state_dict : dict | None ,
48+ input_quantizers_folded_pqs : set ,
49+ fakequant_weights : set ,
50+ inplace : bool ,
51+ ):
52+ """Apply fake-quant to a single QuantModule's weights.
53+
54+ When ``inplace=False``, reads/writes weights from/to ``state_dict``.
55+ When ``inplace=True``, modifies the module's weight parameters directly.
56+ """
57+ if not isinstance (module , QuantModule ):
58+ return
59+ for attr_name , quantizer in module .named_children ():
60+ if not (
61+ attr_name .endswith ("weight_quantizer" )
62+ and isinstance (quantizer , TensorQuantizer )
63+ and quantizer .fake_quant
64+ and quantizer .is_enabled
65+ ):
66+ continue
67+ weight_name = attr_name .removesuffix ("_quantizer" )
68+ prefix = f"{ module_name } ." if module_name else ""
69+ sd_key = f"{ prefix } { weight_name } "
70+ assert sd_key not in fakequant_weights , f"Weight { sd_key } has already been fakequantized"
71+
72+ if inplace :
73+ w = getattr (module , weight_name )
74+ w_quant = quantizer (w .float ()).to (w .dtype )
75+ else :
76+ assert state_dict is not None
77+ if sd_key not in state_dict :
78+ continue
79+ w = state_dict [sd_key ]
80+ w_quant = quantizer (w .float ()).to (w .dtype )
81+
82+ # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
83+ # Only valid when input_quantizer does NOT fake-quant activations. If it does
84+ # fake_quant(x*s), the non-linearity prevents folding s into W.
85+ inp_attr = attr_name .replace ("weight_quantizer" , "input_quantizer" )
86+ if hasattr (module , inp_attr ):
87+ inp_q = getattr (module , inp_attr )
88+ if (
89+ hasattr (inp_q , "_pre_quant_scale" )
90+ and inp_q ._pre_quant_scale is not None
91+ and inp_q ._disabled
92+ ):
93+ scale = inp_q ._pre_quant_scale .squeeze ().to (device = w_quant .device )
94+ w_quant = (w_quant * scale [None , :]).to (w_quant .dtype )
95+ inp_q_key = get_unwrapped_name (
96+ f"{ module_name } .{ inp_attr } " if module_name else inp_attr , model
97+ )
98+ input_quantizers_folded_pqs .add (inp_q_key )
99+
100+ if inplace :
101+ w .data .copy_ (w_quant )
102+ else :
103+ assert state_dict is not None
104+ state_dict [sd_key ] = w_quant .cpu ()
105+ fakequant_weights .add (sd_key )
106+
107+
41108def export_hf_vllm_fq_checkpoint (
42109 model : nn .Module ,
43110 export_dir : Path | str ,
111+ inplace_mem_efficient : bool = False ,
44112):
45113 """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.
46114
@@ -53,62 +121,66 @@ def export_hf_vllm_fq_checkpoint(
53121 Args:
54122 model: In-memory quantized model.
55123 export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
124+ inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at
125+ a time using ``enable_weight_access_and_writeback``, avoiding full state
126+ dict materialization. This is destructive — model weights are permanently
127+ modified and weight quantizers are not re-enabled after export.
56128 """
57129 export_dir = Path (export_dir )
58130 export_dir .mkdir (parents = True , exist_ok = True )
59131
60132 # Step 1: Build the folded HF state dict.
61- # model.state_dict() returns detached copies of all tensors, so model
62- # parameters are never modified. Apply each weight quantizer's fake-quant
63- # to the corresponding weight tensor in the copy.
64- state_dict = model .state_dict ()
65133 fakequant_weights = set ()
66- input_quantizers_folded_pqs = (
67- set ()
68- ) # keys for input_quantizers where pre_quant_scale was folded
134+ input_quantizers_folded_pqs = set ()
69135 with torch .inference_mode ():
70- for module_name , module in model .named_modules ():
71- if not isinstance (module , QuantModule ):
72- continue
73- for attr_name , quantizer in module .named_children ():
74- if not (
75- attr_name .endswith ("weight_quantizer" )
76- and isinstance (quantizer , TensorQuantizer )
77- and quantizer .fake_quant
78- and quantizer .is_enabled
79- ):
136+ if inplace_mem_efficient :
137+ # Inplace path: iterate decoder layers, one offload<->onload per layer.
138+ decoder_layers = LayerActivationCollector .get_decoder_layers (model )
139+ assert decoder_layers is not None , (
140+ "inplace_mem_efficient=True requires a model with discoverable decoder layers"
141+ )
142+ for name , module in model .named_modules ():
143+ if module not in decoder_layers :
80144 continue
81- weight_name = attr_name .removesuffix ("_quantizer" )
82- prefix = f"{ module_name } ." if module_name else ""
83- sd_key = f"{ prefix } { weight_name } "
84- assert sd_key not in fakequant_weights , (
85- f"Weight { sd_key } has already been fakequantized"
86- )
87- if sd_key in state_dict :
88- w = state_dict [sd_key ]
89- w_quant = quantizer (w .float ()).to (w .dtype ).cpu ()
90- # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
91- # Only valid when input_quantizer does NOT fake-quant activations. If it does
92- # fake_quant(x*s), the non-linearity prevents folding s into W.
93- inp_attr = attr_name .replace ("weight_quantizer" , "input_quantizer" )
94- if hasattr (module , inp_attr ):
95- inp_q = getattr (module , inp_attr )
96- if (
97- hasattr (inp_q , "_pre_quant_scale" )
98- and inp_q ._pre_quant_scale is not None
99- and inp_q ._disabled
100- ):
101- scale = inp_q ._pre_quant_scale .squeeze ().to (device = w_quant .device )
102- w_quant = (w_quant * scale [None , :]).to (w_quant .dtype )
103- inp_q_key = get_unwrapped_name (
104- f"{ module_name } .{ inp_attr } " if module_name else inp_attr , model
105- )
106- input_quantizers_folded_pqs .add (inp_q_key )
107- state_dict [sd_key ] = w_quant
108- fakequant_weights .add (sd_key )
109-
110- # Filter quantizer tensors out for a clean HF checkpoint.
111- clean_sd = {k : v for k , v in state_dict .items () if "quantizer" not in k }
145+ with enable_weight_access_and_writeback (module , module ):
146+ for sub_name , sub_mod in module .named_modules ():
147+ full_name = f"{ name } .{ sub_name } " if sub_name else name
148+ _fakequant_module_weights (
149+ sub_mod ,
150+ full_name ,
151+ model ,
152+ None ,
153+ input_quantizers_folded_pqs ,
154+ fakequant_weights ,
155+ inplace = True ,
156+ )
157+ # Meta tensors for offloaded weights (free); offload maps now have
158+ # fakequanted values via writeback.
159+ state_dict = model .state_dict ()
160+ else :
161+ # Default path: full state_dict copy, fakequant into the copy.
162+ state_dict = model .state_dict ()
163+ for module_name , module in model .named_modules ():
164+ with enable_weight_access_and_writeback (module , model ):
165+ _fakequant_module_weights (
166+ module ,
167+ module_name ,
168+ model ,
169+ state_dict ,
170+ input_quantizers_folded_pqs ,
171+ fakequant_weights ,
172+ inplace = False ,
173+ )
174+
175+ if inplace_mem_efficient :
176+ # Let save_pretrained build its own state_dict so offloaded params go through
177+ # its module_map / get_state_dict_from_offload path (modeling_utils.py:3967+).
178+ # Passing state_dict= bypasses that path and crashes on meta tensors.
179+ quantizer_keys = [k for k in state_dict if "quantizer" in k ]
180+ clean_sd = None
181+ else :
182+ clean_sd = {k : v for k , v in state_dict .items () if "quantizer" not in k }
183+ quantizer_keys = None
112184
113185 # Step 2: Disable weight quantizers, save modelopt state + quantizer state
114186 # dict, then re-enable. The _disabled=True flag is captured in modelopt_state
@@ -161,9 +233,18 @@ def export_hf_vllm_fq_checkpoint(
161233 modelopt_state ["modelopt_state_weights" ] = quantizer_state_dict
162234 torch .save (modelopt_state , export_dir / "vllm_fq_modelopt_state.pth" )
163235
164- # Step 3: Save HF weights using the pre-built folded state dict.
165- model .save_pretrained (export_dir , state_dict = clean_sd , save_modelopt_state = False )
166-
167- for wq , orig_rotate in wqs_to_restore :
168- wq .enable ()
169- wq ._rotate = orig_rotate
236+ # Step 3: Save HF weights.
237+ if inplace_mem_efficient :
238+ prev_ignore = getattr (model , "_keys_to_ignore_on_save" , None )
239+ model ._keys_to_ignore_on_save = quantizer_keys
240+ try :
241+ model .save_pretrained (export_dir , save_modelopt_state = False )
242+ finally :
243+ model ._keys_to_ignore_on_save = prev_ignore
244+ else :
245+ model .save_pretrained (export_dir , state_dict = clean_sd , save_modelopt_state = False )
246+
247+ if not inplace_mem_efficient :
248+ for wq , orig_rotate in wqs_to_restore :
249+ wq .enable ()
250+ wq ._rotate = orig_rotate
0 commit comments