diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cd269c9361..80a2a68761 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1130,6 +1130,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + use_sequential: bool = ModeloptField( + default=False, + title="Enable sequential layer-by-layer calibration.", + description=( + "If True, the calibration algorithm is applied sequentially to each decoder block. " + "The current approach recomputes a full forward pass per layer to propagate updated activations," + "incurring O(N²) cost. Future revisions will add caching to eliminate redundant passes." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 077f294b97..e08efece9a 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -63,6 +63,7 @@ local_hessian_calibrate, max_calibrate, mse_calibrate, + sequential_calibrate, smoothquant, svdquant, ) @@ -221,6 +222,7 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") + sequential = kwargs.pop("use_sequential", False) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -235,8 +237,22 @@ def wrapped_calib_func( module._moe_calib_experts_ratio = moe_calib_experts_ratio if func is not None: - # Call the function with forward_loop as a separate argument - func(model, forward_loop=forward_loop, **kwargs) + if sequential: + if forward_loop is None: + raise ValueError("forward_loop is required for calibration but got None.") + assert method in ["max"], ( + f"Sequential calibration currently only supports max calibration, got {method}" + ) + # Wrap with sequential processing + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=func, + **kwargs, + ) + else: + # Direct calibration (existing behavior) + func(model, forward_loop=forward_loop, **kwargs) # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 350af429be..70f036a8d6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,9 +28,14 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop +from modelopt.torch.quantization.utils import LayerActivationCollector from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.network import ( + bind_forward_method, + get_decoder_layers, + unpatch_forward_method, +) from modelopt.torch.utils.perf import get_used_gpu_mem_fraction from .calib import MseCalibrator, NVFP4MSECalibrator @@ -49,7 +54,14 @@ weight_attr_names, ) -__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] +__all__ = [ + "awq", + "local_hessian_calibrate", + "max_calibrate", + "sequential_calibrate", + "smoothquant", + "svdquant", +] def weight_only_quantize(model: nn.Module): @@ -1819,3 +1831,40 @@ def hessian_hook(module, input, output): torch.cuda.empty_cache() print_rank_0("GPTQ-lite quantization completed successfully") + + +@torch.no_grad() +def sequential_calibrate( + model: nn.Module, + forward_loop: ForwardLoop, + calib_func: Callable, + **calib_kwargs, +): + """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" + if forward_loop is None: + raise ValueError("forward_loop must not be None for sequential calibration.") + + transformer_layers = get_decoder_layers(model) + if transformer_layers is None: + raise ValueError( + "Could not find transformer layers in model'. " + "Sequential calibration requires a model with identifiable transformer layers." + ) + + print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + + gettr = LayerActivationCollector(model) + + for layer in transformer_layers: + # Get updated input activations to the current layer + layer_inputs = gettr.get_input_activations(layer, forward_loop) + + # Define a forward loop for the current layer + def _layer_forward_loop(m, _inputs=layer_inputs): + for args, kwargs_input in _inputs: + m(*args, **kwargs_input) + + # Call calibration function + calib_func(layer, _layer_forward_loop, **calib_kwargs) + del layer_inputs + torch.cuda.empty_cache() diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 6cf6bc90fe..df5f288e22 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -29,10 +29,13 @@ from torch.distributed.tensor import Replicate from modelopt.torch.utils import get_unwrapped_name, print_rank_0 +from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method if TYPE_CHECKING: from collections.abc import Generator + from modelopt.torch.opt.searcher import ForwardLoop + __all__ = [ "EXPORT_MODE", "convert_quantization_axis_to_reduce_axis", @@ -808,3 +811,64 @@ def update_quant_cfg_with_kv_cache_quant( quant_cfg["algorithm"] = "max" print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") return quant_cfg + + +class _EarlyStopForwardError(Exception): + """Error to stop the forward pass after collection.""" + + +class LayerActivationCollector: + """Helper class for collecting layer activations during forward passes. + + This class allows for sequential layer calibration by + patching layers to capture inputs/outputs during forward passes + """ + + def __init__(self, model: nn.Module): + self.model = model + + @staticmethod + def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): + """Patch a layer to collect inputs during forward passes.""" + + def _forward_w_data_collection(self, *args, **kwargs): + # Note: 'self' refers to the patched layer. + assert len(args) >= 1, ( + f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs" + ) + # Only collect the inputs to the layer + self.inputs.append((args, kwargs)) + if stop_after_collection: + raise _EarlyStopForwardError() # Stop the forward pass after collection + + return self._original_forward(*args, **kwargs) + + bind_forward_method(layer, _forward_w_data_collection, "_original_forward") + layer.inputs = [] + + @staticmethod + def _unpatch_and_cleanup_layer(layer: torch.nn.Module): + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, "inputs"): + del layer.inputs + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + # Wrap model forward to catch _EarlyStopForward per-batch + def _early_stop_forward(self, *args, **kwargs): + try: + return self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None # Stop propagation but allow next batch + + try: + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + self._patch_and_initialize_layer(layer, stop_after_collection=True) + forward_loop(self.model) + inputs = layer.inputs.copy() + finally: + self._unpatch_and_cleanup_layer(layer) + unpatch_forward_method(self.model, "_original_forward") + + return inputs diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..21c096db2e 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -634,3 +634,36 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): with temporarily_remove_accelerate_hook(module): setattr(module, "forward", getattr(module, orig_forward_cache_name)) delattr(module, orig_forward_cache_name) + + +def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: + """Detect the decoder layers from a model for sequential calibration. + + This temporary decoder-layer detection heuristic will be replaced with a more robust solution + that also supports FSDP/DDP models. + """ + if granularity != "decoder": + raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") + + # HuggingFace transformers pattern: model.model.layers + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + + # Megatron/MCore pattern: model.decoder.layers + if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): + return model.decoder.layers + + # Direct layers attribute (some models) + if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): + return model.layers + + # GPT-style: model.transformer.h + if hasattr(model, "transformer") and hasattr(model.transformer, "h"): + return model.transformer.h + + # Nemotron Super/Nano + if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): + return model.backbone.layers + + print("No decoder layers found for model, returning None") + return None diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_sequential_calibrate.py new file mode 100644 index 0000000000..3b6b166bed --- /dev/null +++ b/tests/unit/torch/quantization/test_sequential_calibrate.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sequential_calibrate and LayerActivationCollector.""" + +import pytest +import torch +import torch.nn as nn + +from modelopt.torch.quantization.model_calib import sequential_calibrate +from modelopt.torch.quantization.utils import LayerActivationCollector + + +class _DecoderBlock(nn.Module): + """Minimal transformer decoder block.""" + + def __init__(self, dim=16): + super().__init__() + self.attn = nn.Linear(dim, dim, bias=False) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4, bias=False), + nn.ReLU(), + nn.Linear(dim * 4, dim, bias=False), + ) + self.norm = nn.LayerNorm(dim) + + def forward(self, x, **kwargs): + x = x + self.attn(self.norm(x)) + x = x + self.ffn(x) + return x + + +class _SimpleTransformerModel(nn.Module): + """model.layers (ModuleList) -- the simplest pattern recognised by get_decoder_layers.""" + + def __init__(self, n_layers=3, dim=16): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + self.embed = nn.Embedding(32, dim) + + def forward(self, x, **kwargs): + x = self.embed(x) + for layer in self.layers: + x = layer(x) + return x + + +class _FlatMLP(nn.Module): + """No decoder-layer structure -- should be rejected by sequential_calibrate.""" + + def __init__(self, dim=16): + super().__init__() + self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)) + + def forward(self, x): + return self.net(x) + + +class _SimpleTwoLayerModel(nn.Module): + """Minimal model with explicit layers for activation-collection tests.""" + + def __init__(self, dim=16): + super().__init__() + self.layers = nn.ModuleList( + [nn.Linear(dim, dim, bias=False), nn.Linear(dim, dim, bias=False)] + ) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _make_model_and_data(n_layers=3, dim=16, n_batches=2, batch_size=4): + torch.manual_seed(42) + model = _SimpleTransformerModel(n_layers=n_layers, dim=dim) + tokens = [torch.randint(0, 32, (batch_size, 8)) for _ in range(n_batches)] + return model, tokens + + +def _run_forward(model, data): + for batch in data: + model(batch) + + +# LayerActivationCollector tests + + +def test_collector_collects_correct_number_of_inputs(): + torch.manual_seed(0) + model = _SimpleTwoLayerModel(dim=8) + collector = LayerActivationCollector(model) + data = [torch.randn(2, 8) for _ in range(3)] + + def forward_loop(m): + for d in data: + m(d) + + inputs = collector.get_input_activations(model.layers[0], forward_loop) + assert len(inputs) == 3 + + +def test_collector_activations_match_expected(): + """First layer should receive the raw input data.""" + torch.manual_seed(0) + model = _SimpleTwoLayerModel(dim=8) + collector = LayerActivationCollector(model) + data = [torch.randn(2, 8)] + + def forward_loop(m): + for d in data: + m(d) + + inputs = collector.get_input_activations(model.layers[0], forward_loop) + args, kwargs = inputs[0] + assert torch.allclose(args[0], data[0]) + + +def test_collector_second_layer_receives_transformed_input(): + """Second layer should receive first layer's output, not raw input.""" + torch.manual_seed(0) + model = _SimpleTwoLayerModel(dim=8) + collector = LayerActivationCollector(model) + x = torch.randn(2, 8) + + def forward_loop(m): + m(x) + + expected = model.layers[0](x) + inputs = collector.get_input_activations(model.layers[1], forward_loop) + args, _ = inputs[0] + assert torch.allclose(args[0], expected) + + +def test_collector_forward_is_restored_after_collection(): + model = _SimpleTwoLayerModel(dim=8) + collector = LayerActivationCollector(model) + + def forward_loop(m): + m(torch.randn(2, 8)) + + collector.get_input_activations(model.layers[0], forward_loop) + + assert not hasattr(model, "_original_forward") + assert not hasattr(model.layers[0], "inputs") + assert not hasattr(model.layers[0], "_original_forward") + + +def test_collector_cleanup_on_forward_loop_error(): + """Patching should be cleaned up even if forward_loop raises.""" + model = _SimpleTwoLayerModel(dim=8) + collector = LayerActivationCollector(model) + + def bad_forward_loop(m): + raise RuntimeError("intentional error") + + with pytest.raises(RuntimeError, match="intentional error"): + collector.get_input_activations(model.layers[0], bad_forward_loop) + + assert not hasattr(model, "_original_forward") + assert not hasattr(model.layers[0], "inputs") + + +# sequential_calibrate tests + + +def test_seq_calib_raises_on_none_forward_loop(): + model, _ = _make_model_and_data(n_layers=2) + with pytest.raises(ValueError, match="forward_loop must not be None"): + sequential_calibrate( + model, + forward_loop=None, + calib_func=lambda *a, **kw: None, + ) + + +def test_seq_calib_raises_on_unrecognized_model(): + model = _FlatMLP() + with pytest.raises(ValueError, match="Could not find transformer layers"): + sequential_calibrate( + model, + forward_loop=lambda m: m(torch.randn(2, 16)), + calib_func=lambda *a, **kw: None, + ) + + +def test_seq_calib_func_called_per_layer(): + model, data = _make_model_and_data(n_layers=4) + call_count = [0] + + def counting_calib(layer, forward_loop, **kwargs): + call_count[0] += 1 + + sequential_calibrate( + model, + forward_loop=lambda m: _run_forward(m, data), + calib_func=counting_calib, + ) + + assert call_count[0] == 4 + + +def test_seq_calib_func_receives_correct_layer(): + model, data = _make_model_and_data(n_layers=3) + called_layers = [] + + def track_layers(layer, forward_loop, **kwargs): + called_layers.append(layer) + + sequential_calibrate( + model, + forward_loop=lambda m: _run_forward(m, data), + calib_func=track_layers, + ) + + for i, layer in enumerate(model.layers): + assert called_layers[i] is layer + + +def test_seq_calib_kwargs_forwarded(): + model, data = _make_model_and_data(n_layers=2) + received_kwargs = [] + + def capture_kwargs(layer, forward_loop, **kwargs): + received_kwargs.append(kwargs) + + sequential_calibrate( + model, + forward_loop=lambda m: _run_forward(m, data), + calib_func=capture_kwargs, + alpha=0.5, + method="max", + ) + + assert len(received_kwargs) == 2 + for kw in received_kwargs: + assert kw["alpha"] == 0.5 + assert kw["method"] == "max" + + +def test_seq_calib_layer_forward_loop_runs_all_batches(): + """The per-layer forward loop passed to calib_func should replay all batches.""" + n_batches = 5 + model, data = _make_model_and_data(n_layers=2, n_batches=n_batches) + batch_counts = [] + + def count_batches(layer, forward_loop, **kwargs): + counter = {"n": 0} + orig_forward = layer.forward + + def counting_forward(*args, **kw): + counter["n"] += 1 + return orig_forward(*args, **kw) + + layer.forward = counting_forward + forward_loop(layer) + layer.forward = orig_forward + batch_counts.append(counter["n"]) + + sequential_calibrate( + model, + forward_loop=lambda m: _run_forward(m, data), + calib_func=count_batches, + ) + + for count in batch_counts: + assert count == n_batches + + +def test_seq_calib_does_not_alter_weights(): + """sequential_calibrate itself should not modify model weights.""" + model, data = _make_model_and_data(n_layers=3) + weights_before = {n: p.clone() for n, p in model.named_parameters()} + + sequential_calibrate( + model, + forward_loop=lambda m: _run_forward(m, data), + calib_func=lambda layer, forward_loop, **kw: None, + ) + + for n, p in model.named_parameters(): + assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" + + +def test_seq_calib_activations_update_across_layers(): + """Subsequent layers should see activations transformed by prior layers.""" + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=2, dim=16) + tokens = [torch.randint(0, 32, (2, 4))] + + layer_inputs_record = {} + + def record_inputs(layer, forward_loop, **kwargs): + activations = [] + orig_forward = layer.forward + + def capture_forward(*args, **kw): + activations.append(args[0].clone()) + return orig_forward(*args, **kw) + + layer.forward = capture_forward + forward_loop(layer) + layer.forward = orig_forward + + layer_idx = list(model.layers).index(layer) + layer_inputs_record[layer_idx] = activations + + sequential_calibrate( + model, + forward_loop=lambda m: [m(t) for t in tokens], + calib_func=record_inputs, + ) + + assert not torch.allclose(layer_inputs_record[0][0], layer_inputs_record[1][0]), ( + "Layer 1 should receive different activations than layer 0" + ) + + +def test_seq_calib_empty_forward_loop(): + """If forward_loop feeds no data, calib_func still gets called with an empty replay.""" + model = _SimpleTransformerModel(n_layers=2, dim=16) + replay_counts = [] + + def check_empty_replay(layer, forward_loop, **kwargs): + counter = {"n": 0} + orig_forward = layer.forward + + def counting_forward(*args, **kw): + counter["n"] += 1 + return orig_forward(*args, **kw) + + layer.forward = counting_forward + forward_loop(layer) + layer.forward = orig_forward + replay_counts.append(counter["n"]) + + sequential_calibrate( + model, + forward_loop=lambda m: None, + calib_func=check_empty_replay, + ) + + for count in replay_counts: + assert count == 0