From e0cda1bca72fed9b42bdcc3415b8b810cd2737ee Mon Sep 17 00:00:00 2001 From: realAsma Date: Wed, 15 Apr 2026 14:17:14 +0000 Subject: [PATCH 1/6] Add layerwise calibration for large models This PR does three things: 1. Rename sequential_calibrate to layerwise_calibrate to better describe the layer-by-layer algorithm (use_sequential -> use_layerwise, _seq_calib -> _layerwise_calib). 2. Make layerwise calibration performant: persistent_materialization keeps the active layer on GPU for the entire calibration step, and _SkipLayer replaces fully-calibrated layers with parameter-free dummies so framework hooks (accelerate, FSDP2) skip materialization. 3. Add checkpoint save/resume so calibration of large models can be interrupted and restarted from the last completed layer. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma Add layerwise calibration for large models Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma Move checkpoint_dir helpers from library to examples/llm_ptq Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma Rename layerwise config fields and enable layerwise on experts-only recipe - use_layerwise -> layerwise, checkpoint_dir -> layerwise_checkpoint_dir - Enable layerwise calibration + checkpointing on nvfp4_experts_only-fp8_kv recipe - Add layerwise_checkpoint_dir to nvfp4_default-none_kv_gptq recipe Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma Address PR review feedback for layerwise calibration - Add inline security comments for all torch.load(weights_only=False) calls - Replace bare assert with RuntimeError for unsupported offload hook layout - Write back buffers (not just parameters) in _writeback_params_to_weights_map - Add cross-field validator rejecting layerwise_checkpoint_dir without layerwise=True - Validate num_layers mismatch on checkpoint resume - Handle integer device ordinals in _get_execution_device_from_hook - Clean up stale layer artifacts in partial-checkpoint tests - Guard non-dict algorithm values in needs_checkpoint_path_update - Add comment explaining dummy output_meta for last layer Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma --- examples/llm_ptq/example_utils.py | 35 + examples/llm_ptq/hf_ptq.py | 13 +- modelopt/torch/quantization/config.py | 26 +- modelopt/torch/quantization/mode.py | 16 +- modelopt/torch/quantization/model_calib.py | 72 +- .../torch/quantization/plugins/accelerate.py | 91 ++- .../torch/quantization/plugins/huggingface.py | 2 +- modelopt/torch/quantization/utils/__init__.py | 2 +- .../utils/activation_collector.py | 335 --------- .../torch/quantization/utils/calib_utils.py | 2 +- .../torch/quantization/utils/core_utils.py | 105 ++- .../quantization/utils/layerwise_calib.py | 681 ++++++++++++++++++ modelopt/torch/utils/dataset_utils.py | 30 +- modelopt/torch/utils/network.py | 65 +- .../general/ptq/nvfp4_default-fp8_kv.yaml | 2 +- .../ptq/nvfp4_default-none_kv_gptq.yaml | 5 +- .../ptq/nvfp4_experts_only-fp8_kv.yaml | 7 +- .../plugins/test_accelerate_gpu.py | 423 +++++++++++ tests/gpu/torch/quantization/test_fsdp2.py | 133 ++++ tests/gpu/torch/quantization/test_gptq.py | 2 +- ...librate.py => test_layerwise_calibrate.py} | 38 +- .../quantization/plugins/test_huggingface.py | 2 +- tests/unit/torch/quantization/test_calib.py | 20 +- ...librate.py => test_layerwise_calibrate.py} | 124 +++- .../test_sequential_checkpoint.py | 185 +++++ tests/unit/torch/quantization/test_utils.py | 2 +- 26 files changed, 1915 insertions(+), 503 deletions(-) delete mode 100644 modelopt/torch/quantization/utils/activation_collector.py create mode 100644 modelopt/torch/quantization/utils/layerwise_calib.py rename tests/gpu/torch/quantization/{test_sequential_calibrate.py => test_layerwise_calibrate.py} (90%) rename tests/unit/torch/quantization/{test_sequential_calibrate.py => test_layerwise_calibrate.py} (78%) create mode 100644 tests/unit/torch/quantization/test_sequential_checkpoint.py diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c2d4d4bfca..581005de43 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -15,6 +15,7 @@ import copy import glob +import hashlib import inspect import json import logging @@ -854,3 +855,37 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod print(f"Successfully copied {len(copied_files)} custom model files to {export_path}") else: print("No custom model files found to copy") + + +def needs_checkpoint_path_update(quant_cfg: dict) -> bool: + """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" + algorithm = quant_cfg.get("algorithm") + if not isinstance(algorithm, dict): + return False + return algorithm.get("layerwise_checkpoint_dir") is not None + + +def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: + """Append a unique ``_`` subdirectory to layerwise_checkpoint_dir. + + Allows a single recipe to be reused across models without checkpoint collisions. + Must only be called when :func:`needs_checkpoint_path_update` returns True. + """ + algorithm = quant_cfg["algorithm"] + base_dir = algorithm["layerwise_checkpoint_dir"] + + name = model_path.rstrip("/") + if "/" in name and not os.path.isabs(name): + name = name.replace("/", "--") + else: + name = Path(name).name + + config_hash = hashlib.sha256( + json.dumps(quant_cfg, sort_keys=True, default=str).encode() + ).hexdigest()[:8] + + quant_cfg = copy.deepcopy(quant_cfg) + quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( + base_dir, f"{name}_{config_hash}" + ) + return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 327605406c..c03e8aab3b 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -34,6 +34,8 @@ is_enc_dec, is_nemotron_vl, load_mtp_weights, + needs_checkpoint_path_update, + resolve_checkpoint_dir, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: for i, entry in enumerate(quant_cfg): if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue - assert isinstance(entry.get("cfg", {}), dict) - quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}} + cfg = entry.get("cfg") or {} + assert isinstance(cfg, dict) + quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} break @@ -1104,6 +1107,12 @@ def quantize_main( quant_cfg = copy.deepcopy(quant_cfg) _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) + if needs_checkpoint_path_update(quant_cfg): + quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) + print( + f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" + ) + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 99c729efbc..0a52f0e866 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - use_sequential: bool = ModeloptField( + layerwise: bool = ModeloptField( default=False, - title="Enable sequential layer-by-layer calibration.", + title="Enable layerwise (layer-by-layer) calibration.", description=( - "If True, the calibration algorithm is applied sequentially to each decoder block. " + "If True, the calibration algorithm is applied to each decoder layer independently. " "Each layer's inputs are captured via a single forward pass that reflects the " "quantization of all preceding layers, incurring O(N) forward passes for N layers." ), ) + layerwise_checkpoint_dir: str | None = ModeloptField( + default=None, + title="Checkpoint directory for layerwise calibration.", + description=( + "If set together with layerwise=True, per-layer checkpoints are saved to this " + "directory during calibration. On restart, calibration resumes from the last " + "completed layer." + ), + ) + + @model_validator(mode="after") + def validate_layerwise_checkpoint_dir(self): + """Raise if layerwise_checkpoint_dir is set but layerwise is False.""" + if self.layerwise_checkpoint_dir is not None and not self.layerwise: + raise ValueError( + "layerwise_checkpoint_dir requires layerwise=True. " + "Set layerwise=True or remove layerwise_checkpoint_dir." + ) + return self + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index c81d5c89c7..5b00308936 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -60,10 +60,10 @@ from .model_calib import ( awq, gptq, + layerwise_calibrate, local_hessian_calibrate, max_calibrate, mse_calibrate, - sequential_calibrate, smoothquant, svdquant, ) @@ -222,7 +222,8 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") - sequential = kwargs.pop("use_sequential", False) + layerwise = kwargs.pop("layerwise", False) + checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -237,17 +238,16 @@ def wrapped_calib_func( module._moe_calib_experts_ratio = moe_calib_experts_ratio if func is not None: - if sequential: + if layerwise: + # TODO: add a method guard here — not all calib methods support per-layer invocation if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max", "gptq"], ( - f"Sequential calibration currently only supports max and gptq calibration, got {method}" - ) - # Wrap with sequential processing - sequential_calibrate( + # Wrap with layerwise processing + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=func, + checkpoint_dir=checkpoint_dir, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..6db1e82cbd 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,7 +28,10 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _CheckpointState, +) from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method @@ -44,6 +47,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + persistent_materialization, promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, @@ -53,9 +57,9 @@ __all__ = [ "awq", + "layerwise_calibrate", "local_hessian_calibrate", "max_calibrate", - "sequential_calibrate", "smoothquant", "svdquant", ] @@ -1552,21 +1556,27 @@ def postprocess(module, name): @torch.no_grad() -def sequential_calibrate( +def layerwise_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, **calib_kwargs, ): - """Sequential calibration - a sequential layer-by-layer calibration algorithm. + """Layerwise calibration - a layer-by-layer calibration algorithm. Runs the full model forward per layer but patches decoder layers with a skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. + + If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints + are saved after each layer completes. On restart, calibration resumes from + the last completed layer. """ + checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) + if forward_loop is None: raise ValueError( - "forward_loop must not be None for sequential calibration. " + "forward_loop must not be None for layerwise calibration. " "Please provide a valid forward_loop callable." ) @@ -1574,31 +1584,57 @@ def sequential_calibrate( if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( "Could not find transformer layers in model. " - "Sequential calibration requires a model with identifiable transformer layers." + "Layerwise calibration requires a model with identifiable transformer layers." ) - print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + num_layers = len(transformer_layers) + print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers") + + ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers) + start_layer = ckpt.start_layer if ckpt else 0 input_getter = LayerActivationCollector(model) input_getter._patch_all_layers(decoder_layers=transformer_layers) + resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None + try: - for layer_idx, layer in enumerate(transformer_layers): - print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") - layer_inputs = input_getter.get_input_activations(layer, forward_loop) + # Bootstrap: get first layer's inputs (or use resumed inputs). + layer_inputs = input_getter.get_first_layer_inputs( + start_layer, resumed_inputs, forward_loop + ) + + for layer_idx in range(start_layer, num_layers): + layer = transformer_layers[layer_idx] def _layer_forward_loop(m, _inputs=layer_inputs): for args, kwargs_input in _inputs: m(*args, **kwargs_input) - calib_func(layer, _layer_forward_loop, **calib_kwargs) + with persistent_materialization(layer): + calib_func(layer, _layer_forward_loop, **calib_kwargs) + + # Run one more forward to get next layer's inputs and set + # output_meta on the just-calibrated layer (via "run" mode). + is_last = layer_idx + 1 >= num_layers + if not is_last: + next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) + else: + next_inputs = None + + if ckpt: + ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs) del layer_inputs torch.cuda.empty_cache() + layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure) finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed") + if ckpt: + ckpt.full_restore(transformer_layers, model) + + print_rank_0("Layerwise calibration completed") @torch.no_grad() @@ -1610,12 +1646,12 @@ def gptq( ): """GPTQ quantization. - Works in two modes depending on ``use_sequential`` in the config: + Works in two modes depending on ``layerwise`` in the config: - * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + * **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this function once per decoder layer with updated activations, producing more accurate Hessian estimates. - * **Non-sequential** (``use_sequential=False``): called once on the full model. + * **Non-layerwise** (``layerwise=False``): called once on the full model. All layers are quantized in parallel from the original activations. Per-module steps: @@ -1628,7 +1664,7 @@ def gptq( Args: model: The module to quantize — either the full model or a single decoder - layer when invoked by ``sequential_calibrate``. + layer when invoked by ``layerwise_calibrate``. forward_loop: Callable that replays calibration inputs through *model*. perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. @@ -1663,8 +1699,10 @@ def gptq( handle.cleanup() print_rank_0("Updating weights using GPTQ algorithm...") + name_to_module = dict(model.named_modules()) for handle in gptq_handles.values(): - handle.update_weights(block_size, perc_damp) + with enable_weight_access_and_writeback(handle.module, model, name_to_module): + handle.update_weights(block_size, perc_damp) handle.free() del gptq_handles diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index 13999df0f0..bbbb75930e 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -33,10 +33,9 @@ def _get_cpu_offload_hook(hook): if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None: - assert "weight" in hook.weights_map - if ( - isinstance(hook.weights_map, PrefixedDataset) - and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict + assert len(hook.weights_map) > 0 + if isinstance(hook.weights_map, PrefixedDataset) and not any( + k.startswith(hook.weights_map.prefix) for k in hook.weights_map.dataset.state_dict ): raise NotImplementedError( "This layer could be offloaded to disk. We don't support this yet." @@ -50,32 +49,84 @@ def _get_cpu_offload_hook(hook): return None +def _writeback_params_to_weights_map(module, align_hook): + """Write all non-meta parameters and buffers back to the hook's CPU weights_map.""" + for name, tensor in module.state_dict(keep_vars=True).items(): + if tensor.device.type == "meta": + continue + if isinstance(align_hook.weights_map, PrefixedDataset): + key = align_hook.weights_map.prefix + name + w_map = align_hook.weights_map.dataset.state_dict + else: + w_map = align_hook.weights_map + key = name + if key in w_map: + w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype) + + @contextmanager def weight_access_and_writeback_context(module): - """Context manager for weight access and writeback for modules managed by accelerate.""" + """Context manager for weight access and writeback for modules managed by accelerate. + + Handles two cases: + 1. **Single-module**: the module's own ``_hf_hook`` is an offload hook. + 2. **Sub-module**: the module's hook is non-offloading, but its children have + offload hooks (common with ``SequentialHook`` on sub-modules placed by + ``load_checkpoint_and_dispatch``). + + For the sub-module case, ``pre_forward`` is skipped on sub-modules whose weights + are already materialized (not on meta). This allows the context manager to be + used as a pure writeback after weight-modifying algorithms. + """ assert hasattr(module, "_hf_hook") align_hook = _get_cpu_offload_hook(module._hf_hook) if align_hook: - # Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass - # The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader - # See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141 - # TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority) - - # This will load the weights from CPU state_dict and move it to the GPU from meta device + # Guard: the sub-module branch below is not reached when the parent has + # an offload hook. Assert that no children also carry offload hooks, + # which would require a combined writeback strategy. + if any( + _get_cpu_offload_hook(mod._hf_hook) + for mod in module.modules() + if mod is not module and hasattr(mod, "_hf_hook") + ): + raise RuntimeError( + "Both the module and one of its sub-modules have CPU-offload hooks. " + "weight_access_and_writeback_context does not support this layout yet." + ) align_hook.pre_forward(module) + align_hook.offload = False + try: + yield + finally: + align_hook.offload = True + _writeback_params_to_weights_map(module, align_hook) + align_hook.post_forward(module, None) + return + + materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = [] + for mod in module.modules(): + if mod is module or not hasattr(mod, "_hf_hook"): + continue + hook = _get_cpu_offload_hook(mod._hf_hook) + if hook is None: + continue + # Only call pre_forward if weights need materializing; already-materialized + # weights would be overwritten with stale CPU state_dict values. + needs_materialize = any(p.device.type == "meta" for p in mod.parameters()) + if needs_materialize: + hook.pre_forward(mod) + hook.offload = False + materialized.append((mod, hook, needs_materialize)) + try: yield finally: - if align_hook: - # Update the weight in the CPU state_dict - if isinstance(align_hook.weights_map, PrefixedDataset): - key = align_hook.weights_map.prefix + "weight" - w_map = align_hook.weights_map.dataset.state_dict - else: - key, w_map = "weight", align_hook.weights_map - w_map[key] = module.weight.data.to(w_map[key].device, dtype=w_map[key].dtype) - align_hook.post_forward(module, None) + for mod, hook, was_materialized in materialized: + hook.offload = True + _writeback_params_to_weights_map(mod, hook) + if was_materialized: + hook.post_forward(mod, None) @contextmanager diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 82ab589934..59bcd215bb 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -39,7 +39,7 @@ from ..nn.modules.quant_linear import _QuantLinear from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE from ..utils import replace_function, sync_moe_expert_amax -from ..utils.activation_collector import LayerActivationCollector +from ..utils.layerwise_calib import LayerActivationCollector from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index 2660363209..dfc23c42ee 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -16,8 +16,8 @@ # ruff: noqa: F405 """Quantization utilities.""" -from .activation_collector import LayerActivationCollector from .core_utils import * +from .layerwise_calib import LayerActivationCollector __all__ = [ "EXPORT_MODE", diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py deleted file mode 100644 index 5f187fdcb2..0000000000 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ /dev/null @@ -1,335 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sequential calibration layer patching and activation capture. - -This module provides :class:`LayerActivationCollector`, a stateful helper that -patches decoder layers with a skip / run / capture strategy for efficient -layer-by-layer calibration. -""" - -from collections import deque -from dataclasses import dataclass, field -from typing import Any - -import torch -import torch.nn as nn - -from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.utils import print_rank_0 -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method - - -class _EarlyStopForwardError(Exception): - """Raised to halt the forward pass after capturing layer inputs.""" - - -@dataclass -class _LayerCalibState: - """Mutable per-layer state used during sequential calibration. - - Attached to each decoder layer as ``_seq_calib`` and accessed by the - patched forward to decide skip / run / capture / original behaviour. - """ - - mode: str = "original" - name: str = "" - cached_inputs: deque = field(default_factory=deque) - collected_inputs: list = field(default_factory=list) - output_meta: tuple | None = None - - -class LayerActivationCollector: - """Collects layer activations for sequential (layer-by-layer) calibration. - - Each decoder layer is patched with a unified forward whose behaviour is - governed by a per-layer :class:`_LayerCalibState`: - - * **skip** — return a zero-filled dummy whose shape and type match the - layer's real output (reconstructed from lightweight metadata). No - computation is performed. The correctly shaped dummy ensures un-patched - inter-layer operations in the parent forward (e.g. LayerNorm, tuple - unpacking) do not raise shape or type errors. - * **run** — replay previously captured inputs through the original forward, - ignoring whatever the parent passes in. Only the just-calibrated layer - uses this mode, so its output reflects updated weights. - * **capture** — record ``(args, kwargs)`` and raise - ``_EarlyStopForwardError`` to halt the forward pass early. - * **original** — call the original forward unchanged. - - Because the *run* layer discards upstream values, skip-layer outputs are - never consumed for real computation. - """ - - # Global registry of (predicate, discoverer) pairs. Populated at import time - # by plugins (e.g. huggingface.py, megatron.py). Order matters: the first - # matching entry wins, so more specific predicates (e.g. Nemotron-H) must be - # registered before generic ones (e.g. homogeneous HF models). - # - # This is intentionally a mutable class variable shared across all instances: - # plugins register once at import time, and the registry is read-only after - # that. register_decoder_layer_support() guards against duplicate entries. - _decoder_layer_support: list[tuple[Any, Any]] = [] - _LAYER_ATTR = "_seq_calib" - - def __init__(self, model: nn.Module): - """Initialize the collector for the given model.""" - self.model = model - self._decoder_layers: nn.ModuleList | None = None - self._layer_to_idx: dict[nn.Module, int] = {} - self._patched = False - - @staticmethod - def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: - """Return decoder layers supported by sequential calibration.""" - for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: - if not is_supported(model): - continue - decoder_layers = discoverer(model) - if decoder_layers is not None: - return decoder_layers - return None - - @staticmethod - def is_supported(model: nn.Module) -> bool: - """Whether the model supports decoder-layer sequential calibration.""" - return LayerActivationCollector.get_decoder_layers(model) is not None - - @classmethod - def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): - """Register a (predicate, discoverer) pair for decoder-layer detection.""" - entry = (is_supported, discoverer) - if entry not in cls._decoder_layer_support: - cls._decoder_layer_support.append(entry) - - @staticmethod - def _extract_output_meta(output): - """Extract lightweight (shape, dtype, device) metadata from a layer output. - - Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). - The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a - zero-filled output with identical shape and type. - """ - if isinstance(output, torch.Tensor): - return ("tensor", output.shape, output.dtype, output.device) - if isinstance(output, tuple): - return ( - "tuple", - tuple(LayerActivationCollector._extract_output_meta(o) for o in output), - ) - if isinstance(output, list): - return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) - return ("other", output) - - @staticmethod - def _zeros_from_meta(meta): - """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" - tag = meta[0] - if tag == "tensor": - _, shape, dtype, device = meta - return torch.zeros(shape, dtype=dtype, device=device) - if tag == "tuple": - return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) - if tag == "list": - return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] - # "other" values are expected to be lightweight non-tensors (e.g. None, small scalars). - # The value is returned directly (not copied); callers must not mutate it. - # In practice this is safe because skip-mode outputs are immediately discarded by the - # downstream run-mode layer, which replays from its own cached inputs instead. - return meta[1] - - def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): - """Bind the unified forward to every decoder layer and the model. Called once. - - Args: - decoder_layers: Pre-resolved decoder layers. If *None*, layers are - discovered via :meth:`get_decoder_layers`. - """ - - def _patched_forward(self, *args, **kwargs): - """Unified forward bound to every decoder layer during sequential calibration. - - ``self`` here is the decoder layer module (bound via ``bind_forward_method``). - All per-layer state is accessed through ``self._seq_calib``. - """ - info: _LayerCalibState = self._seq_calib - - if info.mode == "skip": - if info.output_meta is None: - raise RuntimeError( - f"Layer {info.name} is in 'skip' mode but has no output_meta. " - "This indicates a state-machine bug: the layer should have run " - "in 'run' mode (which sets output_meta) before transitioning to 'skip'." - ) - return LayerActivationCollector._zeros_from_meta(info.output_meta) - - if info.mode == "run": - assert info.cached_inputs, ( - f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." - ) - real_args, real_kwargs = info.cached_inputs.popleft() - output = self._original_forward(*real_args, **real_kwargs) - info.output_meta = LayerActivationCollector._extract_output_meta(output) - return output - - if info.mode == "capture": - info.collected_inputs.append((args, kwargs)) - raise _EarlyStopForwardError() - - return self._original_forward(*args, **kwargs) - - if decoder_layers is not None: - self._decoder_layers = decoder_layers - else: - self._decoder_layers = self.get_decoder_layers(self.model) - assert self._decoder_layers is not None - - self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} - module_to_name = {m: name for name, m in self.model.named_modules()} - - try: - for layer in self._decoder_layers: - layer._seq_calib = _LayerCalibState( - name=module_to_name.get(layer, type(layer).__name__), - ) - bind_forward_method(layer, _patched_forward, "_original_forward") - - def _early_stop_forward(module_self, *args, **kwargs): - try: - return module_self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None - - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - except Exception: - self._cleanup_layers() - raise - - self._patched = True - - def _cleanup_layers(self): - """Best-effort cleanup of any patched layers and model forward.""" - if hasattr(self.model, "_original_forward"): - unpatch_forward_method(self.model, "_original_forward") - - if self._decoder_layers is not None: - for layer in self._decoder_layers: - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, self._LAYER_ATTR): - delattr(layer, self._LAYER_ATTR) - - def _unpatch_all_layers(self): - """Restore original forwards and clean up state attributes. Called once.""" - if not self._patched: - return - self._cleanup_layers() - self._patched = False - - def _set_layer_states(self, layer_idx: int): - """Transition layer modes for the next calibration step. - - When calibrating layer *i*, three transitions happen: - - * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). - * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). - * Layer ``i`` → **capture** (record inputs, then early-stop). - """ - assert self._decoder_layers is not None - - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - # output_meta is intentionally kept: skip mode needs it to produce - # correctly shaped zero-filled outputs for the parent forward. - done.mode = "skip" - done.cached_inputs.clear() - - if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - if not prev.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " - "Layers must be calibrated sequentially — ensure get_input_activations() " - "was called for every preceding layer in order." - ) - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] - - def _log_layer_summary(self, layer_idx: int): - """Log a one-line summary of layer modes for the current calibration step.""" - assert self._decoder_layers is not None - n = len(self._decoder_layers) - groups: dict[str, list[int]] = {} - for i, layer in enumerate(self._decoder_layers): - mode = layer._seq_calib.mode - if mode in ("skip", "run", "capture"): - groups.setdefault(mode, []).append(i + 1) - parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] - print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - """Collect input activations for *layer* by running a full model forward. - - Layers before the target are skipped or re-run (if just calibrated), the - target layer captures its inputs, and an early-stop prevents unnecessary - computation beyond the target. - - :meth:`_patch_all_layers` must be called before this method. - - Note: the model forward returns ``None`` for every batch during capture - (because ``_EarlyStopForwardError`` short-circuits the forward pass). - Callers should not rely on the model's return value within *forward_loop*. - """ - if not self._patched: - raise RuntimeError( - "get_input_activations() requires _patch_all_layers() to be called first." - ) - layer_idx = self._layer_to_idx[layer] - self._set_layer_states(layer_idx) - self._log_layer_summary(layer_idx) - - info = layer._seq_calib - try: - forward_loop(self.model) - except Exception: - # Reset the current layer so subsequent calls don't see stale state. - info.mode = "original" - info.collected_inputs = [] - raise - - if not info.collected_inputs: - info.mode = "original" - raise RuntimeError( - f"Layer {info.name!r} collected no inputs during forward_loop. " - "The forward loop did not reach this layer — check that forward_loop() " - "actually calls the model and that the layer is in the forward path." - ) - - inputs = list(info.collected_inputs) - # After capture, set to original so calib_func can call the layer's - # real forward directly. The layer will transition to run → skip - # in subsequent iterations via _set_layer_states. - info.mode = "original" - return inputs diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index e52a8438d5..e2d8ccf2a2 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -96,7 +96,7 @@ def __init__(self, module, name, offload_to_cpu=False): self.name = name in_features = module.weight.shape[-1] device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + if device.type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65): device = "cpu" self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) self.n_samples = 0 diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 273d7564c6..1a626a51d3 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -423,47 +423,70 @@ def _get_enclosing_fsdp_module( return root_model +def _set_parameter(module: nn.Module, name: str, value: nn.Parameter): + """Set a parameter on a module by dotted name (e.g. ``self_attn.q_proj.weight``).""" + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent = module.get_submodule(parts[0]) + attr = parts[1] + else: + parent = module + attr = name + parent._parameters[attr] = value + + @contextmanager def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module): """Context manager for FSDP2 weight access and writeback. - Note this context will gather the weight across FSDP/HSDP shards. If TP is implemented with DTensor, - the weight will be a local tensor of the TP DTensor under this context. + Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be + read or modified. Works for both leaf modules (single ``weight``) and + composite modules like decoder layers (all ``named_parameters``). + + If TP is implemented with DTensor, the weight will be a local tensor of the + TP DTensor under this context. """ assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2" assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" - assert isinstance(module.weight, torch.distributed.tensor.DTensor) fsdp_module = _get_enclosing_fsdp_module(module, root_model) assert fsdp_module is not None, "Module is not wrapped by FSDP" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim - original_placements = module.weight.placements - original_device_mesh = module.weight.device_mesh - original_weight = module.weight - # Assuming the first fsdp_dim dimensions are for FSDP/HSDP, we only collect the tensor over FSDP/HSDP dimension, - # the TP will be handled by the TP reduction. - if fsdp_dim != original_device_mesh.ndim: - assert fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim], ( - "FSDP2 mesh should be a slice of DTesnor's device mesh." + # Collect all DTensor parameters, replacing them with local replicated copies. + originals: dict[str, tuple] = {} + for name, param in module.named_parameters(): + if not isinstance(param, torch.distributed.tensor.DTensor): + continue + original_placements = param.placements + original_device_mesh = param.device_mesh + if fsdp_dim != original_device_mesh.ndim: + assert ( + fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim] + ), "FSDP2 mesh should be a slice of DTensor's device mesh." + collected = param.redistribute( + placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), + device_mesh=original_device_mesh, ) - - weight_collected = original_weight.redistribute( - placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), - device_mesh=original_device_mesh, - ) - new_weight = nn.Parameter(weight_collected.to_local()) - module._parameters["weight"] = new_weight + originals[name] = (param, collected, original_placements, original_device_mesh) + _set_parameter(module, name, nn.Parameter(collected.to_local())) yield - original_weight.to_local().data.copy_( - weight_collected.redistribute( - placements=original_placements, device_mesh=original_device_mesh - ).to_local() - ) - module._parameters["weight"] = original_weight + # Write back and restore original DTensor parameters. + for name, ( + original_param, + collected, + original_placements, + original_device_mesh, + ) in originals.items(): + original_param.to_local().data.copy_( + collected.redistribute( + placements=original_placements, device_mesh=original_device_mesh + ).to_local() + ) + _set_parameter(module, name, original_param) @contextmanager @@ -498,6 +521,22 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict yield +@contextmanager +def persistent_materialization(layer): + """Keep all layer weights materialized on GPU for the duration. + + Suppresses per-forward weight transfers so that N calibration batches + pay the cost of one load/unload instead of N. + + - **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then + gathers weights once via ``enable_weight_access_and_writeback``. + - **Accelerate**: materializes weights and sets ``hook.offload = False`` + so per-forward hooks skip materialization/offloading. + """ + with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer): + yield + + def get_quantizer_state_dict(model: nn.Module): """Get the state dict of the quantizers in the model.""" # We should not call model.state_dict() here. @@ -607,6 +646,24 @@ def _init_mp_dtypes(self) -> None: ) +@contextmanager +def _disable_fsdp_unshard_reshard(layer): + """Disable FSDP2 unshard/reshard if *layer* is FSDP-wrapped.""" + if isinstance(layer, FSDPModule): + _pg_cls = torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup + orig_unshard = _pg_cls.unshard + orig_reshard = _pg_cls.reshard + _pg_cls.unshard = lambda self, async_op=False: None + _pg_cls.reshard = lambda self: None + try: + yield + finally: + _pg_cls.unshard = orig_unshard + _pg_cls.reshard = orig_reshard + else: + yield + + def get_prefixed_param_names(parent_model, target_module): """Get parameter names for a target module prefixed with the parent model name. diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py new file mode 100644 index 0000000000..0cacdab273 --- /dev/null +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -0,0 +1,681 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layerwise calibration layer patching, activation capture, and checkpoint save/resume. + +This module provides :class:`LayerActivationCollector`, a stateful helper that +patches decoder layers with a skip / run / capture strategy for efficient +layer-by-layer calibration, and :class:`_CheckpointState` for persisting +per-layer calibration progress to disk. +""" + +from __future__ import annotations + +import json +import os +import shutil +from collections import deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn + +from modelopt.torch.utils import distributed as dist +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import ( + bind_forward_method, + get_module_device, + unpatch_forward_method, +) + +if TYPE_CHECKING: + from modelopt.torch.opt.searcher import ForwardLoop + + +class _EarlyStopForwardError(Exception): + """Raised to halt the forward pass after capturing layer inputs.""" + + +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during layerwise calibration. + + Attached to each decoder layer as ``_layerwise_calib`` and accessed by the + patched forward to decide skip / run / capture / original behaviour. + """ + + mode: str = "original" + name: str = "" + cached_inputs: deque = field(default_factory=deque) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None + + +class _SkipLayer(nn.Module): + """Parameter-free stand-in for a fully calibrated decoder layer. + + Replaces the real layer in the ModuleList so that framework hooks + (accelerate, FSDP2, etc.) have no parameters to transfer. Holds a + reference to the original layer for restoration during cleanup. + """ + + def __init__(self, original: nn.Module): + super().__init__() + # Bypass nn.Module.__setattr__ to avoid registering original as a submodule. + object.__setattr__(self, "_original", original) + self._layerwise_calib = _LayerCalibState(mode="skip") + + def __getattr__(self, name: str): + # Proxy non-special attribute lookups to the original layer so that + # parent-model code that accesses layer-level attributes (e.g., + # NemotronH's ``block_type``) still works when the layer is replaced + # with a _SkipLayer. + try: + return super().__getattr__(name) + except AttributeError: + return getattr(object.__getattribute__(self, "_original"), name) + + def forward(self, *args, **kwargs): + return LayerActivationCollector._zeros_from_meta( + self._original._layerwise_calib.output_meta + ) + + +class LayerActivationCollector: + """Collects layer activations for layerwise (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **original** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. + """ + + _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_layerwise_calib" + + def __init__(self, model: nn.Module): + """Initialize the collector for the given model.""" + self.model = model + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + def _swap_to_dummy(self, idx: int): + """Replace decoder layer *idx* with a parameter-free dummy. + + ``output_meta`` is intentionally preserved on the original layer: the + ``_SkipLayer`` reads it to produce correctly shaped zero-filled outputs + for the parent forward pass. + """ + assert self._decoder_layers is not None + layer = self._decoder_layers[idx] + layer._layerwise_calib.mode = "skip" + layer._layerwise_calib.cached_inputs.clear() + self._decoder_layers[idx] = _SkipLayer(layer) + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by layerwise calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer layerwise calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + """Register a (predicate, discoverer) pair for decoder-layer detection.""" + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + # "other" values are lightweight non-tensors (e.g. None, small scalars). + # Returned directly (not copied); safe because skip-mode outputs are + # immediately discarded by the downstream run-mode layer. + return meta[1] + + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + """Bind the unified forward to every decoder layer and the model. Called once. + + Args: + decoder_layers: Pre-resolved decoder layers. If *None*, layers are + discovered via :meth:`get_decoder_layers`. + """ + + def _patched_forward(self, *args, **kwargs): + info: _LayerCalibState = self._layerwise_calib + + if info.mode == "skip": + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) + + if info.mode == "run": + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) + real_args, real_kwargs = info.cached_inputs.popleft() + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + info.collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + + return self._original_forward(*args, **kwargs) + + if decoder_layers is not None: + self._decoder_layers = decoder_layers + else: + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} + + try: + for layer in self._decoder_layers: + layer._layerwise_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, _patched_forward, "_original_forward") + + def _early_stop_forward(module_self, *args, **kwargs): + try: + return module_self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise + + self._patched = True + + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if self._decoder_layers is not None: + for idx, layer in enumerate(self._decoder_layers): + if isinstance(layer, _SkipLayer): + self._decoder_layers[idx] = layer._original + + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + self._cleanup_layers() + self._patched = False + + def _set_layer_states(self, layer_idx: int): + """Transition layer modes for the next calibration step. + + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). + """ + assert self._decoder_layers is not None + + if layer_idx > 1: + idx = layer_idx - 2 + if not isinstance(self._decoder_layers[idx], _SkipLayer): + self._swap_to_dummy(idx) + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1]._layerwise_calib + if not prev.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " + "Layers must be calibrated sequentially — ensure get_input_activations() " + "was called for every preceding layer in order." + ) + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._layerwise_calib + cur.mode = "capture" + cur.collected_inputs = [] + + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._layerwise_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i + 1) + + parts = [] + for mode in ("skip", "run", "capture"): + if mode not in groups: + continue + ids = groups[mode] + parts.append(f"{mode}: {len(ids)}" if mode == "skip" else f"{mode}: {ids}") + print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for *layer* by running a full model forward. + + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. + + Note: the model forward returns ``None`` for every batch during capture + (because ``_EarlyStopForwardError`` short-circuits the forward pass). + Callers should not rely on the model's return value within *forward_loop*. + """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + self._log_layer_summary(layer_idx) + + info = layer._layerwise_calib + try: + forward_loop(self.model) + except Exception: + # Reset the current layer so subsequent calls don't see stale state. + info.mode = "original" + info.collected_inputs = [] + raise + + if not info.collected_inputs: + info.mode = "original" + raise RuntimeError( + f"Layer {info.name!r} collected no inputs during forward_loop. " + "The forward loop did not reach this layer — check that forward_loop() " + "actually calls the model and that the layer is in the forward path." + ) + + inputs = list(info.collected_inputs) + # Reset to original so calib_func can call the layer's real forward + # directly. The layer will transition to run → skip in subsequent + # iterations via _set_layer_states. + info.mode = "original" + return inputs + + def get_first_layer_inputs( + self, + start_layer: int, + resumed_inputs: list | None, + forward_loop: ForwardLoop, + ) -> list: + """Get inputs for the first layer to calibrate, handling resume. + + If *resumed_inputs* is provided, sets skip mode on layers ``0..start_layer-1`` + and seeds the start layer's ``collected_inputs`` for subsequent + ``cache_outputs_for_next_layer_calib`` calls. Otherwise, captures inputs + via a normal forward pass. + """ + assert self._decoder_layers is not None + + if resumed_inputs is not None: + print_rank_0(f"Calibrating layer {start_layer + 1} (resumed)") + for i in range(start_layer): + self._swap_to_dummy(i) + layer = self._decoder_layers[start_layer] + layer._layerwise_calib.collected_inputs = resumed_inputs + layer._layerwise_calib.mode = "original" + return resumed_inputs + + return self.get_input_activations(self._decoder_layers[start_layer], forward_loop) + + @torch.no_grad() + def cache_outputs_for_next_layer_calib( + self, layer: torch.nn.Module, forward_loop: ForwardLoop + ) -> list: + """Run a forward pass after calibrating *layer* to capture the next layer's inputs. + + This puts *layer* into "run" mode (setting its ``output_meta``) and the + next layer into "capture" mode, then runs *forward_loop*. Returns the + captured inputs for the next layer. + + Must be called only when a next layer exists (i.e. *layer* is not the + last decoder layer). + """ + assert self._decoder_layers is not None + layer_idx = self._layer_to_idx[layer] + next_idx = layer_idx + 1 + assert next_idx < len(self._decoder_layers), "No next layer to capture inputs for." + from .core_utils import persistent_materialization + + next_layer = self._decoder_layers[next_idx] + with persistent_materialization(layer): + return self.get_input_activations(next_layer, forward_loop) + + +def _move_to_device(obj: Any, device: torch.device) -> Any: + """Recursively move tensors to *device*. Non-tensors are returned as-is.""" + if isinstance(obj, torch.Tensor): + return obj.to(device) + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + moved = [_move_to_device(v, device) for v in obj] + return type(obj)(moved) + return obj + + +def _remap_output_metadata_device(meta: tuple, device: torch.device) -> tuple: + """Patch the device field inside output_meta tuples so _zeros_from_meta uses *device*.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, _old_device = meta + return ("tensor", shape, dtype, device) + if tag == "tuple": + return ("tuple", tuple(_remap_output_metadata_device(m, device) for m in meta[1])) + if tag == "list": + return ("list", [_remap_output_metadata_device(m, device) for m in meta[1]]) + return meta + + +def _read_manifest(checkpoint_dir: str) -> dict | None: + """Read manifest.json from *checkpoint_dir*. Returns None if missing or corrupt.""" + path = os.path.join(checkpoint_dir, "manifest.json") + if not os.path.isfile(path): + return None + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + +def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: + """Atomically write manifest.json.""" + path = os.path.join(checkpoint_dir, "manifest.json") + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump( + {"last_completed_layer": last_completed_layer, "num_layers": num_layers}, + f, + ) + os.replace(tmp, path) + + +def _layer_dir(checkpoint_dir: str, idx: int) -> str: + return os.path.join(checkpoint_dir, f"layer_{idx:04d}") + + +def _save_layer( + checkpoint_dir: str, + idx: int, + weights: dict, + qstate: dict, + output_meta: tuple, + next_inputs: list | None, + num_layers: int, +) -> None: + """Save a single layer checkpoint and update the manifest atomically.""" + d = _layer_dir(checkpoint_dir, idx) + if os.path.isdir(d): + shutil.rmtree(d) + os.makedirs(d) + torch.save(weights, os.path.join(d, "weights.pt")) + torch.save(qstate, os.path.join(d, "quantizer_state.pt")) + torch.save(output_meta, os.path.join(d, "output_meta.pt")) + if next_inputs is not None: + torch.save(next_inputs, os.path.join(d, "next_inputs.pt")) + _write_manifest(checkpoint_dir, idx, num_layers) + + +def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: + """Detect where to resume from an existing checkpoint directory. + + Returns ``(start_layer, manifest)`` if there is work to resume, + or ``None`` if the directory is empty, corrupt, or calibration was already complete. + """ + manifest = _read_manifest(checkpoint_dir) + if manifest is None: + return None + last = manifest.get("last_completed_layer") + total = manifest.get("num_layers") + if last is None or total is None: + return None + if last + 1 >= total: + return None + return (last + 1, manifest) + + +class _CheckpointState: + """Manages checkpoint save and restore for layerwise calibration. + + Handles both saving per-layer checkpoints during calibration and + restoring from a previous partial run. + + .. todo:: + Support distributed checkpoint save/restore for FSDP2: + use ``torch.distributed.checkpoint`` (or save only from rank 0 + barrier) + and broadcast restored state to all ranks during resume. + """ + + def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): + if dist.is_initialized() and dist.size() > 1: + raise RuntimeError( + "Layerwise calibration checkpointing is not supported in " + "multi-process distributed jobs (e.g. FSDP2). " + "Use single-process calibration or disable checkpointing." + ) + + self.checkpoint_dir = checkpoint_dir + self.num_layers = num_layers + self.start_layer = start_layer + + @classmethod + def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None: + """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" + if not checkpoint_dir: + return None + os.makedirs(checkpoint_dir, exist_ok=True) + info = detect_resume_point(checkpoint_dir) + if info is not None: + manifest_num_layers = info[1].get("num_layers") + if manifest_num_layers is not None and manifest_num_layers != num_layers: + raise ValueError( + f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} " + f"but model has {num_layers}. Use a fresh checkpoint directory." + ) + start = info[0] if info else 0 + if start > 0: + print_rank_0( + f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" + ) + return cls(checkpoint_dir, num_layers, start_layer=start) + + def setup_resume(self, layers: nn.ModuleList) -> list | None: + """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. + + Sets ``output_meta`` on each already-calibrated layer so that + skip mode can produce correctly shaped dummy outputs. + """ + if self.start_layer == 0: + return None + + last_ckpt = self.start_layer - 1 + + for i in range(self.start_layer): + d = _layer_dir(self.checkpoint_dir, i) + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + meta = torch.load( + os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False + ) + layer_device = get_module_device(layers[i]) + meta = _remap_output_metadata_device(meta, layer_device) + layers[i]._layerwise_calib.output_meta = meta + + d = _layer_dir(self.checkpoint_dir, last_ckpt) + next_inputs_path = os.path.join(d, "next_inputs.pt") + if not os.path.isfile(next_inputs_path): + raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) + resume_device = get_module_device(layers[self.start_layer]) + next_inputs = _move_to_device(next_inputs, resume_device) + return next_inputs + + def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: + """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.""" + from modelopt.torch.quantization.config import QuantizeConfig + from modelopt.torch.quantization.conversion import restore_quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + if self.start_layer == 0: + return + + dummy_config = QuantizeConfig() + name_to_module = dict(model.named_modules()) + for i in range(self.start_layer): + layer = layers[i] + layer_device = get_module_device(layer) + d = _layer_dir(self.checkpoint_dir, i) + + # Restore quantizer state first: may promote TensorQuantizer to + # NVFP4StaticQuantizer, changing module structure that load_state_dict + # expects. + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + qstate = torch.load( + os.path.join(d, "quantizer_state.pt"), map_location=layer_device, weights_only=False + ) + restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) + + # Load weights inside the framework's access context so that + # managed-weight frameworks (accelerate CPU offload, FSDP2) sync + # their internal state with the restored parameters. + with enable_weight_access_and_writeback(layer, model, name_to_module): + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + weights = torch.load( + os.path.join(d, "weights.pt"), map_location=layer_device, weights_only=False + ) + layer.load_state_dict(weights, strict=False) + + print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") + + def save( + self, + layer_idx: int, + layer: nn.Module, + model: nn.Module, + layers: nn.ModuleList, + next_layer_inputs: list | None = None, + ) -> None: + """Snapshot layer state and write checkpoint to disk in one step. + + Args: + layer_idx: Index of the layer just calibrated. + layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). + model: The full model (needed for ``enable_weight_access_and_writeback``). + layers: The decoder layer list (to read ``output_meta``). + next_layer_inputs: Inputs for the next layer (``None`` for the final layer). + """ + from modelopt.torch.quantization.conversion import quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + _cpu = torch.device("cpu") + with enable_weight_access_and_writeback(layer, model): + weights = _move_to_device(layer.state_dict(), _cpu) + qstate = _move_to_device(quantizer_state(layer), _cpu) + + output_meta = getattr(layer._layerwise_calib, "output_meta", None) + if output_meta is None: + # Placeholder for the last layer: output_meta is never used for skip mode + # since there is no subsequent layer that needs a correctly shaped dummy output. + output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) + + _save_layer( + self.checkpoint_dir, + layer_idx, + weights, + qstate, + _move_to_device(output_meta, _cpu), + _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, + self.num_layers, + ) + suffix = " (final)" if next_layer_inputs is None else "" + print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 1e9a7fbbbd..01cb3abe88 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -601,16 +601,28 @@ def _forward_loop( dataloader: DataLoader containing the batched input data allowed_non_tensor_keys: Set of key names whose values may be non-tensor types """ - with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward - max_working_batch_size = None # Initialize max working batch size as None + # Disable KV caching during calibration — it is unnecessary overhead and causes + # correctness issues with hybrid Mamba/attention models whose cache state is mutated + # in-place (e.g., NemotronH). + config = getattr(model, "config", None) + prev_use_cache = getattr(config, "use_cache", None) + if config is not None and prev_use_cache is not None: + config.use_cache = False - for _, data in enumerate(tqdm(dataloader)): - # Process batch and update max working batch size - max_working_batch_size = _process_batch( - data, infer_method, max_working_batch_size, allowed_non_tensor_keys - ) + try: + with torch.no_grad(): + is_enc_dec = model_type_is_enc_dec(model) + infer_method = model.generate if is_enc_dec else model.forward + max_working_batch_size = None # Initialize max working batch size as None + + for _, data in enumerate(tqdm(dataloader)): + # Process batch and update max working batch size + max_working_batch_size = _process_batch( + data, infer_method, max_working_batch_size, allowed_non_tensor_keys + ) + finally: + if config is not None and prev_use_cache is not None: + config.use_cache = prev_use_cache def create_forward_loop( diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..440ca522d1 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -90,12 +90,43 @@ def is_parallel(model: nn.Module) -> bool: return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) +def _get_execution_device_from_hook(module: nn.Module) -> torch.device | None: + """Extract the execution device from an accelerate ``_hf_hook``, if present. + + Handles both ``AlignDevicesHook`` (direct) and ``SequentialHook`` (which + may wrap one or more ``AlignDevicesHook`` instances). Returns ``None`` + when no hook is found or the hook carries no ``execution_device``. + """ + hook = getattr(module, "_hf_hook", None) + if hook is None: + return None + + dev = getattr(hook, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + for h in getattr(hook, "hooks", ()): + dev = getattr(h, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + return None + + def get_module_device(module: nn.Module) -> torch.device: - """Get the device of a PyTorch module.""" + """Get the device of a PyTorch module. + + For modules managed by accelerate (``_hf_hook``), returns the hook's + ``execution_device`` which is the authoritative device even when + parameters are offloaded to CPU/meta between forward calls. + """ + hook_device = _get_execution_device_from_hook(module) + if hook_device is not None: + return hook_device + try: return next(module.parameters()).device except StopIteration: - # For modules without parameters return torch.device("cpu") @@ -590,21 +621,29 @@ def get_unwrapped_name(name: str, model: nn.Module | None = None) -> str: @contextmanager def temporarily_remove_accelerate_hook(module): - """Context manager to temporarily remove accelerate hook from a module.""" - accelerate_hook = None - if hasattr(module, "_hf_hook"): - # A module with forward method patched by accelerate - from accelerate.hooks import add_hook_to_module, remove_hook_from_module + """Context manager to temporarily bypass the accelerate hook on a module. + + Swaps ``module.forward`` with the pre-hook forward (``_old_forward``) so + that code inside the context sees the un-hooked forward. On exit the + hook-wrapped forward is restored and ``_old_forward`` is updated to + reflect any changes made inside the context. - accelerate_hook = module._hf_hook - remove_hook_from_module(module) + This avoids ``remove_hook_from_module`` / ``add_hook_to_module`` entirely, + sidestepping ``init_hook`` which would call ``set_module_tensor_to_device`` + and fail when newly-added quantizer modules have weights on the meta device. + """ + hooked_forward = None + cached_old_forward = None + if hasattr(module, "_hf_hook"): + hooked_forward = module.forward + cached_old_forward = module._old_forward + module.forward = cached_old_forward try: yield finally: - if accelerate_hook is not None: - from accelerate.hooks import add_hook_to_module - - add_hook_to_module(module, accelerate_hook) + if hooked_forward is not None: + module._old_forward = module.forward + module.forward = hooked_forward def bind_forward_method( diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml index 6fe4a8c3d1..862929ef34 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml @@ -15,7 +15,7 @@ metadata: recipe_type: ptq - description: NVFP4 MLP/MoE weight only (W4A16), FP8 KV cache, max calibration. + description: NVFP4 W4A4, FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml index a62051b659..99098c9d6d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml @@ -15,11 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 weight and activation (W4A4), gptq sequential calibration. + description: NVFP4 weight and activation (W4A4), gptq layerwise calibration. quantize: algorithm: method: gptq - use_sequential: true + layerwise: true + layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - quantizer_name: '*' enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index cc332733a0..4274e40b62 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -15,9 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration. + description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. quantize: - algorithm: max + algorithm: + method: max + layerwise: true + layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - quantizer_name: '*' enable: false diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 8ed1039e59..809268a635 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import json +import os +import shutil + import pytest import torch from _test_utils.torch.quantization.quantize_common import INT4_AWQ_CLIP_CFG @@ -25,6 +30,7 @@ enable_weight_access_and_writeback, is_quantized_linear, ) +from modelopt.torch.quantization.utils.layerwise_calib import _layer_dir @pytest.mark.parametrize( @@ -73,3 +79,420 @@ def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) assert torch.allclose(output_ref.logits, output_test.logits) + + +def _make_cpu_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to CPU via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +def _make_layerwise_cfg(base_cfg): + """Add layerwise=True to a quant config's algorithm field.""" + cfg = copy.deepcopy(base_cfg) + algo = cfg.get("algorithm", "max") + if isinstance(algo, str): + cfg["algorithm"] = {"method": algo, "layerwise": True} + else: + algo["layerwise"] = True + return cfg + + +def _make_layerwise_checkpoint_cfg(base_cfg, checkpoint_dir): + """Add layerwise=True and layerwise_checkpoint_dir to a quant config's algorithm field.""" + cfg = _make_layerwise_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize( + "quant_cfg", + [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], + ids=["int4_awq", "nvfp4"], +) +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_cpu_offloaded(tmp_path, quant_cfg, use_checkpoint): + """Layerwise calibration on CPU-offloaded model matches GPU-only reference.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + if use_checkpoint: + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == num_layers - 1 + assert manifest["num_layers"] == num_layers + + +@pytest.mark.parametrize( + "quant_cfg", + [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], + ids=["int4_awq", "nvfp4"], +) +def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path, quant_cfg): + """Resume from a partial checkpoint on a CPU-offloaded model matches a full run.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 by truncating the manifest and removing later layers + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from a fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint should produce identical output to full run" + ) + + +def test_sequential_checkpoint_resume_multi_offload(tmp_path): + """Resume with multiple layers offloaded exercises per-layer device resolution.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(mtq.INT4_AWQ_CFG, ckpt_dir) + + def _make_multi_offload_model(): + with init_empty_weights(): + m = AutoModelForCausalLM.from_config(config) + dmap = { + n: 0 + for n, mod in m.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + dmap["model.layers.0"] = "cpu" + dmap["model.layers.1"] = "cpu" + return load_checkpoint_and_dispatch(m, tiny_llama_dir, device_map=dmap) + + # Full reference run + model_ref = _make_multi_offload_model() + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh model with same offload layout + model_resumed = _make_multi_offload_model() + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint with multi-offload should match full run" + ) + + +def _make_gptq_sequential_cfg(base_cfg): + """Create a sequential GPTQ config from a base quantization config.""" + cfg = copy.deepcopy(base_cfg) + cfg["algorithm"] = {"method": "gptq", "layerwise": True} + return cfg + + +def _make_gptq_sequential_checkpoint_cfg(base_cfg, checkpoint_dir): + """Create a sequential GPTQ config with checkpoint dir.""" + cfg = _make_gptq_sequential_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_sequential_gptq_cpu_offloaded(tmp_path, use_checkpoint): + """Sequential GPTQ (weight-modifying) on CPU-offloaded model matches GPU-only reference.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_DEFAULT_CFG, ckpt_dir) + else: + seq_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_DEFAULT_CFG) + + # Reference: GPU-only model + ref_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_DEFAULT_CFG) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + model, _, _, _ = _make_cpu_offloaded_model(tmp_path / "offloaded", num_hidden_layers=num_layers) + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_sequential_gptq_checkpoint_resume_cpu_offloaded(tmp_path): + """GPTQ checkpoint resume with CPU offloading restores modified weights correctly.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_ckpt_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_DEFAULT_CFG, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "GPTQ resumed checkpoint should produce identical output to full run" + ) + + +class _TupleReturningBlock(torch.nn.Module): + """Decoder layer that returns a tuple, mimicking HuggingFace decoder layers.""" + + def __init__(self, dim=16): + super().__init__() + self.linear = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return (self.linear(x), None) + + +class _TupleUnpackingModel(torch.nn.Module): + """Parent model that unpacks layer outputs as tuples.""" + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = torch.nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x, _ = layer(x) + return x + + +def test_skip_dummy_has_no_hf_hook(monkeypatch): + """Dummies must not carry _hf_hook from the original layer.""" + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _SkipLayer, + ) + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + model = _TupleUnpackingModel(n_layers=4, dim=16) + data = [torch.randn(2, 16)] + + for layer in model.layers: + hook = AlignDevicesHook(execution_device=torch.device("cpu")) + add_hook_to_module(layer, hook) + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + for i in range(2): + dummy = model.layers[i] + assert isinstance(dummy, _SkipLayer) + assert not hasattr(dummy, "_hf_hook"), f"Dummy at {i} should not have _hf_hook" + finally: + collector._unpatch_all_layers() + + +def test_persistent_materialization_cpu_offloaded(tmp_path): + """persistent_materialization keeps CPU-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_cpu_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index 4889b6dc8c..c5584ece5c 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -128,3 +128,136 @@ def test_fsdp_simple_linear(dist_workers): ) def test_nested_fsdp2_backward(quant_cfg, dist_workers): dist_workers.run(partial(_test_nested_fsdp2_backward, quant_cfg=quant_cfg)) + + +class _DecoderBlock(nn.Module): + """Minimal decoder block for FSDP2 sequential tests.""" + + def __init__(self, dim=32): + super().__init__() + self.attn = nn.Linear(dim, dim, bias=False) + self.ffn = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.ReLU(), nn.Linear(dim, dim, bias=False) + ) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = x + self.attn(self.norm(x)) + x = x + self.ffn(x) + return x + + +class _SimpleTransformerModel(nn.Module): + """Model with ``model.layers`` for layerwise calibration discovery.""" + + def __init__(self, n_layers=3, dim=32): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _test_layerwise_calibrate_fsdp2(rank, size): + """Layerwise calibration on FSDP2-wrapped model matches non-FSDP reference.""" + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + dim = 32 + torch.manual_seed(1) + model = _SimpleTransformerModel(n_layers=3, dim=dim).cuda() + inputs = torch.randn(2, 2, dim).cuda() + synchronize_state_dict(model) + + # Register discoverer for our simple model + old_support = LayerActivationCollector._decoder_layer_support[:] + LayerActivationCollector._decoder_layer_support = [ + ( + lambda m: hasattr(m, "layers") and isinstance(m.layers, nn.ModuleList), + lambda m: m.layers, + ), + *old_support, + ] + + try: + # Reference: non-FSDP layerwise calibration + ref_model = copy.deepcopy(model) + seq_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + seq_cfg["algorithm"] = {"method": "max", "layerwise": True} + mtq.quantize(ref_model, seq_cfg, lambda m: m(inputs)) + output_ref = ref_model(inputs) + + # Test: FSDP2-wrapped layerwise calibration + for layer in model.layers: + fully_shard(layer) + model = fully_shard(model) + mtq.quantize(model, seq_cfg, lambda m: m(inputs)) + output_test = model(inputs) + + assert torch.allclose(output_ref, output_test) + finally: + LayerActivationCollector._decoder_layer_support = old_support + + +def test_layerwise_calibrate_fsdp2(dist_workers): + dist_workers.run(_test_layerwise_calibrate_fsdp2) + + +def _test_persistent_materialization(rank, size): + """persistent_materialization keeps weights accessible and writes back modifications.""" + from torch.distributed.tensor import DTensor + + from modelopt.torch.quantization.utils import ( + enable_weight_access_and_writeback, + persistent_materialization, + ) + + dim = 32 + torch.manual_seed(1) + model = nn.Sequential( + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + ).cuda(rank) + synchronize_state_dict(model) + + fully_shard(model[0]) + fully_shard(model[1]) + model = fully_shard(model) + + layer = model[0] + inputs = torch.randn(2, dim).cuda(rank) + + # Warmup forward to trigger FSDP2's lazy_init (mirrors real usage where + # layerwise_calibrate always runs get_first_layer_inputs first). + model(inputs) + + # Save reference weight (gathered) + with enable_weight_access_and_writeback(layer[0], model): + ref_weight = layer[0].weight.clone() + + # Verify sharded before context + assert isinstance(next(iter(layer.parameters())), DTensor) + + with persistent_materialization(layer): + # Params are local tensors (not DTensors) + assert not isinstance(layer[0].weight, DTensor) + assert layer[0].weight.device.type == "cuda" + + # Run multiple forward passes (FSDP hooks fire, unshard/reshard are no-ops) + for _ in range(3): + layer(inputs) + + # Modify a weight + layer[0].weight.data.add_(1.0) + + # After context: params restored to DTensors (sharded) + assert isinstance(next(iter(layer.parameters())), DTensor) + + # Verify modification persisted + with enable_weight_access_and_writeback(layer[0], model): + assert torch.allclose(layer[0].weight, ref_weight + 1.0) + + +def test_persistent_materialization(dist_workers): + dist_workers.run(_test_persistent_materialization) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index d183855abb..2d5f9d6d70 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -219,7 +219,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} + quant_cfg["algorithm"] = {"method": "gptq", "layerwise": True} calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", tokenizer=tokenizer, diff --git a/tests/gpu/torch/quantization/test_sequential_calibrate.py b/tests/gpu/torch/quantization/test_layerwise_calibrate.py similarity index 90% rename from tests/gpu/torch/quantization/test_sequential_calibrate.py rename to tests/gpu/torch/quantization/test_layerwise_calibrate.py index ba71e896c7..d38b82f46f 100644 --- a/tests/gpu/torch/quantization/test_sequential_calibrate.py +++ b/tests/gpu/torch/quantization/test_layerwise_calibrate.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for sequential_calibrate and LayerActivationCollector.""" +"""Integration tests for layerwise_calibrate and LayerActivationCollector.""" import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _DecoderBlock(nn.Module): @@ -101,7 +101,7 @@ def _register_test_discoverer(monkeypatch): ) -def test_seq_calib_func_called_per_layer(monkeypatch): +def test_layerwise_calib_func_called_per_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=4) call_count = [0] @@ -109,7 +109,7 @@ def test_seq_calib_func_called_per_layer(monkeypatch): def counting_calib(layer, forward_loop, **kwargs): call_count[0] += 1 - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=counting_calib, @@ -118,7 +118,7 @@ def counting_calib(layer, forward_loop, **kwargs): assert call_count[0] == 4 -def test_seq_calib_func_receives_correct_layer(monkeypatch): +def test_layerwise_calib_func_receives_correct_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) called_layers = [] @@ -126,7 +126,7 @@ def test_seq_calib_func_receives_correct_layer(monkeypatch): def track_layers(layer, forward_loop, **kwargs): called_layers.append(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=track_layers, @@ -136,7 +136,7 @@ def track_layers(layer, forward_loop, **kwargs): assert called_layers[i] is layer -def test_seq_calib_kwargs_forwarded(monkeypatch): +def test_layerwise_calib_kwargs_forwarded(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) received_kwargs = [] @@ -144,7 +144,7 @@ def test_seq_calib_kwargs_forwarded(monkeypatch): def capture_kwargs(layer, forward_loop, **kwargs): received_kwargs.append(kwargs) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=capture_kwargs, @@ -158,7 +158,7 @@ def capture_kwargs(layer, forward_loop, **kwargs): assert kw["method"] == "max" -def test_seq_calib_layer_forward_loop_runs_all_batches(monkeypatch): +def test_layerwise_calib_layer_forward_loop_runs_all_batches(monkeypatch): """The per-layer forward loop passed to calib_func should replay all batches.""" _register_test_discoverer(monkeypatch) n_batches = 5 @@ -178,7 +178,7 @@ def counting_forward(*args, **kw): layer.forward = orig_forward batch_counts.append(counter["n"]) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=count_batches, @@ -188,13 +188,13 @@ def counting_forward(*args, **kw): assert count == n_batches -def test_seq_calib_does_not_alter_weights(monkeypatch): - """sequential_calibrate itself should not modify model weights.""" +def test_layerwise_calib_does_not_alter_weights(monkeypatch): + """layerwise_calibrate itself should not modify model weights.""" _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) weights_before = {n: p.clone() for n, p in model.named_parameters()} - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=lambda layer, forward_loop, **kw: None, @@ -204,7 +204,7 @@ def test_seq_calib_does_not_alter_weights(monkeypatch): assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" -def test_seq_calib_activations_update_across_layers(monkeypatch): +def test_layerwise_calib_activations_update_across_layers(monkeypatch): """Subsequent layers should see activations transformed by prior layers.""" _register_test_discoverer(monkeypatch) torch.manual_seed(0) @@ -228,7 +228,7 @@ def capture_forward(*args, **kw): layer_idx = list(model.layers).index(layer) layer_inputs_record[layer_idx] = activations - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: [m(t) for t in tokens], calib_func=record_inputs, @@ -240,7 +240,7 @@ def capture_forward(*args, **kw): def test_mode_transitions_across_calibration_steps(monkeypatch): - """Verify layer modes after each sequential calibration step. + """Verify layer modes after each layerwise calibration step. After get_input_activations(layers[i]) returns, the current layer is reset to 'original'. Layers further back are left in 'run' (just calibrated) or @@ -259,7 +259,7 @@ def forward_loop(m): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector.get_input_activations(model.layers[0], forward_loop) assert modes() == ["original", "original", "original", "original", "original"] @@ -316,7 +316,7 @@ def weight_doubling_calib(layer, layer_forward_loop, **kwargs): layer.weight.mul_(2.0) layer_forward_loop(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=weight_doubling_calib, diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 692ab07d4a..ae638c42ee 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -35,7 +35,7 @@ get_homogeneous_hf_decoder_layers, is_homogeneous_hf_model, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector pytest.importorskip("transformers") diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index b3c372eb33..d2e6fdd03e 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -27,7 +27,7 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, - sequential_calibrate, + layerwise_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -379,7 +379,7 @@ def test_svdquant_lora_weights(): assert lora_residual.shape == module.weight.shape -def test_sequential_calibrate_support_gate(): +def test_layerwise_calibrate_support_gate(): class _UnsupportedModel(nn.Module): def __init__(self): super().__init__() @@ -392,17 +392,17 @@ def forward(self, x): with ( torch.no_grad(), - pytest.raises(ValueError, match="Sequential calibration requires a model"), + pytest.raises(ValueError, match="Layerwise calibration requires a model"), ): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 4)), calib_func=lambda layer, loop: loop(layer), ) -def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +def test_layerwise_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float, bias: float): @@ -463,7 +463,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert forward_loop_calls == len(model.layers) assert len(observed_layer_inputs) == len(model.layers) @@ -482,9 +482,9 @@ def _pre_hook(_module, args): assert torch.allclose(observed, expected) -def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): +def test_layerwise_calibrate_handles_inter_layer_logic(monkeypatch): """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float): @@ -537,7 +537,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert len(observed_layer_inputs) == 3 # Layer 0 gets raw batch diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py similarity index 78% rename from tests/unit/torch/quantization/test_sequential_calibrate.py rename to tests/unit/torch/quantization/test_layerwise_calibrate.py index 14c1903de2..6596c1b4b1 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for sequential_calibrate and LayerActivationCollector.""" +"""Unit tests for layerwise_calibrate and LayerActivationCollector.""" from collections import deque @@ -21,8 +21,8 @@ import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer class _DecoderBlock(nn.Module): @@ -60,7 +60,7 @@ def forward(self, x, **kwargs): class _FlatMLP(nn.Module): - """No decoder-layer structure -- should be rejected by sequential_calibrate.""" + """No decoder-layer structure -- should be rejected by layerwise_calibrate.""" def __init__(self, dim=16): super().__init__() @@ -180,7 +180,7 @@ def forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") assert not hasattr(model.layers[0], "_original_forward") @@ -201,38 +201,38 @@ def bad_forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") -# sequential_calibrate tests -def test_seq_calib_raises_on_none_forward_loop(monkeypatch): +# layerwise_calibrate tests +def test_layerwise_calib_raises_on_none_forward_loop(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) with pytest.raises(ValueError, match="forward_loop must not be None"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=None, calib_func=lambda *a, **kw: None, ) -def test_seq_calib_raises_on_unrecognized_model(): +def test_layerwise_calib_raises_on_unrecognized_model(): model = _FlatMLP() with pytest.raises(ValueError, match="Could not find transformer layers"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 16)), calib_func=lambda *a, **kw: None, ) -def test_seq_calib_empty_forward_loop_raises(monkeypatch): - """If forward_loop feeds no data, sequential_calibrate raises RuntimeError.""" +def test_layerwise_calib_empty_forward_loop_raises(monkeypatch): + """If forward_loop feeds no data, layerwise_calibrate raises RuntimeError.""" _register_test_discoverer(monkeypatch) model = _SimpleTransformerModel(n_layers=2, dim=16) with pytest.raises(RuntimeError, match="collected no inputs during forward_loop"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: None, calib_func=lambda *a, **kw: None, @@ -344,11 +344,11 @@ def forward_loop(m): try: # Layer 0 starts as capture — no output_meta yet collector.get_input_activations(model.layers[0], forward_loop) - assert model.layers[0]._seq_calib.output_meta is None + assert model.layers[0]._layerwise_calib.output_meta is None # Calibrating layer 1 puts layer 0 into run, which sets output_meta collector.get_input_activations(model.layers[1], forward_loop) - meta = model.layers[0]._seq_calib.output_meta + meta = model.layers[0]._layerwise_calib.output_meta assert meta is not None assert meta[0] == "tuple", "Tuple-returning layer should produce tuple metadata" finally: @@ -375,11 +375,11 @@ def forward_loop(m): # Before calibrating layer 2, layer 1 transitions to run. # Its cached_inputs should be populated from collected_inputs. collector._set_layer_states(2) - assert len(model.layers[1]._seq_calib.cached_inputs) == n_batches + assert len(model.layers[1]._layerwise_calib.cached_inputs) == n_batches # After the forward loop, all cached inputs should be consumed forward_loop(model) - assert len(model.layers[1]._seq_calib.cached_inputs) == 0 + assert len(model.layers[1]._layerwise_calib.cached_inputs) == 0 finally: collector._unpatch_all_layers() @@ -399,24 +399,24 @@ def test_set_layer_states_transitions(monkeypatch): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector._set_layer_states(0) assert modes() == ["capture", "original", "original", "original", "original"] - model.layers[0]._seq_calib.collected_inputs = [fake_inp] + model.layers[0]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(1) assert modes() == ["run", "capture", "original", "original", "original"] - model.layers[1]._seq_calib.collected_inputs = [fake_inp] + model.layers[1]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(2) assert modes() == ["skip", "run", "capture", "original", "original"] - model.layers[2]._seq_calib.collected_inputs = [fake_inp] + model.layers[2]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(3) assert modes() == ["skip", "skip", "run", "capture", "original"] - model.layers[3]._seq_calib.collected_inputs = [fake_inp] + model.layers[3]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(4) assert modes() == ["skip", "skip", "skip", "run", "capture"] finally: @@ -446,8 +446,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector = LayerActivationCollector(model) collector._patch_all_layers() try: - model.layers[0]._seq_calib.mode = "run" - model.layers[0]._seq_calib.cached_inputs = deque() + model.layers[0]._layerwise_calib.mode = "run" + model.layers[0]._layerwise_calib.cached_inputs = deque() with pytest.raises(AssertionError, match="no cached inputs to replay"): model(torch.randn(2, 16)) @@ -455,8 +455,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector._unpatch_all_layers() -def test_cleanup_removes_seq_calib_attr(monkeypatch): - """After unpatch, no layer should have the _seq_calib attribute.""" +def test_cleanup_removes_layerwise_calib_attr(monkeypatch): + """After unpatch, no layer should have the _layerwise_calib attribute.""" _register_test_discoverer(monkeypatch) model = _TupleUnpackingModel(n_layers=3, dim=16) data = [torch.randn(2, 16)] @@ -472,7 +472,9 @@ def forward_loop(m): collector._unpatch_all_layers() for i, layer in enumerate(model.layers): - assert not hasattr(layer, "_seq_calib"), f"Layer {i} still has _seq_calib after cleanup" + assert not hasattr(layer, "_layerwise_calib"), ( + f"Layer {i} still has _layerwise_calib after cleanup" + ) assert not hasattr(layer, "_original_forward"), ( f"Layer {i} still has _original_forward after cleanup" ) @@ -517,15 +519,17 @@ def forward_loop(m): for d in data: m(d) + originals = list(model.layers) collector = LayerActivationCollector(model) collector._patch_all_layers() try: - for layer in model.layers: + for layer in originals: collector.get_input_activations(layer, forward_loop) - # After full calibration, layers 0 and 1 have been through 'run' and have output_meta - meta_0 = model.layers[0]._seq_calib.output_meta - meta_1 = model.layers[1]._seq_calib.output_meta + # After full calibration, layers 0 and 1 have been through 'run' and have output_meta. + # Access via originals since skip-position entries are now _SkipLayer dummies. + meta_0 = originals[0]._layerwise_calib.output_meta + meta_1 = originals[1]._layerwise_calib.output_meta assert meta_0 is not None assert meta_1 is not None # SmallBlock returns 3-element tuple, BigBlock returns 1-element tuple @@ -533,3 +537,59 @@ def forward_loop(m): assert len(meta_1[1]) == 1 finally: collector._unpatch_all_layers() + + +# --------------------------------------------------------------------------- +# _SkipLayer swap / restore tests +# --------------------------------------------------------------------------- + + +def test_skip_layers_replaced_with_dummy(monkeypatch): + """After calibrating enough layers, skip-position entries must be _SkipLayer with no params.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(2)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + # Layers 0..2 should be dummies (swapped when calibrating layers 2..4) + for i in range(3): + assert isinstance(model.layers[i], _SkipLayer), f"Layer {i} should be _SkipLayer" + assert list(model.layers[i].parameters()) == [], ( + f"Layer {i} dummy should have no params" + ) + # Layers 3 (run) and 4 (original) remain real + for i in range(3, 5): + assert not isinstance(model.layers[i], _SkipLayer), f"Layer {i} should still be real" + finally: + collector._unpatch_all_layers() + + +def test_cleanup_restores_original_layers(monkeypatch): + """After _unpatch_all_layers, all ModuleList entries must be the original modules.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + originals = list(model.layers) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + for layer in originals: + collector.get_input_activations(layer, forward_loop) + collector._unpatch_all_layers() + + for i, orig in enumerate(originals): + assert model.layers[i] is orig, f"Layer {i} not restored to original after cleanup" + assert not hasattr(orig, "_layerwise_calib"), f"Layer {i} still has _layerwise_calib" diff --git a/tests/unit/torch/quantization/test_sequential_checkpoint.py b/tests/unit/torch/quantization/test_sequential_checkpoint.py new file mode 100644 index 0000000000..0e592a68c7 --- /dev/null +++ b/tests/unit/torch/quantization/test_sequential_checkpoint.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for layerwise calibration checkpoint save/resume.""" + +import json +import os +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector +from modelopt.torch.utils.network import get_module_device + + +class _DecoderBlock(nn.Module): + def __init__(self, dim=16): + super().__init__() + self.linear = nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return self.linear(x) + + +class _SimpleTransformerModel(nn.Module): + def __init__(self, n_layers=3, dim=16): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + self.embed = nn.Embedding(32, dim) + + def forward(self, x, **kwargs): + x = self.embed(x) + for layer in self.layers: + x = layer(x) + return x + + +def _register_test_discoverer(monkeypatch): + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def _dummy_calib_func(layer, forward_loop, **kwargs): + """Scale all weights by 0.5 to produce a visible, deterministic change.""" + forward_loop(layer) + with torch.no_grad(): + for p in layer.parameters(): + p.mul_(0.5) + + +def _make_model_and_forward(n_layers=3, dim=16, seed=42): + torch.manual_seed(seed) + model = _SimpleTransformerModel(n_layers=n_layers, dim=dim) + tokens = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def forward_loop(m): + for t in tokens: + m(t) + + return model, forward_loop + + +def test_full_run_creates_checkpoints(monkeypatch, tmp_path): + """layerwise_calibrate with checkpoint_dir creates correct layer dirs and manifest.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + ckpt_dir = str(tmp_path / "ckpt") + + layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == 2 + assert manifest["num_layers"] == 3 + + for i in range(3): + layer_dir = os.path.join(ckpt_dir, f"layer_{i:04d}") + assert os.path.isdir(layer_dir) + assert os.path.isfile(os.path.join(layer_dir, "weights.pt")) + assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt")) + assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt")) + # All layers except the last should have next_inputs + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt")) + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt")) + + +def test_resume_matches_full_run(monkeypatch, tmp_path): + """Resume from a truncated checkpoint produces the same final weights as a full run.""" + _register_test_discoverer(monkeypatch) + ckpt_dir = str(tmp_path / "ckpt") + + # Full reference run + ref_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()} + + # Simulate crash after layer 0: truncate manifest + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": 0, "num_layers": 3}, f) + + # Resume from a fresh model + resumed_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + for name, ref_param in ref_weights.items(): + resumed_param = dict(resumed_model.named_parameters())[name] + assert torch.allclose(ref_param, resumed_param, atol=1e-6), ( + f"Parameter {name} diverged after resume" + ) + + +def test_no_checkpoint_unchanged(monkeypatch): + """Without checkpoint_dir, calibration still works and modifies parameters.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + original_weights = {n: p.clone() for n, p in model.named_parameters()} + + layerwise_calibrate(model, forward_loop, _dummy_calib_func) + + changed = False + for name, param in model.named_parameters(): + if not torch.allclose(original_weights[name], param): + changed = True + break + assert changed, "Expected calibration to modify at least one parameter" + + +# --------------------------------------------------------------------------- +# get_module_device tests +# --------------------------------------------------------------------------- + + +def test_get_module_device_no_hook(): + """Falls back to parameter device when no _hf_hook is present.""" + layer = nn.Linear(4, 4) + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_with_direct_hook(): + """Returns execution_device from a direct AlignDevicesHook-style hook.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace(execution_device=torch.device("cuda:0")) + assert get_module_device(layer) == torch.device("cuda:0") + + +def test_get_module_device_with_sequential_hook(): + """Returns execution_device from an AlignDevicesHook wrapped in SequentialHook.""" + layer = nn.Linear(4, 4) + inner_hook = SimpleNamespace(execution_device=torch.device("cuda:1")) + layer._hf_hook = SimpleNamespace(hooks=[inner_hook]) + assert get_module_device(layer) == torch.device("cuda:1") + + +def test_get_module_device_hook_without_execution_device(): + """Falls back to parameters when hook has no execution_device.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace() + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_parameterless_module(): + """Returns cpu for a module with no parameters and no hook.""" + module = nn.Module() + assert get_module_device(module) == torch.device("cpu") diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index 92fe1345f9..73d3423ba5 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -20,7 +20,7 @@ convert_quantization_axis_to_reduce_axis, reduce_block_amax, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector @pytest.mark.parametrize( From d14ccbbecd1849ab1801a3240f9a2cab9b64092b Mon Sep 17 00:00:00 2001 From: realAsma Date: Wed, 15 Apr 2026 22:55:30 +0000 Subject: [PATCH 2/6] Fix json.dumps sort_keys error with mixed int/str keys in quant_cfg The block_sizes config has mixed int and str keys which causes TypeError when sort_keys=True is used in json.dumps for checkpoint dir hashing. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma --- examples/llm_ptq/example_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 581005de43..90532efe38 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -880,9 +880,7 @@ def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: else: name = Path(name).name - config_hash = hashlib.sha256( - json.dumps(quant_cfg, sort_keys=True, default=str).encode() - ).hexdigest()[:8] + config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8] quant_cfg = copy.deepcopy(quant_cfg) quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( From fbba7d76c0da1d5bdaef3818c5ba41b5f3c10b50 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 16 Apr 2026 08:36:28 +0000 Subject: [PATCH 3/6] Add disk offloading support to enable_weight_access_and_writeback Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma --- .../torch/quantization/plugins/accelerate.py | 30 +-- .../torch/quantization/utils/core_utils.py | 2 +- .../plugins/test_accelerate_gpu.py | 175 ++++++++++++++++++ 3 files changed, 193 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index bbbb75930e..1c600cf83a 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -31,19 +31,13 @@ __all__ = ["init_quantized_weights"] -def _get_cpu_offload_hook(hook): +def _get_offload_hook(hook): if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None: assert len(hook.weights_map) > 0 - if isinstance(hook.weights_map, PrefixedDataset) and not any( - k.startswith(hook.weights_map.prefix) for k in hook.weights_map.dataset.state_dict - ): - raise NotImplementedError( - "This layer could be offloaded to disk. We don't support this yet." - ) return hook elif isinstance(hook, SequentialHook): for h in hook.hooks: - align_hook = _get_cpu_offload_hook(h) + align_hook = _get_offload_hook(h) if align_hook is not None: return align_hook return None @@ -62,13 +56,23 @@ def _writeback_params_to_weights_map(module, align_hook): key = name if key in w_map: w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype) + elif ( + isinstance(align_hook.weights_map, PrefixedDataset) + and hasattr(align_hook.weights_map.dataset, "index") + and key in align_hook.weights_map.dataset.index + ): + # Disk-offloaded weight: promote into state_dict so the next + # pre_forward picks up the modified tensor instead of the stale + # on-disk version. OffloadedWeightsLoader.__getitem__ gives + # state_dict priority over index, so this is sufficient. + w_map[key] = tensor.detach().cpu() @contextmanager def weight_access_and_writeback_context(module): """Context manager for weight access and writeback for modules managed by accelerate. - Handles two cases: + Handles CPU-offloaded and disk-offloaded models, in two layout cases: 1. **Single-module**: the module's own ``_hf_hook`` is an offload hook. 2. **Sub-module**: the module's hook is non-offloading, but its children have offload hooks (common with ``SequentialHook`` on sub-modules placed by @@ -79,19 +83,19 @@ def weight_access_and_writeback_context(module): used as a pure writeback after weight-modifying algorithms. """ assert hasattr(module, "_hf_hook") - align_hook = _get_cpu_offload_hook(module._hf_hook) + align_hook = _get_offload_hook(module._hf_hook) if align_hook: # Guard: the sub-module branch below is not reached when the parent has # an offload hook. Assert that no children also carry offload hooks, # which would require a combined writeback strategy. if any( - _get_cpu_offload_hook(mod._hf_hook) + _get_offload_hook(mod._hf_hook) for mod in module.modules() if mod is not module and hasattr(mod, "_hf_hook") ): raise RuntimeError( - "Both the module and one of its sub-modules have CPU-offload hooks. " + "Both the module and one of its sub-modules have offload hooks. " "weight_access_and_writeback_context does not support this layout yet." ) align_hook.pre_forward(module) @@ -108,7 +112,7 @@ def weight_access_and_writeback_context(module): for mod in module.modules(): if mod is module or not hasattr(mod, "_hf_hook"): continue - hook = _get_cpu_offload_hook(mod._hf_hook) + hook = _get_offload_hook(mod._hf_hook) if hook is None: continue # Only call pre_forward if weights need materializing; already-materialized diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a626a51d3..29661e18f5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -494,7 +494,7 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or - HF accelerate CPU off-loaded models. + HF accelerate offloaded models (CPU or disk). Args: module: The module to access weights for. diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 809268a635..618b35a5f5 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -496,3 +496,178 @@ def test_persistent_materialization_cpu_offloaded(tmp_path): # Verify weight modification persisted through writeback with enable_weight_access_and_writeback(offloaded_layer, model): assert torch.allclose(linear.weight, ref_weight + 1.0) + + +def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to disk via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "disk" + + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +@pytest.mark.parametrize( + "quant_cfg", + [ + mtq.INT4_AWQ_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + INT4_AWQ_CLIP_CFG, + mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + mtq.INT8_DEFAULT_CFG, + ], +) +def test_disk_offloaded_tinyllama(tmp_path, quant_cfg): + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) + + config = AutoConfig.from_pretrained(tiny_llama_dir) + + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + inputs = torch.randint(0, model_ref.config.vocab_size, (1, 4)).cuda() + + mtq.quantize(model_ref, quant_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "disk" + + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + assert all(p.device == torch.device("meta") for p in model.model.layers[0].parameters()) + + mtq.quantize(model, quant_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_persistent_materialization_disk_offloaded(tmp_path): + """persistent_materialization keeps disk-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_disk_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) + + +@pytest.mark.parametrize( + "quant_cfg", + [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], + ids=["int4_awq", "nvfp4"], +) +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_disk_offloaded(tmp_path, quant_cfg, use_checkpoint): + """Layerwise calibration on disk-offloaded model matches GPU-only reference.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: disk-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "disk" + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) From 67dbeafa84d1640d9b0c5772bc6ed40ac4404f50 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 16 Apr 2026 13:15:14 +0000 Subject: [PATCH 4/6] Add memory-efficient inplace fakequant export and disk offload support Add inplace_mem_efficient mode to export_hf_vllm_fq_checkpoint that applies fake-quant one decoder layer at a time via enable_weight_access_and_writeback, avoiding full state dict materialization. Refactor fakequant logic into _fakequant_module_weights helper shared by both paths. Add disk offloading support to enable_weight_access_and_writeback alongside the existing CPU offload path. Also fix: checkpoint resume restoring quantizer state on meta-device tensors, _SkipLayer leaking _hf_hook, disk-offload test device_map, and simplify weight_access_and_writeback_context into a single loop. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: realAsma --- examples/llm_ptq/hf_ptq.py | 4 +- .../torch/export/plugins/vllm_fakequant_hf.py | 162 ++++++++++++------ .../torch/quantization/plugins/accelerate.py | 39 +---- .../quantization/utils/layerwise_calib.py | 32 ++-- .../plugins/test_accelerate_gpu.py | 100 ++++------- 5 files changed, 174 insertions(+), 163 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index c03e8aab3b..c405de51e7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -762,7 +762,9 @@ def export_quantized( # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: - export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + export_hf_vllm_fq_checkpoint( + full_model, export_dir=export_path, inplace_mem_efficient=True + ) else: mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( full_model, args.pyt_ckpt_path diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..44b2d55ba8 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -24,6 +24,8 @@ from modelopt.torch.quantization.conversion import quantizer_state from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer from modelopt.torch.quantization.utils import get_quantizer_state_dict +from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector from modelopt.torch.utils import get_unwrapped_name __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer): return False +def _fakequant_module_weights( + module: nn.Module, + module_name: str, + model: nn.Module, + state_dict: dict | None, + input_quantizers_folded_pqs: set, + fakequant_weights: set, + inplace: bool, +): + """Apply fake-quant to a single QuantModule's weights. + + When ``inplace=False``, reads/writes weights from/to ``state_dict``. + When ``inplace=True``, modifies the module's weight parameters directly. + """ + if not isinstance(module, QuantModule): + return + for attr_name, quantizer in module.named_children(): + if not ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.fake_quant + and quantizer.is_enabled + ): + continue + weight_name = attr_name.removesuffix("_quantizer") + prefix = f"{module_name}." if module_name else "" + sd_key = f"{prefix}{weight_name}" + assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized" + + if inplace: + w = getattr(module, weight_name) + w_quant = quantizer(w.float()).to(w.dtype) + else: + assert state_dict is not None + if sd_key not in state_dict: + continue + w = state_dict[sd_key] + w_quant = quantizer(w.float()).to(w.dtype) + + # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) + # Only valid when input_quantizer does NOT fake-quant activations. If it does + # fake_quant(x*s), the non-linearity prevents folding s into W. + inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") + if hasattr(module, inp_attr): + inp_q = getattr(module, inp_attr) + if ( + hasattr(inp_q, "_pre_quant_scale") + and inp_q._pre_quant_scale is not None + and inp_q._disabled + ): + scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) + w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) + inp_q_key = get_unwrapped_name( + f"{module_name}.{inp_attr}" if module_name else inp_attr, model + ) + input_quantizers_folded_pqs.add(inp_q_key) + + if inplace: + w.data.copy_(w_quant) + else: + assert state_dict is not None + state_dict[sd_key] = w_quant.cpu() + fakequant_weights.add(sd_key) + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, + inplace_mem_efficient: bool = False, ): """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. @@ -53,59 +121,56 @@ def export_hf_vllm_fq_checkpoint( Args: model: In-memory quantized model. export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. + inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at + a time using ``enable_weight_access_and_writeback``, avoiding full state + dict materialization. This is destructive — model weights are permanently + modified and weight quantizers are not re-enabled after export. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) # Step 1: Build the folded HF state dict. - # model.state_dict() returns detached copies of all tensors, so model - # parameters are never modified. Apply each weight quantizer's fake-quant - # to the corresponding weight tensor in the copy. - state_dict = model.state_dict() fakequant_weights = set() - input_quantizers_folded_pqs = ( - set() - ) # keys for input_quantizers where pre_quant_scale was folded + input_quantizers_folded_pqs = set() with torch.inference_mode(): - for module_name, module in model.named_modules(): - if not isinstance(module, QuantModule): - continue - for attr_name, quantizer in module.named_children(): - if not ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.fake_quant - and quantizer.is_enabled - ): + if inplace_mem_efficient: + # Inplace path: iterate decoder layers, one offload<->onload per layer. + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + assert decoder_layers is not None, ( + "inplace_mem_efficient=True requires a model with discoverable decoder layers" + ) + for name, module in model.named_modules(): + if module not in decoder_layers: continue - weight_name = attr_name.removesuffix("_quantizer") - prefix = f"{module_name}." if module_name else "" - sd_key = f"{prefix}{weight_name}" - assert sd_key not in fakequant_weights, ( - f"Weight {sd_key} has already been fakequantized" - ) - if sd_key in state_dict: - w = state_dict[sd_key] - w_quant = quantizer(w.float()).to(w.dtype).cpu() - # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) - # Only valid when input_quantizer does NOT fake-quant activations. If it does - # fake_quant(x*s), the non-linearity prevents folding s into W. - inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") - if hasattr(module, inp_attr): - inp_q = getattr(module, inp_attr) - if ( - hasattr(inp_q, "_pre_quant_scale") - and inp_q._pre_quant_scale is not None - and inp_q._disabled - ): - scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) - w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) - inp_q_key = get_unwrapped_name( - f"{module_name}.{inp_attr}" if module_name else inp_attr, model - ) - input_quantizers_folded_pqs.add(inp_q_key) - state_dict[sd_key] = w_quant - fakequant_weights.add(sd_key) + with enable_weight_access_and_writeback(module, module): + for sub_name, sub_mod in module.named_modules(): + full_name = f"{name}.{sub_name}" if sub_name else name + _fakequant_module_weights( + sub_mod, + full_name, + model, + None, + input_quantizers_folded_pqs, + fakequant_weights, + inplace=True, + ) + # Meta tensors for offloaded weights (free); offload maps now have + # fakequanted values via writeback. + state_dict = model.state_dict() + else: + # Default path: full state_dict copy, fakequant into the copy. + state_dict = model.state_dict() + for module_name, module in model.named_modules(): + with enable_weight_access_and_writeback(module, model): + _fakequant_module_weights( + module, + module_name, + model, + state_dict, + input_quantizers_folded_pqs, + fakequant_weights, + inplace=False, + ) # Filter quantizer tensors out for a clean HF checkpoint. clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} @@ -164,6 +229,7 @@ def export_hf_vllm_fq_checkpoint( # Step 3: Save HF weights using the pre-built folded state dict. model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + if not inplace_mem_efficient: + for wq, orig_rotate in wqs_to_restore: + wq.enable() + wq._rotate = orig_rotate diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index 1c600cf83a..f80e2478dc 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -72,45 +72,16 @@ def _writeback_params_to_weights_map(module, align_hook): def weight_access_and_writeback_context(module): """Context manager for weight access and writeback for modules managed by accelerate. - Handles CPU-offloaded and disk-offloaded models, in two layout cases: - 1. **Single-module**: the module's own ``_hf_hook`` is an offload hook. - 2. **Sub-module**: the module's hook is non-offloading, but its children have - offload hooks (common with ``SequentialHook`` on sub-modules placed by - ``load_checkpoint_and_dispatch``). - - For the sub-module case, ``pre_forward`` is skipped on sub-modules whose weights - are already materialized (not on meta). This allows the context manager to be - used as a pure writeback after weight-modifying algorithms. + Handles CPU-offloaded and disk-offloaded models. Iterates over the module and all + its descendants, materializing weights from any offload hook found and writing them + back on exit. ``pre_forward`` is skipped on modules whose weights are already + materialized (not on meta) to avoid overwriting them with stale CPU copies. """ assert hasattr(module, "_hf_hook") - align_hook = _get_offload_hook(module._hf_hook) - - if align_hook: - # Guard: the sub-module branch below is not reached when the parent has - # an offload hook. Assert that no children also carry offload hooks, - # which would require a combined writeback strategy. - if any( - _get_offload_hook(mod._hf_hook) - for mod in module.modules() - if mod is not module and hasattr(mod, "_hf_hook") - ): - raise RuntimeError( - "Both the module and one of its sub-modules have offload hooks. " - "weight_access_and_writeback_context does not support this layout yet." - ) - align_hook.pre_forward(module) - align_hook.offload = False - try: - yield - finally: - align_hook.offload = True - _writeback_params_to_weights_map(module, align_hook) - align_hook.post_forward(module, None) - return materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = [] for mod in module.modules(): - if mod is module or not hasattr(mod, "_hf_hook"): + if not hasattr(mod, "_hf_hook"): continue hook = _get_offload_hook(mod._hf_hook) if hook is None: diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index 0cacdab273..6494d1a774 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -78,14 +78,19 @@ def __init__(self, original: nn.Module): object.__setattr__(self, "_original", original) self._layerwise_calib = _LayerCalibState(mode="skip") + _PROXY_BLOCKLIST = frozenset({"_hf_hook", "_old_forward"}) + def __getattr__(self, name: str): # Proxy non-special attribute lookups to the original layer so that # parent-model code that accesses layer-level attributes (e.g., # NemotronH's ``block_type``) still works when the layer is replaced - # with a _SkipLayer. + # with a _SkipLayer. Accelerate hook attrs are blocked so the + # framework does not attempt to manage this parameter-free stand-in. try: return super().__getattr__(name) except AttributeError: + if name in self._PROXY_BLOCKLIST: + raise return getattr(object.__getattribute__(self, "_original"), name) def forward(self, *args, **kwargs): @@ -616,23 +621,20 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: layer_device = get_module_device(layer) d = _layer_dir(self.checkpoint_dir, i) - # Restore quantizer state first: may promote TensorQuantizer to - # NVFP4StaticQuantizer, changing module structure that load_state_dict - # expects. - # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + # weights_only=False is safe: files are internally generated by _save_layer qstate = torch.load( - os.path.join(d, "quantizer_state.pt"), map_location=layer_device, weights_only=False + os.path.join(d, "quantizer_state.pt"), + map_location=layer_device, + weights_only=False, + ) + weights = torch.load( + os.path.join(d, "weights.pt"), map_location=layer_device, weights_only=False ) - restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) - # Load weights inside the framework's access context so that - # managed-weight frameworks (accelerate CPU offload, FSDP2) sync - # their internal state with the restored parameters. + # Restore inside the access context so offloaded weights are + # materialized and written back after modification. with enable_weight_access_and_writeback(layer, model, name_to_module): - # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied - weights = torch.load( - os.path.join(d, "weights.pt"), map_location=layer_device, weights_only=False - ) + restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) layer.load_state_dict(weights, strict=False) print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") @@ -660,7 +662,7 @@ def save( _cpu = torch.device("cpu") with enable_weight_access_and_writeback(layer, model): weights = _move_to_device(layer.state_dict(), _cpu) - qstate = _move_to_device(quantizer_state(layer), _cpu) + qstate = _move_to_device(quantizer_state(layer), _cpu) output_meta = getattr(layer._layerwise_calib, "output_meta", None) if output_meta is None: diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 618b35a5f5..49e74e5851 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -20,7 +20,6 @@ import pytest import torch -from _test_utils.torch.quantization.quantize_common import INT4_AWQ_CLIP_CFG from _test_utils.torch.transformers_models import create_tiny_llama_dir from accelerate import init_empty_weights, load_checkpoint_and_dispatch from transformers import AutoConfig, AutoModelForCausalLM @@ -33,17 +32,8 @@ from modelopt.torch.quantization.utils.layerwise_calib import _layer_dir -@pytest.mark.parametrize( - "quant_cfg", - [ - mtq.INT4_AWQ_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - INT4_AWQ_CLIP_CFG, - mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - mtq.INT8_DEFAULT_CFG, - ], -) -def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): +def test_cpu_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -119,14 +109,10 @@ def _make_layerwise_checkpoint_cfg(base_cfg, checkpoint_dir): return cfg -@pytest.mark.parametrize( - "quant_cfg", - [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], - ids=["int4_awq", "nvfp4"], -) @pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) -def test_layerwise_calibrate_cpu_offloaded(tmp_path, quant_cfg, use_checkpoint): +def test_layerwise_calibrate_cpu_offloaded(tmp_path, use_checkpoint): """Layerwise calibration on CPU-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG num_layers = 3 tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -178,13 +164,9 @@ def test_layerwise_calibrate_cpu_offloaded(tmp_path, quant_cfg, use_checkpoint): assert manifest["num_layers"] == num_layers -@pytest.mark.parametrize( - "quant_cfg", - [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], - ids=["int4_awq", "nvfp4"], -) -def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path, quant_cfg): +def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path): """Resume from a partial checkpoint on a CPU-offloaded model matches a full run.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG num_layers = 3 tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -307,12 +289,12 @@ def test_sequential_gptq_cpu_offloaded(tmp_path, use_checkpoint): if use_checkpoint: ckpt_dir = str(tmp_path / "gptq_ckpt") - seq_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_DEFAULT_CFG, ckpt_dir) + seq_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) else: - seq_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_DEFAULT_CFG) + seq_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) # Reference: GPU-only model - ref_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_DEFAULT_CFG) + ref_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) model_ref = AutoModelForCausalLM.from_pretrained( tiny_llama_dir, torch_dtype=config.torch_dtype ).cuda() @@ -342,7 +324,7 @@ def test_sequential_gptq_checkpoint_resume_cpu_offloaded(tmp_path): inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() ckpt_dir = str(tmp_path / "gptq_ckpt") - seq_ckpt_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_DEFAULT_CFG, ckpt_dir) + seq_ckpt_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) # Full reference run with checkpointing with init_empty_weights(): @@ -498,6 +480,24 @@ def test_persistent_materialization_cpu_offloaded(tmp_path): assert torch.allclose(linear.weight, ref_weight + 1.0) +def _make_disk_offload_device_map(model): + """Build a device_map with layer 0 on disk, everything else on GPU 0. + + Ancestor modules (``""`` and ``"model"``) are excluded so that + ``dispatch_model`` does not attach a ``place_submodules=True`` hook that + would try to move disk-offloaded meta tensors to GPU (which fails because + no ``value`` is available — unlike CPU offload where weights are on CPU and + can be moved directly). + """ + device_map = { + n: 0 + for n, m in model.named_modules() + if n not in ("", "model") and ("layers" not in n or n.split("layers.")[-1].isdigit()) + } + device_map["model.layers.0"] = "disk" + return device_map + + def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): """Create a tiny LLaMA model with layer 0 offloaded to disk via accelerate.""" tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) @@ -506,13 +506,7 @@ def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) - device_map = { - n: 0 - for n, m in model.named_modules() - if "layers" not in n or n.split("layers.")[-1].isdigit() - } - device_map["model.layers.0"] = "disk" - + device_map = _make_disk_offload_device_map(model) offload_dir = str(tmp_path / "offload") model = load_checkpoint_and_dispatch( model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir @@ -521,17 +515,8 @@ def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): return model, config, tiny_llama_dir, inputs -@pytest.mark.parametrize( - "quant_cfg", - [ - mtq.INT4_AWQ_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - INT4_AWQ_CLIP_CFG, - mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - mtq.INT8_DEFAULT_CFG, - ], -) -def test_disk_offloaded_tinyllama(tmp_path, quant_cfg): +def test_disk_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -547,13 +532,7 @@ def test_disk_offloaded_tinyllama(tmp_path, quant_cfg): with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) - device_map = { - n: 0 - for n, m in model.named_modules() - if "layers" not in n or n.split("layers.")[-1].isdigit() - } - device_map["model.layers.0"] = "disk" - + device_map = _make_disk_offload_device_map(model) offload_dir = str(tmp_path / "offload") model = load_checkpoint_and_dispatch( model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir @@ -619,14 +598,10 @@ def test_persistent_materialization_disk_offloaded(tmp_path): assert torch.allclose(linear.weight, ref_weight + 1.0) -@pytest.mark.parametrize( - "quant_cfg", - [mtq.INT4_AWQ_CFG, mtq.NVFP4_DEFAULT_CFG], - ids=["int4_awq", "nvfp4"], -) @pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) -def test_layerwise_calibrate_disk_offloaded(tmp_path, quant_cfg, use_checkpoint): +def test_layerwise_calibrate_disk_offloaded(tmp_path, use_checkpoint): """Layerwise calibration on disk-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG num_layers = 3 tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -649,12 +624,7 @@ def test_layerwise_calibrate_disk_offloaded(tmp_path, quant_cfg, use_checkpoint) # Test: disk-offloaded model with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) - device_map = { - n: 0 - for n, m in model.named_modules() - if "layers" not in n or n.split("layers.")[-1].isdigit() - } - device_map["model.layers.0"] = "disk" + device_map = _make_disk_offload_device_map(model) offload_dir = str(tmp_path / "offload") model = load_checkpoint_and_dispatch( model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir From e9840adee227f574021272e4170ab471c1b03e61 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 16 Apr 2026 19:07:54 +0000 Subject: [PATCH 5/6] Address review feedback on layerwise calibration Replace TODO with explanatory comment noting that all currently implemented PTQ algorithms support layerwise calibration, and fix inaccurate "independently" wording in the layerwise docstring. Signed-off-by: realAsma --- modelopt/torch/quantization/config.py | 4 ++-- modelopt/torch/quantization/mode.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 0a52f0e866..3f24ac09a4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1221,8 +1221,8 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): default=False, title="Enable layerwise (layer-by-layer) calibration.", description=( - "If True, the calibration algorithm is applied to each decoder layer independently. " - "Each layer's inputs are captured via a single forward pass that reflects the " + "If True, the calibration algorithm is applied layer by layer. " + "Each layer's inputs are captured via a forward pass that reflects the " "quantization of all preceding layers, incurring O(N) forward passes for N layers." ), ) diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 5b00308936..1328ef5821 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -239,7 +239,8 @@ def wrapped_calib_func( if func is not None: if layerwise: - # TODO: add a method guard here — not all calib methods support per-layer invocation + # All currently implemented PTQ algorithms support layerwise calibration; + # future algorithms that need full-model context must add a guard here. if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") # Wrap with layerwise processing From 560a94f02865b12c10de1f3a50168bfdc1ca0a15 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 16 Apr 2026 21:28:48 +0000 Subject: [PATCH 6/6] Fix meta device detection in layerwise restore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compute layer_device and load checkpoint tensors inside enable_weight_access_and_writeback so params are materialized — otherwise get_module_device falls back to meta when no _hf_hook is attached directly on the layer (e.g. MoE with per-submodule hooks), leaving restored _amax buffers on meta. Signed-off-by: realAsma --- .../quantization/utils/layerwise_calib.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py index 6494d1a774..aed403ad87 100644 --- a/modelopt/torch/quantization/utils/layerwise_calib.py +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -618,24 +618,25 @@ def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: name_to_module = dict(model.named_modules()) for i in range(self.start_layer): layer = layers[i] - layer_device = get_module_device(layer) d = _layer_dir(self.checkpoint_dir, i) - # weights_only=False is safe: files are internally generated by _save_layer - qstate = torch.load( - os.path.join(d, "quantizer_state.pt"), - map_location=layer_device, - weights_only=False, - ) - weights = torch.load( - os.path.join(d, "weights.pt"), map_location=layer_device, weights_only=False - ) - - # Restore inside the access context so offloaded weights are - # materialized and written back after modification. + # Resolve layer_device and load inside the context so params are + # materialized — otherwise get_module_device can return meta. with enable_weight_access_and_writeback(layer, model, name_to_module): + layer_device = get_module_device(layer) + # weights_only=False is safe: files are internally generated by _save_layer + qstate = torch.load( + os.path.join(d, "quantizer_state.pt"), + map_location=layer_device, + weights_only=False, + ) + weights = torch.load( + os.path.join(d, "weights.pt"), + map_location=layer_device, + weights_only=False, + ) restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) - layer.load_state_dict(weights, strict=False) + layer.load_state_dict(weights, strict=False, assign=True) print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers")