3636from ..utils .device import get_device , get_device_new
3737from ..utils .looper_helpers import normalize_device_like
3838from ..utils .logger import live_renderables_suppressed , log_time_block , setup_logger
39- from ..utils .model import find_modules , get_module
39+ from ..utils .model import find_modules , get_layer_name , get_module
4040from ..utils .offload import offload_to_disk
4141from ..utils .torch import CPU , torch_empty_cache , torch_sync
4242from .stage_subset import SubsetPlan , build_layer_subset_plans , run_subset_stage
@@ -49,11 +49,11 @@ def _find_last_quantized_layer_index(
4949 looper : "ModuleLooper" ,
5050 * ,
5151 layer_modules : List [List [str ]],
52- layers_prefix : Optional [str ],
52+ layer_names : Optional [List [ str ] ],
5353 layer_count : int ,
5454) -> Optional [int ]:
5555 """Return the highest layer index whose tracked modules are not all dynamically skipped."""
56- if looper .gptq_model .quantize_config .lm_head or not layers_prefix :
56+ if looper .gptq_model .quantize_config .lm_head or not layer_names :
5757 return None
5858
5959 layer_module_names = {
@@ -67,8 +67,9 @@ def _find_last_quantized_layer_index(
6767
6868 last_quantized_layer_index = - 1
6969 for candidate_layer_index in range (layer_count ):
70+ layer_name = get_layer_name (layer_names , candidate_layer_index )
7071 for module_name in layer_module_names :
71- module_full_name = f"{ layers_prefix } . { candidate_layer_index } .{ module_name } "
72+ module_full_name = f"{ layer_name } .{ module_name } "
7273 # If at least one module in this layer is not dynamically excluded,
7374 # the layer still needs forward/quantization work.
7475 if looper .gptq_model .quantize_config .dynamic_get (layer_name = module_full_name ) != False :
@@ -387,7 +388,7 @@ def run_layer_stage(
387388 layers : List [torch .nn .Module ],
388389 layer_modules : List [List [str ]],
389390 planning_layer_modules : List [List [str ]],
390- layers_prefix : Optional [str ],
391+ layer_names : Optional [List [ str ] ],
391392 fallback ,
392393 shared_kv_cache_dict : Dict [int , torch .Tensor ],
393394 pb ,
@@ -403,7 +404,7 @@ def run_layer_stage(
403404 last_quantized_layer_index = _find_last_quantized_layer_index (
404405 looper ,
405406 layer_modules = layer_modules ,
406- layers_prefix = layers_prefix ,
407+ layer_names = layer_names ,
407408 layer_count = layer_count ,
408409 )
409410
@@ -436,10 +437,12 @@ def run_layer_stage(
436437 layer_title = "Quantizing lm_head"
437438 module = get_module (looper .gptq_model .model , key = looper .gptq_model .lm_head )
438439 pristine_group_module = None
440+ layer_name = ""
439441 else :
440442 layer_title = f"Quantizing layer { layer_index } of { layer_count - 1 } "
441443 module = layers [layer_index ]
442444 pristine_group_module = None
445+ layer_name = get_layer_name (layer_names , layer_index )
443446
444447 pb .title (layer_title ).subtitle ("" ).draw ()
445448 if durable_progress_logs :
@@ -483,8 +486,8 @@ def run_layer_stage(
483486
484487 layers [layer_index ] = module
485488
486- if layers_prefix :
487- layer_descriptor = f" { layers_prefix } . { layer_index } "
489+ if layer_name :
490+ layer_descriptor = layer_name
488491 else :
489492 layer_descriptor = str (layer_index )
490493
@@ -530,7 +533,7 @@ def run_layer_stage(
530533 full = full ,
531534 is_lm_head_module = is_lm_head_module ,
532535 layer_index = layer_index ,
533- layers_prefix = layers_prefix ,
536+ layers_prefix = layer_name ,
534537 fallback = fallback ,
535538 )
536539 if durable_progress_logs :
0 commit comments