Skip to content

Commit fcdaf65

Browse files
authored
Support decoder block-level sequential calibration (#924)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add support for sequential calibration of layers (at decoder level granularity) in ModelOpt. Calibration flow 1. Get list of decoder blocks 2. For current block call get input activations (considering weight and activation QDQ from all other previous blocks) and call specified calibration function. functions added 1. get_decoder_layers() -> to detect and get list of blocks to iterate over 2. LayerActivationCollector class -> to get input activations to the layer 3. sequential_calibrate() -> to perform the described calibration flow 4. use_sequential field in QuantizeAlgorithmConfig ## Usage <!-- You can potentially add a usage example below. --> ```python # Sample config NVFP4_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, "axis": None, "enable": True, }, "*input_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, "axis": None, "enable": True, }, **_default_disabled_quantizer_cfg, }, "algorithm": { "method": "max", "use_sequential": True, } ``` Set use_sequential=True in QUANT_CFG's "algorithm" section. ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Sequential layer-by-layer calibration: Quantization now supports processing decoder layers sequentially to improve memory efficiency on large models. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 2905cb0 commit fcdaf65

File tree

6 files changed

+532
-4
lines changed

6 files changed

+532
-4
lines changed

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
11301130
),
11311131
)
11321132

1133+
use_sequential: bool = ModeloptField(
1134+
default=False,
1135+
title="Enable sequential layer-by-layer calibration.",
1136+
description=(
1137+
"If True, the calibration algorithm is applied sequentially to each decoder block. "
1138+
"The current approach recomputes a full forward pass per layer to propagate updated activations,"
1139+
"incurring O(N²) cost. Future revisions will add caching to eliminate redundant passes."
1140+
),
1141+
)
1142+
11331143

11341144
class MaxCalibConfig(QuantizeAlgorithmConfig):
11351145
"""The config for max calibration algorithm.

modelopt/torch/quantization/mode.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
local_hessian_calibrate,
6464
max_calibrate,
6565
mse_calibrate,
66+
sequential_calibrate,
6667
smoothquant,
6768
svdquant,
6869
)
@@ -221,6 +222,7 @@ def wrapped_calib_func(
221222
"""
222223
kwargs = config.model_dump()
223224
method = kwargs.pop("method")
225+
sequential = kwargs.pop("use_sequential", False)
224226
if method is not None and "awq" in method:
225227
# For backward compatibility
226228
kwargs["algorithm"] = method
@@ -235,8 +237,22 @@ def wrapped_calib_func(
235237
module._moe_calib_experts_ratio = moe_calib_experts_ratio
236238

237239
if func is not None:
238-
# Call the function with forward_loop as a separate argument
239-
func(model, forward_loop=forward_loop, **kwargs)
240+
if sequential:
241+
if forward_loop is None:
242+
raise ValueError("forward_loop is required for calibration but got None.")
243+
assert method in ["max"], (
244+
f"Sequential calibration currently only supports max calibration, got {method}"
245+
)
246+
# Wrap with sequential processing
247+
sequential_calibrate(
248+
model,
249+
forward_loop=forward_loop,
250+
calib_func=func,
251+
**kwargs,
252+
)
253+
else:
254+
# Direct calibration (existing behavior)
255+
func(model, forward_loop=forward_loop, **kwargs)
240256

241257
# Lets get the latest metadata for the quantizer states
242258
metadata = {}

modelopt/torch/quantization/model_calib.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,14 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31+
from modelopt.torch.quantization.utils import LayerActivationCollector
3132
from modelopt.torch.utils import print_rank_0
3233
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
33-
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
34+
from modelopt.torch.utils.network import (
35+
bind_forward_method,
36+
get_decoder_layers,
37+
unpatch_forward_method,
38+
)
3439
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
3540

3641
from .calib import MseCalibrator, NVFP4MSECalibrator
@@ -49,7 +54,14 @@
4954
weight_attr_names,
5055
)
5156

52-
__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"]
57+
__all__ = [
58+
"awq",
59+
"local_hessian_calibrate",
60+
"max_calibrate",
61+
"sequential_calibrate",
62+
"smoothquant",
63+
"svdquant",
64+
]
5365

5466

5567
def weight_only_quantize(model: nn.Module):
@@ -1819,3 +1831,40 @@ def hessian_hook(module, input, output):
18191831
torch.cuda.empty_cache()
18201832

