Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

use_sequential: bool = ModeloptField(
default=False,
title="Enable sequential layer-by-layer calibration.",
description=(
"If True, the calibration algorithm is applied sequentially to each decoder block. "
"The current approach recomputes a full forward pass per layer to propagate updated activations,"
"incurring O(N²) cost. Future revisions will add caching to eliminate redundant passes."
),
)


class MaxCalibConfig(QuantizeAlgorithmConfig):
"""The config for max calibration algorithm.
Expand Down
20 changes: 18 additions & 2 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
local_hessian_calibrate,
max_calibrate,
mse_calibrate,
sequential_calibrate,
smoothquant,
svdquant,
)
Expand Down Expand Up @@ -221,6 +222,7 @@ def wrapped_calib_func(
"""
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand All @@ -235,8 +237,22 @@ def wrapped_calib_func(
module._moe_calib_experts_ratio = moe_calib_experts_ratio

if func is not None:
# Call the function with forward_loop as a separate argument
func(model, forward_loop=forward_loop, **kwargs)
if sequential:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design feedback: Consider adding an explicit validation for forward_loop when use_sequential=True. Without it, the error from sequential_calibrate is harder to diagnose:

if sequential and forward_loop is None:
    raise ValueError("forward_loop must be provided when use_sequential=True")

This is a small addition that improves the developer experience.

if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this True? How can we use this for GPTQ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is just targeting the sequential calibration flow. I plan on adding gptq in this assertion in the GPTQ support PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sequential calibration will work OOTB for mse_calibrate and local_hessian_calibrate. But I can double check after this PR lands

f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)

# Lets get the latest metadata for the quantizer states
metadata = {}
Expand Down
53 changes: 51 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@
from tqdm import tqdm

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.quantization.utils import LayerActivationCollector
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
from modelopt.torch.utils.network import (
bind_forward_method,
get_decoder_layers,
unpatch_forward_method,
)
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction

from .calib import MseCalibrator, NVFP4MSECalibrator
Expand All @@ -49,7 +54,14 @@
weight_attr_names,
)

__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"]
__all__ = [
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The __all__ export now includes "sequential_calibrate" which is good since it is now part of the public API. 👍

"awq",
"local_hessian_calibrate",
"max_calibrate",
"sequential_calibrate",
"smoothquant",
"svdquant",
]


def weight_only_quantize(model: nn.Module):
Expand Down Expand Up @@ -1819,3 +1831,40 @@ def hessian_hook(module, input, output):
torch.cuda.empty_cache()

print_rank_0("GPTQ-lite quantization completed successfully")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation suggestion: Consider adding a note about the computational complexity in the docstring. Users should understand that this implementation runs O(n) forward passes where n is the number of layers:

"""Sequential calibration - a sequential layer-by-layer calibration algorithm.

Note: This implementation runs O(n) full forward passes where n is the number of 
transformer layers. This is the simplest approach that handles arbitrary model 
architectures (including those with residual connections). Future optimizations 
may include activation caching.

Args:
    model: Model to be calibrated (must have identifiable transformer layers).
    ...
"""

This sets clear expectations about the trade-off: memory efficiency vs computation.



@torch.no_grad()
def sequential_calibrate(
model: nn.Module,
forward_loop: ForwardLoop,
calib_func: Callable,
**calib_kwargs,
):
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
if forward_loop is None:
raise ValueError("forward_loop must not be None for sequential calibration.")

transformer_layers = get_decoder_layers(model)
if transformer_layers is None:
raise ValueError(
"Could not find transformer layers in model'. "
"Sequential calibration requires a model with identifiable transformer layers."
)

print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")

gettr = LayerActivationCollector(model)

for layer in transformer_layers:
# Get updated input activations to the current layer
layer_inputs = gettr.get_input_activations(layer, forward_loop)

# Define a forward loop for the current layer
def _layer_forward_loop(m, _inputs=layer_inputs):
for args, kwargs_input in _inputs:
m(*args, **kwargs_input)

# Call calibration function
calib_func(layer, _layer_forward_loop, **calib_kwargs)
Comment on lines +1858 to +1868
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that for layer n, we will rerun n-1 layers? so basically there are duplicated compute?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. That's fixed in #930 which will be merged after this

del layer_inputs
torch.cuda.empty_cache()
Comment thread
sugunav14 marked this conversation as resolved.
64 changes: 64 additions & 0 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
from torch.distributed.tensor import Replicate

from modelopt.torch.utils import get_unwrapped_name, print_rank_0
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

if TYPE_CHECKING:
from collections.abc import Generator

from modelopt.torch.opt.searcher import ForwardLoop

