Skip to content

Commit d67d8fa

Browse files
committed
Add progress bar to the submodule finalize
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 73fecd7 commit d67d8fa

2 files changed

Lines changed: 249 additions & 57 deletions

File tree

gptqmodel/looper/weight_only_looper.py

Lines changed: 222 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import os
1718
import threading
1819
import time
1920
from concurrent.futures import as_completed
@@ -30,6 +31,7 @@
3031
from ..models._const import CPU, SUPPORTS_MODULE_TYPES
3132
from ..nn_modules.converter import MODULE_CONVERTER_MAP
3233
from ..quantization.config import BitsAndBytesConfig, FP8Config, GGUFConfig, RTNConfig, VramStrategy
34+
from ..utils import has_gil_disabled
3335
from ..utils.device import get_device
3436
from ..utils.device_telemetry import emit_device_telemetry
3537
from ..utils.logger import log_time_block, setup_logger
@@ -58,10 +60,7 @@ def __init__(self, model: BaseQModel, processor: WeightOnlyProcessor):
5860
self.processor = processor
5961
self._quant_devices = self._resolve_quant_devices()
6062
self._quant_device_rr = 0
61-
self._dense_device_rr = 0
62-
self._moe_group_device_rr = 0
6363
self._module_device_map: Dict[str, torch.device] = {}
64-
self._moe_group_device_map: Dict[str, torch.device] = {}
6564
self._quant_device_lock = threading.Lock()
6665
self._resolve_strategy_device_pools()
6766

@@ -171,38 +170,143 @@ def _extract_moe_group_key(module_name: Optional[str]) -> Optional[str]:
171170
return None
172171
return f"{prefix}.experts.{expert_id}"
173172

174-
def _select_strategy_device_for_module(self, named_module: NamedModule) -> Optional[torch.device]:
175-
"""Apply dense/MoE VRAM strategy before falling back to generic round-robin."""
173+
@staticmethod
174+
def _collect_assignable_moe_group_keys(moe_groups: Dict[str, List[str]]) -> List[str]:
175+
"""Return expert families that should stay co-located on one device."""
176176

177-
module_name = getattr(named_module, "name", None) or getattr(named_module, "full_name", None)
178-
moe_group_key = self._extract_moe_group_key(module_name) or self._extract_moe_group_key(
179-
getattr(named_module, "full_name", None)
180-
)
177+
assignable_group_keys: List[str] = []
178+
for group_key, module_names in moe_groups.items():
179+
suffixes = {name.rsplit(".", 1)[-1] for name in module_names}
180+
if {"gate_proj", "up_proj"}.issubset(suffixes) or {"w1", "w3"}.issubset(suffixes):
181+
assignable_group_keys.append(group_key)
182+
return assignable_group_keys
181183

182-
if moe_group_key is not None and self._moe_vram_strategy_explicit and self._moe_quant_devices:
183-
cached = self._moe_group_device_map.get(moe_group_key)
184-
if cached is not None:
185-
return cached
186-
# MoE BALANCED spreads expert families across the MoE pool; each
187-
# expert family keeps all projections on the same device.
188-
if self._moe_vram_strategy == VramStrategy.BALANCED and len(self._moe_quant_devices) > 1:
189-
device = self._moe_quant_devices[self._moe_group_device_rr % len(self._moe_quant_devices)]
190-
self._moe_group_device_rr += 1
191-
else:
192-
device = self._moe_quant_devices[0]
193-
self._moe_group_device_map[moe_group_key] = device
194-
return device
184+
@staticmethod
185+
def _normalize_planning_module_name(module_name: str) -> str:
186+
"""Strip model-tree annotations so planning blocks match live module names."""
195187

196-
if self._dense_vram_strategy_explicit and self._dense_quant_devices:
197-
# Dense EXCLUSIVE pins attention/router/shared/dense MLP work to
198-
# the first dense device; BALANCED round-robins dense modules.
199-
if self._dense_vram_strategy == VramStrategy.BALANCED and len(self._dense_quant_devices) > 1:
200-
device = self._dense_quant_devices[self._dense_device_rr % len(self._dense_quant_devices)]
201-
self._dense_device_rr += 1
202-
return device
203-
return self._dense_quant_devices[0]
188+
return module_name.split(":", 1)[0]
204189