18211833
print_rank_0("GPTQ-lite quantization completed successfully")
1834+
1835+
1836+
@torch.no_grad()
1837+
def sequential_calibrate(
1838+
model: nn.Module,
1839+
forward_loop: ForwardLoop,
1840+
calib_func: Callable,
1841+
**calib_kwargs,
1842+
):
1843+
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
1844+
if forward_loop is None:
1845+
raise ValueError("forward_loop must not be None for sequential calibration.")
1846+
1847+
transformer_layers = get_decoder_layers(model)
1848+
if transformer_layers is None:
1849+
raise ValueError(
1850+
"Could not find transformer layers in model'. "
1851+
"Sequential calibration requires a model with identifiable transformer layers."
1852+
)
1853+
1854+
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1855+
1856+
gettr = LayerActivationCollector(model)
1857+
1858+
for layer in transformer_layers:
1859+
# Get updated input activations to the current layer
1860+
layer_inputs = gettr.get_input_activations(layer, forward_loop)
1861+
1862+
# Define a forward loop for the current layer
1863+
def _layer_forward_loop(m, _inputs=layer_inputs):
1864+
for args, kwargs_input in _inputs:
1865+
m(*args, **kwargs_input)
1866+
1867+
# Call calibration function
1868+
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1869+
del layer_inputs
1870+
torch.cuda.empty_cache()

modelopt/torch/quantization/utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929
from torch.distributed.tensor import Replicate
3030

3131
from modelopt.torch.utils import get_unwrapped_name, print_rank_0
32+
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3233

3334
if TYPE_CHECKING:
3435
from collections.abc import Generator
3536

37+
from modelopt.torch.opt.searcher import ForwardLoop
38+
3639
__all__ = [
3740
"EXPORT_MODE",
3841
"convert_quantization_axis_to_reduce_axis",
@@ -808,3 +811,64 @@ def update_quant_cfg_with_kv_cache_quant(
808811
quant_cfg["algorithm"] = "max"
809812
print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}")
810813
return quant_cfg
814+
815+
816+
class _EarlyStopForwardError(Exception):
817+
"""Error to stop the forward pass after collection."""
818+
819+
820+
class LayerActivationCollector:
821+
"""Helper class for collecting layer activations during forward passes.
822+
823+
This class allows for sequential layer calibration by
824+
patching layers to capture inputs/outputs during forward passes
825+
"""
826+
827+
def __init__(self, model: nn.Module):
828+
self.model = model
829+
830+
@staticmethod
831+
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
832+
"""Patch a layer to collect inputs during forward passes."""
833+
834+
def _forward_w_data_collection(self, *args, **kwargs):
835+
# Note: 'self' refers to the patched layer.
836+
assert len(args) >= 1, (
837+
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
838+
)
839+
# Only collect the inputs to the layer
840+
self.inputs.append((args, kwargs))
841+
if stop_after_collection:
842+
raise _EarlyStopForwardError() # Stop the forward pass after collection
843+
844+
return self._original_forward(*args, **kwargs)
845+
846+
bind_forward_method(layer, _forward_w_data_collection, "_original_forward")
847+
layer.inputs = []
848+
849+
@staticmethod
850+
def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
851+
if hasattr(layer, "_original_forward"):
852+
unpatch_forward_method(layer, "_original_forward")
853+
if hasattr(layer, "inputs"):
854+
del layer.inputs
855+
856+
@torch.no_grad()
857+
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
858+
# Wrap model forward to catch _EarlyStopForward per-batch
859+
def _early_stop_forward(self, *args, **kwargs):
860+
try:
861+
return self._original_forward(*args, **kwargs)
862+
except _EarlyStopForwardError:
863+
return None # Stop propagation but allow next batch
864+
865+
try:
866+
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
867+
self._patch_and_initialize_layer(layer, stop_after_collection=True)
868+
forward_loop(self.model)
869+
inputs = layer.inputs.copy()
870+
finally:
871+
self._unpatch_and_cleanup_layer(layer)
872+
unpatch_forward_method(self.model, "_original_forward")
873+
874+
return inputs

modelopt/torch/utils/network.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,36 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str):
634634
with temporarily_remove_accelerate_hook(module):
635635
setattr(module, "forward", getattr(module, orig_forward_cache_name))
636636
delattr(module, orig_forward_cache_name)
637+
638+
639+
def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None:
640+
"""Detect the decoder layers from a model for sequential calibration.
641+
642+
This temporary decoder-layer detection heuristic will be replaced with a more robust solution
643+
that also supports FSDP/DDP models.
644+
"""
645+
if granularity != "decoder":
646+
raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.")
647+
648+
# HuggingFace transformers pattern: model.model.layers
649+
if hasattr(model, "model") and hasattr(model.model, "layers"):
650+
return model.model.layers
651+
652+
# Megatron/MCore pattern: model.decoder.layers
653+
if hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
654+
return model.decoder.layers
655+
656+
# Direct layers attribute (some models)
657+
if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList):
658+
return model.layers
659+
660+
# GPT-style: model.transformer.h
661+
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
662+
return model.transformer.h
663+
664+
# Nemotron Super/Nano
665+
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
666+
return model.backbone.layers
667+
668+
print("No decoder layers found for model, returning None")
669+
return None

0 commit comments

Comments
 (0)