@@ -824,8 +824,43 @@ class LayerActivationCollector:
824824 patching layers to capture inputs/outputs during forward passes
825825 """
826826
827+ _next_layer_input_support : list [tuple [Any , Any ]] = []
828+ _decoder_layer_support : list [tuple [Any , Any ]] = []
829+
827830 def __init__ (self , model : nn .Module ):
828831 self .model = model
832+ self ._previous_layer = None
833+ self ._previous_layer_inputs = None
834+
835+ @staticmethod
836+ def get_decoder_layers (model : nn .Module ) -> nn .ModuleList | None :
837+ """Return decoder layers supported by sequential calibration."""
838+ for is_supported , discoverer in LayerActivationCollector ._decoder_layer_support :
839+ if not is_supported (model ):
840+ continue
841+ decoder_layers = discoverer (model )
842+ if decoder_layers is not None :
843+ return decoder_layers
844+ return None
845+
846+ @staticmethod
847+ def is_supported (model : nn .Module ) -> bool :
848+ """Whether the model supports decoder-layer sequential calibration."""
849+ return LayerActivationCollector .get_decoder_layers (model ) is not None
850+
851+ @classmethod
852+ def register_next_layer_input_support (
853+ cls , is_supported : Any , build_next_layer_inputs_hook : Any
854+ ):
855+ entry = (is_supported , build_next_layer_inputs_hook )
856+ if entry not in cls ._next_layer_input_support :
857+ cls ._next_layer_input_support .append (entry )
858+
859+ @classmethod
860+ def register_decoder_layer_support (cls , is_supported : Any , discoverer : Any ):
861+ entry = (is_supported , discoverer )
862+ if entry not in cls ._decoder_layer_support :
863+ cls ._decoder_layer_support .append (entry )
829864
830865 @staticmethod
831866 def _patch_and_initialize_layer (layer : torch .nn .Module , stop_after_collection : bool = False ):
@@ -851,8 +886,15 @@ def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
851886 if hasattr (layer , "inputs" ):
852887 del layer .inputs
853888
889+ def _resolve_next_layer_inputs_hook (self ):
890+ for is_supported , build_next_layer_inputs_hook in self ._next_layer_input_support :
891+ if not is_supported (self .model ):
892+ continue
893+ return build_next_layer_inputs_hook (self .model )
894+ return None
895+
854896 @torch .no_grad ()
855- def get_input_activations (self , layer : torch .nn .Module , forward_loop : ForwardLoop ) -> list :
897+ def _collect_input_activations (self , layer : torch .nn .Module , forward_loop : ForwardLoop ) -> list :
856898 # Wrap model forward to catch _EarlyStopForward per-batch
857899 def _early_stop_forward (self , * args , ** kwargs ):
858900 try :
@@ -870,3 +912,19 @@ def _early_stop_forward(self, *args, **kwargs):
870912 unpatch_forward_method (self .model , "_original_forward" )
871913
872914 return inputs
915+
916+ @torch .no_grad ()
917+ def get_input_activations (self , layer : torch .nn .Module , forward_loop : ForwardLoop ) -> list :
918+ is_first_layer = self ._previous_layer is None or self ._previous_layer_inputs is None
919+ if is_first_layer :
920+ inputs = self ._collect_input_activations (layer , forward_loop )
921+ else :
922+ next_layer_inputs_hook = self ._resolve_next_layer_inputs_hook ()
923+ if next_layer_inputs_hook is None :
924+ inputs = self ._collect_input_activations (layer , forward_loop )
925+ else :
926+ inputs = next_layer_inputs_hook (self ._previous_layer , self ._previous_layer_inputs )
927+
928+ self ._previous_layer = layer
929+ self ._previous_layer_inputs = inputs
930+ return inputs
0 commit comments