2828from tqdm import tqdm
2929
3030from modelopt .torch .opt .searcher import ForwardLoop
31- from modelopt .torch .quantization .utils .activation_collector import LayerActivationCollector
31+ from modelopt .torch .quantization .utils .layerwise_calib import LayerActivationCollector
3232from modelopt .torch .utils import print_rank_0
3333from modelopt .torch .utils .distributed import DistributedProcessGroup , ParallelState
3434from modelopt .torch .utils .network import bind_forward_method , unpatch_forward_method
@@ -1563,7 +1563,15 @@ def sequential_calibrate(
15631563 Runs the full model forward per layer but patches decoder layers with a
15641564 skip / run / capture strategy so that inter-layer logic in parent modules
15651565 (e.g. mask construction) executes naturally without model-specific hooks.
1566+
1567+ If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints
1568+ are saved after each layer completes. On restart, calibration resumes from
1569+ the last completed layer.
15661570 """
1571+ from modelopt .torch .quantization .utils .layerwise_calib import _CheckpointState
1572+
1573+ checkpoint_dir = calib_kwargs .pop ("checkpoint_dir" , None )
1574+
15671575 if forward_loop is None :
15681576 raise ValueError (
15691577 "forward_loop must not be None for sequential calibration. "
@@ -1577,30 +1585,77 @@ def sequential_calibrate(
15771585 "Sequential calibration requires a model with identifiable transformer layers."
15781586 )
15791587
1580- print_rank_0 (f"Sequential calibration: Found { len (transformer_layers )} transformer layers" )
1588+ num_layers = len (transformer_layers )
1589+ print_rank_0 (f"Sequential calibration: Found { num_layers } transformer layers" )
1590+
1591+ ckpt = _CheckpointState .from_folder (checkpoint_dir , num_layers )
1592+ start_layer = ckpt .start_layer if ckpt else 0
15811593
15821594 input_getter = LayerActivationCollector (model )
15831595 input_getter ._patch_all_layers (decoder_layers = transformer_layers )
15841596
1597+ resumed_inputs = ckpt .setup_resume (transformer_layers ) if ckpt and start_layer > 0 else None
1598+
15851599 try :
1586- for layer_idx , layer in enumerate (transformer_layers ):
1587- print_rank_0 (f"Calibrating layer { layer_idx + 1 } /{ len (transformer_layers )} " )
1588- layer_inputs = input_getter .get_input_activations (layer , forward_loop )
1600+ for layer_idx , layer in enumerate (list (transformer_layers )):
1601+ if layer_idx < start_layer :
1602+ continue
1603+
1604+ layer_inputs = _get_layer_inputs (
1605+ layer_idx , start_layer , resumed_inputs , layer , input_getter , forward_loop
1606+ )
1607+ if ckpt :
1608+ ckpt .save_prev (transformer_layers , layer_inputs )
15891609
15901610 def _layer_forward_loop (m , _inputs = layer_inputs ):
15911611 for args , kwargs_input in _inputs :
15921612 m (* args , ** kwargs_input )
15931613
15941614 calib_func (layer , _layer_forward_loop , ** calib_kwargs )
15951615
1616+ if ckpt :
1617+ ckpt .stash (layer_idx , layer , model )
1618+
15961619 del layer_inputs
15971620 torch .cuda .empty_cache ()
1621+
1622+ if ckpt :
1623+ ckpt .save_last (transformer_layers )
15981624 finally :
15991625 input_getter ._unpatch_all_layers ()
16001626
1627+ if ckpt :
1628+ ckpt .full_restore (transformer_layers , model )
1629+
16011630 print_rank_0 ("Sequential calibration completed" )
16021631
16031632
1633+ def _get_layer_inputs (
1634+ layer_idx : int ,
1635+ start_layer : int ,
1636+ resumed_inputs : list | None ,
1637+ layer : nn .Module ,
1638+ input_getter : LayerActivationCollector ,
1639+ forward_loop : ForwardLoop ,
1640+ ) -> list :
1641+ """Get inputs for a layer, using resumed_inputs for the first resumed layer."""
1642+ if layer_idx == start_layer and resumed_inputs is not None :
1643+ print_rank_0 (f"Calibrating layer { layer_idx + 1 } (resumed)" )
1644+ # Manually set skip mode on all already-calibrated layers (output_meta
1645+ # was loaded by setup_resume). Don't call _set_layer_states which
1646+ # assumes the normal sequential progression with collected_inputs.
1647+ assert input_getter ._decoder_layers is not None
1648+ for i in range (start_layer ):
1649+ input_getter ._swap_to_dummy (i )
1650+ # Seed collected_inputs so the next _set_layer_states call can
1651+ # transition this layer to "run" mode.
1652+ layer ._seq_calib .collected_inputs = resumed_inputs
1653+ layer ._seq_calib .mode = "original"
1654+ return resumed_inputs
1655+
1656+ return input_getter .get_input_activations (layer , forward_loop )
1657+
1658+
16041659@torch .no_grad ()
16051660def gptq (
16061661 model : nn .Module ,
@@ -1663,8 +1718,10 @@ def gptq(
16631718 handle .cleanup ()
16641719
16651720 print_rank_0 ("Updating weights using GPTQ algorithm..." )
1721+ name_to_module = dict (model .named_modules ())
16661722 for handle in gptq_handles .values ():
1667- handle .update_weights (block_size , perc_damp )
1723+ with enable_weight_access_and_writeback (handle .module , model , name_to_module ):
1724+ handle .update_weights (block_size , perc_damp )
16681725 handle .free ()
16691726 del gptq_handles
16701727
0 commit comments