Skip to content

Commit a938963

Browse files
committed
clean up
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 6578dff commit a938963

4 files changed

Lines changed: 35 additions & 31 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,15 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
10971097
title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.",
10981098
)
10991099

1100+
use_sequential: bool = ModeloptField(
1101+
default=False,
1102+
title="Enable sequential layer-by-layer calibration.",
1103+
description=(
1104+
"If True, the calibration algorithm is applied sequentially to each decoder block. "
1105+
"Outputs from one layer become inputs to the next, reducing memory usage for large models."
1106+
),
1107+
)
1108+
11001109

11011110
class MaxCalibConfig(QuantizeAlgorithmConfig):
11021111
"""The config for max calibration algorithm.

modelopt/torch/quantization/model_calib.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils import LayerActivationGettr
31+
from modelopt.torch.quantization.utils import LayerActivationCollector
3232
from modelopt.torch.utils import print_rank_0
3333
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
3434
from modelopt.torch.utils.network import (
@@ -1850,10 +1850,9 @@ def sequential_calibrate(
18501850

18511851
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
18521852

1853-
gettr = LayerActivationGettr(model)
1854-
inputs = gettr.get_input_activations(transformer_layers[0], forward_loop)
1853+
gettr = LayerActivationCollector(model)
18551854

1856-
for layer_idx, layer in enumerate(transformer_layers):
1855+
for _, layer in enumerate(transformer_layers):
18571856
# Get updated input activations to the current layer
18581857
inputs = gettr.get_input_activations(layer, forward_loop)
18591858

modelopt/torch/quantization/utils.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,11 @@ def update_quant_cfg_with_kv_cache_quant(
813813
return quant_cfg
814814

815815

816-
class LayerActivationGettr:
816+
class _EarlyStopForwardError(Exception):
817+
"""Error to stop the forward pass after collection."""
818+
819+
820+
class LayerActivationCollector:
817821
"""Helper class for collecting layer activations during forward passes.
818822
819823
This class allows for sequential layer calibration by
@@ -825,53 +829,44 @@ def __init__(self, model: nn.Module):
825829

826830
@staticmethod
827831
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
828-
"""Patch a layer to collect inputs and outputs during forward passes."""
832+
"""Patch a layer to collect inputs during forward passes."""
829833

830834
def _forward_w_data_collection(self, *args, **kwargs):
831-
"""Custom forward that collects inputs and outputs.
832-
833-
Note: 'self' refers to the patched layer.
834-
"""
835-
assert len(args) >= 1
835+
# Note: 'self' refers to the patched layer.
836+
assert len(args) >= 1, (
837+
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
838+
)
836839
# Only collect the inputs to the layer
837840
self.inputs.append((args, kwargs))
838-
if getattr(self, "_stop_after_collection", False):
839-
raise StopIteration()
841+
if stop_after_collection:
842+
raise _EarlyStopForwardError() # Stop the forward pass after collection
840843

841844
bind_forward_method(layer, _forward_w_data_collection, "_original_forward")
842845
layer.inputs = []
843-
layer._stop_after_collection = stop_after_collection
844846

845847
@staticmethod
846848
def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
847-
"""Restore a layer's original forward method and clean up."""
848-
unpatch_forward_method(layer, "_original_forward")
849-
del layer.inputs
850-
if hasattr(layer, "_stop_after_collection"):
851-
del layer._stop_after_collection
849+
if hasattr(layer, "_original_forward"):
850+
unpatch_forward_method(layer, "_original_forward")
851+
if hasattr(layer, "inputs"):
852+
del layer.inputs
852853

854+
@torch.no_grad()
853855
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
854-
"""Collect input activations for a layer by running the forward loop.
855-
856-
Propagation stops at the patched layer for each batch (saves compute by not running deeper layers),
857-
but the forward_loop continues to process all batches.
858-
859-
This function is typically used to collect input activations for the first decoder layer of the model.
860-
"""
861-
862-
# Wrap model forward to catch StopIteration per-batch
856+
# Wrap model forward to catch _EarlyStopForward per-batch
863857
def _early_stop_forward(self, *args, **kwargs):
864858
try:
865859
return self._original_forward(*args, **kwargs)
866-
except StopIteration:
860+
except _EarlyStopForwardError:
867861
return None # Stop propagation but allow next batch
868862

869-
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
870-
self._patch_and_initialize_layer(layer, stop_after_collection=True)
871863
try:
864+
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
865+
self._patch_and_initialize_layer(layer, stop_after_collection=True)
872866
forward_loop(self.model)
873867
inputs = layer.inputs.copy()
874868
finally:
875869
self._unpatch_and_cleanup_layer(layer)
876870
unpatch_forward_method(self.model, "_original_forward")
871+
877872
return inputs

modelopt/torch/utils/network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,4 +669,5 @@ def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.Mod
669669
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
670670
return model.backbone.layers
671671

672+
print("No decoder layers found for model, returning None")
672673
return None

0 commit comments

Comments
 (0)