205-
return None
190+
def _collect_dense_groups(
191+
self,
192+
layer_candidate_names: List[str],
193+
layer_moe_group_key_by_name: Dict[str, Optional[str]],
194+
planning_layer_modules: Optional[List[List[str]]],
195+
) -> Dict[str, List[str]]:
196+
"""Collect dense modules into model-tree-defined calculation groups."""
197+
198+
remaining_dense_names = [
199+
module_name
200+
for module_name in layer_candidate_names
201+
if layer_moe_group_key_by_name.get(module_name) is None
202+
]
203+
remaining_dense_set = set(remaining_dense_names)
204+
dense_groups: Dict[str, List[str]] = {}
205+
206+
if planning_layer_modules:
207+
for block_index, block in enumerate(planning_layer_modules):
208+
block_dense_names: List[str] = []
209+
block_seen = set()
210+
for block_entry in block:
211+
module_name = self._normalize_planning_module_name(block_entry)
212+
if module_name in block_seen or module_name not in remaining_dense_set:
213+
continue
214+
block_seen.add(module_name)
215+
if layer_moe_group_key_by_name.get(module_name) is not None:
216+
continue
217+
block_dense_names.append(module_name)
218+
219+
if block_dense_names:
220+
dense_groups[f"planning:{block_index}"] = block_dense_names
221+
for module_name in block_dense_names:
222+
remaining_dense_set.discard(module_name)
223+
224+
for module_name in remaining_dense_names:
225+
if module_name not in remaining_dense_set:
226+
continue
227+
dense_groups[module_name] = [module_name]
228+
remaining_dense_set.discard(module_name)
229+
230+
return dense_groups
231+
232+
def _build_layer_strategy_device_map(
233+
self,
234+
*,
235+
full: Dict[str, torch.nn.Module],
236+
planning_layer_modules: Optional[List[List[str]]],
237+
) -> Dict[str, torch.device]:
238+
"""Build the dense/MoE preferred-device map for one layer."""
239+
240+
dense_strategy_active = self._dense_vram_strategy_explicit
241+
moe_strategy_active = self._moe_vram_strategy_explicit
242+
if not dense_strategy_active and not moe_strategy_active:
243+
return {}
244+
245+
layer_candidate_names = list(full.keys())
246+
moe_group_key_by_name = {
247+
module_name: self._extract_moe_group_key(module_name)
248+
for module_name in layer_candidate_names
249+
}
250+
moe_groups: Dict[str, List[str]] = {}
251+
for module_name, group_key in moe_group_key_by_name.items():
252+
if group_key is not None:
253+
moe_groups.setdefault(group_key, []).append(module_name)
254+
255+
dense_groups = self._collect_dense_groups(
256+
layer_candidate_names,
257+
moe_group_key_by_name,
258+
planning_layer_modules,
259+
)
260+
preferred_devices: Dict[str, torch.device] = {}
261+
dense_devices = [
262+
device for device in self._dense_quant_devices
263+
if device is not None and getattr(device, "type", None) != "cpu"
264+
] or list(self._dense_quant_devices)
265+
moe_devices = [
266+
device for device in self._moe_quant_devices
267+
if device is not None and getattr(device, "type", None) != "cpu"
268+
] or list(self._moe_quant_devices)
269+
270+
if dense_strategy_active and dense_groups and dense_devices:
271+
dense_group_keys = list(dense_groups.keys())
272+
for group_index, group_key in enumerate(dense_group_keys):
273+
# Dense EXCLUSIVE pins the serial path to the first dense
274+
# device; BALANCED spreads model-tree calculation groups.
275+
target_device = (
276+
dense_devices[group_index % len(dense_devices)]
277+
if self._dense_vram_strategy == VramStrategy.BALANCED and len(dense_devices) > 1
278+
else dense_devices[0]
279+
)
280+
for module_name in dense_groups[group_key]:
281+
preferred_devices[module_name] = target_device
282+
283+
if moe_strategy_active and moe_groups and moe_devices:
284+
assignable_group_keys = self._collect_assignable_moe_group_keys(moe_groups)
285+
for group_index, group_key in enumerate(assignable_group_keys):
286+
# MoE BALANCED spreads expert families across the MoE pool;
287+
# every projection in one expert family stays co-located.
288+
target_device = (
289+
moe_devices[group_index % len(moe_devices)]
290+
if self._moe_vram_strategy == VramStrategy.BALANCED and len(moe_devices) > 1
291+
else moe_devices[0]
292+
)
293+
for module_name in moe_groups[group_key]:
294+
preferred_devices[module_name] = target_device
295+
296+
gil_env = os.environ.get("PYTHON_GIL")
297+
gil_disabled = has_gil_disabled()
298+
free_threaded_parallel_quant_eligible = bool(gil_disabled and len(self._moe_quant_devices) > 0)
299+
log.info(
300+
"ModuleLooper: MoE quant runtime dense_pool=%s moe_pool=%s "
301+
"PYTHON_GIL=%s gil_disabled=%s free_threaded_parallel_quant_eligible=%s",
302+
dense_devices,
303+
moe_devices,
304+
gil_env,
305+
gil_disabled,
306+
free_threaded_parallel_quant_eligible,
307+
)
308+
309+
return preferred_devices
206310

