3333import torch
3434import torch .nn as nn
3535
36+ from modelopt .torch .utils import distributed as dist
3637from modelopt .torch .utils import print_rank_0
3738from modelopt .torch .utils .network import (
3839 bind_forward_method ,
@@ -77,6 +78,16 @@ def __init__(self, original: nn.Module):
7778 object .__setattr__ (self , "_original" , original )
7879 self ._layerwise_calib = _LayerCalibState (mode = "skip" )
7980
81+ def __getattr__ (self , name : str ):
82+ # Proxy non-special attribute lookups to the original layer so that
83+ # parent-model code that accesses layer-level attributes (e.g.,
84+ # NemotronH's ``block_type``) still works when the layer is replaced
85+ # with a _SkipLayer.
86+ try :
87+ return super ().__getattr__ (name )
88+ except AttributeError :
89+ return getattr (object .__getattribute__ (self , "_original" ), name )
90+
8091 def forward (self , * args , ** kwargs ):
8192 return LayerActivationCollector ._zeros_from_meta (
8293 self ._original ._layerwise_calib .output_meta
@@ -315,7 +326,13 @@ def _log_layer_summary(self, layer_idx: int):
315326 mode = layer ._layerwise_calib .mode
316327 if mode in ("skip" , "run" , "capture" ):
317328 groups .setdefault (mode , []).append (i + 1 )
318- parts = [f"{ mode } : { groups [mode ]} " for mode in ("skip" , "run" , "capture" ) if mode in groups ]
329+
330+ parts = []
331+ for mode in ("skip" , "run" , "capture" ):
332+ if mode not in groups :
333+ continue
334+ ids = groups [mode ]
335+ parts .append (f"{ mode } : { len (ids )} " if mode == "skip" else f"{ mode } : { ids } " )
319336 print_rank_0 (f"Calibrating layer { layer_idx + 1 } /{ n } | { ' | ' .join (parts )} " )
320337
321338 @torch .no_grad ()
@@ -489,6 +506,42 @@ def _save_layer(
489506 _write_manifest (checkpoint_dir , idx , num_layers )
490507
491508
509+ def needs_checkpoint_path_update (quant_cfg : dict ) -> bool :
510+ """Check if quant_cfg has a checkpoint_dir that should be auto-resolved to a unique subpath."""
511+ algorithm = quant_cfg .get ("algorithm" )
512+ if algorithm is None or isinstance (algorithm , str ):
513+ return False
514+ return algorithm .get ("checkpoint_dir" ) is not None
515+
516+
517+ def resolve_checkpoint_dir (quant_cfg : dict , model_path : str ) -> dict :
518+ """Append a unique ``<model_name>_<config_hash>`` subdirectory to checkpoint_dir.
519+
520+ Allows a single recipe to be reused across models without checkpoint collisions.
521+ Must only be called when :func:`needs_checkpoint_path_update` returns True.
522+ """
523+ import copy
524+ import hashlib
525+ from pathlib import Path
526+
527+ algorithm = quant_cfg ["algorithm" ]
528+ base_dir = algorithm ["checkpoint_dir" ]
529+
530+ name = model_path .rstrip ("/" )
531+ if "/" in name and not os .path .isabs (name ):
532+ name = name .replace ("/" , "--" )
533+ else :
534+ name = Path (name ).name
535+
536+ config_hash = hashlib .sha256 (
537+ json .dumps (quant_cfg , sort_keys = True , default = str ).encode ()
538+ ).hexdigest ()[:8 ]
539+
540+ quant_cfg = copy .deepcopy (quant_cfg )
541+ quant_cfg ["algorithm" ]["checkpoint_dir" ] = os .path .join (base_dir , f"{ name } _{ config_hash } " )
542+ return quant_cfg
543+
544+
492545def detect_resume_point (checkpoint_dir : str ) -> tuple [int , dict ] | None :
493546 """Detect where to resume from an existing checkpoint directory.
494547
@@ -512,9 +565,21 @@ class _CheckpointState:
512565
513566 Handles both saving per-layer checkpoints during calibration and
514567 restoring from a previous partial run.
568+
569+ .. todo::
570+ Support distributed checkpoint save/restore for FSDP2:
571+ use ``torch.distributed.checkpoint`` (or save only from rank 0 + barrier)
572+ and broadcast restored state to all ranks during resume.
515573 """
516574
517575 def __init__ (self , checkpoint_dir : str , num_layers : int , start_layer : int = 0 ):
576+ if dist .is_initialized () and dist .size () > 1 :
577+ raise RuntimeError (
578+ "Layerwise calibration checkpointing is not supported in "
579+ "multi-process distributed jobs (e.g. FSDP2). "
580+ "Use single-process calibration or disable checkpointing."
581+ )
582+
518583 self .checkpoint_dir = checkpoint_dir
519584 self .num_layers = num_layers
520585 self .start_layer = start_layer
0 commit comments