@@ -227,22 +227,20 @@ def _patched_forward(self, *args, **kwargs):
227227 f"Layer { info .name } is in 'run' mode but has no cached inputs to replay."
228228 )
229229 real_args , real_kwargs = info .cached_inputs .popleft ()
230- if (
231- real_args
232- and isinstance (real_args [0 ], torch .Tensor )
233- and real_args [0 ].device .type == "cpu"
234- ):
235- device = get_module_device (self )
236- real_args = _move_to_device (real_args , device )
237- real_kwargs = _move_to_device (real_kwargs , device )
230+ # Captured inputs are stored on CPU (see "capture" branch); move
231+ # back to the layer's device for replay. `_move_to_device` is a
232+ # no-op for tensors already on `device`.
233+ device = get_module_device (self )
234+ real_args = _move_to_device (real_args , device )
235+ real_kwargs = _move_to_device (real_kwargs , device )
238236 output = self ._original_forward (* real_args , ** real_kwargs )
239237 info .output_meta = LayerActivationCollector ._extract_output_meta (output )
240238 return output
241239
242240 if info .mode == "capture" :
243241 # Offload captured inputs to CPU at append time. For early layers
244- # on a single GPU (e.g. layer 0– 2 on GPU 0 with seq_device_map),
245- # accumulating thousands of batches' worth of (bs × seq × hidden)
242+ # on a single GPU (e.g. layer 0- 2 on GPU 0 with seq_device_map),
243+ # accumulating thousands of batches' worth of (bs x seq x hidden)
246244 # activations on-device saturates that GPU during the capture loop
247245 # and OOMs before _set_layer_states gets a chance to move them.
248246 # The "run" branch already handles CPU-resident inputs (see the
@@ -333,11 +331,8 @@ def _set_layer_states(self, layer_idx: int):
333331 "was called for every preceding layer in order."
334332 )
335333 prev .mode = "run"
336- cpu = torch .device ("cpu" )
337- prev .cached_inputs = deque (
338- (_move_to_device (args , cpu ), _move_to_device (kwargs , cpu ))
339- for args , kwargs in prev .collected_inputs
340- )
334+ # Inputs are already CPU-resident at capture time (see _patched_forward).
335+ prev .cached_inputs = deque (prev .collected_inputs )
341336 prev .collected_inputs = []
342337
343338 cur = self ._decoder_layers [layer_idx ]._layerwise_calib
@@ -534,9 +529,6 @@ def _save_layer(
534529 torch .save (output_meta , os .path .join (d , "output_meta.pt" ))
535530 if next_inputs is not None :
536531 torch .save (next_inputs , os .path .join (d , "next_inputs.pt" ))
537- amax_state = {k : v for k , v in weights .items () if "_amax" in k }
538- if amax_state :
539- torch .save (amax_state , os .path .join (d , "quantizer_amaxes.pt" ))
540532 _write_manifest (checkpoint_dir , idx , num_layers )
541533
542534
@@ -635,17 +627,8 @@ def setup_resume(self, layers: nn.ModuleList) -> list | None:
635627 # Keep on CPU — _patched_forward's run mode moves each entry to device on pop.
636628 return next_inputs
637629
638- def full_restore (
639- self , layers : nn .ModuleList , model : nn .Module , restore_weights : bool = True
640- ) -> None :
641- """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.
642-
643- Args:
644- restore_weights: If False, skip reloading ``weights.pt`` and load only the
645- ``_amax`` values (from ``quantizer_amaxes.pt`` or filtered from ``weights.pt``).
646- Set to False for calibration algorithms (max, MSE) that never modify weights
647- to avoid re-reading gigabytes of unchanged expert weights from disk.
648- """
630+ def full_restore (self , layers : nn .ModuleList , model : nn .Module ) -> None :
631+ """Restore weights and quantizer state for layers 0..K-1 after the calibration loop."""
649632 from modelopt .torch .quantization .config import QuantizeConfig
650633 from modelopt .torch .quantization .conversion import restore_quantizer_state
651634 from modelopt .torch .quantization .utils .core_utils import enable_weight_access_and_writeback
@@ -671,31 +654,13 @@ def full_restore(
671654 map_location = "cpu" ,
672655 weights_only = False ,
673656 )
657+ weights = torch .load (
658+ os .path .join (d , "weights.pt" ),
659+ map_location = "cpu" ,
660+ weights_only = False ,
661+ )
674662 restore_quantizer_state (layer , dummy_config , {"quantizer_state" : qstate })
675- if restore_weights :
676- weights = torch .load (
677- os .path .join (d , "weights.pt" ),
678- map_location = "cpu" ,
679- weights_only = False ,
680- )
681- layer .load_state_dict (weights , strict = False , assign = False )
682- else :
683- # Load only _amax entries — skip gigabytes of unchanged expert weights.
684- # Use map_location="cpu" to get fresh CPU tensors (no storage_offset).
685- # _export_fused_experts moves _amax to the weight device on demand.
686- amax_path = os .path .join (d , "quantizer_amaxes.pt" )
687- if os .path .exists (amax_path ):
688- amaxes = torch .load (amax_path , map_location = "cpu" , weights_only = False )
689- else :
690- # Legacy checkpoint: filter _amax entries from the full weights.pt.
691- weights = torch .load (
692- os .path .join (d , "weights.pt" ),
693- map_location = "cpu" ,
694- weights_only = False ,
695- )
696- amaxes = {k : v for k , v in weights .items () if "_amax" in k }
697- if amaxes :
698- layer .load_state_dict (amaxes , strict = False , assign = True )
663+ layer .load_state_dict (weights , strict = False , assign = False )
699664
700665 print_rank_0 (f"Checkpoint: restored { self .start_layer } previously calibrated layers" )
701666
0 commit comments