Skip to content

Commit 1b87855

Browse files
committed
Modular/Plugin based sequential calib
Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 4e59790 commit 1b87855

7 files changed

Lines changed: 558 additions & 46 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@
3131
from modelopt.torch.quantization.utils import LayerActivationCollector
3232
from modelopt.torch.utils import print_rank_0
3333
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
34-
from modelopt.torch.utils.network import (
35-
bind_forward_method,
36-
get_decoder_layers,
37-
unpatch_forward_method,
38-
)
34+
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
3935
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction
4036

4137
from .calib import MseCalibrator, NVFP4MSECalibrator
@@ -1841,27 +1837,30 @@ def sequential_calibrate(
18411837
**calib_kwargs,
18421838
):
18431839
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
1844-
transformer_layers = get_decoder_layers(model)
1845-
if transformer_layers is None:
1840+
if not LayerActivationCollector.is_supported(model):
18461841
raise ValueError(
1847-
"Could not find transformer layers in model'. "
1842+
"Could not find transformer layers in model. "
18481843
"Sequential calibration requires a model with identifiable transformer layers."
18491844
)
1845+
transformer_layers = LayerActivationCollector.get_decoder_layers(model)
1846+
assert transformer_layers is not None
18501847

18511848
print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")
1849+
if len(transformer_layers) == 0:
1850+
return
18521851

1853-
gettr = LayerActivationCollector(model)
1852+
input_getter = LayerActivationCollector(model)
18541853

1855-
for _, layer in enumerate(transformer_layers):
1856-
# Get updated input activations to the current layer
1857-
inputs = gettr.get_input_activations(layer, forward_loop)
1854+
for layer in transformer_layers:
1855+
layer_inputs = input_getter.get_input_activations(layer, forward_loop)
18581856

18591857
# Define a forward loop for the current layer
1860-
def _layer_forward_loop(m):
1861-
for args, kwargs_input in inputs: # noqa: F821
1858+
def _layer_forward_loop(m, _inputs=layer_inputs):
1859+
for args, kwargs_input in _inputs:
18621860
m(*args, **kwargs_input)
18631861

18641862
# Call calibration function
18651863
calib_func(layer, _layer_forward_loop, **calib_kwargs)
1866-
del inputs
1864+
1865+
del layer_inputs
18671866
torch.cuda.empty_cache()

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
else:
5757
weight_dequant = None
5858

59-
from ..utils import replace_function
59+
from ..utils import LayerActivationCollector, replace_function
6060
from .attention import register_attention_for_kv_quant
6161
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
6262

@@ -1042,6 +1042,55 @@ def _is_supported_hf_model(model):
10421042
return isinstance(model, tuple(supported_models))
10431043

10441044

1045+
def is_homogenous_hf_model(model: nn.Module) -> bool:
1046+
decoder_layers = get_homogeneous_hf_decoder_layers(model)
1047+
if decoder_layers is None or len(decoder_layers) == 0:
1048+
return False
1049+
layer_classes = {type(layer) for layer in decoder_layers}
1050+
return len(layer_classes) == 1
1051+
1052+
1053+
def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
1054+
if not _is_supported_hf_model(model):
1055+
return None
1056+
1057+
if hasattr(model, "model") and hasattr(model.model, "layers"):
1058+
return model.model.layers
1059+
1060+
return None
1061+
1062+
1063+
def build_hf_homogenous_next_layer_inputs_hook(model: nn.Module):
1064+
def _extract_hidden_states(layer_output):
1065+
if isinstance(layer_output, tuple):
1066+
return layer_output[0]
1067+
if isinstance(layer_output, dict):
1068+
if "hidden_states" in layer_output:
1069+
return layer_output["hidden_states"]
1070+
return layer_output
1071+
1072+
def _build_next_layer_inputs_hook(prev_layer, cached_inputs):
1073+
next_inputs = []
1074+
for args, kwargs in cached_inputs:
1075+
prev_output = prev_layer(*args, **kwargs)
1076+
hidden_states = _extract_hidden_states(prev_output)
1077+
if len(args) >= 1:
1078+
next_args = (hidden_states, *args[1:])
1079+
next_kwargs = kwargs
1080+
elif "hidden_states" in kwargs:
1081+
next_args = args
1082+
next_kwargs = dict(kwargs)
1083+
next_kwargs["hidden_states"] = hidden_states
1084+
else:
1085+
raise ValueError(
1086+
"Unable to build next-layer inputs without hidden_states in args/kwargs."
1087+
)
1088+
next_inputs.append((next_args, next_kwargs))
1089+
return next_inputs
1090+
1091+
return _build_next_layer_inputs_hook
1092+
1093+
10451094
@contextmanager
10461095
def setup_model_for_gradient_checkpointing(model: nn.Module):
10471096
use_cache = None
@@ -1091,6 +1140,14 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
10911140
_is_param_grad_enabled_for_auto_quantize,
10921141
)
10931142

