1414# limitations under the License.
1515"""Export HuggingFace model to vLLM fakequant checkpoint."""
1616
17+ import logging
18+ import time
19+ from collections import defaultdict
20+ from concurrent .futures import ThreadPoolExecutor
21+ from dataclasses import dataclass
1722from pathlib import Path
1823
1924import torch
2833
2934__all__ = ["export_hf_vllm_fq_checkpoint" ]
3035
36+ logger = logging .getLogger (__name__ )
37+
38+
39+ @dataclass
40+ class _WeightQuantWork :
41+ """A single weight tensor to be fake-quantized during export."""
42+
43+ sd_key : str
44+ quantizer : TensorQuantizer
45+ weight : torch .Tensor
46+ # For optional pre_quant_scale folding:
47+ inp_q : TensorQuantizer | None
48+ inp_q_key : str | None
49+
50+
51+ def _collect_quant_work (
52+ model : nn .Module , state_dict : dict [str , torch .Tensor ]
53+ ) -> list [_WeightQuantWork ]:
54+ """Collect all weight quantization work items from the model."""
55+ work_items = []
56+ seen_keys : set [str ] = set ()
57+ for module_name , module in model .named_modules ():
58+ if not isinstance (module , QuantModule ):
59+ continue
60+ for attr_name , quantizer in module .named_children ():
61+ if not (
62+ attr_name .endswith ("weight_quantizer" )
63+ and isinstance (quantizer , TensorQuantizer )
64+ and quantizer .fake_quant
65+ and quantizer .is_enabled
66+ ):
67+ continue
68+ weight_name = attr_name .removesuffix ("_quantizer" )
69+ prefix = f"{ module_name } ." if module_name else ""
70+ sd_key = f"{ prefix } { weight_name } "
71+ assert sd_key not in seen_keys , f"Weight { sd_key } has already been fakequantized"
72+ seen_keys .add (sd_key )
73+ if sd_key not in state_dict :
74+ continue
75+ # Check for pre_quant_scale folding eligibility.
76+ inp_q = None
77+ inp_q_key = None
78+ inp_attr = attr_name .replace ("weight_quantizer" , "input_quantizer" )
79+ if hasattr (module , inp_attr ):
80+ candidate = getattr (module , inp_attr )
81+ if (
82+ hasattr (candidate , "_pre_quant_scale" )
83+ and candidate ._pre_quant_scale is not None
84+ and candidate ._disabled
85+ ):
86+ inp_q = candidate
87+ inp_q_key = get_unwrapped_name (
88+ f"{ module_name } .{ inp_attr } " if module_name else inp_attr , model
89+ )
90+ work_items .append (
91+ _WeightQuantWork (
92+ sd_key = sd_key ,
93+ quantizer = quantizer ,
94+ weight = state_dict [sd_key ],
95+ inp_q = inp_q ,
96+ inp_q_key = inp_q_key ,
97+ )
98+ )
99+ return work_items
100+
101+
102+ def _process_weight (item : _WeightQuantWork ) -> tuple [str , torch .Tensor , str | None ]:
103+ """Fake-quantize a single weight tensor and optionally fold pre_quant_scale.
104+
105+ Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None).
106+ """
107+ w = item .weight
108+ w_quant = item .quantizer (w .float ()).to (w .dtype ).cpu ()
109+ if item .inp_q is not None :
110+ scale = item .inp_q ._pre_quant_scale .squeeze ().to (device = w_quant .device )
111+ w_quant = (w_quant * scale [None , :]).to (w_quant .dtype )
112+ return item .sd_key , w_quant , item .inp_q_key
113+
114+
115+ def _process_device_batch (items : list [_WeightQuantWork ], device : torch .device ):
116+ """Process all weight items on a single GPU. Runs in a dedicated thread."""
117+ with torch .cuda .device (device ):
118+ results = [_process_weight (item ) for item in items ]
119+ torch .cuda .synchronize (device )
120+ return results
121+
31122
32123def disable_rotate (quantizer : TensorQuantizer ):
33124 """Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
@@ -41,6 +132,7 @@ def disable_rotate(quantizer: TensorQuantizer):
41132def export_hf_vllm_fq_checkpoint (
42133 model : nn .Module ,
43134 export_dir : Path | str ,
135+ parallel : bool = True ,
44136):
45137 """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.
46138
@@ -53,6 +145,9 @@ def export_hf_vllm_fq_checkpoint(
53145 Args:
54146 model: In-memory quantized model.
55147 export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
148+ parallel: If True, fake-quantize weights across GPUs concurrently using
149+ one thread per GPU device. Falls back to sequential when all weights
150+ are on the same device or on CPU. Default True.
56151 """
57152 export_dir = Path (export_dir )
58153 export_dir .mkdir (parents = True , exist_ok = True )
@@ -62,50 +157,60 @@ def export_hf_vllm_fq_checkpoint(
62157 # parameters are never modified. Apply each weight quantizer's fake-quant
63158 # to the corresponding weight tensor in the copy.
64159 state_dict = model .state_dict ()
65- fakequant_weights = set ()
66- input_quantizers_folded_pqs = (
67- set ()
68- ) # keys for input_quantizers where pre_quant_scale was folded
160+ fakequant_weights : set [str ] = set ()
161+ input_quantizers_folded_pqs : set [str ] = set ()
162+
163+ work_items = _collect_quant_work (model , state_dict )
164+
165+ # Group work items by device for parallel dispatch.
166+ device_groups : dict [torch .device , list [_WeightQuantWork ]] = defaultdict (list )
167+ for item in work_items :
168+ device_groups [item .weight .device ].append (item )
169+
170+ num_cuda_devices = sum (1 for d in device_groups if d .type == "cuda" )
171+ use_parallel = parallel and num_cuda_devices > 1
172+
173+ t0 = time .monotonic ()
69174 with torch .inference_mode ():
70- for module_name , module in model . named_modules () :
71- if not isinstance ( module , QuantModule ):
72- continue
73- for attr_name , quantizer in module . named_children ():
74- if not (
75- attr_name . endswith ( "weight_quantizer" )
76- and isinstance ( quantizer , TensorQuantizer )
77- and quantizer . fake_quant
78- and quantizer . is_enabled
79- ):
80- continue
81- weight_name = attr_name . removesuffix ( "_quantizer" )
82- prefix = f" { module_name } ." if module_name else ""
83- sd_key = f" { prefix } { weight_name } "
84- assert sd_key not in fakequant_weights , (
85- f"Weight { sd_key } has already been fakequantized"
86- )
87- if sd_key in state_dict :
88- w = state_dict [ sd_key ]
89- w_quant = quantizer ( w . float ()). to ( w . dtype ). cpu ()
90- # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s )
91- # Only valid when input_quantizer does NOT fake-quant activations. If it does
92- # fake_quant(x*s), the non-linearity prevents folding s into W.
93- inp_attr = attr_name . replace ( "weight_quantizer" , "input_quantizer" )
94- if hasattr ( module , inp_attr ):
95- inp_q = getattr ( module , inp_attr )
96- if (
97- hasattr ( inp_q , "_pre_quant_scale" )
98- and inp_q . _pre_quant_scale is not None
99- and inp_q . _disabled
100- ):
101- scale = inp_q . _pre_quant_scale . squeeze (). to ( device = w_quant . device )
102- w_quant = ( w_quant * scale [ None , :]). to ( w_quant . dtype )
103- inp_q_key = get_unwrapped_name (
104- f" { module_name } . { inp_attr } " if module_name else inp_attr , model
105- )
106- input_quantizers_folded_pqs . add ( inp_q_key )
107- state_dict [ sd_key ] = w_quant
108- fakequant_weights . add ( sd_key )
175+ if use_parallel :
176+ logger . info (
177+ "Parallel export: %d weights across %d GPUs (%s)" ,
178+ len ( work_items ),
179+ num_cuda_devices ,
180+ ", " . join ( f" { d } : { len ( items ) } weights" for d , items in device_groups . items ()),
181+ )
182+ all_results : list [ tuple [ str , torch . Tensor , str | None ]] = []
183+ with ThreadPoolExecutor ( max_workers = num_cuda_devices ) as pool :
184+ futures = []
185+ for device , items in device_groups . items ():
186+ if device . type == "cuda" :
187+ futures . append ( pool . submit ( _process_device_batch , items , device ))
188+ else :
189+ # CPU weights: process inline (no thread needed).
190+ all_results . extend ([ _process_weight ( item ) for item in items ])
191+ for future in futures :
192+ all_results . extend ( future . result ())
193+ for sd_key , w_quant , inp_q_key in all_results :
194+ state_dict [ sd_key ] = w_quant
195+ fakequant_weights . add ( sd_key )
196+ if inp_q_key is not None :
197+ input_quantizers_folded_pqs . add ( inp_q_key )
198+ else :
199+ # Sequential fallback (single GPU, CPU, or parallel=False).
200+ for item in work_items :
201+ sd_key , w_quant , inp_q_key = _process_weight ( item )
202+ state_dict [ sd_key ] = w_quant
203+ fakequant_weights . add ( sd_key )
204+ if inp_q_key is not None :
205+ input_quantizers_folded_pqs . add ( inp_q_key )
206+
207+ elapsed = time . monotonic () - t0
208+ logger . info (
209+ "Export step 1 (%s): %d weights fake-quantized in %.1fs" ,
210+ "parallel" if use_parallel else "sequential" ,
211+ len ( fakequant_weights ),
212+ elapsed ,
213+ )
109214
110215 # Filter quantizer tensors out for a clean HF checkpoint.
111216 clean_sd = {k : v for k , v in state_dict .items () if "quantizer" not in k }
@@ -166,4 +271,5 @@ def export_hf_vllm_fq_checkpoint(
166271
167272 for wq , orig_rotate in wqs_to_restore :
168273 wq .enable ()
169- wq ._rotate = orig_rotate
274+ if orig_rotate is not None :
275+ wq ._rotate = orig_rotate
0 commit comments