1414
1515from __future__ import annotations
1616
17+ import os
1718import threading
1819import time
1920from concurrent .futures import as_completed
3031from ..models ._const import CPU , SUPPORTS_MODULE_TYPES
3132from ..nn_modules .converter import MODULE_CONVERTER_MAP
3233from ..quantization .config import BitsAndBytesConfig , FP8Config , GGUFConfig , RTNConfig , VramStrategy
34+ from ..utils import has_gil_disabled
3335from ..utils .device import get_device
3436from ..utils .device_telemetry import emit_device_telemetry
3537from ..utils .logger import log_time_block , setup_logger
@@ -58,10 +60,7 @@ def __init__(self, model: BaseQModel, processor: WeightOnlyProcessor):
5860 self .processor = processor
5961 self ._quant_devices = self ._resolve_quant_devices ()
6062 self ._quant_device_rr = 0
61- self ._dense_device_rr = 0
62- self ._moe_group_device_rr = 0
6363 self ._module_device_map : Dict [str , torch .device ] = {}
64- self ._moe_group_device_map : Dict [str , torch .device ] = {}
6564 self ._quant_device_lock = threading .Lock ()
6665 self ._resolve_strategy_device_pools ()
6766
@@ -171,38 +170,143 @@ def _extract_moe_group_key(module_name: Optional[str]) -> Optional[str]:
171170 return None
172171 return f"{ prefix } .experts.{ expert_id } "
173172
174- def _select_strategy_device_for_module (self , named_module : NamedModule ) -> Optional [torch .device ]:
175- """Apply dense/MoE VRAM strategy before falling back to generic round-robin."""
173+ @staticmethod
174+ def _collect_assignable_moe_group_keys (moe_groups : Dict [str , List [str ]]) -> List [str ]:
175+ """Return expert families that should stay co-located on one device."""
176176
177- module_name = getattr (named_module , "name" , None ) or getattr (named_module , "full_name" , None )
178- moe_group_key = self ._extract_moe_group_key (module_name ) or self ._extract_moe_group_key (
179- getattr (named_module , "full_name" , None )
180- )
177+ assignable_group_keys : List [str ] = []
178+ for group_key , module_names in moe_groups .items ():
179+ suffixes = {name .rsplit ("." , 1 )[- 1 ] for name in module_names }
180+ if {"gate_proj" , "up_proj" }.issubset (suffixes ) or {"w1" , "w3" }.issubset (suffixes ):
181+ assignable_group_keys .append (group_key )
182+ return assignable_group_keys
181183
182- if moe_group_key is not None and self ._moe_vram_strategy_explicit and self ._moe_quant_devices :
183- cached = self ._moe_group_device_map .get (moe_group_key )
184- if cached is not None :
185- return cached
186- # MoE BALANCED spreads expert families across the MoE pool; each
187- # expert family keeps all projections on the same device.
188- if self ._moe_vram_strategy == VramStrategy .BALANCED and len (self ._moe_quant_devices ) > 1 :
189- device = self ._moe_quant_devices [self ._moe_group_device_rr % len (self ._moe_quant_devices )]
190- self ._moe_group_device_rr += 1
191- else :
192- device = self ._moe_quant_devices [0 ]
193- self ._moe_group_device_map [moe_group_key ] = device
194- return device
184+ @staticmethod
185+ def _normalize_planning_module_name (module_name : str ) -> str :
186+ """Strip model-tree annotations so planning blocks match live module names."""
195187
196- if self ._dense_vram_strategy_explicit and self ._dense_quant_devices :
197- # Dense EXCLUSIVE pins attention/router/shared/dense MLP work to
198- # the first dense device; BALANCED round-robins dense modules.
199- if self ._dense_vram_strategy == VramStrategy .BALANCED and len (self ._dense_quant_devices ) > 1 :
200- device = self ._dense_quant_devices [self ._dense_device_rr % len (self ._dense_quant_devices )]
201- self ._dense_device_rr += 1
202- return device
203- return self ._dense_quant_devices [0 ]
188+ return module_name .split (":" , 1 )[0 ]
204189
205- return None
190+ def _collect_dense_groups (
191+ self ,
192+ layer_candidate_names : List [str ],
193+ layer_moe_group_key_by_name : Dict [str , Optional [str ]],
194+ planning_layer_modules : Optional [List [List [str ]]],
195+ ) -> Dict [str , List [str ]]:
196+ """Collect dense modules into model-tree-defined calculation groups."""
197+
198+ remaining_dense_names = [
199+ module_name
200+ for module_name in layer_candidate_names
201+ if layer_moe_group_key_by_name .get (module_name ) is None
202+ ]
203+ remaining_dense_set = set (remaining_dense_names )
204+ dense_groups : Dict [str , List [str ]] = {}
205+
206+ if planning_layer_modules :
207+ for block_index , block in enumerate (planning_layer_modules ):
208+ block_dense_names : List [str ] = []
209+ block_seen = set ()
210+ for block_entry in block :
211+ module_name = self ._normalize_planning_module_name (block_entry )
212+ if module_name in block_seen or module_name not in remaining_dense_set :
213+ continue
214+ block_seen .add (module_name )
215+ if layer_moe_group_key_by_name .get (module_name ) is not None :
216+ continue
217+ block_dense_names .append (module_name )
218+
219+ if block_dense_names :
220+ dense_groups [f"planning:{ block_index } " ] = block_dense_names
221+ for module_name in block_dense_names :
222+ remaining_dense_set .discard (module_name )
223+
224+ for module_name in remaining_dense_names :
225+ if module_name not in remaining_dense_set :
226+ continue
227+ dense_groups [module_name ] = [module_name ]
228+ remaining_dense_set .discard (module_name )
229+
230+ return dense_groups
231+
232+ def _build_layer_strategy_device_map (
233+ self ,
234+ * ,
235+ full : Dict [str , torch .nn .Module ],
236+ planning_layer_modules : Optional [List [List [str ]]],
237+ ) -> Dict [str , torch .device ]:
238+ """Build the dense/MoE preferred-device map for one layer."""
239+
240+ dense_strategy_active = self ._dense_vram_strategy_explicit
241+ moe_strategy_active = self ._moe_vram_strategy_explicit
242+ if not dense_strategy_active and not moe_strategy_active :
243+ return {}
244+
245+ layer_candidate_names = list (full .keys ())
246+ moe_group_key_by_name = {
247+ module_name : self ._extract_moe_group_key (module_name )
248+ for module_name in layer_candidate_names
249+ }
250+ moe_groups : Dict [str , List [str ]] = {}
251+ for module_name , group_key in moe_group_key_by_name .items ():
252+ if group_key is not None :
253+ moe_groups .setdefault (group_key , []).append (module_name )
254+
255+ dense_groups = self ._collect_dense_groups (
256+ layer_candidate_names ,
257+ moe_group_key_by_name ,
258+ planning_layer_modules ,
259+ )
260+ preferred_devices : Dict [str , torch .device ] = {}
261+ dense_devices = [
262+ device for device in self ._dense_quant_devices
263+ if device is not None and getattr (device , "type" , None ) != "cpu"
264+ ] or list (self ._dense_quant_devices )
265+ moe_devices = [
266+ device for device in self ._moe_quant_devices
267+ if device is not None and getattr (device , "type" , None ) != "cpu"
268+ ] or list (self ._moe_quant_devices )
269+
270+ if dense_strategy_active and dense_groups and dense_devices :
271+ dense_group_keys = list (dense_groups .keys ())
272+ for group_index , group_key in enumerate (dense_group_keys ):
273+ # Dense EXCLUSIVE pins the serial path to the first dense
274+ # device; BALANCED spreads model-tree calculation groups.
275+ target_device = (
276+ dense_devices [group_index % len (dense_devices )]
277+ if self ._dense_vram_strategy == VramStrategy .BALANCED and len (dense_devices ) > 1
278+ else dense_devices [0 ]
279+ )
280+ for module_name in dense_groups [group_key ]:
281+ preferred_devices [module_name ] = target_device
282+
283+ if moe_strategy_active and moe_groups and moe_devices :
284+ assignable_group_keys = self ._collect_assignable_moe_group_keys (moe_groups )
285+ for group_index , group_key in enumerate (assignable_group_keys ):
286+ # MoE BALANCED spreads expert families across the MoE pool;
287+ # every projection in one expert family stays co-located.
288+ target_device = (
289+ moe_devices [group_index % len (moe_devices )]
290+ if self ._moe_vram_strategy == VramStrategy .BALANCED and len (moe_devices ) > 1
291+ else moe_devices [0 ]
292+ )
293+ for module_name in moe_groups [group_key ]:
294+ preferred_devices [module_name ] = target_device
295+
296+ gil_env = os .environ .get ("PYTHON_GIL" )
297+ gil_disabled = has_gil_disabled ()
298+ free_threaded_parallel_quant_eligible = bool (gil_disabled and len (self ._moe_quant_devices ) > 0 )
299+ log .info (
300+ "ModuleLooper: MoE quant runtime dense_pool=%s moe_pool=%s "
301+ "PYTHON_GIL=%s gil_disabled=%s free_threaded_parallel_quant_eligible=%s" ,
302+ dense_devices ,
303+ moe_devices ,
304+ gil_env ,
305+ gil_disabled ,
306+ free_threaded_parallel_quant_eligible ,
307+ )
308+
309+ return preferred_devices
206310
207311 def _assign_quant_device_for_module (self , named_module : NamedModule , fallback_device : torch .device ) -> torch .device :
208312 """Pick and memoize the quantization device for one named module."""
@@ -220,21 +324,14 @@ def _assign_quant_device_for_module(self, named_module: NamedModule, fallback_de
220324
221325 preferred_device = normalize_device_like (named_module .state .get ("preferred_quant_device" ))
222326 if preferred_device is not None and any (device == preferred_device for device in self ._quant_devices ):
327+ # Dense/MoE strategy placement is planned before this point,
328+ # matching ModuleLooper's preferred-device handoff.
223329 device = preferred_device
224330 emit_device_telemetry (
225331 "weight_only_quant_device_preferred_hint" ,
226332 module = key ,
227333 target_device = device ,
228334 )
229- elif (strategy_device := self ._select_strategy_device_for_module (named_module )) is not None :
230- device = strategy_device
231- emit_device_telemetry (
232- "weight_only_quant_device_strategy" ,
233- module = key ,
234- target_device = device ,
235- dense_strategy = self ._dense_vram_strategy .value ,
236- moe_strategy = self ._moe_vram_strategy .value ,
237- )
238335 elif len (self ._quant_devices ) <= 1 :
239336 device = self ._quant_devices [0 ]
240337 else :
@@ -402,29 +499,73 @@ def _finalize_subset_modules(
402499 }
403500 use_parallel_finalize = len (finalize_tasks ) > 1 and len (unique_targets ) > 1
404501
502+ finalize_count = len (finalize_tasks )
503+ finalize_pb = log .pb (range (finalize_count )).manual ().set (show_left_steps = False )
504+ known_layers = sorted (
505+ {
506+ getattr (named , "layer_index" , None )
507+ for named , _ , _ in finalize_tasks
508+ if getattr (named , "layer_index" , None ) is not None
509+ }
510+ )
511+ includes_unknown = any (getattr (named , "layer_index" , None ) is None for named , _ , _ in finalize_tasks )
512+ layer_heading = "Layer ?"
513+ if known_layers :
514+ sample_layers = ", " .join (str (idx ) for idx in known_layers [:3 ])
515+ if len (known_layers ) > 3 :
516+ sample_layers += ", ..."
517+ suffix = ", ?" if includes_unknown else ""
518+ prefix = "Layer" if len (known_layers ) == 1 else "Layers"
519+ layer_heading = f"{ prefix } { sample_layers } { suffix } "
520+ elif includes_unknown :
521+ layer_heading = "Layer ?"
522+
523+ finalize_pb .title (
524+ f"{ layer_heading } Submodule finalize 0/{ finalize_count } "
525+ ).subtitle ("Waiting for completions..." ).draw ()
526+
527+ completed = 0
528+
529+ def _advance_finalize_progress (named : NamedModule , module_label : str ) -> None :
530+ nonlocal completed
531+
532+ completed += 1
533+ layer_idx = getattr (named , "layer_index" , None )
534+ layer_label = f"Layer { layer_idx } " if layer_idx is not None else "Layer ?"
535+ finalize_pb .next ()
536+ finalize_pb .title (
537+ f"{ layer_label } Finalize { completed } /{ finalize_count } "
538+ ).subtitle (f"{ self .processor .name ()} : { module_label } " ).draw ()
539+
405540 emit_device_telemetry (
406541 "weight_only_finalize_subset" ,
407542 module_count = len (finalize_tasks ),
408543 target_devices = [target_device for _ , _ , target_device in finalize_tasks ],
409544 parallel = use_parallel_finalize ,
410545 )
411546
412- if not use_parallel_finalize :
413- for named , active_qcfg , _target_device in finalize_tasks :
414- self ._finalize_quantized_module (named , active_qcfg )
415- return
547+ try :
548+ if not use_parallel_finalize :
549+ for named , active_qcfg , _target_device in finalize_tasks :
550+ module_label = self ._finalize_quantized_module (named , active_qcfg )
551+ _advance_finalize_progress (named , module_label )
552+ return
416553
417- futures = [
418- DEVICE_THREAD_POOL .submit (
419- target_device ,
420- self ._finalize_quantized_module ,
421- named ,
422- active_qcfg ,
423- )
424- for named , active_qcfg , target_device in finalize_tasks
425- ]
426- for future in as_completed (futures ):
427- future .result ()
554+ future_map = {
555+ DEVICE_THREAD_POOL .submit (
556+ target_device ,
557+ self ._finalize_quantized_module ,
558+ named ,
559+ active_qcfg ,
560+ ): named
561+ for named , active_qcfg , target_device in finalize_tasks
562+ }
563+ for future in as_completed (future_map ):
564+ named = future_map [future ]
565+ module_label = future .result ()
566+ _advance_finalize_progress (named , module_label )
567+ finally :
568+ finalize_pb .close ()
428569
429570 def _quantize_subset_modules (
430571 self ,
@@ -582,6 +723,16 @@ def loop(self, **kwargs):
582723 is_awq_quantize = False ,
583724 include_capture_only = False ,
584725 )
726+ full_layer_modules = getattr (self .gptq_model , "full_layer_modules" , None )
727+ if callable (full_layer_modules ):
728+ planning_layer_modules = full_layer_modules (
729+ model_config = self .gptq_model .model .config ,
730+ is_awq_quantize = False ,
731+ include_capture_only = False ,
732+ )
733+ else :
734+ planning_layer_modules = layer_modules
735+
585736 if not quant_config .true_sequential :
586737 layer_modules = [sum (layer_modules , [])]
587738
@@ -644,6 +795,11 @@ def loop(self, **kwargs):
644795 # transforms so quantization targets the final layer layout.
645796 materialize_model (module )
646797 full = find_modules (module , name = self .gptq_model .lm_head if is_lm_head_module else "" )
798+ layer_strategy_modules = None if is_lm_head_module else planning_layer_modules
799+ layer_strategy_device_map = self ._build_layer_strategy_device_map (
800+ full = full ,
801+ planning_layer_modules = layer_strategy_modules ,
802+ )
647803
648804 self .processor .collect_memory_info (layer_index )
649805 for subset_names in subsets :
@@ -660,6 +816,17 @@ def loop(self, **kwargs):
660816 if named is None :
661817 continue
662818
819+ preferred_device = layer_strategy_device_map .get (module_name )
820+ if preferred_device is not None :
821+ # Weight-only has no SubsetPlan, so store the same
822+ # preferred-device hint ModuleLooper would consume.
823+ named .state ["preferred_quant_device" ] = preferred_device
824+ emit_device_telemetry (
825+ "weight_only_strategy_preferred_device" ,
826+ module = named .full_name ,
827+ target_device = preferred_device ,
828+ )
829+
663830 if preprocessor is not None :
664831 preprocessor .preprocess (named )
665832 if isinstance (named .state .get ("auto_module_decoder" ), dict ):
0 commit comments