|
29 | 29 | from torch.distributed.tensor import Replicate |
30 | 30 |
|
31 | 31 | from modelopt.torch.utils import get_unwrapped_name, print_rank_0 |
| 32 | +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method |
32 | 33 |
|
33 | 34 | if TYPE_CHECKING: |
34 | 35 | from collections.abc import Generator |
35 | 36 |
|
| 37 | + from modelopt.torch.opt.searcher import ForwardLoop |
| 38 | + |
36 | 39 | __all__ = [ |
37 | 40 | "EXPORT_MODE", |
38 | 41 | "convert_quantization_axis_to_reduce_axis", |
@@ -808,3 +811,67 @@ def update_quant_cfg_with_kv_cache_quant( |
808 | 811 | quant_cfg["algorithm"] = "max" |
809 | 812 | print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") |
810 | 813 | return quant_cfg |
| 814 | + |
| 815 | + |
| 816 | +class LayerActivationGettr: |
| 817 | + """Helper class for collecting layer activations during forward passes. |
| 818 | +
|
| 819 | + This class allows for sequential layer calibration by |
| 820 | + patching layers to capture inputs/outputs during forward passes |
| 821 | + """ |
| 822 | + |
| 823 | + def __init__(self, model: nn.Module): |
| 824 | + self.model = model |
| 825 | + |
| 826 | + @staticmethod |
| 827 | + def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): |
| 828 | + """Patch a layer to collect inputs and outputs during forward passes.""" |
| 829 | + |
| 830 | + def _forward_w_data_collection(self, *args, **kwargs): |
| 831 | + """Custom forward that collects inputs and outputs. |
| 832 | +
|
| 833 | + Note: 'self' refers to the patched layer. |
| 834 | + """ |
| 835 | + assert len(args) >= 1 |
| 836 | + # Only collect the inputs to the layer |
| 837 | + self.inputs.append((args, kwargs)) |
| 838 | + if getattr(self, "_stop_after_collection", False): |
| 839 | + raise StopIteration() |
| 840 | + |
| 841 | + bind_forward_method(layer, _forward_w_data_collection, "_original_forward") |
| 842 | + layer.inputs = [] |
| 843 | + layer._stop_after_collection = stop_after_collection |
| 844 | + |
| 845 | + @staticmethod |
| 846 | + def _unpatch_and_cleanup_layer(layer: torch.nn.Module): |
| 847 | + """Restore a layer's original forward method and clean up.""" |
| 848 | + unpatch_forward_method(layer, "_original_forward") |
| 849 | + del layer.inputs |
| 850 | + if hasattr(layer, "_stop_after_collection"): |
| 851 | + del layer._stop_after_collection |
| 852 | + |
| 853 | + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: |
| 854 | + """Collect input activations for a layer by running the forward loop. |
| 855 | +
|
| 856 | + Propagation stops at the patched layer for each batch (saves compute by not running deeper layers), |
| 857 | + but the forward_loop continues to process all batches. |
| 858 | +
|
| 859 | + This function is typically used to collect input activations for the first decoder layer of the model. |
| 860 | + """ |
| 861 | + |
| 862 | + # Wrap model forward to catch StopIteration per-batch |
| 863 | + def _early_stop_forward(self, *args, **kwargs): |
| 864 | + try: |
| 865 | + return self._original_forward(*args, **kwargs) |
| 866 | + except StopIteration: |
| 867 | + return None # Stop propagation but allow next batch |
| 868 | + |
| 869 | + bind_forward_method(self.model, _early_stop_forward, "_original_forward") |
| 870 | + self._patch_and_initialize_layer(layer, stop_after_collection=True) |
| 871 | + try: |
| 872 | + forward_loop(self.model) |
| 873 | + inputs = layer.inputs.copy() |
| 874 | + finally: |
| 875 | + self._unpatch_and_cleanup_layer(layer) |
| 876 | + unpatch_forward_method(self.model, "_original_forward") |
| 877 | + return inputs |
0 commit comments