@@ -813,7 +813,11 @@ def update_quant_cfg_with_kv_cache_quant(
813813 return quant_cfg
814814
815815
816- class LayerActivationGettr :
816+ class _EarlyStopForwardError (Exception ):
817+ """Error to stop the forward pass after collection."""
818+
819+
820+ class LayerActivationCollector :
817821 """Helper class for collecting layer activations during forward passes.
818822
819823 This class allows for sequential layer calibration by
@@ -825,53 +829,44 @@ def __init__(self, model: nn.Module):
825829
826830 @staticmethod
827831 def _patch_and_initialize_layer (layer : torch .nn .Module , stop_after_collection : bool = False ):
828- """Patch a layer to collect inputs and outputs during forward passes."""
832+ """Patch a layer to collect inputs during forward passes."""
829833
830834 def _forward_w_data_collection (self , * args , ** kwargs ):
831- """Custom forward that collects inputs and outputs.
832-
833- Note: 'self' refers to the patched layer.
834- """
835- assert len (args ) >= 1
835+ # Note: 'self' refers to the patched layer.
836+ assert len (args ) >= 1 , (
837+ f"Expected at least 1 positional arg, got { len (args )} args and { list (kwargs .keys ())} kwargs"
838+ )
836839 # Only collect the inputs to the layer
837840 self .inputs .append ((args , kwargs ))
838- if getattr ( self , "_stop_after_collection" , False ) :
839- raise StopIteration ()
841+ if stop_after_collection :
842+ raise _EarlyStopForwardError () # Stop the forward pass after collection
840843
841844 bind_forward_method (layer , _forward_w_data_collection , "_original_forward" )
842845 layer .inputs = []
843- layer ._stop_after_collection = stop_after_collection
844846
845847 @staticmethod
846848 def _unpatch_and_cleanup_layer (layer : torch .nn .Module ):
847- """Restore a layer's original forward method and clean up."""
848- unpatch_forward_method (layer , "_original_forward" )
849- del layer .inputs
850- if hasattr (layer , "_stop_after_collection" ):
851- del layer ._stop_after_collection
849+ if hasattr (layer , "_original_forward" ):
850+ unpatch_forward_method (layer , "_original_forward" )
851+ if hasattr (layer , "inputs" ):
852+ del layer .inputs
852853
854+ @torch .no_grad ()
853855 def get_input_activations (self , layer : torch .nn .Module , forward_loop : ForwardLoop ) -> list :
854- """Collect input activations for a layer by running the forward loop.
855-
856- Propagation stops at the patched layer for each batch (saves compute by not running deeper layers),
857- but the forward_loop continues to process all batches.
858-
859- This function is typically used to collect input activations for the first decoder layer of the model.
860- """
861-
862- # Wrap model forward to catch StopIteration per-batch
856+ # Wrap model forward to catch _EarlyStopForward per-batch
863857 def _early_stop_forward (self , * args , ** kwargs ):
864858 try :
865859 return self ._original_forward (* args , ** kwargs )
866- except StopIteration :
860+ except _EarlyStopForwardError :
867861 return None # Stop propagation but allow next batch
868862
869- bind_forward_method (self .model , _early_stop_forward , "_original_forward" )
870- self ._patch_and_initialize_layer (layer , stop_after_collection = True )
871863 try :
864+ bind_forward_method (self .model , _early_stop_forward , "_original_forward" )
865+ self ._patch_and_initialize_layer (layer , stop_after_collection = True )
872866 forward_loop (self .model )
873867 inputs = layer .inputs .copy ()
874868 finally :
875869 self ._unpatch_and_cleanup_layer (layer )
876870 unpatch_forward_method (self .model , "_original_forward" )
871+
877872 return inputs
0 commit comments