Skip to content

Commit beac6e9

Browse files
sugunav14realAsma
andauthored
Sequential calibrate refactor (#982)
### What does this PR do? Type of change: New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> The current sequential calibration support has O(N^2) complexity for collecting updated activations for a decoder layer. To solve this, we adopted a modular/plugin based approach which involves hooks to capture the updated activations by running forward on the previous decoder layer using cached prev layer activations. This leads to an issue with nested modules i.e. the logic in the parent module might need to be replicated in the lower level modules to ensure equivalence. For example, in the nemotron model, the parent module NemotronHModel has logic to create and select appropriate mask based on the decoder layer type (mamba vs attention). This PR implements a more generic solution for sequential calibration, by choosing to collect activations using model forward, thereby ensuring that all the parent module logic is preserved. We use an attribute "state"on the modules to indicate whether to perform recomputation/skip the layer while running module forward. This can help us avoid redundant computations for getting updated activations. The overall flow is as follows 1. The user must register a get_decoder_layers() function that returns a list of layers to be calibrated sequentially 2. LayerActivationCollector, goes through the list of layers and patches module forward with a "state aware" module forward 3. When model.forward() is called, all the parent logic is recomputed as expected (embeddings, residual connections, generating attention mask etc). 4. Lets say we are currently calibrating layer N and we want to get updated activations; we set layer N to capture and layer N-1 to run (because this layer was processed previously and updated activations need to be generated). Already processed layers are set to skip. When model.forward() is called, all the previous decoder layer computations are skipped. Layer N-1 uses the cached inputs to generate new activations. Layer N inputs are captured using the same logic as before and cached so that they can be used to get updated activations for Layer N+1. ### Usage ```python # Sequential calibrate config NVFP4_SEQUENTIAL_CFG = { "quant_cfg": { "*weight_quantizer": _nvfp4_quantizer, "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": {"method": "max", "use_sequential": True}, } ``` ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, using `torch.load(..., weights_only=True)`, avoiding `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other source, did you follow IP policy in [CONTRIBUTING.md](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-copying-code-from-other-sources)?: ✅ - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Public sequential, per-layer calibration API and an activation-collection utility. * Broader model discovery support including Nemotron-H and homogeneous HuggingFace variants. * **Improvements** * Clearer validation/error messages and deterministic patching/unpatching with guaranteed cleanup and resource handling. * Consolidated discovery/registration flow for decoder-layer handling and improved per-layer logging/progress. * **Tests** * Extensive new unit tests covering discovery, per-layer capture/replay, inter-layer behavior, and edge cases. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com> Signed-off-by: realAsma <akuriparambi@nvidia.com> Co-authored-by: realAsma <akuriparambi@nvidia.com>
1 parent 7b34de6 commit beac6e9

File tree

11 files changed

+1520
-286
lines changed

11 files changed

+1520
-286
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,10 @@
2828
from tqdm import tqdm
2929

3030
from modelopt.torch.opt.searcher import ForwardLoop
31-
from modelopt.torch.quantization.utils import LayerActivationCollector
31+
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
3232
from modelopt.torch.utils import print_rank_0
3333
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
34-
from modelopt.torch.utils.network import (
35-
bind_forward_method,
36-
get_decoder_layers,
37-
unpatch_forward_method,
38-
)
34+
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3935
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
4036

4137
from .calib import MseCalibrator, NVFP4MSECalibrator
@@ -1848,31 +1844,42 @@ def sequential_calibrate(
18481844
calib_func: Callable,
18491845
**calib_kwargs,
18501846
):
1851-
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
1847+
"""Sequential calibration - a sequential layer-by-layer calibration algorithm.
1848+
1849+
Runs the full model forward per layer but patches decoder layers with a
1850+
skip / run / capture strategy so that inter-layer logic in parent modules
1851+
(e.g. mask construction) executes naturally without model-specific hooks.
1852+
"""
18521853
if forward_loop is None:
1853-
raise ValueError("forward_loop must not be None for sequential calibration.")
1854+
raise ValueError(
1855+
"forward_loop must not be None for sequential calibration. "
1856+
"Please provide a valid forward_loop callable."
1857+
)
18541858

1855-
transformer_layers = get_decoder_layers(model)
1856-
if transformer_layers is None:
1859+
transformer_layers = LayerActivationCollector.get_decoder_layers(model)
1860+
if transformer_layers is None or len(transformer_layers) == 0:
18571861
raise ValueError(
1858-
"Could not find transformer layers in model'. "
1862+
"Could not find transformer layers in model. "
18591863
"Sequential calibration requires a model with identifiable transformer layers."
18601864
)
18611865

18621866
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
18631867

1864-
gettr = LayerActivationCollector(model)
1868+
input_getter = LayerActivationCollector(model)
1869+
input_getter._patch_all_layers(decoder_layers=transformer_layers)
18651870

1866-
for layer in transformer_layers:
1867-
# Get updated input activations to the current layer
1868-
layer_inputs = gettr.get_input_activations(layer, forward_loop)
1871+
try:
1872+
for layer_idx, layer in enumerate(transformer_layers):
1873+
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
1874+
layer_inputs = input_getter.get_input_activations(layer, forward_loop)
18691875

1870-
# Define a forward loop for the current layer
1871-
def _layer_forward_loop(m, _inputs=layer_inputs):
1872-
for args, kwargs_input in _inputs:
1873-
m(*args, **kwargs_input)
1876+
def _layer_forward_loop(m, _inputs=layer_inputs):
1877+
for args, kwargs_input in _inputs:
1878+
m(*args, **kwargs_input)
18741879

1875-
# Call calibration function
1876-
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1877-
del layer_inputs
1878-
torch.cuda.empty_cache()
1880+
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1881+
1882+
del layer_inputs
1883+
torch.cuda.empty_cache()
1884+
finally:
1885+
input_getter._unpatch_all_layers()

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from ..nn.modules.quant_linear import _QuantLinear
4040
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
4141
from ..utils import replace_function, sync_moe_expert_amax
42+
from ..utils.activation_collector import LayerActivationCollector
4243
from .attention import register_attention_for_kv_quant
4344
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
4445

@@ -1367,6 +1368,42 @@ def _is_supported_hf_model(model):
13671368
return isinstance(model, tuple(supported_models))
13681369

13691370

1371+
def is_nemotron_h_model(model: nn.Module) -> bool:
1372+
return get_nemotron_h_decoder_layers(model) is not None
1373+
1374+
1375+
def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
1376+
if not _is_supported_hf_model(model):
1377+
return None
1378+
1379+
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
1380+
layers = model.backbone.layers
1381+
if len(layers) > 0 and hasattr(layers[0], "block_type"):
1382+
return layers
1383+
1384+
return None
1385+
1386+
1387+
def is_homogeneous_hf_model(model: nn.Module) -> bool:
1388+
if is_nemotron_h_model(model):
1389+
return False
1390+
decoder_layers = get_homogeneous_hf_decoder_layers(model)
1391+
if decoder_layers is None or len(decoder_layers) == 0:
1392+
return False
1393+
layer_classes = {type(layer) for layer in decoder_layers}
1394+
return len(layer_classes) == 1
1395+
1396+
1397+
def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
1398+
if not _is_supported_hf_model(model):
1399+
return None
1400+
1401+
if hasattr(model, "model") and hasattr(model.model, "layers"):
1402+
return model.model.layers
1403+
1404+
return None
1405+
1406+
13701407
@contextmanager
13711408
def setup_model_for_gradient_checkpointing(model: nn.Module):
13721409
use_cache = None
@@ -1420,6 +1457,17 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
14201457
_is_param_grad_enabled_for_auto_quantize,
14211458
)
14221459