207311
def _assign_quant_device_for_module(self, named_module: NamedModule, fallback_device: torch.device) -> torch.device:
208312
"""Pick and memoize the quantization device for one named module."""
@@ -220,21 +324,14 @@ def _assign_quant_device_for_module(self, named_module: NamedModule, fallback_de
220324

221325
preferred_device = normalize_device_like(named_module.state.get("preferred_quant_device"))
222326
if preferred_device is not None and any(device == preferred_device for device in self._quant_devices):
327+
# Dense/MoE strategy placement is planned before this point,
328+
# matching ModuleLooper's preferred-device handoff.
223329
device = preferred_device
224330
emit_device_telemetry(
225331
"weight_only_quant_device_preferred_hint",
226332
module=key,
227333
target_device=device,
228334
)
229-
elif (strategy_device := self._select_strategy_device_for_module(named_module)) is not None:
230-
device = strategy_device
231-
emit_device_telemetry(
232-
"weight_only_quant_device_strategy",
233-
module=key,
234-
target_device=device,
235-
dense_strategy=self._dense_vram_strategy.value,
236-
moe_strategy=self._moe_vram_strategy.value,
237-
)
238335
elif len(self._quant_devices) <= 1:
239336
device = self._quant_devices[0]
240337
else:
@@ -402,29 +499,73 @@ def _finalize_subset_modules(
402499
}
403500
use_parallel_finalize = len(finalize_tasks) > 1 and len(unique_targets) > 1
404501

502+
finalize_count = len(finalize_tasks)
503+
finalize_pb = log.pb(range(finalize_count)).manual().set(show_left_steps=False)
504+
known_layers = sorted(
505+
{
506+
getattr(named, "layer_index", None)
507+
for named, _, _ in finalize_tasks
508+
if getattr(named, "layer_index", None) is not None
509+
}
510+
)
511+
includes_unknown = any(getattr(named, "layer_index", None) is None for named, _, _ in finalize_tasks)
512+
layer_heading = "Layer ?"
513+
if known_layers:
514+
sample_layers = ", ".join(str(idx) for idx in known_layers[:3])
515+
if len(known_layers) > 3:
516+
sample_layers += ", ..."
517+
suffix = ", ?" if includes_unknown else ""
518+
prefix = "Layer" if len(known_layers) == 1 else "Layers"
519+
layer_heading = f"{prefix} {sample_layers}{suffix}"
520+
elif includes_unknown:
521+
layer_heading = "Layer ?"
522+
523+
finalize_pb.title(
524+
f"{layer_heading} Submodule finalize 0/{finalize_count}"
525+
).subtitle("Waiting for completions...").draw()
526+
527+
completed = 0
528+
529+
def _advance_finalize_progress(named: NamedModule, module_label: str) -> None:
530+
nonlocal completed
531+
532+
completed += 1
533+
layer_idx = getattr(named, "layer_index", None)
534+
layer_label = f"Layer {layer_idx}" if layer_idx is not None else "Layer ?"
535+
finalize_pb.next()
536+
finalize_pb.title(
537+
f"{layer_label} Finalize {completed}/{finalize_count}"
538+
).subtitle(f"{self.processor.name()}: {module_label}").draw()
539+
405540
emit_device_telemetry(
406541
"weight_only_finalize_subset",
407542
module_count=len(finalize_tasks),
408543
target_devices=[target_device for _, _, target_device in finalize_tasks],
409544
parallel=use_parallel_finalize,
410545
)
411546

412-
if not use_parallel_finalize:
413-
for named, active_qcfg, _target_device in finalize_tasks:
414-
self._finalize_quantized_module(named, active_qcfg)
415-
return
547+
try:
548+
if not use_parallel_finalize:
549+
for named, active_qcfg, _target_device in finalize_tasks:
550+
module_label = self._finalize_quantized_module(named, active_qcfg)
551+
_advance_finalize_progress(named, module_label)
552+
return
416553

417-
futures = [
418-
DEVICE_THREAD_POOL.submit(
419-
target_device,
420-
self._finalize_quantized_module,
421-
named,
422-
active_qcfg,
423-
)
424-
for named, active_qcfg, target_device in finalize_tasks
425-
]
426-
for future in as_completed(futures):
427-
future.result()
554+
future_map = {
555+
DEVICE_THREAD_POOL.submit(
556+
target_device,
557+
self._finalize_quantized_module,
558+
named,
559+
active_qcfg,
560+
): named
561+
for named, active_qcfg, target_device in finalize_tasks
562+
}
563+
for future in as_completed(future_map):
564+
named = future_map[future]
565+
module_label = future.result()
566+
_advance_finalize_progress(named, module_label)
567+
finally:
568+
finalize_pb.close()
428569

429570
def _quantize_subset_modules(
430571
self,
@@ -582,6 +723,16 @@ def loop(self, **kwargs):
582723
is_awq_quantize=False,
583724
include_capture_only=False,
584725
)
726+
full_layer_modules = getattr(self.gptq_model, "full_layer_modules", None)
727+
if callable(full_layer_modules):
728+
planning_layer_modules = full_layer_modules(
729+
model_config=self.gptq_model.model.config,
730+
is_awq_quantize=False,
731+
include_capture_only=False,
732+
)
733+
else:
734+
planning_layer_modules = layer_modules
735+
585736
if not quant_config.true_sequential:
586737
layer_modules = [sum(layer_modules, [])]
587738

@@ -644,6 +795,11 @@ def loop(self, **kwargs):
644795
# transforms so quantization targets the final layer layout.
645796
materialize_model(module)
646797
full = find_modules(module, name=self.gptq_model.lm_head if is_lm_head_module else "")
798+
layer_strategy_modules = None if is_lm_head_module else planning_layer_modules
799+
layer_strategy_device_map = self._build_layer_strategy_device_map(
800+
full=full,
801+
planning_layer_modules=layer_strategy_modules,
802+
)
647803

648804
self.processor.collect_memory_info(layer_index)
649805
for subset_names in subsets:
@@ -660,6 +816,17 @@ def loop(self, **kwargs):
660816
if named is None:
661817
continue
662818

819+
preferred_device = layer_strategy_device_map.get(module_name)
820+
if preferred_device is not None:
821+
# Weight-only has no SubsetPlan, so store the same
822+
# preferred-device hint ModuleLooper would consume.
823+
named.state["preferred_quant_device"] = preferred_device
824+
emit_device_telemetry(
825+
"weight_only_strategy_preferred_device",
826+
module=named.full_name,
827+
target_device=preferred_device,
828+
)
829+
663830
if preprocessor is not None:
664831
preprocessor.preprocess(named)
665832
if isinstance(named.state.get("auto_module_decoder"), dict):

0 commit comments

Comments
 (0)