1143+
LayerActivationCollector.register_decoder_layer_support(
1144+
is_homogenous_hf_model, get_homogeneous_hf_decoder_layers
1145+
)
1146+
1147+
LayerActivationCollector.register_next_layer_input_support(
1148+
is_homogenous_hf_model, build_hf_homogenous_next_layer_inputs_hook
1149+
)
1150+
10941151
CUSTOM_MODEL_PLUGINS.update(
10951152
[
10961153
register_falcon_linears_on_the_fly,

modelopt/torch/quantization/utils.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,8 +824,43 @@ class LayerActivationCollector:
824824
patching layers to capture inputs/outputs during forward passes
825825
"""
826826

827+
_next_layer_input_support: list[tuple[Any, Any]] = []
828+
_decoder_layer_support: list[tuple[Any, Any]] = []
829+
827830
def __init__(self, model: nn.Module):
828831
self.model = model
832+
self._previous_layer = None
833+
self._previous_layer_inputs = None
834+
835+
@staticmethod
836+
def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
837+
"""Return decoder layers supported by sequential calibration."""
838+
for is_supported, discoverer in LayerActivationCollector._decoder_layer_support:
839+
if not is_supported(model):
840+
continue
841+
decoder_layers = discoverer(model)
842+
if decoder_layers is not None:
843+
return decoder_layers
844+
return None
845+
846+
@staticmethod
847+
def is_supported(model: nn.Module) -> bool:
848+
"""Whether the model supports decoder-layer sequential calibration."""
849+
return LayerActivationCollector.get_decoder_layers(model) is not None
850+
851+
@classmethod
852+
def register_next_layer_input_support(
853+
cls, is_supported: Any, build_next_layer_inputs_hook: Any
854+
):
855+
entry = (is_supported, build_next_layer_inputs_hook)
856+
if entry not in cls._next_layer_input_support:
857+
cls._next_layer_input_support.append(entry)
858+
859+
@classmethod
860+
def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any):
861+
entry = (is_supported, discoverer)
862+
if entry not in cls._decoder_layer_support:
863+
cls._decoder_layer_support.append(entry)
829864

830865
@staticmethod
831866
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
@@ -851,8 +886,15 @@ def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
851886
if hasattr(layer, "inputs"):
852887
del layer.inputs
853888

889+
def _resolve_next_layer_inputs_hook(self):
890+
for is_supported, build_next_layer_inputs_hook in self._next_layer_input_support:
891+
if not is_supported(self.model):
892+
continue
893+
return build_next_layer_inputs_hook(self.model)
894+
return None
895+
854896
@torch.no_grad()
855-
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
897+
def _collect_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
856898
# Wrap model forward to catch _EarlyStopForward per-batch
857899
def _early_stop_forward(self, *args, **kwargs):
858900
try:
@@ -870,3 +912,19 @@ def _early_stop_forward(self, *args, **kwargs):
870912
unpatch_forward_method(self.model, "_original_forward")
871913

872914
return inputs
915+
916+
@torch.no_grad()
917+
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
918+
is_first_layer = self._previous_layer is None or self._previous_layer_inputs is None
919+
if is_first_layer:
920+
inputs = self._collect_input_activations(layer, forward_loop)
921+
else:
922+
next_layer_inputs_hook = self._resolve_next_layer_inputs_hook()
923+
if next_layer_inputs_hook is None:
924+
inputs = self._collect_input_activations(layer, forward_loop)
925+
else:
926+
inputs = next_layer_inputs_hook(self._previous_layer, self._previous_layer_inputs)
927+
928+
self._previous_layer = layer
929+
self._previous_layer_inputs = inputs
930+
return inputs

modelopt/torch/utils/network.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -634,32 +634,3 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str):
634634
with temporarily_remove_accelerate_hook(module):
635635
setattr(module, "forward", getattr(module, orig_forward_cache_name))
636636
delattr(module, orig_forward_cache_name)
637-
638-
639-
def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None:
640-
"""Detect the decoder layers from a model for sequential calibration."""
641-
if granularity != "decoder":
642-
raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.")
643-
644-
# HuggingFace transformers pattern: model.model.layers
645-
if hasattr(model, "model") and hasattr(model.model, "layers"):
646-
return model.model.layers
647-
648-
# Megatron/MCore pattern: model.decoder.layers
649-
if hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
650-
return model.decoder.layers
651-
652-
# Direct layers attribute (some models)
653-
if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList):
654-
return model.layers
655-
656-
# GPT-style: model.transformer.h
657-
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
658-
return model.transformer.h
659-
660-
# Nemotron Super/Nano
661-
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
662-
return model.backbone.layers
663-
664-
print("No decoder layers found for model, returning None")
665-
return None

tests/unit/torch/quantization/plugins/test_huggingface.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,19 @@
2323
from _test_utils.torch.misc import set_seed
2424
from _test_utils.torch.transformers_models import (
2525
create_tiny_llama_dir,
26+
get_tiny_gpt_oss,
2627
get_tiny_llama,
2728
get_tiny_qwen3_moe,
2829
tf_modelopt_state_and_output_tester,
2930
)
3031

3132
import modelopt.torch.quantization as mtq
3233
from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry
34+
from modelopt.torch.quantization.plugins.huggingface import (
35+
get_homogeneous_hf_decoder_layers,
36+
is_homogenous_hf_model,
37+
)
38+
from modelopt.torch.quantization.utils import LayerActivationCollector
3339

3440
pytest.importorskip("transformers")
3541

@@ -199,3 +205,24 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config):
199205

200206
model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model")
201207
tf_modelopt_state_and_output_tester(model_ref, model_test)
208+
209+
210+
def test_is_homogenous_hf_model_llama():
211+
model = get_tiny_llama()
212+
assert is_homogenous_hf_model(model)
213+
214+
215+
def test_is_homogenous_hf_model_gpt_oss():
216+
model = get_tiny_gpt_oss(num_hidden_layers=1)
217+
assert is_homogenous_hf_model(model)
218+
219+
220+
def test_hf_decoder_discoverer_registration_path():
221+
model = get_tiny_llama()
222+
assert any(
223+
is_supported is is_homogenous_hf_model and discoverer is get_homogeneous_hf_decoder_layers
224+
for is_supported, discoverer in LayerActivationCollector._decoder_layer_support
225+
)
226+
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
227+
model
228+
)

0 commit comments

Comments
 (0)