Skip to content

Commit 6578dff

Browse files
committed
sequential flow
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 52e662d commit 6578dff

4 files changed

Lines changed: 167 additions & 4 deletions

File tree

modelopt/torch/quantization/mode.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
local_hessian_calibrate,
6464
max_calibrate,
6565
mse_calibrate,
66+
sequential_calibrate,
6667
smoothquant,
6768
svdquant,
6869
)
@@ -221,13 +222,25 @@ def wrapped_calib_func(
221222
"""
222223
kwargs = config.model_dump()
223224
method = kwargs.pop("method")
225+
sequential = kwargs.pop("use_sequential", False)
224226
if method is not None and "awq" in method:
225227
# For backward compatibility
226228
kwargs["algorithm"] = method
227229

228230
if func is not None:
229-
# Call the function with forward_loop as a separate argument
230-
func(model, forward_loop=forward_loop, **kwargs)
231+
if sequential:
232+
# Wrap with sequential processing
233+
sequential_calibrate(
234+
model,
235+
forward_loop=forward_loop,
236+
calib_func=func,
237+
**kwargs,
238+
)
239+
else:
240+
# Direct calibration (existing behavior)
241+
func(model, forward_loop=forward_loop, **kwargs)
242+
else:
243+
raise ValueError(f"No calibration function provided for method: {method}")
231244

232245
# Lets get the latest metadata for the quantizer states
233246
metadata = {}

modelopt/torch/quantization/model_calib.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,14 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31+
from modelopt.torch.quantization.utils import LayerActivationGettr
3132
from modelopt.torch.utils import print_rank_0
3233
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
33-
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
34+
from modelopt.torch.utils.network import (
35+
bind_forward_method,
36+
get_decoder_layers,
37+
unpatch_forward_method,
38+
)
3439
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
3540

3641
from .calib import MseCalibrator, NVFP4MSECalibrator
@@ -49,7 +54,14 @@
4954
weight_attr_names,
5055
)
5156

52-
__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"]
57+
__all__ = [
58+
"awq",
59+
"local_hessian_calibrate",
60+
"max_calibrate",
61+
"sequential_calibrate",
62+
"smoothquant",
63+
"svdquant",
64+
]
5365

5466

5567
def weight_only_quantize(model: nn.Module):
@@ -1819,3 +1831,38 @@ def hessian_hook(module, input, output):
18191831
torch.cuda.empty_cache()
18201832

18211833
print_rank_0("GPTQ-lite quantization completed successfully")
1834+
1835+
1836+
@torch.no_grad()
1837+
def sequential_calibrate(
1838+
model: nn.Module,
1839+
forward_loop: ForwardLoop,
1840+
calib_func: Callable,
1841+
**calib_kwargs,
1842+
):
1843+
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
1844+
transformer_layers = get_decoder_layers(model)
1845+
if transformer_layers is None:
1846+
raise ValueError(
1847+
"Could not find transformer layers in model'. "
1848+
"Sequential calibration requires a model with identifiable transformer layers."
1849+
)
1850+
1851+
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1852+
1853+
gettr = LayerActivationGettr(model)
1854+
inputs = gettr.get_input_activations(transformer_layers[0], forward_loop)
1855+
1856+
for layer_idx, layer in enumerate(transformer_layers):
1857+
# Get updated input activations to the current layer
1858+
inputs = gettr.get_input_activations(layer, forward_loop)
1859+
1860+
# Define a forward loop for the current layer
1861+
def _layer_forward_loop(m):
1862+
for args, kwargs_input in inputs: # noqa: F821
1863+
m(*args, **kwargs_input)
1864+
1865+
# Call GPTQ
1866+
calib_func(layer, inputs, forward_loop=_layer_forward_loop, **calib_kwargs)
1867+
del inputs
1868+
torch.cuda.empty_cache()

modelopt/torch/quantization/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929
from torch.distributed.tensor import Replicate
3030

3131
from modelopt.torch.utils import get_unwrapped_name, print_rank_0
32+
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3233

3334
if TYPE_CHECKING:
3435
from collections.abc import Generator
3536

37+
from modelopt.torch.opt.searcher import ForwardLoop
38+
3639
__all__ = [
3740
"EXPORT_MODE",
3841
"convert_quantization_axis_to_reduce_axis",
@@ -808,3 +811,67 @@ def update_quant_cfg_with_kv_cache_quant(
808811
quant_cfg["algorithm"] = "max"
809812
print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}")
810813
return quant_cfg
814+
815+
816+
class LayerActivationGettr:
817+
"""Helper class for collecting layer activations during forward passes.
818+
819+
This class allows for sequential layer calibration by
820+
patching layers to capture inputs/outputs during forward passes
821+
"""
822+
823+
def __init__(self, model: nn.Module):
824+
self.model = model
825+
826+
@staticmethod
827+
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
828+
"""Patch a layer to collect inputs and outputs during forward passes."""
829+
830+
def _forward_w_data_collection(self, *args, **kwargs):
831+
"""Custom forward that collects inputs and outputs.
832+
833+
Note: 'self' refers to the patched layer.
834+
"""
835+
assert len(args) >= 1
836+
# Only collect the inputs to the layer
837+
self.inputs.append((args, kwargs))
838+
if getattr(self, "_stop_after_collection", False):
839+
raise StopIteration()
840+
841+
bind_forward_method(layer, _forward_w_data_collection, "_original_forward")
842+
layer.inputs = []
843+
layer._stop_after_collection = stop_after_collection
844+
845+
@staticmethod
846+
def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
847+
"""Restore a layer's original forward method and clean up."""
848+
unpatch_forward_method(layer, "_original_forward")
849+
del layer.inputs
850+
if hasattr(layer, "_stop_after_collection"):
851+
del layer._stop_after_collection
852+
853+
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
854+
"""Collect input activations for a layer by running the forward loop.
855+
856+
Propagation stops at the patched layer for each batch (saves compute by not running deeper layers),
857+
but the forward_loop continues to process all batches.
858+
859+
This function is typically used to collect input activations for the first decoder layer of the model.
860+
"""
861+
862+
# Wrap model forward to catch StopIteration per-batch
863+
def _early_stop_forward(self, *args, **kwargs):
864+
try:
865+
return self._original_forward(*args, **kwargs)
866+
except StopIteration:
867+
return None # Stop propagation but allow next batch
868+
869+
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
870+
self._patch_and_initialize_layer(layer, stop_after_collection=True)
871+
try:
872+
forward_loop(self.model)
873+
inputs = layer.inputs.copy()
874+
finally:
875+
self._unpatch_and_cleanup_layer(layer)
876+
unpatch_forward_method(self.model, "_original_forward")
877+
return inputs

modelopt/torch/utils/network.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,39 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str):
634634
with temporarily_remove_accelerate_hook(module):
635635
setattr(module, "forward", getattr(module, orig_forward_cache_name))
636636
delattr(module, orig_forward_cache_name)
637+
638+
639+
def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None:
640+
"""Get the decoder layers from a model for sequential calibration.
641+
642+
Args:
643+
model: The model to extract decoder layers from.
644+
granularity: The type of layers to extract. Currently only "decoder" is supported.
645+
646+
Returns:
647+
A ModuleList of decoder layers, or None if not found.
648+
"""
649+
if granularity != "decoder":
650+
raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.")
651+
652+
# HuggingFace transformers pattern: model.model.layers
653+
if hasattr(model, "model") and hasattr(model.model, "layers"):
654+
return model.model.layers
655+
656+
# Megatron/MCore pattern: model.decoder.layers
657+
if hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
658+
return model.decoder.layers
659+
660+
# Direct layers attribute (some models)
661+
if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList):
662+
return model.layers
663+
664+
# GPT-style: model.transformer.h
665+
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
666+
return model.transformer.h
667+
668+
# Nemotron Super/Nano
669+
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
670+
return model.backbone.layers
671+
672+
return None

0 commit comments

Comments
 (0)