-
Notifications
You must be signed in to change notification settings - Fork 364
Support decoder block-level sequential calibration #924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
694f38c
6e87be0
674e640
8af9e10
3a7a93b
16ba9f4
583b045
4b006f2
2285fba
edb0975
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,7 @@ | |
| local_hessian_calibrate, | ||
| max_calibrate, | ||
| mse_calibrate, | ||
| sequential_calibrate, | ||
| smoothquant, | ||
| svdquant, | ||
| ) | ||
|
|
@@ -221,6 +222,7 @@ def wrapped_calib_func( | |
| """ | ||
| kwargs = config.model_dump() | ||
| method = kwargs.pop("method") | ||
| sequential = kwargs.pop("use_sequential", False) | ||
| if method is not None and "awq" in method: | ||
| # For backward compatibility | ||
| kwargs["algorithm"] = method | ||
|
|
@@ -235,8 +237,22 @@ def wrapped_calib_func( | |
| module._moe_calib_experts_ratio = moe_calib_experts_ratio | ||
|
|
||
| if func is not None: | ||
| # Call the function with forward_loop as a separate argument | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
| if sequential: | ||
| if forward_loop is None: | ||
| raise ValueError("forward_loop is required for calibration but got None.") | ||
| assert method in ["max"], ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this True? How can we use this for GPTQ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is just targeting the sequential calibration flow. I plan on adding gptq in this assertion in the GPTQ support PR
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this sequential calibration will work OOTB for mse_calibrate and local_hessian_calibrate. But I can double check after this PR lands |
||
| f"Sequential calibration currently only supports max calibration, got {method}" | ||
| ) | ||
| # Wrap with sequential processing | ||
| sequential_calibrate( | ||
| model, | ||
| forward_loop=forward_loop, | ||
| calib_func=func, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| # Direct calibration (existing behavior) | ||
| func(model, forward_loop=forward_loop, **kwargs) | ||
|
|
||
| # Lets get the latest metadata for the quantizer states | ||
| metadata = {} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,9 +28,14 @@ | |
| from tqdm import tqdm | ||
|
|
||
| from modelopt.torch.opt.searcher import ForwardLoop | ||
| from modelopt.torch.quantization.utils import LayerActivationCollector | ||
| from modelopt.torch.utils import print_rank_0 | ||
| from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState | ||
| from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method | ||
| from modelopt.torch.utils.network import ( | ||
| bind_forward_method, | ||
| get_decoder_layers, | ||
| unpatch_forward_method, | ||
| ) | ||
| from modelopt.torch.utils.perf import get_used_gpu_mem_fraction | ||
|
|
||
| from .calib import MseCalibrator, NVFP4MSECalibrator | ||
|
|
@@ -49,7 +54,14 @@ | |
| weight_attr_names, | ||
| ) | ||
|
|
||
| __all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] | ||
| __all__ = [ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The |
||
| "awq", | ||
| "local_hessian_calibrate", | ||
| "max_calibrate", | ||
| "sequential_calibrate", | ||
| "smoothquant", | ||
| "svdquant", | ||
| ] | ||
|
|
||
|
|
||
| def weight_only_quantize(model: nn.Module): | ||
|
|
@@ -1819,3 +1831,40 @@ def hessian_hook(module, input, output): | |
| torch.cuda.empty_cache() | ||
|
|
||
| print_rank_0("GPTQ-lite quantization completed successfully") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Documentation suggestion: Consider adding a note about the computational complexity in the docstring. Users should understand that this implementation runs O(n) forward passes where n is the number of layers: """Sequential calibration - a sequential layer-by-layer calibration algorithm.
Note: This implementation runs O(n) full forward passes where n is the number of
transformer layers. This is the simplest approach that handles arbitrary model
architectures (including those with residual connections). Future optimizations
may include activation caching.
Args:
model: Model to be calibrated (must have identifiable transformer layers).
...
"""This sets clear expectations about the trade-off: memory efficiency vs computation. |
||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def sequential_calibrate( | ||
| model: nn.Module, | ||
| forward_loop: ForwardLoop, | ||
| calib_func: Callable, | ||
| **calib_kwargs, | ||
| ): | ||
| """Sequential calibration - a sequential layer-by-layer calibration algorithm.""" | ||
| if forward_loop is None: | ||
| raise ValueError("forward_loop must not be None for sequential calibration.") | ||
|
|
||
| transformer_layers = get_decoder_layers(model) | ||
| if transformer_layers is None: | ||
| raise ValueError( | ||
| "Could not find transformer layers in model'. " | ||
| "Sequential calibration requires a model with identifiable transformer layers." | ||
| ) | ||
|
|
||
| print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") | ||
|
|
||
| gettr = LayerActivationCollector(model) | ||
|
|
||
| for layer in transformer_layers: | ||
| # Get updated input activations to the current layer | ||
| layer_inputs = gettr.get_input_activations(layer, forward_loop) | ||
|
|
||
| # Define a forward loop for the current layer | ||
| def _layer_forward_loop(m, _inputs=layer_inputs): | ||
| for args, kwargs_input in _inputs: | ||
| m(*args, **kwargs_input) | ||
|
|
||
| # Call calibration function | ||
| calib_func(layer, _layer_forward_loop, **calib_kwargs) | ||
|
Comment on lines
+1858
to
+1868
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is my understanding correct that for layer n, we will rerun n-1 layers? so basically there are duplicated compute?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah. That's fixed in #930 which will be merged after this |
||
| del layer_inputs | ||
| torch.cuda.empty_cache() | ||
|
sugunav14 marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,10 +29,13 @@ | |
| from torch.distributed.tensor import Replicate | ||
|
|
||
| from modelopt.torch.utils import get_unwrapped_name, print_rank_0 | ||
| from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Generator | ||
|
|
||
| from modelopt.torch.opt.searcher import ForwardLoop | ||
|
|
||
| __all__ = [ | ||
| "EXPORT_MODE", | ||
| "convert_quantization_axis_to_reduce_axis", | ||
|
|
@@ -808,3 +811,64 @@ def update_quant_cfg_with_kv_cache_quant( | |
| quant_cfg["algorithm"] = "max" | ||
| print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") | ||
| return quant_cfg | ||
|
|
||
|
|
||
| class _EarlyStopForwardError(Exception): | ||
| """Error to stop the forward pass after collection.""" | ||
|
|
||
|
|
||
| class LayerActivationCollector: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make it in a separate file?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's a generic helper that can help you get and cache (in the future) input activations wouldn't it make sense to have it here? Otherwise we would be creating a separate file just for this class.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think creating a separate file for this class does not hurt and improves the readability |
||
| """Helper class for collecting layer activations during forward passes. | ||
|
|
||
| This class allows for sequential layer calibration by | ||
| patching layers to capture inputs/outputs during forward passes | ||
| """ | ||
|
|
||
| def __init__(self, model: nn.Module): | ||
| self.model = model | ||
|
|
||
| @staticmethod | ||
| def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False): | ||
| """Patch a layer to collect inputs during forward passes.""" | ||
|
|
||
| def _forward_w_data_collection(self, *args, **kwargs): | ||
| # Note: 'self' refers to the patched layer. | ||
| assert len(args) >= 1, ( | ||
| f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs" | ||
| ) | ||
| # Only collect the inputs to the layer | ||
| self.inputs.append((args, kwargs)) | ||
| if stop_after_collection: | ||
| raise _EarlyStopForwardError() # Stop the forward pass after collection | ||
|
|
||
| return self._original_forward(*args, **kwargs) | ||
|
|
||
| bind_forward_method(layer, _forward_w_data_collection, "_original_forward") | ||
| layer.inputs = [] | ||
|
|
||
| @staticmethod | ||
| def _unpatch_and_cleanup_layer(layer: torch.nn.Module): | ||
| if hasattr(layer, "_original_forward"): | ||
| unpatch_forward_method(layer, "_original_forward") | ||
| if hasattr(layer, "inputs"): | ||
| del layer.inputs | ||
|
|
||
| @torch.no_grad() | ||
| def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: | ||
| # Wrap model forward to catch _EarlyStopForward per-batch | ||
| def _early_stop_forward(self, *args, **kwargs): | ||
| try: | ||
| return self._original_forward(*args, **kwargs) | ||
| except _EarlyStopForwardError: | ||
| return None # Stop propagation but allow next batch | ||
|
|
||
| try: | ||
| bind_forward_method(self.model, _early_stop_forward, "_original_forward") | ||
| self._patch_and_initialize_layer(layer, stop_after_collection=True) | ||
| forward_loop(self.model) | ||
| inputs = layer.inputs.copy() | ||
| finally: | ||
| self._unpatch_and_cleanup_layer(layer) | ||
| unpatch_forward_method(self.model, "_original_forward") | ||
|
|
||
| return inputs | ||
|
Comment on lines
+816
to
+874
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Preserve the original forward when not early-stopping.
🐛 Proposed fix def _forward_w_data_collection(self, *args, **kwargs):
# Note: 'self' refers to the patched layer.
assert len(args) >= 1, (
f"Expected at least 1 positional arg, got {len(args)} args and {list(kwargs.keys())} kwargs"
)
# Only collect the inputs to the layer
self.inputs.append((args, kwargs))
if stop_after_collection:
raise _EarlyStopForwardError() # Stop the forward pass after collection
+ return self._original_forward(*args, **kwargs)🤖 Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -634,3 +634,36 @@ def unpatch_forward_method(module: nn.Module, orig_forward_cache_name: str): | |
| with temporarily_remove_accelerate_hook(module): | ||
| setattr(module, "forward", getattr(module, orig_forward_cache_name)) | ||
| delattr(module, orig_forward_cache_name) | ||
|
|
||
|
|
||
| def get_decoder_layers(model: nn.Module, granularity: str = "decoder") -> nn.ModuleList | None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This breaks our modular plugin abstractions. Can we have a plugin based implementation for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to #930 |
||
| """Detect the decoder layers from a model for sequential calibration. | ||
|
|
||
| This temporary decoder-layer detection heuristic will be replaced with a more robust solution | ||
| that also supports FSDP/DDP models. | ||
| """ | ||
| if granularity != "decoder": | ||
| raise ValueError(f"Unsupported granularity: {granularity}. Only 'decoder' is supported.") | ||
|
|
||
| # HuggingFace transformers pattern: model.model.layers | ||
| if hasattr(model, "model") and hasattr(model.model, "layers"): | ||
| return model.model.layers | ||
|
|
||
| # Megatron/MCore pattern: model.decoder.layers | ||
|
sugunav14 marked this conversation as resolved.
|
||
| if hasattr(model, "decoder") and hasattr(model.decoder, "layers"): | ||
| return model.decoder.layers | ||
|
|
||
| # Direct layers attribute (some models) | ||
| if hasattr(model, "layers") and isinstance(model.layers, nn.ModuleList): | ||
| return model.layers | ||
|
|
||
| # GPT-style: model.transformer.h | ||
| if hasattr(model, "transformer") and hasattr(model.transformer, "h"): | ||
| return model.transformer.h | ||
|
|
||
| # Nemotron Super/Nano | ||
| if hasattr(model, "backbone") and hasattr(model.backbone, "layers"): | ||
| return model.backbone.layers | ||
|
|
||
| print("No decoder layers found for model, returning None") | ||
| return None | ||
|
sugunav14 marked this conversation as resolved.
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Design feedback: Consider adding an explicit validation for
forward_loopwhenuse_sequential=True. Without it, the error fromsequential_calibrateis harder to diagnose:This is a small addition that improves the developer experience.