__all__ = [
"EXPORT_MODE",
"convert_quantization_axis_to_reduce_axis",
Expand Down Expand Up @@ -808,3 +811,64 @@ def update_quant_cfg_with_kv_cache_quant(
quant_cfg["algorithm"] = "max"
print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}")
return quant_cfg


class _EarlyStopForwardError(Exception):
"""Error to stop the forward pass after collection."""


class LayerActivationCollector:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it in a separate file?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a generic helper that can help you get and cache (in the future) input activations wouldn't it make sense to have it here? Otherwise we would be creating a separate file just for this class.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think creating a separate file for this class does not hurt and improves the readability

"""Helper class for collecting layer activations during forward passes.

This class allows for sequential layer calibration by
patching layers to capture inputs/outputs during forward passes
"""

def __init__(self, model: nn.Module):
self.model = model

@staticmethod
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
"""Patch a layer to collect inputs during forward passes."""

def _forward_w_data_collection(self, *args, **kwargs):
# Note: 'self' refers to the patched layer.
assert len(args) >= 1, (
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
)
# Only collect the inputs to the layer
self.inputs.append((args, kwargs))
if stop_after_collection:
raise _EarlyStopForwardError() # Stop the forward pass after collection

return self._original_forward(*args, **kwargs)

bind_forward_method(layer, _forward_w_data_collection, "_original_forward")
layer.inputs = []

@staticmethod
def _unpatch_and_cleanup_layer(layer: torch.nn.Module):
if hasattr(layer, "_original_forward"):
unpatch_forward_method(layer, "_original_forward")
if hasattr(layer, "inputs"):
del layer.inputs

@torch.no_grad()
def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list:
# Wrap model forward to catch _EarlyStopForward per-batch
def _early_stop_forward(self, *args, **kwargs):
try:
return self._original_forward(*args, **kwargs)
except _EarlyStopForwardError:
return None # Stop propagation but allow next batch

try:
bind_forward_method(self.model, _early_stop_forward, "_original_forward")
self._patch_and_initialize_layer(layer, stop_after_collection=True)
forward_loop(self.model)
inputs = layer.inputs.copy()
finally:
self._unpatch_and_cleanup_layer(layer)
unpatch_forward_method(self.model, "_original_forward")

return inputs
Comment on lines +816 to +874
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Preserve the original forward when not early-stopping.

_forward_w_data_collection never calls the original forward, so stop_after_collection=False makes the patched layer return None and breaks downstream execution. Either enforce early-stop or forward to _original_forward.

🐛 Proposed fix
         def _forward_w_data_collection(self, *args, **kwargs):
             # Note: 'self' refers to the patched layer.
             assert len(args) >= 1, (
                 f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
             )
             # Only collect the inputs to the layer
             self.inputs.append((args, kwargs))
             if stop_after_collection:
                 raise _EarlyStopForwardError()  # Stop the forward pass after collection
+            return self._original_forward(*args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils.py` around lines 816 - 872, The patched
layer forward (_forward_w_data_collection inside _patch_and_initialize_layer)
currently only appends inputs and never calls the original forward, so when
stop_after_collection is False the layer returns None and breaks the model;
modify _forward_w_data_collection to, after appending to self.inputs, call and
return the original forward (e.g. call self._original_forward(*args, **kwargs)
if present) when stop_after_collection is False (and retain the early raise when
True), ensuring you reference bind_forward_method/_original_forward so the
original method is invoked correctly.

33 changes: 33 additions & 0 deletions modelopt/torch/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,36 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str):
with temporarily_remove_accelerate_hook(module):
setattr(module, "forward", getattr(module, orig_forward_cache_name))
delattr(module, orig_forward_cache_name)


def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks our modular plugin abstractions. Can we have a plugin based implementation for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to #930

"""Detect the decoder layers from a model for sequential calibration.

This temporary decoder-layer detection heuristic will be replaced with a more robust solution
that also supports FSDP/DDP models.
"""
if granularity != "decoder":
raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.")

# HuggingFace transformers pattern: model.model.layers
if hasattr(model, "model") and hasattr(model.model, "layers"):
return model.model.layers

# Megatron/MCore pattern: model.decoder.layers
Comment thread
sugunav14 marked this conversation as resolved.
if hasattr(model, "decoder") and hasattr(model.decoder, "layers"):
return model.decoder.layers

# Direct layers attribute (some models)
if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList):
return model.layers

# GPT-style: model.transformer.h
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
return model.transformer.h

# Nemotron Super/Nano
if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
return model.backbone.layers

print("No decoder layers found for model, returning None")
return None
Comment thread
sugunav14 marked this conversation as resolved.
Loading