1460+
# Order matters: more specific predicates must be registered first because
1461+
# the first matching entry wins. Nemotron-H must precede the generic
1462+
# homogeneous HF discoverer (which explicitly rejects Nemotron-H).
1463+
LayerActivationCollector.register_decoder_layer_support(
1464+
is_nemotron_h_model, get_nemotron_h_decoder_layers
1465+
)
1466+
1467+
LayerActivationCollector.register_decoder_layer_support(
1468+
is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers
1469+
)
1470+
14231471
CUSTOM_MODEL_PLUGINS.update(
14241472
[
14251473
register_falcon_linears_on_the_fly,
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# ruff: noqa: F405
17+
"""Quantization utilities."""
18+
19+
from .activation_collector import LayerActivationCollector
20+
from .core_utils import *
21+
22+
__all__ = [
23+
"EXPORT_MODE",
24+
"convert_quantization_axis_to_reduce_axis",
25+
"export_torch_mode",
26+
"is_quantized",
27+
"is_quantized_column_parallel_linear",
28+
"is_quantized_linear",
29+
"is_quantized_row_parallel_linear",
30+
"reduce_amax",
31+
"reduce_sum",
32+
"replace_function",
33+
"update_quant_cfg_with_kv_cache_quant",
34+
"weight_attr_names",
35+
]

0 commit comments

Comments
 (0)