diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5c189bd28b..7fba13a529 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ Changelog - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. - Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. - [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution. +- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml `_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml `_ for usage. **Backward Breaking Changes** diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c2d4d4bfca..90532efe38 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -15,6 +15,7 @@ import copy import glob +import hashlib import inspect import json import logging @@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod print(f"Successfully copied {len(copied_files)} custom model files to {export_path}") else: print("No custom model files found to copy") + + +def needs_checkpoint_path_update(quant_cfg: dict) -> bool: + """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" + algorithm = quant_cfg.get("algorithm") + if not isinstance(algorithm, dict): + return False + return algorithm.get("layerwise_checkpoint_dir") is not None + + +def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: + """Append a unique ``_`` subdirectory to layerwise_checkpoint_dir. + + Allows a single recipe to be reused across models without checkpoint collisions. + Must only be called when :func:`needs_checkpoint_path_update` returns True. + """ + algorithm = quant_cfg["algorithm"] + base_dir = algorithm["layerwise_checkpoint_dir"] + + name = model_path.rstrip("/") + if "/" in name and not os.path.isabs(name): + name = name.replace("/", "--") + else: + name = Path(name).name + + config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8] + + quant_cfg = copy.deepcopy(quant_cfg) + quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( + base_dir, f"{name}_{config_hash}" + ) + return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 327605406c..c405de51e7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -34,6 +34,8 @@ is_enc_dec, is_nemotron_vl, load_mtp_weights, + needs_checkpoint_path_update, + resolve_checkpoint_dir, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: for i, entry in enumerate(quant_cfg): if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue - assert isinstance(entry.get("cfg", {}), dict) - quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}} + cfg = entry.get("cfg") or {} + assert isinstance(cfg, dict) + quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} break @@ -759,7 +762,9 @@ def export_quantized( # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: - export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + export_hf_vllm_fq_checkpoint( + full_model, export_dir=export_path, inplace_mem_efficient=True + ) else: mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( full_model, args.pyt_ckpt_path @@ -1104,6 +1109,12 @@ def quantize_main( quant_cfg = copy.deepcopy(quant_cfg) _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) + if needs_checkpoint_path_update(quant_cfg): + quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) + print( + f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" + ) + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 3d96ebb46a..6513b5b04d 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -98,6 +98,5 @@ QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fa ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). -2. AWQ reload is not supported yet -3. KV cache quantization export and reload is not supported in MCore yet. -4. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs. +2. KV cache quantization export and reload is not supported in MCore yet. +3. **`NVFP4_KV_CFG` and `NVFP4_AFFINE_KV_CFG` require `--enforce-eager`**; these configs use a dynamic-block Triton kernel for KV-cache quantization that is incompatible with CUDA graph capture (the kernel grid is computed from Python-level tensor shapes, which get baked in at capture time). Without `--enforce-eager`, the captured grid will be wrong for different batch sizes, producing incorrect outputs. diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index b88af9c72e..4f84df0581 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -15,6 +15,7 @@ import os +import warnings from typing import Any import torch @@ -26,13 +27,16 @@ convert_modelopt_state_to_vllm, load_state_dict_from_path, restore_from_modelopt_state_vllm, + shard_pre_quant_scale_for_tp, ) import modelopt.torch.quantization as mtq +from modelopt.torch.export.plugins.vllm_fakequant_hf import is_weight_quantizer_state_key from modelopt.torch.quantization.plugins.vllm import ( disable_compilation, post_restore_vllm_parallel_linears, ) +from modelopt.torch.utils import safe_load from modelopt.torch.utils.dataset_utils import get_dataset_dataloader quant_config: dict[str, Any] = { @@ -61,28 +65,48 @@ def _fakequant_run_prolog_worker(self) -> None: model = model.unwrap() if quant_config["modelopt_state_path"]: print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") - # Load on CPU to avoid failures when the checkpoint was saved from a different - # GPU mapping - modelopt_state = torch.load( - quant_config["modelopt_state_path"], weights_only=True, map_location="cpu" - ) + # Load on CPU to avoid failures when the checkpoint was saved from a different GPU mapping. + modelopt_state = safe_load(quant_config["modelopt_state_path"], map_location="cpu") modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) map_fun = ( self.model_runner.model.hf_to_vllm_mapper.apply_dict if hasattr(self.model_runner.model, "hf_to_vllm_mapper") else None ) - # convert modelopt state to vllm format modelopt_state = convert_modelopt_state_to_vllm(modelopt_state, map_fun=map_fun) - # restore model from modelopt state restore_from_modelopt_state_vllm(model, modelopt_state) if modelopt_weights is not None: - # convert quantizer state values to vllm format modelopt_weights = convert_dict_to_vllm(modelopt_weights, map_fun=map_fun) mtq.utils.set_quantizer_state_dict(model, modelopt_weights) - # set_quantizer_state_dict does not invoke modelopt_post_restore (unlike restore_quantizer_state). + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + from modelopt.torch.quantization.nn import TensorQuantizer + from modelopt.torch.utils import get_unwrapped_name + + loaded_keys = { + get_unwrapped_name(n, model) + for n, m in model.named_modules() + if isinstance(m, TensorQuantizer) + } + # Same namespace as ``loaded_keys``: checkpoint keys may include DDP/FSDP + # prefixes that ``convert_dict_to_vllm`` does not strip. + pqs_in_weights = { + get_unwrapped_name(k, model) + for k, v in modelopt_weights.items() + if isinstance(v, dict) and "_pre_quant_scale" in v + } + unmatched_pqs = pqs_in_weights - loaded_keys + if unmatched_pqs: + sample = sorted(unmatched_pqs)[:20] + warnings.warn( + f"{len(unmatched_pqs)} checkpoint pre_quant_scale key(s) have no " + f"matching TensorQuantizer in the model (showing up to 20): {sample}", + stacklevel=2, + ) + # set_quantizer_state_dict does not run modelopt_post_restore (unlike restore_quantizer_state). post_restore_vllm_parallel_linears(model) + # Must follow post_restore: shard_pre_quant_scale_for_tp uses weight H_in vs pqs length. + shard_pre_quant_scale_for_tp(model) else: if quant_config["quant_file_path"]: @@ -101,15 +125,13 @@ def _fakequant_run_prolog_worker(self) -> None: quant_cfg = get_quant_config(quant_config, model) - # quantize model with disable_compilation(model): print("Quantizing model...") mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) quantizer_file_path = quant_config["quant_file_path"] if quantizer_file_path: - # Get amax and other quantizer state from the quantizer file - # this can be used with Megatron-LM exported model using export_mcore_gpt_to_hf_vllm_fq + self.model_runner._dummy_run(1) current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model) model.load_state_dict(current_state_dict) @@ -122,8 +144,11 @@ def _fakequant_run_prolog_worker(self) -> None: mtq.fold_weight(model) for name, module in model.named_modules(): - if name.endswith("weight_quantizer"): - assert not module.is_enabled, f"quantizer {name} is still enabled" + if is_weight_quantizer_state_key(name) and module.is_enabled: + raise RuntimeError( + f"Weight quantizer {name!r} is still enabled after fold_weight — " + "double-quantization would corrupt activations." + ) class FakeQuantWorker(BaseWorker): diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index 2b59d1be2b..aa8d3a5388 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -22,6 +22,11 @@ import torch from vllm.distributed.parallel_state import get_tp_group +from modelopt.torch.export.plugins.vllm_fakequant_hf import ( + infer_quantizer_prefix_remap, + is_weight_quantizer_state_key, + merge_amax_tensors_for_group, +) from modelopt.torch.opt.conversion import ( ModelLikeModule, ModeloptStateManager, @@ -84,7 +89,7 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: if "quantizer" not in key: return ("copy", key, value) - # Skip softmax_quantizer and lm_head quantizers(not needed in vLLM) + # Skip softmax_quantizer and lm_head quantizers (not needed in vLLM). if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key): return ("skip", None, None) @@ -95,8 +100,7 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: group_key = qkv_match.group(1) + "qkv_proj." + qkv_match.group(3) + suffix return ("group", group_key, value) - # Check if this is an expert gate/up projection - # if "mixer" not in key: + # Expert gate/up (per-expert) → w13 merge expert_gate_up_match = re.search( r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key ) @@ -113,8 +117,6 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: group_key = gate_up_match.group(1) + "gate_up_proj." + gate_up_match.group(3) + suffix return ("group", group_key, value) - # Check if this is an expert down_proj - # if "mixer" not in key: expert_down_match = re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer)(\..+)?$", key) if expert_down_match: suffix = expert_down_match.group(3) or "" @@ -148,9 +150,10 @@ def _group_keys_for_vllm( for key, value in state_dict.items(): action, new_key, new_value = _convert_key_for_vllm(key, value) if new_key is None or new_value is None: - assert action == "skip", ( - f"Expected action to be 'skip' for key {key}, value {value}, got {action}" - ) + if action != "skip": + raise RuntimeError( + f"Expected action to be 'skip' for key {key}, value {value}, got {action}" + ) continue if action == "copy": vllm_state_dict[new_key] = new_value @@ -176,7 +179,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[ for dict_key in values[0]: tensors = [v[dict_key] for v in values] if "_amax" in dict_key: - merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] + merged_value[dict_key] = merge_amax_tensors_for_group(tensors) elif "_pre_quant_scale" in dict_key: # _pre_quant_scale is per-input-channel: identical across q/k/v projections # since they share the same input. Do not concatenate; take the first value. @@ -187,7 +190,7 @@ def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[ else: # Values are tensors directly if "_amax" in merged_key: - merged_value = torch.stack(values).max(dim=0)[0] + merged_value = merge_amax_tensors_for_group(values) else: merged_value = torch.cat(values, dim=0) return merged_value @@ -231,6 +234,25 @@ def convert_dict_to_vllm( max_or_concat: Whether to merge grouped values by taking max/concatenate or require identical map_fun: Function to map the state dict to vLLM format """ + # If map_fun is provided, pre-transform quantizer key module-path prefixes so that + # HF→vLLM model renames (e.g. backbone.layers → model.layers) are applied before + # key grouping (q/k/v → qkv, experts.N.up_proj → experts.w13, etc.). + # This is necessary for models where the HF root module differs from vLLM's (e.g. + # NemotronH uses backbone.layers in HF but model.layers in vLLM), and for + # modelopt_state_weights where ALL keys are quantizer keys so map_fun is never + # invoked on non-quantizer keys. + if map_fun is not None: + q_only = {k: v for k, v in state_dict.items() if "_quantizer" in k} + prefix_remap = infer_quantizer_prefix_remap(q_only, map_fun) + if prefix_remap: + renamed = {} + for k, v in state_dict.items(): + if "_quantizer" in k: + first = k.split(".")[0] + k = prefix_remap.get(first, first) + k[len(first) :] + renamed[k] = v + state_dict = renamed + vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) merge_fn = _merge_values_by_max_or_concat if max_or_concat else _merge_values_require_identical @@ -340,7 +362,26 @@ def _has_buffers(state: dict) -> bool: } # Add state for quantizers in model but not in metadata (e.g. disabled/excluded) for k in model_keys - filtered.keys(): - filtered[k] = model_qstate[k] + state = model_qstate[k] + # Weight quantizers absent from exported metadata were disabled during export + # (weights are already fake-quantized and pre_quant_scale is folded in). + # Keep them disabled on reload so fold_weight does not re-quantize the + # already-folded weights (re-quantizing distorts the pqs-scaled values). + if is_weight_quantizer_state_key(k) and not state.get("_disabled"): + state = {**state, "_disabled": True} + filtered[k] = state + + # Invariant: weight quantizers absent from export must be _disabled. + for wq_k in model_keys: + if not is_weight_quantizer_state_key(wq_k): + continue + wq_state = filtered[wq_k] + if wq_k not in saved and not wq_state.get("_disabled"): + raise RuntimeError( + f"Weight quantizer {wq_k!r} is missing from saved quantizer_state but " + f"is not marked _disabled (got _disabled={wq_state.get('_disabled')!r}). " + f"vLLM fakequant export omits weight quantizer keys when weights are folded." + ) metadata["quantizer_state"] = filtered @@ -379,10 +420,123 @@ def restore_from_modelopt_state_vllm( if not manager.has_state and isinstance(model, ModelLikeModule): model = model.init_modellike() - assert not isinstance(model, ModelLikeModule), "Model must be a regular Module now!" + if isinstance(model, ModelLikeModule): + raise RuntimeError("Model must be a regular Module after restore, got ModelLikeModule") return model +def _tp_concat_shard_dims( + value_shape: tuple[int, ...], + expected_shape: tuple[int, ...], + tp_world_size: int, +) -> list[int]: + """Dims ``d`` where checkpoint looks like TP concat: ``value[d] == expected[d] * tp_world_size``.""" + return [ + d for d in range(len(expected_shape)) if value_shape[d] == expected_shape[d] * tp_world_size + ] + + +def _narrow_tensor_to_tp_local_shard( + value: torch.Tensor, + expected_shape: tuple[int, ...] | torch.Size, + tp_rank: int, + tp_world_size: int, + *, + context: str, +) -> torch.Tensor: + """Slice ``value`` to this TP rank when it is the concat of per-rank shards along one dim.""" + value_shape = value.shape + expected_shape = tuple(expected_shape) + if value_shape == expected_shape: + return value + if len(value_shape) != len(expected_shape): + raise ValueError( + f"{context}: rank mismatch (checkpoint={tuple(value_shape)}, expected={tuple(expected_shape)})" + ) + shard_dims = _tp_concat_shard_dims(value_shape, expected_shape, tp_world_size) + if len(shard_dims) != 1: + raise ValueError( + f"{context}: cannot infer TP shard dim " + f"(expected={tuple(expected_shape)}, checkpoint={tuple(value_shape)}, tp={tp_world_size})" + ) + d = shard_dims[0] + shard_size = expected_shape[d] + start = tp_rank * shard_size + if start + shard_size > value_shape[d]: + raise ValueError( + f"{context}: TP shard out of bounds " + f"(expected={tuple(expected_shape)}, checkpoint={tuple(value_shape)})" + ) + return value.narrow(d, start, shard_size).contiguous() + + +def _pqs_local_expected_shape(pqs: torch.Tensor, expected_in: int) -> tuple[int, ...] | None: + """Local per-rank shape for ``_pre_quant_scale`` (1-D ``[H]`` or broadcast 2-D ``[1, H]``).""" + if pqs.ndim == 1: + return (expected_in,) + if pqs.ndim == 2 and pqs.shape[0] == 1: + return (1, expected_in) + return None + + +def _expected_in_features_for_input_quantizer(parent: Any, input_quantizer_attr: str) -> int | None: + """Input feature count for the weight paired with ``*_input_quantizer`` (Linear or FusedMoE).""" + stem = input_quantizer_attr[: -len("_input_quantizer")] + w = getattr(parent, (stem + "_weight") if stem else "weight", None) + if w is None or not isinstance(w, torch.Tensor) or w.is_meta: + return None + return int(w.shape[-1] if w.ndim == 3 else w.shape[1]) + + +def shard_pre_quant_scale_for_tp(model: Any) -> None: + """Shard ``_pre_quant_scale`` in-place for the local TP rank (row-parallel inputs). + + HF exports often store full (unsharded) scales; after load, row-parallel layers need + ``pqs`` narrowed to ``H_in / tp`` when ``len(pqs) == H_in * tp_world_size``. + + Call after parallel linear modules expose TP-sharded weight shapes (e.g. + ``post_restore_vllm_parallel_linears``). If run earlier, ``expected_in`` inferred from + weights can match an unsharded checkpoint and a second call becomes a no-op even when + pqs should still be narrowed. + + Args: + model: vLLM model with ``TensorQuantizer`` submodules. + """ + from modelopt.torch.quantization.nn import TensorQuantizer + + tp_group = get_tp_group() + tp_rank, tp_world_size = tp_group.rank_in_group, tp_group.world_size + if tp_world_size == 1: + return + + for qname, quantizer in model.named_modules(): + if not isinstance(quantizer, TensorQuantizer): + continue + pqs = getattr(quantizer, "_pre_quant_scale", None) + if pqs is None: + continue + last = qname.rfind(".") + if last == -1 or not qname[last + 1 :].endswith("input_quantizer"): + continue + try: + parent = model.get_submodule(qname[:last]) + except (AttributeError, LookupError): + continue + expected_in = _expected_in_features_for_input_quantizer(parent, qname[last + 1 :]) + if expected_in is None: + continue + expected_shape = _pqs_local_expected_shape(pqs, expected_in) + if expected_shape is None: + continue + quantizer._pre_quant_scale = _narrow_tensor_to_tp_local_shard( + pqs, + expected_shape, + tp_rank, + tp_world_size, + context=f"{qname}._pre_quant_scale", + ) + + def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): """Shard quantizer tensors for tensor parallelism by matching expected shapes.""" tp_group = get_tp_group() @@ -393,42 +547,14 @@ def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): for key, value in saved_qstate_dict.items(): if key in current_state_dict: expected = current_state_dict[key] - if not hasattr(value, "shape") or not hasattr(expected, "shape"): - result[key] = value - continue - expected_shape = expected.shape - value_shape = value.shape - if value_shape != expected_shape: - # Verify compatible rank before indexing - if len(value_shape) != len(expected_shape): - raise ValueError( - f"Cannot infer TP shard dim for {key}: rank mismatch " - f"(checkpoint rank={len(value_shape)}, expected rank={len(expected_shape)})" - ) - # Find the dimension that was tensor-parallel sharded. - # We expect exactly one dimension to satisfy: - # checkpoint_dim == expected_dim * tp_world_size - shard_dims = [ - d - for d in range(len(expected_shape)) - if value_shape[d] == expected_shape[d] * tp_world_size - ] - if len(shard_dims) != 1: - raise ValueError( - f"Cannot infer TP shard dim for {key}: " - f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value_shape)}" - ) - - shard_dim = shard_dims[0] - shard_size = expected_shape[shard_dim] - start = tp_rank * shard_size - end = start + shard_size - if end > value_shape[shard_dim]: - raise ValueError( - f"TP shard out of bounds for {key}: " - f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value_shape)}" - ) - value = value.narrow(shard_dim, start, shard_size).contiguous() + if hasattr(value, "shape") and hasattr(expected, "shape"): + value = _narrow_tensor_to_tp_local_shard( + value, + expected.shape, + tp_rank, + tp_world_size, + context=f"Key {key!r}", + ) result[key] = value return result @@ -437,12 +563,8 @@ def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): def load_state_dict_from_path( fakequant_runner: Any, quantizer_file_path: str, model: Any ) -> dict[str, Any]: - fakequant_runner.model_runner._dummy_run(1) - print(f"Loading quantizer values from {quantizer_file_path}") - # Load on CPU to avoid failures when the checkpoint was saved from a different - # GPU mapping + # Load on CPU to avoid failures when the checkpoint was saved from a different GPU mapping. saved_quant_dict = torch.load(quantizer_file_path, weights_only=True, map_location="cpu") - # convert quant keys to vLLM format if hasattr(fakequant_runner.model_runner.model, "hf_to_vllm_mapper"): saved_quant_dict = fakequant_runner.model_runner.model.hf_to_vllm_mapper.apply_dict( saved_quant_dict @@ -455,7 +577,6 @@ def load_state_dict_from_path( saved_quant_dict = convert_dict_to_vllm(saved_quant_dict) current_state_dict = model.state_dict() - # Count quant keys in checkpoint and model checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] model_quant_keys = [key for key in current_state_dict if "quantizer" in key] ckpt_key_set = set(checkpoint_quant_keys) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 7726bf61af..e8ee5afd45 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -81,8 +81,16 @@ has_mcore = True -def get_experts_list(module: torch.nn.Module, model_type: str): - """Returns list of grouped experts by linear name for given module.""" +def get_experts_list( + module: torch.nn.Module, + model_type: str, +): + """Returns list of grouped experts by linear name for given module. + + Args: + module: MoE block (e.g. MixtralSparseMoeBlock, NemotronHMOE). + model_type: `type(root_model).__name__.lower()` (may change after ModelOpt quantize). + """ experts_list = [] # Define linear layer names for different model types @@ -98,6 +106,8 @@ def get_experts_list(module: torch.nn.Module, model_type: str): ] ): linear_names = ["gate_proj", "down_proj", "up_proj"] + elif "nemotronhforcausallm" in model_type: + linear_names = ["up_proj", "down_proj"] else: raise NotImplementedError(f" {model_type} not supported") @@ -305,7 +315,7 @@ def is_moe(module: nn.Module) -> bool: if name.endswith("sparsemoeblock") or "moelayer" in name: return True # Explicit matches for non-standard naming - return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"]) + return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"]) def is_quantlinear(module: nn.Module) -> bool: @@ -994,6 +1004,9 @@ def module_match_name_list(module, name_list): return ["w1_linear", "w2_linear", "v1_linear"] elif module_match_name_list(module, ["GptOssMoE"]): return ["gate_up_proj", "down_proj"] + elif module_match_name_list(module, ["NemotronHMOE"]): + # NemotronHMOE experts (NemotronHMLP) use up_proj and down_proj only (no gate). + return ["up_proj", "down_proj"] else: # assuming w1, w2, w3 by default return ["w1", "w2", "w3"] diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..42baad912b 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -14,7 +14,14 @@ # limitations under the License. """Export HuggingFace model to vLLM fakequant checkpoint.""" +import copy +import logging +import re +import warnings +from collections.abc import Callable +from contextlib import ExitStack, contextmanager from pathlib import Path +from typing import Any import torch import torch.nn as nn @@ -22,11 +29,108 @@ import modelopt.torch.opt as mto from modelopt.torch.quantization.config import RotateConfig from modelopt.torch.quantization.conversion import quantizer_state -from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer +from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection +from modelopt.torch.quantization.nn import QuantModule, SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.utils import get_quantizer_state_dict -from modelopt.torch.utils import get_unwrapped_name +from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector +from modelopt.torch.utils import get_unwrapped_name, safe_save -__all__ = ["export_hf_vllm_fq_checkpoint"] +from ..layer_utils import get_experts_list, is_moe +from ..quant_utils import get_quantization_format +from ..unified_export_hf import collect_shared_input_modules + +__all__ = [ + "export_hf_vllm_fq_checkpoint", + "infer_quantizer_prefix_remap", + "is_weight_quantizer_state_key", + "merge_amax_tensors_for_group", +] + +# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc. +_WEIGHT_QUANTIZER_STATE_KEY = re.compile(r"(?:^|\.)(?:\w+_)?weight_quantizer(?:\.\d+)*$") + + +def is_weight_quantizer_state_key(key: str) -> bool: + """Return True for weight-quantizer state keys, including SequentialQuantizer entries. + + Matches ``weight_quantizer``, ``w13_weight_quantizer``, ``weight_quantizer.0``, etc. + """ + return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key)) + + +def infer_quantizer_prefix_remap( + quantizer_keys: dict[str, Any], + map_fun: Callable[[dict[str, Any]], dict[str, Any]], +) -> dict[str, str]: + """Infer HF root name → vLLM root (e.g. ``backbone`` → ``model``) for reload/export. + + Map HF root → vLLM root (e.g. ``backbone`` → ``model``) by probing ``map_fun`` with + synthetic ``.weight`` keys and a 2-D placeholder (quantizer paths are not weight + keys). Keys under the same HF root must agree on the target root or :exc:`ValueError` is + raised; failed probes are skipped. Returns ``{hf_root: vllm_root}`` only where the root + renames; not for arbitrary layer rewrites. + + Args: + quantizer_keys: HF quantizer state paths as keys (values unused). + map_fun: HF→vLLM weight ``state_dict`` mapper, same as for ``convert_dict_to_vllm``. + + Returns: + ``{hf_root: vllm_root}`` for roots that rename; omits identity pairs. + """ + logger = logging.getLogger(__name__) + probe_weight = torch.empty((1, 1)) + observed_vllm_root: dict[str, str] = {} + + for key in quantizer_keys: + first_component = key.split(".")[0] + last_dot = key.rfind(".") + if last_dot == -1: + continue + probe_key = key[:last_dot] + ".weight" + try: + result = map_fun({probe_key: probe_weight}) + if not result: + continue + new_key = next(iter(result)) + new_first = new_key.split(".")[0] + except Exception as e: + logger.debug("prefix-remap probe failed for %r: %s", probe_key, e) + continue + + if first_component not in observed_vllm_root: + observed_vllm_root[first_component] = new_first + elif observed_vllm_root[first_component] != new_first: + raise ValueError( + "Inconsistent HF→vLLM prefix remap for " + f"{first_component!r}: probes implied " + f"{observed_vllm_root[first_component]!r} and {new_first!r}. " + "map_fun must apply one target root per HF root, or use explicit quantizer " + "key remapping." + ) + + return { + hf_root: vllm_root + for hf_root, vllm_root in observed_vllm_root.items() + if hf_root != vllm_root + } + + +def _check_all_weight_quantizers_disabled(model: nn.Module) -> None: + """Export invariant before writing metadata: every weight quantizer must be off.""" + for _, module in model.named_modules(): + if not isinstance(module, QuantModule): + continue + for attr_name, quantizer in module.named_children(): + if attr_name.endswith("weight_quantizer") and isinstance( + quantizer, (TensorQuantizer, SequentialQuantizer) + ): + if quantizer.is_enabled: + raise RuntimeError( + f"vLLM fakequant export: {attr_name!r} must be disabled before saving " + f"quantizer_state (weights already folded). " + f"See filter_modelopt_state_quantizer_state_for_model in vllm_reload_utils." + ) def disable_rotate(quantizer: TensorQuantizer): @@ -38,77 +142,429 @@ def disable_rotate(quantizer: TensorQuantizer): return False +def _fakequant_module_weights( + module: nn.Module, + module_name: str, + model: nn.Module, + state_dict: dict | None, + input_quantizers_folded_pqs: set, + fakequant_weights: set, + requant_weights: set[str], + inplace: bool, +): + """Apply fake-quant to a single QuantModule's weights. + + When ``inplace=False``, reads/writes weights from/to ``state_dict``. + When ``inplace=True``, modifies the module's weight parameters directly. + """ + if not isinstance(module, QuantModule): + return + for attr_name, quantizer in module.named_children(): + if not ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.fake_quant + and quantizer.is_enabled + ): + continue + weight_name = attr_name.removesuffix("_quantizer") + prefix = f"{module_name}." if module_name else "" + sd_key = f"{prefix}{weight_name}" + if sd_key in fakequant_weights: + raise RuntimeError(f"Weight {sd_key} has already been fakequantized") + + if inplace: + w = getattr(module, weight_name) + if sd_key in requant_weights: + w_quant = requant_weights_for_export(quantizer, w, copy_quantizer=False) + else: + w_quant = quantizer(w.float()).to(w.dtype) + else: + if state_dict is None: + raise RuntimeError("state_dict is required when inplace=False for fakequant export") + if sd_key not in state_dict: + continue + w = state_dict[sd_key] + if sd_key in requant_weights: + w_quant = requant_weights_for_export(quantizer, w) + else: + w_quant = quantizer(w.float()).to(w.dtype) + + # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) + # Only valid when input_quantizer does NOT fake-quant activations. If it does + # fake_quant(x*s), the non-linearity prevents folding s into W. + inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") + if hasattr(module, inp_attr): + inp_q = getattr(module, inp_attr) + if ( + hasattr(inp_q, "_pre_quant_scale") + and inp_q._pre_quant_scale is not None + and not inp_q.is_enabled + ): + scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) + w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) + inp_q_key = get_unwrapped_name( + f"{module_name}.{inp_attr}" if module_name else inp_attr, model + ) + input_quantizers_folded_pqs.add(inp_q_key) + + if inplace: + w.data.copy_(w_quant) + else: + if state_dict is None: + raise RuntimeError("state_dict is required when inplace=False for fakequant export") + state_dict[sd_key] = w_quant.cpu() + fakequant_weights.add(sd_key) + + +def _collect_group_pre_quant_scales( + experts: list[nn.Module], +) -> list[torch.Tensor] | None: + """Return per-expert ``pre_quant_scale`` tensors if every expert can be averaged; else None. + + Skips groups where any expert has no input quantizer, no pqs (e.g. weight-only AWQ INT4), + or a disabled input quantizer (pqs already folded / not used). + """ + pre_quant_scales: list[torch.Tensor] = [] + for expert_module in experts: + input_quantizer = getattr(expert_module, "input_quantizer", None) + if ( + input_quantizer is None + or not input_quantizer.is_enabled + or input_quantizer.pre_quant_scale is None + ): + return None + pre_quant_scales.append(input_quantizer.pre_quant_scale) + return pre_quant_scales + + +def requant_weights_for_export( + quantizer: TensorQuantizer | SequentialQuantizer, + weight: torch.Tensor, + copy_quantizer: bool = True, +) -> torch.Tensor: + """Requantize folded weights after resmooth (``TensorQuantizer`` or ``SequentialQuantizer``). + + A single ``TensorQuantizer`` is treated as a one-stage chain so the same + calibrate-then-apply steps cover W4A8-style sequential weights (e.g. INT4→FP8). + + Deepcopy may leave buffers on the original device; ``.to(device=w.device)`` aligns with + ``w`` (e.g. CPU offload). + """ + if copy_quantizer: + copied = copy.deepcopy(quantizer).to(device=weight.device) + else: + copied = quantizer + quantizers: list[TensorQuantizer] = ( + list(copied) if isinstance(copied, SequentialQuantizer) else [copied] + ) + + for quantizer_copy in quantizers: + quantizer_copy.eval() + quantizer_copy.reset_amax() + enable_stats_collection(quantizer_copy) + weight_quantized = weight + for quantizer_copy in quantizers: + weight_quantized = quantizer_copy(weight_quantized) + for quantizer_copy in quantizers: + finish_stats_collection(quantizer_copy) + # Re-run application pass to get the quantized output with the freshly collected amax. + # The calibration forward above only collected stats; its output is intentionally discarded. + weight_quantized = weight + for quantizer_copy in quantizers: + weight_quantized = quantizer_copy(weight_quantized) + return weight_quantized.to(weight.dtype) + + +def merge_amax_tensors_for_group(tensors: list[torch.Tensor]) -> torch.Tensor: + """Combine `_amax` buffers from a merge group into a single tensor. + + Used when HuggingFace module names are folded to vLLM names (e.g. q/k/v → qkv_proj). + + - If every tensor has the same shape, take the element-wise maximum over the group + (conservative when each branch carried the same axis layout). + - If shapes differ: ``torch.cat(..., dim=0)`` assumes **1D per-channel** amaxes in + fused order (e.g. GQA q/k/v → ``[N_q]`` + ``[N_kv]`` + ``[N_kv]``), matching vLLM’s + grouped quantizer. Not valid for 2D blockwise amax; on failure, **scalar** + max (drops channel structure). + """ + if not tensors: + raise ValueError("merge_amax_tensors_for_group: expected at least one tensor") + if len(tensors) == 1: + return tensors[0] + + first = tensors[0] + if all(t.shape == first.shape for t in tensors): + stacked = torch.stack([t.float() for t in tensors], dim=0) + return torch.amax(stacked, dim=0).to(dtype=first.dtype, device=first.device) + + try: + return torch.cat(tensors, dim=0).to(dtype=first.dtype, device=first.device) + except RuntimeError: + shapes = [tuple(t.shape) for t in tensors] + warnings.warn( + f"merge_amax_tensors_for_group: torch.cat failed for shapes {shapes}; " + "falling back to scalar max which loses per-channel amax structure.", + stacklevel=2, + ) + flat = torch.cat([t.reshape(-1).float() for t in tensors]) + return torch.max(flat).to(dtype=first.dtype, device=first.device) + + +@contextmanager +def _enable_writeback_for_group( + group: list[nn.Module], + root_model: nn.Module, + name_to_module: dict[str, nn.Module], +): + """Nest ``enable_weight_access_and_writeback`` for every module in ``group`` (one ``with``). + + The stdlib pattern for a *variable* number of context managers is :class:`ExitStack`; + wrapping it here keeps call sites readable. + """ + with ExitStack() as stack: + for m in group: + stack.enter_context(enable_weight_access_and_writeback(m, root_model, name_to_module)) + yield + + +def _resmooth_experts_for_export( + model: nn.Module, + state_dict: dict[str, Any] | None, + *, + inplace: bool = False, +) -> tuple[dict[str, tuple[torch.Tensor, torch.Tensor | None]], set[str]]: + """Prepare AWQ weights for vLLM fakequant export when several linears share one input quantizer. + + PTQ can assign a different ``pre_quant_scale`` per branch (per expert, or per + q/k/v projection) even though they see the same activation. vLLM’s fused kernels expose a + **single** input quantizer for that fused group, so reload must use one scale — otherwise + activations are scaled wrong for k/v or non-primary experts. + + For each group (MoE experts via ``get_experts_list``; dense shared-input linears + via ``collect_shared_input_modules`` / hooks), average ``pre_quant_scale``, set weights to + ``W' = W * old_pqs / avg_pqs`` so the net is unchanged, merge input ``amax`` where needed, + and return per-``input_quantizer`` tensor overrides for ``modelopt_state_weights``. + + Runs only for AWQ with **enabled** input quantizers (e.g. activation-aware); if inputs are + disabled and PQS was folded into weights only, there is nothing to unify. + + ``inplace=False`` — adjust a detached ``state_dict`` copy (``state_dict`` required). + ``inplace=True`` — pass ``state_dict=None``; update live ``nn.Parameter`` data under + ``_enable_writeback_for_group`` (nested writeback per module so offloaded/meta weights + materialize before ``copy_``). + """ + if not inplace and state_dict is None: + raise ValueError("state_dict is required when inplace=False") + qfmt = get_quantization_format(model) + if qfmt is None or "awq" not in qfmt.lower(): + return {}, set() + + name_to_module = dict(model.named_modules()) if inplace else None + + model_type = type(model).__name__.lower() + id_to_name: dict[int, str] = {id(m): n for n, m in model.named_modules()} + out: dict[str, tuple[torch.Tensor, torch.Tensor | None]] = {} + requant_weights: set[str] = set() + + def _process_group(modules: list[nn.Module]) -> None: + pqs_list = _collect_group_pre_quant_scales(modules) + if pqs_list is None: + return + + # Mean and clamp in float32: fp16/bf16 would underflow float32.tiny to 0 and divide by zero. + pqs_dtype = pqs_list[0].dtype + avg_pqs = torch.stack([p.float() for p in pqs_list]).mean(0) + avg_pqs = avg_pqs.clamp(min=torch.finfo(torch.float32).tiny) + + for m in modules: + nm = id_to_name.get(id(m)) + if nm is None or not hasattr(m, "weight"): + continue + w_key = f"{nm}.weight" + old_pqs = m.input_quantizer._pre_quant_scale + avg_pqs_dev = avg_pqs.to(device=old_pqs.device, dtype=old_pqs.dtype) + if torch.equal(old_pqs, avg_pqs_dev): + continue + if inplace: + w_param = m.weight + ratio = old_pqs.to(dtype=torch.float32, device=w_param.device) / avg_pqs.to( + device=w_param.device + ) + w_param.data.copy_((w_param.to(torch.float32) * ratio).to(w_param.dtype)) + else: + if state_dict is None: + raise RuntimeError( + "state_dict is required when inplace=False in _resmooth_experts_for_export" + ) + weight = state_dict[w_key] + ratio = old_pqs.to(dtype=torch.float32, device=weight.device) / avg_pqs.to( + device=weight.device + ) + state_dict[w_key] = (weight.to(torch.float32) * ratio).to(weight.dtype) + requant_weights.add(w_key) + + synced_amax: torch.Tensor | None = None + amaxes = [m.input_quantizer.amax for m in modules] + if all(a is not None for a in amaxes): + synced_amax = merge_amax_tensors_for_group(amaxes) + + avg_pqs_out = avg_pqs.detach().to(pqs_dtype).clone() + for m in modules: + nm = id_to_name.get(id(m)) + if nm is None: + continue + out[get_unwrapped_name(f"{nm}.input_quantizer", model)] = (avg_pqs_out, synced_amax) + + # MoE expert groups — must be enumerated by name because MoE routing sends + # different tokens to each expert, so forward hooks cannot detect them as + # sharing the same input tensor. + for _, module in model.named_modules(): + if not is_moe(module): + continue + try: + expert_groups = get_experts_list(module, model_type) + except NotImplementedError: + continue + for experts in expert_groups: + if not experts: + continue + if inplace: + if name_to_module is None: + raise RuntimeError( + "name_to_module is required when inplace=True in _resmooth_experts_for_export" + ) + with _enable_writeback_for_group(experts, model, name_to_module): + _process_group(experts) + else: + _process_group(experts) + + # Dense shared-input groups (e.g. q/k/v in GQA attention) — detected via forward + # hooks so any architecture is covered regardless of projection attribute names. + + dev = next(model.parameters()).device + + def _dummy_forward() -> None: + # Partial forward is OK: hooks record layers reached before failure. + with torch.inference_mode(): + try: + model(torch.ones([1, 2], dtype=torch.long, device=dev)) + except Exception as e: + logging.getLogger(__name__).debug( + "Dummy forward for shared-input detection failed (expected for VLMs): %s", e + ) + + input_to_linear, _ = collect_shared_input_modules(model, _dummy_forward) + for modules in input_to_linear.values(): + if len(modules) <= 1: + continue + if inplace: + if name_to_module is None: + raise RuntimeError( + "name_to_module is required when inplace=True in _resmooth_experts_for_export" + ) + with _enable_writeback_for_group(modules, model, name_to_module): + _process_group(modules) + else: + _process_group(modules) + + return out, requant_weights + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, + inplace_mem_efficient: bool = False, ): """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. Folds fake-quant weights into a ``state_dict()`` copy (optional ``pre_quant_scale`` into weight when input fake-quant is off), drops quantizer keys from the HF save, briefly disables weight quantizers to snapshot - ModelOpt/quantizer state, then re-enables them. Writes ``export_dir`` via - ``save_pretrained(..., save_modelopt_state=False)``. + ModelOpt/quantizer state, then re-enables them. Weight files are written with an + explicit ``state_dict`` (and ``hf_quantizer`` cleared during save) so safetensors + do not pick up live quantizer buffers. + + For MoE models with AWQ quantization, pre_quant_scale is averaged across experts + and input amax is unified — required because vLLM uses a single input quantizer + per expert group. By default this updates only a detached ``state_dict`` copy. + With ``inplace_mem_efficient=True``, resmooth runs **in place** on materialized + weight parameters only (no ``state_dict``), before the inplace fakequant loop. Args: model: In-memory quantized model. export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. + inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at + a time using ``enable_weight_access_and_writeback``, avoiding full state + dict materialization. This is destructive — model weights are permanently + modified and weight quantizers are not re-enabled after export. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - # Step 1: Build the folded HF state dict. - # model.state_dict() returns detached copies of all tensors, so model - # parameters are never modified. Apply each weight quantizer's fake-quant - # to the corresponding weight tensor in the copy. - state_dict = model.state_dict() - fakequant_weights = set() - input_quantizers_folded_pqs = ( - set() - ) # keys for input_quantizers where pre_quant_scale was folded + fakequant_weights: set[str] = set() + # Input quantizer keys whose _pre_quant_scale was folded into the weight above. + input_quantizers_folded_pqs: set[str] = set() with torch.inference_mode(): - for module_name, module in model.named_modules(): - if not isinstance(module, QuantModule): - continue - for attr_name, quantizer in module.named_children(): - if not ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.fake_quant - and quantizer.is_enabled - ): - continue - weight_name = attr_name.removesuffix("_quantizer") - prefix = f"{module_name}." if module_name else "" - sd_key = f"{prefix}{weight_name}" - assert sd_key not in fakequant_weights, ( - f"Weight {sd_key} has already been fakequantized" + if inplace_mem_efficient: + # Resmooth shared-input groups, then fakequant (state dict and/or params). + pqs_overrides, requant_weights = _resmooth_experts_for_export(model, None, inplace=True) + # Inplace path: iterate decoder layers, one offload<->onload per layer. + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + if decoder_layers is None: + raise RuntimeError( + "inplace_mem_efficient=True requires a model with discoverable decoder layers" ) - if sd_key in state_dict: - w = state_dict[sd_key] - w_quant = quantizer(w.float()).to(w.dtype).cpu() - # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) - # Only valid when input_quantizer does NOT fake-quant activations. If it does - # fake_quant(x*s), the non-linearity prevents folding s into W. - inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") - if hasattr(module, inp_attr): - inp_q = getattr(module, inp_attr) - if ( - hasattr(inp_q, "_pre_quant_scale") - and inp_q._pre_quant_scale is not None - and inp_q._disabled - ): - scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) - w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) - inp_q_key = get_unwrapped_name( - f"{module_name}.{inp_attr}" if module_name else inp_attr, model - ) - input_quantizers_folded_pqs.add(inp_q_key) - state_dict[sd_key] = w_quant - fakequant_weights.add(sd_key) - - # Filter quantizer tensors out for a clean HF checkpoint. - clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} + for name, module in model.named_modules(): + if module not in decoder_layers: + continue + with enable_weight_access_and_writeback(module, module): + for sub_name, sub_mod in module.named_modules(): + full_name = f"{name}.{sub_name}" if sub_name else name + _fakequant_module_weights( + sub_mod, + full_name, + model, + None, + input_quantizers_folded_pqs, + fakequant_weights, + requant_weights, + inplace=True, + ) + # Meta tensors for offloaded weights (free); offload maps now have + # fakequanted values via writeback. + state_dict = model.state_dict() + else: + state_dict = model.state_dict() + # Resmooth shared-input groups, then fakequant (state dict and/or params). + pqs_overrides, requant_weights = _resmooth_experts_for_export( + model, state_dict, inplace=False + ) + + # Default path: fakequant into the resmoothed state_dict copy (do not refresh + # from model.state_dict() or resmooth is lost). + for module_name, module in model.named_modules(): + with enable_weight_access_and_writeback(module, model): + _fakequant_module_weights( + module, + module_name, + model, + state_dict, + input_quantizers_folded_pqs, + fakequant_weights, + requant_weights, + inplace=False, + ) + + if inplace_mem_efficient: + # Let save_pretrained build its own state_dict so offloaded params go through + # its module_map / get_state_dict_from_offload path (modeling_utils.py:3967+). + # Passing state_dict= bypasses that path and crashes on meta tensors. + quantizer_keys = [k for k in state_dict if "quantizer" in k] + clean_sd = None + else: + clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} + quantizer_keys = None # Step 2: Disable weight quantizers, save modelopt state + quantizer state # dict, then re-enable. The _disabled=True flag is captured in modelopt_state @@ -116,54 +572,87 @@ def export_hf_vllm_fq_checkpoint( # attention quantizers remain active. # Rotation is also cleared: the weight was already folded with rotation applied, # so if fold_weight is called on reload it must not re-rotate the exported weight. - wqs_to_restore = [] - for _, module in model.named_modules(): - if isinstance(module, QuantModule): - for attr_name, quantizer in module.named_children(): - if ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.is_enabled - ): - quantizer.disable() - orig_rotate = quantizer._rotate - if quantizer.rotate_is_enabled: - quantizer._rotate = disable_rotate(quantizer) - wqs_to_restore.append((quantizer, orig_rotate)) - - quantizer_state_dict = get_quantizer_state_dict(model) - for key in list(quantizer_state_dict): - if key.endswith("weight_quantizer"): - # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. - quantizer_state_dict.pop(key) - elif key in input_quantizers_folded_pqs: - # pre_quant_scale was folded into the weight; keep the buffer for strict load but - # save identity so activations are not scaled twice. - qstate_val = quantizer_state_dict[key] - if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: - quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( - qstate_val["_pre_quant_scale"] - ) - modelopt_state = mto.modelopt_state(model) - # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild - # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). - qstate = quantizer_state(model) - for key in list(qstate): - if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): - qstate.pop(key) - - for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): - if mode_str == "quantize" and "metadata" in m_state: - m_state["metadata"]["quantizer_state"] = qstate - break - - # Per-quantizer tensor dict loaded alongside metadata on reload. - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") - - # Step 3: Save HF weights using the pre-built folded state dict. - model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + wqs_to_restore: list[tuple[TensorQuantizer, Any]] = [] + try: + for _, module in model.named_modules(): + if isinstance(module, QuantModule): + for attr_name, quantizer in module.named_children(): + if not (attr_name.endswith("weight_quantizer") and quantizer.is_enabled): + continue + if isinstance(quantizer, SequentialQuantizer): + quantizer.disable() + for sub in quantizer: + orig_rotate = sub._rotate + if sub.rotate_is_enabled: + sub._rotate = disable_rotate(sub) + wqs_to_restore.append((sub, orig_rotate)) + elif isinstance(quantizer, TensorQuantizer): + quantizer.disable() + orig_rotate = quantizer._rotate + if quantizer.rotate_is_enabled: + quantizer._rotate = disable_rotate(quantizer) + wqs_to_restore.append((quantizer, orig_rotate)) + + quantizer_state_dict = get_quantizer_state_dict(model) + for key in list(quantizer_state_dict): + if is_weight_quantizer_state_key(key): + # Fakequant amax is folded into HF weights; do not reload weight quantizer tensors. + # Reload must force-disable WQs missing from saved state (see + # ``filter_modelopt_state_quantizer_state_for_model`` assertion in vllm_reload_utils). + quantizer_state_dict.pop(key) + elif key in input_quantizers_folded_pqs: + # pre_quant_scale was folded into the weight; keep the buffer for strict load but + # save identity so activations are not scaled twice. + qstate_val = quantizer_state_dict[key] + if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: + quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( + qstate_val["_pre_quant_scale"] + ) + + # Patch input quantizers with averaged pqs and unified amax so that vLLM's single + # per-group input quantizer sees consistent values (covers both dense qkv and MoE experts). + for iq_key, (avg_pqs, max_input_amax) in pqs_overrides.items(): + if iq_key in quantizer_state_dict: + qstate_val = quantizer_state_dict[iq_key] + if isinstance(qstate_val, dict): + if "_pre_quant_scale" in qstate_val: + qstate_val["_pre_quant_scale"] = avg_pqs + if max_input_amax is not None and "_amax" in qstate_val: + qstate_val["_amax"] = max_input_amax + + modelopt_state = mto.modelopt_state(model) + _check_all_weight_quantizers_disabled(model) + # Rebuild quantizer_state from the live model (post-disable) and strip weight-quantizer + # entries. Apply to every mode that carries quantizer_state so that stale entries from + # a calibrate pass (which also stores quantizer_state in its metadata) are cleaned up. + # Reload synthesizes missing WQ rows with ``_disabled`` via + # ``filter_modelopt_state_quantizer_state_for_model``. + qstate = quantizer_state(model) + for key in list(qstate): + if is_weight_quantizer_state_key(key): + qstate.pop(key) + for _mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): + md = m_state.get("metadata", {}) + if "quantizer_state" in md: + md["quantizer_state"] = qstate + + # Per-quantizer tensor dict loaded alongside metadata on reload. + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") + + # Step 3: Save HF weights. + if inplace_mem_efficient: + prev_ignore = getattr(model, "_keys_to_ignore_on_save", None) + model._keys_to_ignore_on_save = quantizer_keys + try: + model.save_pretrained(export_dir, save_modelopt_state=False) + finally: + model._keys_to_ignore_on_save = prev_ignore + else: + model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + + finally: + if not inplace_mem_efficient: + for wq, orig_rotate in wqs_to_restore: + wq.enable() + wq._rotate = orig_rotate diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 22d87e303f..af936a3002 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -163,7 +163,7 @@ def _save_component_state_dict_safetensors( json.dump(metadata, f, indent=4) -def _collect_shared_input_modules( +def collect_shared_input_modules( model: nn.Module, dummy_forward_fn: Callable[[], None], collect_layernorms: bool = False, @@ -387,7 +387,7 @@ def llm_dummy_forward(): else: model(fake_input) - input_to_linear, output_to_layernorm = _collect_shared_input_modules( + input_to_linear, output_to_layernorm = collect_shared_input_modules( model, llm_dummy_forward, collect_layernorms=True ) @@ -862,7 +862,7 @@ def _fuse_qkv_linears_diffusion( # Collect modules sharing the same input try: - input_to_linear, _ = _collect_shared_input_modules( + input_to_linear, _ = collect_shared_input_modules( model, dummy_forward_fn, collect_layernorms=False ) except Exception as e: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 99c729efbc..3f24ac09a4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - use_sequential: bool = ModeloptField( + layerwise: bool = ModeloptField( default=False, - title="Enable sequential layer-by-layer calibration.", + title="Enable layerwise (layer-by-layer) calibration.", description=( - "If True, the calibration algorithm is applied sequentially to each decoder block. " - "Each layer's inputs are captured via a single forward pass that reflects the " + "If True, the calibration algorithm is applied layer by layer. " + "Each layer's inputs are captured via a forward pass that reflects the " "quantization of all preceding layers, incurring O(N) forward passes for N layers." ), ) + layerwise_checkpoint_dir: str | None = ModeloptField( + default=None, + title="Checkpoint directory for layerwise calibration.", + description=( + "If set together with layerwise=True, per-layer checkpoints are saved to this " + "directory during calibration. On restart, calibration resumes from the last " + "completed layer." + ), + ) + + @model_validator(mode="after") + def validate_layerwise_checkpoint_dir(self): + """Raise if layerwise_checkpoint_dir is set but layerwise is False.""" + if self.layerwise_checkpoint_dir is not None and not self.layerwise: + raise ValueError( + "layerwise_checkpoint_dir requires layerwise=True. " + "Set layerwise=True or remove layerwise_checkpoint_dir." + ) + return self + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index c81d5c89c7..713cdd7373 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -60,10 +60,10 @@ from .model_calib import ( awq, gptq, + layerwise_calibrate, local_hessian_calibrate, max_calibrate, mse_calibrate, - sequential_calibrate, smoothquant, svdquant, ) @@ -213,6 +213,7 @@ def wrapped_calib_func( config: QuantizeAlgorithmConfig, forward_loop: ForwardLoop | None = None, func: Callable | None = None, + supports_layerwise: bool = True, ) -> ConvertReturnType: """Wrap the calibration function to be compatible with the ModelOpt convert entrypoint. @@ -222,7 +223,8 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") - sequential = kwargs.pop("use_sequential", False) + layerwise = kwargs.pop("layerwise", False) + checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -237,17 +239,24 @@ def wrapped_calib_func( module._moe_calib_experts_ratio = moe_calib_experts_ratio if func is not None: - if sequential: + if layerwise: + # All currently implemented PTQ algorithms support layerwise calibration; + # future algorithms that need full-model context must add a guard here. + if not supports_layerwise: + raise ValueError( + f"Calibration algorithm '{method}' does not support layerwise=True. " + "Set layerwise=False, or override `_supports_layerwise = True` on the " + "corresponding CalibrateModeDescriptor once the algorithm is made " + "compatible with per-layer calibration." + ) if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max", "gptq"], ( - f"Sequential calibration currently only supports max and gptq calibration, got {method}" - ) - # Wrap with sequential processing - sequential_calibrate( + # Wrap with layerwise processing + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=func, + checkpoint_dir=checkpoint_dir, **kwargs, ) else: @@ -281,6 +290,10 @@ class BaseCalibrateModeDescriptor(ModeDescriptor): _calib_func: Callable | None + # Override to False when the algorithm requires full-model context and + # cannot run per decoder layer (e.g. needs ModeloptStateManager on the root). + _supports_layerwise: bool = True + def __init__(self, *args, **kwargs): """Initialize Base calibrate mode descriptor.""" assert issubclass(self.config_class, QuantizeAlgorithmConfig), ( @@ -326,7 +339,13 @@ def convert(self) -> ConvertEntrypoint: def wrapped_func(model, config, forward_loop=None): # Access _calib_func as a class attribute to avoid binding # Check if _calib_func is defined as a class attribute - return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func) + return wrapped_calib_func( + model, + config, + forward_loop, + func=self.__class__._calib_func, + supports_layerwise=self.__class__._supports_layerwise, + ) return wrapped_func @@ -485,6 +504,9 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: return SVDQuantConfig _calib_func = svdquant + # create_and_replace_svdquant_linear_on_the_fly reads ModeloptStateManager from the + # root model, which is not present when layerwise_calibrate dispatches per decoder layer. + _supports_layerwise = False @property def restore(self) -> RestoreEntrypoint: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..6db1e82cbd 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,7 +28,10 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _CheckpointState, +) from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method @@ -44,6 +47,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + persistent_materialization, promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, @@ -53,9 +57,9 @@ __all__ = [ "awq", + "layerwise_calibrate", "local_hessian_calibrate", "max_calibrate", - "sequential_calibrate", "smoothquant", "svdquant", ] @@ -1552,21 +1556,27 @@ def postprocess(module, name): @torch.no_grad() -def sequential_calibrate( +def layerwise_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, **calib_kwargs, ): - """Sequential calibration - a sequential layer-by-layer calibration algorithm. + """Layerwise calibration - a layer-by-layer calibration algorithm. Runs the full model forward per layer but patches decoder layers with a skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. + + If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints + are saved after each layer completes. On restart, calibration resumes from + the last completed layer. """ + checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) + if forward_loop is None: raise ValueError( - "forward_loop must not be None for sequential calibration. " + "forward_loop must not be None for layerwise calibration. " "Please provide a valid forward_loop callable." ) @@ -1574,31 +1584,57 @@ def sequential_calibrate( if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( "Could not find transformer layers in model. " - "Sequential calibration requires a model with identifiable transformer layers." + "Layerwise calibration requires a model with identifiable transformer layers." ) - print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + num_layers = len(transformer_layers) + print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers") + + ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers) + start_layer = ckpt.start_layer if ckpt else 0 input_getter = LayerActivationCollector(model) input_getter._patch_all_layers(decoder_layers=transformer_layers) + resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None + try: - for layer_idx, layer in enumerate(transformer_layers): - print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") - layer_inputs = input_getter.get_input_activations(layer, forward_loop) + # Bootstrap: get first layer's inputs (or use resumed inputs). + layer_inputs = input_getter.get_first_layer_inputs( + start_layer, resumed_inputs, forward_loop + ) + + for layer_idx in range(start_layer, num_layers): + layer = transformer_layers[layer_idx] def _layer_forward_loop(m, _inputs=layer_inputs): for args, kwargs_input in _inputs: m(*args, **kwargs_input) - calib_func(layer, _layer_forward_loop, **calib_kwargs) + with persistent_materialization(layer): + calib_func(layer, _layer_forward_loop, **calib_kwargs) + + # Run one more forward to get next layer's inputs and set + # output_meta on the just-calibrated layer (via "run" mode). + is_last = layer_idx + 1 >= num_layers + if not is_last: + next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) + else: + next_inputs = None + + if ckpt: + ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs) del layer_inputs torch.cuda.empty_cache() + layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure) finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed") + if ckpt: + ckpt.full_restore(transformer_layers, model) + + print_rank_0("Layerwise calibration completed") @torch.no_grad() @@ -1610,12 +1646,12 @@ def gptq( ): """GPTQ quantization. - Works in two modes depending on ``use_sequential`` in the config: + Works in two modes depending on ``layerwise`` in the config: - * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + * **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this function once per decoder layer with updated activations, producing more accurate Hessian estimates. - * **Non-sequential** (``use_sequential=False``): called once on the full model. + * **Non-layerwise** (``layerwise=False``): called once on the full model. All layers are quantized in parallel from the original activations. Per-module steps: @@ -1628,7 +1664,7 @@ def gptq( Args: model: The module to quantize — either the full model or a single decoder - layer when invoked by ``sequential_calibrate``. + layer when invoked by ``layerwise_calibrate``. forward_loop: Callable that replays calibration inputs through *model*. perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. @@ -1663,8 +1699,10 @@ def gptq( handle.cleanup() print_rank_0("Updating weights using GPTQ algorithm...") + name_to_module = dict(model.named_modules()) for handle in gptq_handles.values(): - handle.update_weights(block_size, perc_damp) + with enable_weight_access_and_writeback(handle.module, model, name_to_module): + handle.update_weights(block_size, perc_damp) handle.free() del gptq_handles diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index 13999df0f0..f80e2478dc 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -31,51 +31,77 @@ __all__ = ["init_quantized_weights"] -def _get_cpu_offload_hook(hook): +def _get_offload_hook(hook): if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None: - assert "weight" in hook.weights_map - if ( - isinstance(hook.weights_map, PrefixedDataset) - and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict - ): - raise NotImplementedError( - "This layer could be offloaded to disk. We don't support this yet." - ) + assert len(hook.weights_map) > 0 return hook elif isinstance(hook, SequentialHook): for h in hook.hooks: - align_hook = _get_cpu_offload_hook(h) + align_hook = _get_offload_hook(h) if align_hook is not None: return align_hook return None +def _writeback_params_to_weights_map(module, align_hook): + """Write all non-meta parameters and buffers back to the hook's CPU weights_map.""" + for name, tensor in module.state_dict(keep_vars=True).items(): + if tensor.device.type == "meta": + continue + if isinstance(align_hook.weights_map, PrefixedDataset): + key = align_hook.weights_map.prefix + name + w_map = align_hook.weights_map.dataset.state_dict + else: + w_map = align_hook.weights_map + key = name + if key in w_map: + w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype) + elif ( + isinstance(align_hook.weights_map, PrefixedDataset) + and hasattr(align_hook.weights_map.dataset, "index") + and key in align_hook.weights_map.dataset.index + ): + # Disk-offloaded weight: promote into state_dict so the next + # pre_forward picks up the modified tensor instead of the stale + # on-disk version. OffloadedWeightsLoader.__getitem__ gives + # state_dict priority over index, so this is sufficient. + w_map[key] = tensor.detach().cpu() + + @contextmanager def weight_access_and_writeback_context(module): - """Context manager for weight access and writeback for modules managed by accelerate.""" + """Context manager for weight access and writeback for modules managed by accelerate. + + Handles CPU-offloaded and disk-offloaded models. Iterates over the module and all + its descendants, materializing weights from any offload hook found and writing them + back on exit. ``pre_forward`` is skipped on modules whose weights are already + materialized (not on meta) to avoid overwriting them with stale CPU copies. + """ assert hasattr(module, "_hf_hook") - align_hook = _get_cpu_offload_hook(module._hf_hook) - if align_hook: - # Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass - # The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader - # See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141 - # TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority) + materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = [] + for mod in module.modules(): + if not hasattr(mod, "_hf_hook"): + continue + hook = _get_offload_hook(mod._hf_hook) + if hook is None: + continue + # Only call pre_forward if weights need materializing; already-materialized + # weights would be overwritten with stale CPU state_dict values. + needs_materialize = any(p.device.type == "meta" for p in mod.parameters()) + if needs_materialize: + hook.pre_forward(mod) + hook.offload = False + materialized.append((mod, hook, needs_materialize)) - # This will load the weights from CPU state_dict and move it to the GPU from meta device - align_hook.pre_forward(module) try: yield finally: - if align_hook: - # Update the weight in the CPU state_dict - if isinstance(align_hook.weights_map, PrefixedDataset): - key = align_hook.weights_map.prefix + "weight" - w_map = align_hook.weights_map.dataset.state_dict - else: - key, w_map = "weight", align_hook.weights_map - w_map[key] = module.weight.data.to(w_map[key].device, dtype=w_map[key].dtype) - align_hook.post_forward(module, None) + for mod, hook, was_materialized in materialized: + hook.offload = True + _writeback_params_to_weights_map(mod, hook) + if was_materialized: + hook.post_forward(mod, None) @contextmanager diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 82ab589934..59bcd215bb 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -39,7 +39,7 @@ from ..nn.modules.quant_linear import _QuantLinear from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE from ..utils import replace_function, sync_moe_expert_amax -from ..utils.activation_collector import LayerActivationCollector +from ..utils.layerwise_calib import LayerActivationCollector from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index 2660363209..dfc23c42ee 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -16,8 +16,8 @@ # ruff: noqa: F405 """Quantization utilities.""" -from .activation_collector import LayerActivationCollector from .core_utils import * +from .layerwise_calib import LayerActivationCollector __all__ = [ "EXPORT_MODE", diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py deleted file mode 100644 index 5f187fdcb2..0000000000 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ /dev/null @@ -1,335 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sequential calibration layer patching and activation capture. - -This module provides :class:`LayerActivationCollector`, a stateful helper that -patches decoder layers with a skip / run / capture strategy for efficient -layer-by-layer calibration. -""" - -from collections import deque -from dataclasses import dataclass, field -from typing import Any - -import torch -import torch.nn as nn - -from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.utils import print_rank_0 -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method - - -class _EarlyStopForwardError(Exception): - """Raised to halt the forward pass after capturing layer inputs.""" - - -@dataclass -class _LayerCalibState: - """Mutable per-layer state used during sequential calibration. - - Attached to each decoder layer as ``_seq_calib`` and accessed by the - patched forward to decide skip / run / capture / original behaviour. - """ - - mode: str = "original" - name: str = "" - cached_inputs: deque = field(default_factory=deque) - collected_inputs: list = field(default_factory=list) - output_meta: tuple | None = None - - -class LayerActivationCollector: - """Collects layer activations for sequential (layer-by-layer) calibration. - - Each decoder layer is patched with a unified forward whose behaviour is - governed by a per-layer :class:`_LayerCalibState`: - - * **skip** — return a zero-filled dummy whose shape and type match the - layer's real output (reconstructed from lightweight metadata). No - computation is performed. The correctly shaped dummy ensures un-patched - inter-layer operations in the parent forward (e.g. LayerNorm, tuple - unpacking) do not raise shape or type errors. - * **run** — replay previously captured inputs through the original forward, - ignoring whatever the parent passes in. Only the just-calibrated layer - uses this mode, so its output reflects updated weights. - * **capture** — record ``(args, kwargs)`` and raise - ``_EarlyStopForwardError`` to halt the forward pass early. - * **original** — call the original forward unchanged. - - Because the *run* layer discards upstream values, skip-layer outputs are - never consumed for real computation. - """ - - # Global registry of (predicate, discoverer) pairs. Populated at import time - # by plugins (e.g. huggingface.py, megatron.py). Order matters: the first - # matching entry wins, so more specific predicates (e.g. Nemotron-H) must be - # registered before generic ones (e.g. homogeneous HF models). - # - # This is intentionally a mutable class variable shared across all instances: - # plugins register once at import time, and the registry is read-only after - # that. register_decoder_layer_support() guards against duplicate entries. - _decoder_layer_support: list[tuple[Any, Any]] = [] - _LAYER_ATTR = "_seq_calib" - - def __init__(self, model: nn.Module): - """Initialize the collector for the given model.""" - self.model = model - self._decoder_layers: nn.ModuleList | None = None - self._layer_to_idx: dict[nn.Module, int] = {} - self._patched = False - - @staticmethod - def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: - """Return decoder layers supported by sequential calibration.""" - for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: - if not is_supported(model): - continue - decoder_layers = discoverer(model) - if decoder_layers is not None: - return decoder_layers - return None - - @staticmethod - def is_supported(model: nn.Module) -> bool: - """Whether the model supports decoder-layer sequential calibration.""" - return LayerActivationCollector.get_decoder_layers(model) is not None - - @classmethod - def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): - """Register a (predicate, discoverer) pair for decoder-layer detection.""" - entry = (is_supported, discoverer) - if entry not in cls._decoder_layer_support: - cls._decoder_layer_support.append(entry) - - @staticmethod - def _extract_output_meta(output): - """Extract lightweight (shape, dtype, device) metadata from a layer output. - - Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). - The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a - zero-filled output with identical shape and type. - """ - if isinstance(output, torch.Tensor): - return ("tensor", output.shape, output.dtype, output.device) - if isinstance(output, tuple): - return ( - "tuple", - tuple(LayerActivationCollector._extract_output_meta(o) for o in output), - ) - if isinstance(output, list): - return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) - return ("other", output) - - @staticmethod - def _zeros_from_meta(meta): - """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" - tag = meta[0] - if tag == "tensor": - _, shape, dtype, device = meta - return torch.zeros(shape, dtype=dtype, device=device) - if tag == "tuple": - return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) - if tag == "list": - return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] - # "other" values are expected to be lightweight non-tensors (e.g. None, small scalars). - # The value is returned directly (not copied); callers must not mutate it. - # In practice this is safe because skip-mode outputs are immediately discarded by the - # downstream run-mode layer, which replays from its own cached inputs instead. - return meta[1] - - def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): - """Bind the unified forward to every decoder layer and the model. Called once. - - Args: - decoder_layers: Pre-resolved decoder layers. If *None*, layers are - discovered via :meth:`get_decoder_layers`. - """ - - def _patched_forward(self, *args, **kwargs): - """Unified forward bound to every decoder layer during sequential calibration. - - ``self`` here is the decoder layer module (bound via ``bind_forward_method``). - All per-layer state is accessed through ``self._seq_calib``. - """ - info: _LayerCalibState = self._seq_calib - - if info.mode == "skip": - if info.output_meta is None: - raise RuntimeError( - f"Layer {info.name} is in 'skip' mode but has no output_meta. " - "This indicates a state-machine bug: the layer should have run " - "in 'run' mode (which sets output_meta) before transitioning to 'skip'." - ) - return LayerActivationCollector._zeros_from_meta(info.output_meta) - - if info.mode == "run": - assert info.cached_inputs, ( - f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." - ) - real_args, real_kwargs = info.cached_inputs.popleft() - output = self._original_forward(*real_args, **real_kwargs) - info.output_meta = LayerActivationCollector._extract_output_meta(output) - return output - - if info.mode == "capture": - info.collected_inputs.append((args, kwargs)) - raise _EarlyStopForwardError() - - return self._original_forward(*args, **kwargs) - - if decoder_layers is not None: - self._decoder_layers = decoder_layers - else: - self._decoder_layers = self.get_decoder_layers(self.model) - assert self._decoder_layers is not None - - self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} - module_to_name = {m: name for name, m in self.model.named_modules()} - - try: - for layer in self._decoder_layers: - layer._seq_calib = _LayerCalibState( - name=module_to_name.get(layer, type(layer).__name__), - ) - bind_forward_method(layer, _patched_forward, "_original_forward") - - def _early_stop_forward(module_self, *args, **kwargs): - try: - return module_self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None - - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - except Exception: - self._cleanup_layers() - raise - - self._patched = True - - def _cleanup_layers(self): - """Best-effort cleanup of any patched layers and model forward.""" - if hasattr(self.model, "_original_forward"): - unpatch_forward_method(self.model, "_original_forward") - - if self._decoder_layers is not None: - for layer in self._decoder_layers: - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, self._LAYER_ATTR): - delattr(layer, self._LAYER_ATTR) - - def _unpatch_all_layers(self): - """Restore original forwards and clean up state attributes. Called once.""" - if not self._patched: - return - self._cleanup_layers() - self._patched = False - - def _set_layer_states(self, layer_idx: int): - """Transition layer modes for the next calibration step. - - When calibrating layer *i*, three transitions happen: - - * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). - * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). - * Layer ``i`` → **capture** (record inputs, then early-stop). - """ - assert self._decoder_layers is not None - - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - # output_meta is intentionally kept: skip mode needs it to produce - # correctly shaped zero-filled outputs for the parent forward. - done.mode = "skip" - done.cached_inputs.clear() - - if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - if not prev.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " - "Layers must be calibrated sequentially — ensure get_input_activations() " - "was called for every preceding layer in order." - ) - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] - - def _log_layer_summary(self, layer_idx: int): - """Log a one-line summary of layer modes for the current calibration step.""" - assert self._decoder_layers is not None - n = len(self._decoder_layers) - groups: dict[str, list[int]] = {} - for i, layer in enumerate(self._decoder_layers): - mode = layer._seq_calib.mode - if mode in ("skip", "run", "capture"): - groups.setdefault(mode, []).append(i + 1) - parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] - print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - """Collect input activations for *layer* by running a full model forward. - - Layers before the target are skipped or re-run (if just calibrated), the - target layer captures its inputs, and an early-stop prevents unnecessary - computation beyond the target. - - :meth:`_patch_all_layers` must be called before this method. - - Note: the model forward returns ``None`` for every batch during capture - (because ``_EarlyStopForwardError`` short-circuits the forward pass). - Callers should not rely on the model's return value within *forward_loop*. - """ - if not self._patched: - raise RuntimeError( - "get_input_activations() requires _patch_all_layers() to be called first." - ) - layer_idx = self._layer_to_idx[layer] - self._set_layer_states(layer_idx) - self._log_layer_summary(layer_idx) - - info = layer._seq_calib - try: - forward_loop(self.model) - except Exception: - # Reset the current layer so subsequent calls don't see stale state. - info.mode = "original" - info.collected_inputs = [] - raise - - if not info.collected_inputs: - info.mode = "original" - raise RuntimeError( - f"Layer {info.name!r} collected no inputs during forward_loop. " - "The forward loop did not reach this layer — check that forward_loop() " - "actually calls the model and that the layer is in the forward path." - ) - - inputs = list(info.collected_inputs) - # After capture, set to original so calib_func can call the layer's - # real forward directly. The layer will transition to run → skip - # in subsequent iterations via _set_layer_states. - info.mode = "original" - return inputs diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index e52a8438d5..e2d8ccf2a2 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -96,7 +96,7 @@ def __init__(self, module, name, offload_to_cpu=False): self.name = name in_features = module.weight.shape[-1] device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + if device.type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65): device = "cpu" self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) self.n_samples = 0 diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 273d7564c6..29661e18f5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -423,47 +423,70 @@ def _get_enclosing_fsdp_module( return root_model +def _set_parameter(module: nn.Module, name: str, value: nn.Parameter): + """Set a parameter on a module by dotted name (e.g. ``self_attn.q_proj.weight``).""" + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent = module.get_submodule(parts[0]) + attr = parts[1] + else: + parent = module + attr = name + parent._parameters[attr] = value + + @contextmanager def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module): """Context manager for FSDP2 weight access and writeback. - Note this context will gather the weight across FSDP/HSDP shards. If TP is implemented with DTensor, - the weight will be a local tensor of the TP DTensor under this context. + Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be + read or modified. Works for both leaf modules (single ``weight``) and + composite modules like decoder layers (all ``named_parameters``). + + If TP is implemented with DTensor, the weight will be a local tensor of the + TP DTensor under this context. """ assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2" assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" - assert isinstance(module.weight, torch.distributed.tensor.DTensor) fsdp_module = _get_enclosing_fsdp_module(module, root_model) assert fsdp_module is not None, "Module is not wrapped by FSDP" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim - original_placements = module.weight.placements - original_device_mesh = module.weight.device_mesh - original_weight = module.weight - # Assuming the first fsdp_dim dimensions are for FSDP/HSDP, we only collect the tensor over FSDP/HSDP dimension, - # the TP will be handled by the TP reduction. - if fsdp_dim != original_device_mesh.ndim: - assert fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim], ( - "FSDP2 mesh should be a slice of DTesnor's device mesh." + # Collect all DTensor parameters, replacing them with local replicated copies. + originals: dict[str, tuple] = {} + for name, param in module.named_parameters(): + if not isinstance(param, torch.distributed.tensor.DTensor): + continue + original_placements = param.placements + original_device_mesh = param.device_mesh + if fsdp_dim != original_device_mesh.ndim: + assert ( + fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim] + ), "FSDP2 mesh should be a slice of DTensor's device mesh." + collected = param.redistribute( + placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), + device_mesh=original_device_mesh, ) - - weight_collected = original_weight.redistribute( - placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), - device_mesh=original_device_mesh, - ) - new_weight = nn.Parameter(weight_collected.to_local()) - module._parameters["weight"] = new_weight + originals[name] = (param, collected, original_placements, original_device_mesh) + _set_parameter(module, name, nn.Parameter(collected.to_local())) yield - original_weight.to_local().data.copy_( - weight_collected.redistribute( - placements=original_placements, device_mesh=original_device_mesh - ).to_local() - ) - module._parameters["weight"] = original_weight + # Write back and restore original DTensor parameters. + for name, ( + original_param, + collected, + original_placements, + original_device_mesh, + ) in originals.items(): + original_param.to_local().data.copy_( + collected.redistribute( + placements=original_placements, device_mesh=original_device_mesh + ).to_local() + ) + _set_parameter(module, name, original_param) @contextmanager @@ -471,7 +494,7 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or - HF accelerate CPU off-loaded models. + HF accelerate offloaded models (CPU or disk). Args: module: The module to access weights for. @@ -498,6 +521,22 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict yield +@contextmanager +def persistent_materialization(layer): + """Keep all layer weights materialized on GPU for the duration. + + Suppresses per-forward weight transfers so that N calibration batches + pay the cost of one load/unload instead of N. + + - **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then + gathers weights once via ``enable_weight_access_and_writeback``. + - **Accelerate**: materializes weights and sets ``hook.offload = False`` + so per-forward hooks skip materialization/offloading. + """ + with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer): + yield + + def get_quantizer_state_dict(model: nn.Module): """Get the state dict of the quantizers in the model.""" # We should not call model.state_dict() here. @@ -607,6 +646,24 @@ def _init_mp_dtypes(self) -> None: ) +@contextmanager +def _disable_fsdp_unshard_reshard(layer): + """Disable FSDP2 unshard/reshard if *layer* is FSDP-wrapped.""" + if isinstance(layer, FSDPModule): + _pg_cls = torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup + orig_unshard = _pg_cls.unshard + orig_reshard = _pg_cls.reshard + _pg_cls.unshard = lambda self, async_op=False: None + _pg_cls.reshard = lambda self: None + try: + yield + finally: + _pg_cls.unshard = orig_unshard + _pg_cls.reshard = orig_reshard + else: + yield + + def get_prefixed_param_names(parent_model, target_module): """Get parameter names for a target module prefixed with the parent model name. diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py new file mode 100644 index 0000000000..aed403ad87 --- /dev/null +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -0,0 +1,684 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layerwise calibration layer patching, activation capture, and checkpoint save/resume. + +This module provides :class:`LayerActivationCollector`, a stateful helper that +patches decoder layers with a skip / run / capture strategy for efficient +layer-by-layer calibration, and :class:`_CheckpointState` for persisting +per-layer calibration progress to disk. +""" + +from __future__ import annotations + +import json +import os +import shutil +from collections import deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn + +from modelopt.torch.utils import distributed as dist +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import ( + bind_forward_method, + get_module_device, + unpatch_forward_method, +) + +if TYPE_CHECKING: + from modelopt.torch.opt.searcher import ForwardLoop + + +class _EarlyStopForwardError(Exception): + """Raised to halt the forward pass after capturing layer inputs.""" + + +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during layerwise calibration. + + Attached to each decoder layer as ``_layerwise_calib`` and accessed by the + patched forward to decide skip / run / capture / original behaviour. + """ + + mode: str = "original" + name: str = "" + cached_inputs: deque = field(default_factory=deque) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None + + +class _SkipLayer(nn.Module): + """Parameter-free stand-in for a fully calibrated decoder layer. + + Replaces the real layer in the ModuleList so that framework hooks + (accelerate, FSDP2, etc.) have no parameters to transfer. Holds a + reference to the original layer for restoration during cleanup. + """ + + def __init__(self, original: nn.Module): + super().__init__() + # Bypass nn.Module.__setattr__ to avoid registering original as a submodule. + object.__setattr__(self, "_original", original) + self._layerwise_calib = _LayerCalibState(mode="skip") + + _PROXY_BLOCKLIST = frozenset({"_hf_hook", "_old_forward"}) + + def __getattr__(self, name: str): + # Proxy non-special attribute lookups to the original layer so that + # parent-model code that accesses layer-level attributes (e.g., + # NemotronH's ``block_type``) still works when the layer is replaced + # with a _SkipLayer. Accelerate hook attrs are blocked so the + # framework does not attempt to manage this parameter-free stand-in. + try: + return super().__getattr__(name) + except AttributeError: + if name in self._PROXY_BLOCKLIST: + raise + return getattr(object.__getattribute__(self, "_original"), name) + + def forward(self, *args, **kwargs): + return LayerActivationCollector._zeros_from_meta( + self._original._layerwise_calib.output_meta + ) + + +class LayerActivationCollector: + """Collects layer activations for layerwise (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **original** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. + """ + + _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_layerwise_calib" + + def __init__(self, model: nn.Module): + """Initialize the collector for the given model.""" + self.model = model + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + def _swap_to_dummy(self, idx: int): + """Replace decoder layer *idx* with a parameter-free dummy. + + ``output_meta`` is intentionally preserved on the original layer: the + ``_SkipLayer`` reads it to produce correctly shaped zero-filled outputs + for the parent forward pass. + """ + assert self._decoder_layers is not None + layer = self._decoder_layers[idx] + layer._layerwise_calib.mode = "skip" + layer._layerwise_calib.cached_inputs.clear() + self._decoder_layers[idx] = _SkipLayer(layer) + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by layerwise calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer layerwise calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + """Register a (predicate, discoverer) pair for decoder-layer detection.""" + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + # "other" values are lightweight non-tensors (e.g. None, small scalars). + # Returned directly (not copied); safe because skip-mode outputs are + # immediately discarded by the downstream run-mode layer. + return meta[1] + + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + """Bind the unified forward to every decoder layer and the model. Called once. + + Args: + decoder_layers: Pre-resolved decoder layers. If *None*, layers are + discovered via :meth:`get_decoder_layers`. + """ + + def _patched_forward(self, *args, **kwargs): + info: _LayerCalibState = self._layerwise_calib + + if info.mode == "skip": + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) + + if info.mode == "run": + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) + real_args, real_kwargs = info.cached_inputs.popleft() + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + info.collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + + return self._original_forward(*args, **kwargs) + + if decoder_layers is not None: + self._decoder_layers = decoder_layers + else: + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} + + try: + for layer in self._decoder_layers: + layer._layerwise_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, _patched_forward, "_original_forward") + + def _early_stop_forward(module_self, *args, **kwargs): + try: + return module_self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise + + self._patched = True + + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if self._decoder_layers is not None: + for idx, layer in enumerate(self._decoder_layers): + if isinstance(layer, _SkipLayer): + self._decoder_layers[idx] = layer._original + + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + self._cleanup_layers() + self._patched = False + + def _set_layer_states(self, layer_idx: int): + """Transition layer modes for the next calibration step. + + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). + """ + assert self._decoder_layers is not None + + if layer_idx > 1: + idx = layer_idx - 2 + if not isinstance(self._decoder_layers[idx], _SkipLayer): + self._swap_to_dummy(idx) + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1]._layerwise_calib + if not prev.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " + "Layers must be calibrated sequentially — ensure get_input_activations() " + "was called for every preceding layer in order." + ) + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._layerwise_calib + cur.mode = "capture" + cur.collected_inputs = [] + + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._layerwise_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i + 1) + + parts = [] + for mode in ("skip", "run", "capture"): + if mode not in groups: + continue + ids = groups[mode] + parts.append(f"{mode}: {len(ids)}" if mode == "skip" else f"{mode}: {ids}") + print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for *layer* by running a full model forward. + + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. + + Note: the model forward returns ``None`` for every batch during capture + (because ``_EarlyStopForwardError`` short-circuits the forward pass). + Callers should not rely on the model's return value within *forward_loop*. + """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + self._log_layer_summary(layer_idx) + + info = layer._layerwise_calib + try: + forward_loop(self.model) + except Exception: + # Reset the current layer so subsequent calls don't see stale state. + info.mode = "original" + info.collected_inputs = [] + raise + + if not info.collected_inputs: + info.mode = "original" + raise RuntimeError( + f"Layer {info.name!r} collected no inputs during forward_loop. " + "The forward loop did not reach this layer — check that forward_loop() " + "actually calls the model and that the layer is in the forward path." + ) + + inputs = list(info.collected_inputs) + # Reset to original so calib_func can call the layer's real forward + # directly. The layer will transition to run → skip in subsequent + # iterations via _set_layer_states. + info.mode = "original" + return inputs + + def get_first_layer_inputs( + self, + start_layer: int, + resumed_inputs: list | None, + forward_loop: ForwardLoop, + ) -> list: + """Get inputs for the first layer to calibrate, handling resume. + + If *resumed_inputs* is provided, sets skip mode on layers ``0..start_layer-1`` + and seeds the start layer's ``collected_inputs`` for subsequent + ``cache_outputs_for_next_layer_calib`` calls. Otherwise, captures inputs + via a normal forward pass. + """ + assert self._decoder_layers is not None + + if resumed_inputs is not None: + print_rank_0(f"Calibrating layer {start_layer + 1} (resumed)") + for i in range(start_layer): + self._swap_to_dummy(i) + layer = self._decoder_layers[start_layer] + layer._layerwise_calib.collected_inputs = resumed_inputs + layer._layerwise_calib.mode = "original" + return resumed_inputs + + return self.get_input_activations(self._decoder_layers[start_layer], forward_loop) + + @torch.no_grad() + def cache_outputs_for_next_layer_calib( + self, layer: torch.nn.Module, forward_loop: ForwardLoop + ) -> list: + """Run a forward pass after calibrating *layer* to capture the next layer's inputs. + + This puts *layer* into "run" mode (setting its ``output_meta``) and the + next layer into "capture" mode, then runs *forward_loop*. Returns the + captured inputs for the next layer. + + Must be called only when a next layer exists (i.e. *layer* is not the + last decoder layer). + """ + assert self._decoder_layers is not None + layer_idx = self._layer_to_idx[layer] + next_idx = layer_idx + 1 + assert next_idx < len(self._decoder_layers), "No next layer to capture inputs for." + from .core_utils import persistent_materialization + + next_layer = self._decoder_layers[next_idx] + with persistent_materialization(layer): + return self.get_input_activations(next_layer, forward_loop) + + +def _move_to_device(obj: Any, device: torch.device) -> Any: + """Recursively move tensors to *device*. Non-tensors are returned as-is.""" + if isinstance(obj, torch.Tensor): + return obj.to(device) + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + moved = [_move_to_device(v, device) for v in obj] + return type(obj)(moved) + return obj + + +def _remap_output_metadata_device(meta: tuple, device: torch.device) -> tuple: + """Patch the device field inside output_meta tuples so _zeros_from_meta uses *device*.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, _old_device = meta + return ("tensor", shape, dtype, device) + if tag == "tuple": + return ("tuple", tuple(_remap_output_metadata_device(m, device) for m in meta[1])) + if tag == "list": + return ("list", [_remap_output_metadata_device(m, device) for m in meta[1]]) + return meta + + +def _read_manifest(checkpoint_dir: str) -> dict | None: + """Read manifest.json from *checkpoint_dir*. Returns None if missing or corrupt.""" + path = os.path.join(checkpoint_dir, "manifest.json") + if not os.path.isfile(path): + return None + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + +def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: + """Atomically write manifest.json.""" + path = os.path.join(checkpoint_dir, "manifest.json") + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump( + {"last_completed_layer": last_completed_layer, "num_layers": num_layers}, + f, + ) + os.replace(tmp, path) + + +def _layer_dir(checkpoint_dir: str, idx: int) -> str: + return os.path.join(checkpoint_dir, f"layer_{idx:04d}") + + +def _save_layer( + checkpoint_dir: str, + idx: int, + weights: dict, + qstate: dict, + output_meta: tuple, + next_inputs: list | None, + num_layers: int, +) -> None: + """Save a single layer checkpoint and update the manifest atomically.""" + d = _layer_dir(checkpoint_dir, idx) + if os.path.isdir(d): + shutil.rmtree(d) + os.makedirs(d) + torch.save(weights, os.path.join(d, "weights.pt")) + torch.save(qstate, os.path.join(d, "quantizer_state.pt")) + torch.save(output_meta, os.path.join(d, "output_meta.pt")) + if next_inputs is not None: + torch.save(next_inputs, os.path.join(d, "next_inputs.pt")) + _write_manifest(checkpoint_dir, idx, num_layers) + + +def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: + """Detect where to resume from an existing checkpoint directory. + + Returns ``(start_layer, manifest)`` if there is work to resume, + or ``None`` if the directory is empty, corrupt, or calibration was already complete. + """ + manifest = _read_manifest(checkpoint_dir) + if manifest is None: + return None + last = manifest.get("last_completed_layer") + total = manifest.get("num_layers") + if last is None or total is None: + return None + if last + 1 >= total: + return None + return (last + 1, manifest) + + +class _CheckpointState: + """Manages checkpoint save and restore for layerwise calibration. + + Handles both saving per-layer checkpoints during calibration and + restoring from a previous partial run. + + .. todo:: + Support distributed checkpoint save/restore for FSDP2: + use ``torch.distributed.checkpoint`` (or save only from rank 0 + barrier) + and broadcast restored state to all ranks during resume. + """ + + def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): + if dist.is_initialized() and dist.size() > 1: + raise RuntimeError( + "Layerwise calibration checkpointing is not supported in " + "multi-process distributed jobs (e.g. FSDP2). " + "Use single-process calibration or disable checkpointing." + ) + + self.checkpoint_dir = checkpoint_dir + self.num_layers = num_layers + self.start_layer = start_layer + + @classmethod + def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None: + """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" + if not checkpoint_dir: + return None + os.makedirs(checkpoint_dir, exist_ok=True) + info = detect_resume_point(checkpoint_dir) + if info is not None: + manifest_num_layers = info[1].get("num_layers") + if manifest_num_layers is not None and manifest_num_layers != num_layers: + raise ValueError( + f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} " + f"but model has {num_layers}. Use a fresh checkpoint directory." + ) + start = info[0] if info else 0 + if start > 0: + print_rank_0( + f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" + ) + return cls(checkpoint_dir, num_layers, start_layer=start) + + def setup_resume(self, layers: nn.ModuleList) -> list | None: + """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. + + Sets ``output_meta`` on each already-calibrated layer so that + skip mode can produce correctly shaped dummy outputs. + """ + if self.start_layer == 0: + return None + + last_ckpt = self.start_layer - 1 + + for i in range(self.start_layer): + d = _layer_dir(self.checkpoint_dir, i) + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + meta = torch.load( + os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False + ) + layer_device = get_module_device(layers[i]) + meta = _remap_output_metadata_device(meta, layer_device) + layers[i]._layerwise_calib.output_meta = meta + + d = _layer_dir(self.checkpoint_dir, last_ckpt) + next_inputs_path = os.path.join(d, "next_inputs.pt") + if not os.path.isfile(next_inputs_path): + raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) + resume_device = get_module_device(layers[self.start_layer]) + next_inputs = _move_to_device(next_inputs, resume_device) + return next_inputs + + def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: + """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.""" + from modelopt.torch.quantization.config import QuantizeConfig + from modelopt.torch.quantization.conversion import restore_quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + if self.start_layer == 0: + return + + dummy_config = QuantizeConfig() + name_to_module = dict(model.named_modules()) + for i in range(self.start_layer): + layer = layers[i] + d = _layer_dir(self.checkpoint_dir, i) + + # Resolve layer_device and load inside the context so params are + # materialized — otherwise get_module_device can return meta. + with enable_weight_access_and_writeback(layer, model, name_to_module): + layer_device = get_module_device(layer) + # weights_only=False is safe: files are internally generated by _save_layer + qstate = torch.load( + os.path.join(d, "quantizer_state.pt"), + map_location=layer_device, + weights_only=False, + ) + weights = torch.load( + os.path.join(d, "weights.pt"), + map_location=layer_device, + weights_only=False, + ) + restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) + layer.load_state_dict(weights, strict=False, assign=True) + + print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") + + def save( + self, + layer_idx: int, + layer: nn.Module, + model: nn.Module, + layers: nn.ModuleList, + next_layer_inputs: list | None = None, + ) -> None: + """Snapshot layer state and write checkpoint to disk in one step. + + Args: + layer_idx: Index of the layer just calibrated. + layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). + model: The full model (needed for ``enable_weight_access_and_writeback``). + layers: The decoder layer list (to read ``output_meta``). + next_layer_inputs: Inputs for the next layer (``None`` for the final layer). + """ + from modelopt.torch.quantization.conversion import quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + _cpu = torch.device("cpu") + with enable_weight_access_and_writeback(layer, model): + weights = _move_to_device(layer.state_dict(), _cpu) + qstate = _move_to_device(quantizer_state(layer), _cpu) + + output_meta = getattr(layer._layerwise_calib, "output_meta", None) + if output_meta is None: + # Placeholder for the last layer: output_meta is never used for skip mode + # since there is no subsequent layer that needs a correctly shaped dummy output. + output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) + + _save_layer( + self.checkpoint_dir, + layer_idx, + weights, + qstate, + _move_to_device(output_meta, _cpu), + _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, + self.num_layers, + ) + suffix = " (final)" if next_layer_inputs is None else "" + print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 1e9a7fbbbd..01cb3abe88 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -601,16 +601,28 @@ def _forward_loop( dataloader: DataLoader containing the batched input data allowed_non_tensor_keys: Set of key names whose values may be non-tensor types """ - with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward - max_working_batch_size = None # Initialize max working batch size as None + # Disable KV caching during calibration — it is unnecessary overhead and causes + # correctness issues with hybrid Mamba/attention models whose cache state is mutated + # in-place (e.g., NemotronH). + config = getattr(model, "config", None) + prev_use_cache = getattr(config, "use_cache", None) + if config is not None and prev_use_cache is not None: + config.use_cache = False - for _, data in enumerate(tqdm(dataloader)): - # Process batch and update max working batch size - max_working_batch_size = _process_batch( - data, infer_method, max_working_batch_size, allowed_non_tensor_keys - ) + try: + with torch.no_grad(): + is_enc_dec = model_type_is_enc_dec(model) + infer_method = model.generate if is_enc_dec else model.forward + max_working_batch_size = None # Initialize max working batch size as None + + for _, data in enumerate(tqdm(dataloader)): + # Process batch and update max working batch size + max_working_batch_size = _process_batch( + data, infer_method, max_working_batch_size, allowed_non_tensor_keys + ) + finally: + if config is not None and prev_use_cache is not None: + config.use_cache = prev_use_cache def create_forward_loop( diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..440ca522d1 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -90,12 +90,43 @@ def is_parallel(model: nn.Module) -> bool: return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) +def _get_execution_device_from_hook(module: nn.Module) -> torch.device | None: + """Extract the execution device from an accelerate ``_hf_hook``, if present. + + Handles both ``AlignDevicesHook`` (direct) and ``SequentialHook`` (which + may wrap one or more ``AlignDevicesHook`` instances). Returns ``None`` + when no hook is found or the hook carries no ``execution_device``. + """ + hook = getattr(module, "_hf_hook", None) + if hook is None: + return None + + dev = getattr(hook, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + for h in getattr(hook, "hooks", ()): + dev = getattr(h, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + return None + + def get_module_device(module: nn.Module) -> torch.device: - """Get the device of a PyTorch module.""" + """Get the device of a PyTorch module. + + For modules managed by accelerate (``_hf_hook``), returns the hook's + ``execution_device`` which is the authoritative device even when + parameters are offloaded to CPU/meta between forward calls. + """ + hook_device = _get_execution_device_from_hook(module) + if hook_device is not None: + return hook_device + try: return next(module.parameters()).device except StopIteration: - # For modules without parameters return torch.device("cpu") @@ -590,21 +621,29 @@ def get_unwrapped_name(name: str, model: nn.Module | None = None) -> str: @contextmanager def temporarily_remove_accelerate_hook(module): - """Context manager to temporarily remove accelerate hook from a module.""" - accelerate_hook = None - if hasattr(module, "_hf_hook"): - # A module with forward method patched by accelerate - from accelerate.hooks import add_hook_to_module, remove_hook_from_module + """Context manager to temporarily bypass the accelerate hook on a module. + + Swaps ``module.forward`` with the pre-hook forward (``_old_forward``) so + that code inside the context sees the un-hooked forward. On exit the + hook-wrapped forward is restored and ``_old_forward`` is updated to + reflect any changes made inside the context. - accelerate_hook = module._hf_hook - remove_hook_from_module(module) + This avoids ``remove_hook_from_module`` / ``add_hook_to_module`` entirely, + sidestepping ``init_hook`` which would call ``set_module_tensor_to_device`` + and fail when newly-added quantizer modules have weights on the meta device. + """ + hooked_forward = None + cached_old_forward = None + if hasattr(module, "_hf_hook"): + hooked_forward = module.forward + cached_old_forward = module._old_forward + module.forward = cached_old_forward try: yield finally: - if accelerate_hook is not None: - from accelerate.hooks import add_hook_to_module - - add_hook_to_module(module, accelerate_hook) + if hooked_forward is not None: + module._old_forward = module.forward + module.forward = hooked_forward def bind_forward_method( diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml index 6fe4a8c3d1..862929ef34 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml @@ -15,7 +15,7 @@ metadata: recipe_type: ptq - description: NVFP4 MLP/MoE weight only (W4A16), FP8 KV cache, max calibration. + description: NVFP4 W4A4, FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml index a62051b659..99098c9d6d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml @@ -15,11 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 weight and activation (W4A4), gptq sequential calibration. + description: NVFP4 weight and activation (W4A4), gptq layerwise calibration. quantize: algorithm: method: gptq - use_sequential: true + layerwise: true + layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - quantizer_name: '*' enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index cc332733a0..220d062232 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -15,9 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration. + description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. quantize: - algorithm: max + algorithm: + method: max + # Max calibration is fast and does not typically need checkpointing. + layerwise: true quant_cfg: - quantizer_name: '*' enable: false diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py index 8ee71ed453..021a0b6bc0 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py @@ -12,21 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from copy import deepcopy import pytest import torch -from _test_utils.torch.transformers_models import create_tiny_llama_dir -from transformers import AutoModelForCausalLM +import transformers +from _test_utils.torch.transformers_models import create_tiny_llama_dir, create_tiny_qwen3_moe_dir +from accelerate import init_empty_weights, load_checkpoint_and_dispatch +from transformers import AutoConfig, AutoModelForCausalLM import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_vllm_fq_checkpoint from modelopt.torch.quantization.model_quant import fold_weight +from modelopt.torch.quantization.utils import enable_weight_access_and_writeback from modelopt.torch.utils import safe_load -@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) -def test_hf_vllm_export(tmp_path, quant_cfg): +def _test_hf_vllm_export(tmp_path, quant_cfg, model_dir): """Test HuggingFace model export for vLLM with fake quantization. This test verifies: @@ -36,11 +39,8 @@ def test_hf_vllm_export(tmp_path, quant_cfg): 4. Weight quantizer states are empty in saved state dict; input quantizer amaxes preserved """ - # Create a tiny LLaMA model for testing - tiny_model_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) - # Load the model - model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) + model = AutoModelForCausalLM.from_pretrained(model_dir) model = model.cuda() model.eval() @@ -57,6 +57,23 @@ def forward_loop(model): folded_model = deepcopy(model) fold_weight(folded_model) expected_weights = {k: v for k, v in folded_model.state_dict().items() if "quantizer" not in k} + # fold_weight only applies the weight quantizer's fake-quant; it does NOT fold + # input_quantizer.pre_quant_scale into the weight. The export path does: + # w_exported = fake_quant(W) * pqs[None, :] + # for modules where input_quantizer is disabled but has pqs (AWQ weight-only). + # Apply the same pqs fold here so expected_weights matches the export output. + for module_name, module in folded_model.named_modules(): + inp_q = getattr(module, "input_quantizer", None) + if ( + inp_q is not None + and not inp_q.is_enabled + and getattr(inp_q, "_pre_quant_scale", None) is not None + ): + w_key = f"{module_name}.weight" if module_name else "weight" + if w_key in expected_weights: + w = expected_weights[w_key] + scale = inp_q._pre_quant_scale.squeeze().to(device=w.device) + expected_weights[w_key] = (w * scale[None, :]).to(w.dtype) del folded_model # Snapshot model state before export to verify it is not mutated @@ -111,3 +128,133 @@ def forward_loop(model): "_amax" in k for k in quantizer_state_dict_before[name] ): assert any("_amax" in k for k in state), f"input quantizer {name} should preserve _amax" + + +def _make_cpu_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to CPU via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + return model, config, tiny_llama_dir + + +def _make_layerwise_cfg(base_cfg): + """Add layerwise=True to a quant config's algorithm field.""" + cfg = copy.deepcopy(base_cfg) + algo = cfg.get("algorithm", "max") + if isinstance(algo, str): + cfg["algorithm"] = {"method": algo, "layerwise": True} + else: + algo["layerwise"] = True + return cfg + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) +def test_hf_vllm_export_offload(tmp_path, quant_cfg): + """Test ``inplace_mem_efficient=True`` export path on a CPU-offloaded model. + + Mirrors ``test_hf_vllm_export`` but uses a CPU-offloaded model with layerwise + calibration. Skips the "model not mutated" assertion since the inplace path + is intentionally destructive. + """ + num_hidden_layers = 3 + + # Test model: CPU-offloaded, layerwise calibration + model, _config, tiny_llama_dir = _make_cpu_offloaded_model( + tmp_path / "offloaded", num_hidden_layers=num_hidden_layers + ) + model.eval() + + seq_cfg = _make_layerwise_cfg(quant_cfg) + + def forward_loop(model): + input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() + with torch.no_grad(): + model(input_ids) + + model = mtq.quantize(model, seq_cfg, forward_loop) + quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model) + + folded_model = deepcopy(model) + with enable_weight_access_and_writeback(folded_model.model.layers[0], folded_model): + fold_weight(folded_model) + expected_weights = { + k: v.detach().clone() + for k, v in folded_model.state_dict().items() + if "quantizer" not in k + } + del folded_model + + export_dir = tmp_path / "vllm_export_offload" + export_dir.mkdir(exist_ok=True) + + # Snapshot the offloaded layer's weight before/after export to verify the + # inplace_mem_efficient path actually mutates offloaded weights (would otherwise + # be unfalsifiable if the function silently took the copy path). + with enable_weight_access_and_writeback(model.model.layers[0], model): + weight_before = model.model.layers[0].self_attn.q_proj.weight.data.clone() + + export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, inplace_mem_efficient=True) + + with enable_weight_access_and_writeback(model.model.layers[0], model): + weight_after = model.model.layers[0].self_attn.q_proj.weight.data.clone() + assert not torch.equal(weight_before, weight_after), ( + "inplace path must mutate offloaded layer weights" + ) + + modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth" + assert modelopt_state_file.exists(), ( + f"vllm_fq_modelopt_state.pth file should be created in {export_dir}" + ) + + hf_quant_config_file = export_dir / "hf_quant_config.json" + assert not hf_quant_config_file.exists(), ( + f"hf_quant_config.json file should not be created in {export_dir}" + ) + + model_after = AutoModelForCausalLM.from_pretrained(export_dir).cuda() + model_after.eval() + model_after_state_dict = model_after.state_dict() + for key, param in expected_weights.items(): + assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( + f"Weight mismatch for {key}: " + f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " + f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" + ) + + quantizer_state_dict = safe_load(modelopt_state_file)["modelopt_state_weights"] + assert len(quantizer_state_dict) > 0, ( + f"modelopt_state_weights should not be empty in {modelopt_state_file}" + ) + for name, state in quantizer_state_dict.items(): + if "weight_quantizer" in name: + assert state == {}, f"weight quantizer {name} should have empty state after fold" + elif "input_quantizer" in name and any( + "_amax" in k for k in quantizer_state_dict_before[name] + ): + assert any("_amax" in k for k in state), f"input quantizer {name} should preserve _amax" + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.INT4_AWQ_CFG]) +def test_hf_vllm_export_tiny_llama(tmp_path, quant_cfg): + tiny_model_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) + _test_hf_vllm_export(tmp_path, quant_cfg, tiny_model_dir) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.INT4_AWQ_CFG]) +def test_hf_vllm_export_tiny_qwen3_moe(tmp_path, quant_cfg): + if quant_cfg == mtq.INT4_AWQ_CFG and transformers.__version__.startswith("5."): + pytest.skip("INT4_AWQ_CFG is not supported for Qwen3 MoE in transformers > 5.x") + tiny_model_dir = create_tiny_qwen3_moe_dir(tmp_path, num_hidden_layers=2) + _test_hf_vllm_export(tmp_path, quant_cfg, tiny_model_dir) diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 8ed1039e59..49e74e5851 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import json +import os +import shutil + import pytest import torch -from _test_utils.torch.quantization.quantize_common import INT4_AWQ_CLIP_CFG from _test_utils.torch.transformers_models import create_tiny_llama_dir from accelerate import init_empty_weights, load_checkpoint_and_dispatch from transformers import AutoConfig, AutoModelForCausalLM @@ -25,19 +29,11 @@ enable_weight_access_and_writeback, is_quantized_linear, ) +from modelopt.torch.quantization.utils.layerwise_calib import _layer_dir -@pytest.mark.parametrize( - "quant_cfg", - [ - mtq.INT4_AWQ_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - INT4_AWQ_CLIP_CFG, - mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - mtq.INT8_DEFAULT_CFG, - ], -) -def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): +def test_cpu_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -73,3 +69,575 @@ def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) assert torch.allclose(output_ref.logits, output_test.logits) + + +def _make_cpu_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to CPU via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +def _make_layerwise_cfg(base_cfg): + """Add layerwise=True to a quant config's algorithm field.""" + cfg = copy.deepcopy(base_cfg) + algo = cfg.get("algorithm", "max") + if isinstance(algo, str): + cfg["algorithm"] = {"method": algo, "layerwise": True} + else: + algo["layerwise"] = True + return cfg + + +def _make_layerwise_checkpoint_cfg(base_cfg, checkpoint_dir): + """Add layerwise=True and layerwise_checkpoint_dir to a quant config's algorithm field.""" + cfg = _make_layerwise_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_cpu_offloaded(tmp_path, use_checkpoint): + """Layerwise calibration on CPU-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + if use_checkpoint: + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == num_layers - 1 + assert manifest["num_layers"] == num_layers + + +def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path): + """Resume from a partial checkpoint on a CPU-offloaded model matches a full run.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 by truncating the manifest and removing later layers + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from a fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint should produce identical output to full run" + ) + + +def test_sequential_checkpoint_resume_multi_offload(tmp_path): + """Resume with multiple layers offloaded exercises per-layer device resolution.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(mtq.INT4_AWQ_CFG, ckpt_dir) + + def _make_multi_offload_model(): + with init_empty_weights(): + m = AutoModelForCausalLM.from_config(config) + dmap = { + n: 0 + for n, mod in m.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + dmap["model.layers.0"] = "cpu" + dmap["model.layers.1"] = "cpu" + return load_checkpoint_and_dispatch(m, tiny_llama_dir, device_map=dmap) + + # Full reference run + model_ref = _make_multi_offload_model() + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh model with same offload layout + model_resumed = _make_multi_offload_model() + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint with multi-offload should match full run" + ) + + +def _make_gptq_sequential_cfg(base_cfg): + """Create a sequential GPTQ config from a base quantization config.""" + cfg = copy.deepcopy(base_cfg) + cfg["algorithm"] = {"method": "gptq", "layerwise": True} + return cfg + + +def _make_gptq_sequential_checkpoint_cfg(base_cfg, checkpoint_dir): + """Create a sequential GPTQ config with checkpoint dir.""" + cfg = _make_gptq_sequential_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_sequential_gptq_cpu_offloaded(tmp_path, use_checkpoint): + """Sequential GPTQ (weight-modifying) on CPU-offloaded model matches GPU-only reference.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) + else: + seq_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) + + # Reference: GPU-only model + ref_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + model, _, _, _ = _make_cpu_offloaded_model(tmp_path / "offloaded", num_hidden_layers=num_layers) + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_sequential_gptq_checkpoint_resume_cpu_offloaded(tmp_path): + """GPTQ checkpoint resume with CPU offloading restores modified weights correctly.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_ckpt_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "GPTQ resumed checkpoint should produce identical output to full run" + ) + + +class _TupleReturningBlock(torch.nn.Module): + """Decoder layer that returns a tuple, mimicking HuggingFace decoder layers.""" + + def __init__(self, dim=16): + super().__init__() + self.linear = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return (self.linear(x), None) + + +class _TupleUnpackingModel(torch.nn.Module): + """Parent model that unpacks layer outputs as tuples.""" + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = torch.nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x, _ = layer(x) + return x + + +def test_skip_dummy_has_no_hf_hook(monkeypatch): + """Dummies must not carry _hf_hook from the original layer.""" + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _SkipLayer, + ) + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + model = _TupleUnpackingModel(n_layers=4, dim=16) + data = [torch.randn(2, 16)] + + for layer in model.layers: + hook = AlignDevicesHook(execution_device=torch.device("cpu")) + add_hook_to_module(layer, hook) + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + for i in range(2): + dummy = model.layers[i] + assert isinstance(dummy, _SkipLayer) + assert not hasattr(dummy, "_hf_hook"), f"Dummy at {i} should not have _hf_hook" + finally: + collector._unpatch_all_layers() + + +def test_persistent_materialization_cpu_offloaded(tmp_path): + """persistent_materialization keeps CPU-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_cpu_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) + + +def _make_disk_offload_device_map(model): + """Build a device_map with layer 0 on disk, everything else on GPU 0. + + Ancestor modules (``""`` and ``"model"``) are excluded so that + ``dispatch_model`` does not attach a ``place_submodules=True`` hook that + would try to move disk-offloaded meta tensors to GPU (which fails because + no ``value`` is available — unlike CPU offload where weights are on CPU and + can be moved directly). + """ + device_map = { + n: 0 + for n, m in model.named_modules() + if n not in ("", "model") and ("layers" not in n or n.split("layers.")[-1].isdigit()) + } + device_map["model.layers.0"] = "disk" + return device_map + + +def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to disk via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +def test_disk_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) + + config = AutoConfig.from_pretrained(tiny_llama_dir) + + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + inputs = torch.randint(0, model_ref.config.vocab_size, (1, 4)).cuda() + + mtq.quantize(model_ref, quant_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + assert all(p.device == torch.device("meta") for p in model.model.layers[0].parameters()) + + mtq.quantize(model, quant_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_persistent_materialization_disk_offloaded(tmp_path): + """persistent_materialization keeps disk-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_disk_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_disk_offloaded(tmp_path, use_checkpoint): + """Layerwise calibration on disk-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: disk-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index 4889b6dc8c..c5584ece5c 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -128,3 +128,136 @@ def test_fsdp_simple_linear(dist_workers): ) def test_nested_fsdp2_backward(quant_cfg, dist_workers): dist_workers.run(partial(_test_nested_fsdp2_backward, quant_cfg=quant_cfg)) + + +class _DecoderBlock(nn.Module): + """Minimal decoder block for FSDP2 sequential tests.""" + + def __init__(self, dim=32): + super().__init__() + self.attn = nn.Linear(dim, dim, bias=False) + self.ffn = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.ReLU(), nn.Linear(dim, dim, bias=False) + ) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = x + self.attn(self.norm(x)) + x = x + self.ffn(x) + return x + + +class _SimpleTransformerModel(nn.Module): + """Model with ``model.layers`` for layerwise calibration discovery.""" + + def __init__(self, n_layers=3, dim=32): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _test_layerwise_calibrate_fsdp2(rank, size): + """Layerwise calibration on FSDP2-wrapped model matches non-FSDP reference.""" + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + dim = 32 + torch.manual_seed(1) + model = _SimpleTransformerModel(n_layers=3, dim=dim).cuda() + inputs = torch.randn(2, 2, dim).cuda() + synchronize_state_dict(model) + + # Register discoverer for our simple model + old_support = LayerActivationCollector._decoder_layer_support[:] + LayerActivationCollector._decoder_layer_support = [ + ( + lambda m: hasattr(m, "layers") and isinstance(m.layers, nn.ModuleList), + lambda m: m.layers, + ), + *old_support, + ] + + try: + # Reference: non-FSDP layerwise calibration + ref_model = copy.deepcopy(model) + seq_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + seq_cfg["algorithm"] = {"method": "max", "layerwise": True} + mtq.quantize(ref_model, seq_cfg, lambda m: m(inputs)) + output_ref = ref_model(inputs) + + # Test: FSDP2-wrapped layerwise calibration + for layer in model.layers: + fully_shard(layer) + model = fully_shard(model) + mtq.quantize(model, seq_cfg, lambda m: m(inputs)) + output_test = model(inputs) + + assert torch.allclose(output_ref, output_test) + finally: + LayerActivationCollector._decoder_layer_support = old_support + + +def test_layerwise_calibrate_fsdp2(dist_workers): + dist_workers.run(_test_layerwise_calibrate_fsdp2) + + +def _test_persistent_materialization(rank, size): + """persistent_materialization keeps weights accessible and writes back modifications.""" + from torch.distributed.tensor import DTensor + + from modelopt.torch.quantization.utils import ( + enable_weight_access_and_writeback, + persistent_materialization, + ) + + dim = 32 + torch.manual_seed(1) + model = nn.Sequential( + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + ).cuda(rank) + synchronize_state_dict(model) + + fully_shard(model[0]) + fully_shard(model[1]) + model = fully_shard(model) + + layer = model[0] + inputs = torch.randn(2, dim).cuda(rank) + + # Warmup forward to trigger FSDP2's lazy_init (mirrors real usage where + # layerwise_calibrate always runs get_first_layer_inputs first). + model(inputs) + + # Save reference weight (gathered) + with enable_weight_access_and_writeback(layer[0], model): + ref_weight = layer[0].weight.clone() + + # Verify sharded before context + assert isinstance(next(iter(layer.parameters())), DTensor) + + with persistent_materialization(layer): + # Params are local tensors (not DTensors) + assert not isinstance(layer[0].weight, DTensor) + assert layer[0].weight.device.type == "cuda" + + # Run multiple forward passes (FSDP hooks fire, unshard/reshard are no-ops) + for _ in range(3): + layer(inputs) + + # Modify a weight + layer[0].weight.data.add_(1.0) + + # After context: params restored to DTensors (sharded) + assert isinstance(next(iter(layer.parameters())), DTensor) + + # Verify modification persisted + with enable_weight_access_and_writeback(layer[0], model): + assert torch.allclose(layer[0].weight, ref_weight + 1.0) + + +def test_persistent_materialization(dist_workers): + dist_workers.run(_test_persistent_materialization) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index d183855abb..2d5f9d6d70 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -219,7 +219,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} + quant_cfg["algorithm"] = {"method": "gptq", "layerwise": True} calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", tokenizer=tokenizer, diff --git a/tests/gpu/torch/quantization/test_sequential_calibrate.py b/tests/gpu/torch/quantization/test_layerwise_calibrate.py similarity index 90% rename from tests/gpu/torch/quantization/test_sequential_calibrate.py rename to tests/gpu/torch/quantization/test_layerwise_calibrate.py index ba71e896c7..d38b82f46f 100644 --- a/tests/gpu/torch/quantization/test_sequential_calibrate.py +++ b/tests/gpu/torch/quantization/test_layerwise_calibrate.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for sequential_calibrate and LayerActivationCollector.""" +"""Integration tests for layerwise_calibrate and LayerActivationCollector.""" import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _DecoderBlock(nn.Module): @@ -101,7 +101,7 @@ def _register_test_discoverer(monkeypatch): ) -def test_seq_calib_func_called_per_layer(monkeypatch): +def test_layerwise_calib_func_called_per_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=4) call_count = [0] @@ -109,7 +109,7 @@ def test_seq_calib_func_called_per_layer(monkeypatch): def counting_calib(layer, forward_loop, **kwargs): call_count[0] += 1 - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=counting_calib, @@ -118,7 +118,7 @@ def counting_calib(layer, forward_loop, **kwargs): assert call_count[0] == 4 -def test_seq_calib_func_receives_correct_layer(monkeypatch): +def test_layerwise_calib_func_receives_correct_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) called_layers = [] @@ -126,7 +126,7 @@ def test_seq_calib_func_receives_correct_layer(monkeypatch): def track_layers(layer, forward_loop, **kwargs): called_layers.append(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=track_layers, @@ -136,7 +136,7 @@ def track_layers(layer, forward_loop, **kwargs): assert called_layers[i] is layer -def test_seq_calib_kwargs_forwarded(monkeypatch): +def test_layerwise_calib_kwargs_forwarded(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) received_kwargs = [] @@ -144,7 +144,7 @@ def test_seq_calib_kwargs_forwarded(monkeypatch): def capture_kwargs(layer, forward_loop, **kwargs): received_kwargs.append(kwargs) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=capture_kwargs, @@ -158,7 +158,7 @@ def capture_kwargs(layer, forward_loop, **kwargs): assert kw["method"] == "max" -def test_seq_calib_layer_forward_loop_runs_all_batches(monkeypatch): +def test_layerwise_calib_layer_forward_loop_runs_all_batches(monkeypatch): """The per-layer forward loop passed to calib_func should replay all batches.""" _register_test_discoverer(monkeypatch) n_batches = 5 @@ -178,7 +178,7 @@ def counting_forward(*args, **kw): layer.forward = orig_forward batch_counts.append(counter["n"]) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=count_batches, @@ -188,13 +188,13 @@ def counting_forward(*args, **kw): assert count == n_batches -def test_seq_calib_does_not_alter_weights(monkeypatch): - """sequential_calibrate itself should not modify model weights.""" +def test_layerwise_calib_does_not_alter_weights(monkeypatch): + """layerwise_calibrate itself should not modify model weights.""" _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) weights_before = {n: p.clone() for n, p in model.named_parameters()} - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=lambda layer, forward_loop, **kw: None, @@ -204,7 +204,7 @@ def test_seq_calib_does_not_alter_weights(monkeypatch): assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" -def test_seq_calib_activations_update_across_layers(monkeypatch): +def test_layerwise_calib_activations_update_across_layers(monkeypatch): """Subsequent layers should see activations transformed by prior layers.""" _register_test_discoverer(monkeypatch) torch.manual_seed(0) @@ -228,7 +228,7 @@ def capture_forward(*args, **kw): layer_idx = list(model.layers).index(layer) layer_inputs_record[layer_idx] = activations - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: [m(t) for t in tokens], calib_func=record_inputs, @@ -240,7 +240,7 @@ def capture_forward(*args, **kw): def test_mode_transitions_across_calibration_steps(monkeypatch): - """Verify layer modes after each sequential calibration step. + """Verify layer modes after each layerwise calibration step. After get_input_activations(layers[i]) returns, the current layer is reset to 'original'. Layers further back are left in 'run' (just calibrated) or @@ -259,7 +259,7 @@ def forward_loop(m): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector.get_input_activations(model.layers[0], forward_loop) assert modes() == ["original", "original", "original", "original", "original"] @@ -316,7 +316,7 @@ def weight_doubling_calib(layer, layer_forward_loop, **kwargs): layer.weight.mul_(2.0) layer_forward_loop(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=weight_doubling_calib, diff --git a/tests/unit/torch/export/test_vllm_quantizer_reload.py b/tests/unit/torch/export/test_vllm_quantizer_reload.py new file mode 100644 index 0000000000..49fb34c25a --- /dev/null +++ b/tests/unit/torch/export/test_vllm_quantizer_reload.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from modelopt.torch.export.plugins.vllm_fakequant_hf import ( + infer_quantizer_prefix_remap, + merge_amax_tensors_for_group, +) + + +def _map_backbone_to_model(sd: dict) -> dict: + """Test mapper: rename top-level ``backbone.`` to ``model.`` (typical HF vs vLLM).""" + out = {} + for k, v in sd.items(): + if k.startswith("backbone."): + out["model." + k[len("backbone.") :]] = v + else: + out[k] = v + return out + + +def test_infer_prefix_remap_simple_root_rename(): + """``infer_quantizer_prefix_remap`` infers one HF root → vLLM root from ``*.weight`` probes.""" + q = { + "backbone.layers.0.mlp.gate_proj.input_quantizer": {}, + "backbone.layers.1.self_attn.q_proj.weight_quantizer": {}, + } + rem = infer_quantizer_prefix_remap(q, _map_backbone_to_model) + assert rem == {"backbone": "model"} + + +def test_infer_prefix_remap_multiple_probes_same_root_agree(): + """Regression: every quantizer key under the same HF root must agree on the mapped vLLM root.""" + q = { + "backbone.a.w.input_quantizer": {}, + "backbone.b.w.weight_quantizer": {}, + } + rem = infer_quantizer_prefix_remap(q, _map_backbone_to_model) + assert rem == {"backbone": "model"} + + +def test_infer_prefix_remap_raises_on_inconsistent_root(): + """If ``map_fun`` maps the same HF root to different vLLM roots, raise with a clear error.""" + + def bad_map(sd: dict) -> dict: + out = {} + for k, v in sd.items(): + if "layers.0" in k: + out[k.replace("backbone.", "model.")] = v + elif "head" in k: + out[k.replace("backbone.", "encoder.")] = v + else: + out[k] = v + return out + + q = { + "backbone.layers.0.mlp.gate_proj.input_quantizer": {}, + "backbone.head.proj.input_quantizer": {}, + } + with pytest.raises(ValueError, match="Inconsistent HF→vLLM prefix remap"): + infer_quantizer_prefix_remap(q, bad_map) + + +def test_infer_prefix_remap_identity_empty(): + """When keys already match the mapper output, the inferred remap is empty (no rename).""" + q = {"model.layers.0.foo.input_quantizer": {}} + rem = infer_quantizer_prefix_remap(q, lambda d: dict(d)) + assert rem == {} + + +def test_infer_prefix_remap_probe_failure_skipped(): + """A probe that raises does not block remap if another key under the same root succeeds.""" + + def map_drop_layers0(sd: dict) -> dict: + out = {} + for k, v in sd.items(): + if "layers.0" in k: + raise RuntimeError("simulate missing layer") + if k.startswith("backbone."): + out["model." + k[len("backbone.") :]] = v + else: + out[k] = v + return out + + q = { + "backbone.layers.0.mlp.gate_proj.input_quantizer": {}, + "backbone.layers.1.mlp.gate_proj.input_quantizer": {}, + } + rem = infer_quantizer_prefix_remap(q, map_drop_layers0) + assert rem == {"backbone": "model"} + + +def test_infer_prefix_remap_no_quantizer_segment_still_probes_weight_path(): + """Short paths (e.g. ``embed.weight_quantizer``) still build a ``.weight`` probe path.""" + q = {"backbone.embed.weight_quantizer": {}} + rem = infer_quantizer_prefix_remap(q, _map_backbone_to_model) + assert rem == {"backbone": "model"} + + +def test_infer_prefix_remap_complex_mapper_not_one_root_raises_or_wrong(): + """Same HF root ``x`` mapping to different first components (``va.*`` vs ``vb.*``) must error.""" + + def split_map(sd: dict) -> dict: + k = next(iter(sd)) + v = sd[k] + if "branch_a" in k: + return {"va." + k[2:]: v} # x.branch_a... -> va.branch_a... + return {"vb." + k[2:]: v} + + q = { + "x.branch_a.mlp.w.input_quantizer": {}, + "x.branch_b.mlp.w.input_quantizer": {}, + } + with pytest.raises(ValueError, match="Inconsistent HF→vLLM prefix remap"): + infer_quantizer_prefix_remap(q, split_map) + + +def test_merge_amax_same_shape_elementwise_max(): + """``merge_amax_tensors_for_group``: identical shapes → element-wise max (stack then amax).""" + a = torch.tensor([1.0, 4.0, 2.0]) + b = torch.tensor([2.0, 3.0, 5.0]) + out = merge_amax_tensors_for_group([a, b]) + assert torch.allclose(out, torch.tensor([2.0, 4.0, 5.0])) + + +def test_merge_amax_different_1d_lengths_uses_cat(): + """``merge_amax_tensors_for_group``: mismatched 1-D lengths (e.g. GQA q/k/v) → ``cat`` on dim 0.""" + q = torch.tensor([1.0, 2.0, 3.0]) # e.g. 3 heads + k = torch.tensor([0.5, 0.5]) # 2 KV heads + v = torch.tensor([0.5, 0.5]) + out = merge_amax_tensors_for_group([q, k, v]) + assert out.shape == (7,) + assert torch.allclose(out, torch.cat([q, k, v])) + + +def test_merge_amax_incompatible_shapes_scalar_fallback(): + """``merge_amax_tensors_for_group``: when ``cat`` fails, fall back to a scalar global max.""" + a = torch.ones(2, 3) + b = torch.ones(2, 2) # cannot cat along dim=0 with matching trailing dims + out = merge_amax_tensors_for_group([a, b]) + assert out.shape == () + assert out.item() == 1.0 diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 692ab07d4a..ae638c42ee 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -35,7 +35,7 @@ get_homogeneous_hf_decoder_layers, is_homogeneous_hf_model, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector pytest.importorskip("transformers") diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index b3c372eb33..d2e6fdd03e 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -27,7 +27,7 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, - sequential_calibrate, + layerwise_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -379,7 +379,7 @@ def test_svdquant_lora_weights(): assert lora_residual.shape == module.weight.shape -def test_sequential_calibrate_support_gate(): +def test_layerwise_calibrate_support_gate(): class _UnsupportedModel(nn.Module): def __init__(self): super().__init__() @@ -392,17 +392,17 @@ def forward(self, x): with ( torch.no_grad(), - pytest.raises(ValueError, match="Sequential calibration requires a model"), + pytest.raises(ValueError, match="Layerwise calibration requires a model"), ): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 4)), calib_func=lambda layer, loop: loop(layer), ) -def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +def test_layerwise_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float, bias: float): @@ -463,7 +463,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert forward_loop_calls == len(model.layers) assert len(observed_layer_inputs) == len(model.layers) @@ -482,9 +482,9 @@ def _pre_hook(_module, args): assert torch.allclose(observed, expected) -def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): +def test_layerwise_calibrate_handles_inter_layer_logic(monkeypatch): """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float): @@ -537,7 +537,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert len(observed_layer_inputs) == 3 # Layer 0 gets raw batch diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py similarity index 64% rename from tests/unit/torch/quantization/test_sequential_calibrate.py rename to tests/unit/torch/quantization/test_layerwise_calibrate.py index 14c1903de2..3739feff96 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for sequential_calibrate and LayerActivationCollector.""" +"""Unit tests for layerwise_calibrate and LayerActivationCollector.""" +import copy from collections import deque import pytest import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer class _DecoderBlock(nn.Module): @@ -60,7 +63,7 @@ def forward(self, x, **kwargs): class _FlatMLP(nn.Module): - """No decoder-layer structure -- should be rejected by sequential_calibrate.""" + """No decoder-layer structure -- should be rejected by layerwise_calibrate.""" def __init__(self, dim=16): super().__init__() @@ -180,7 +183,7 @@ def forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") assert not hasattr(model.layers[0], "_original_forward") @@ -201,38 +204,38 @@ def bad_forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") -# sequential_calibrate tests -def test_seq_calib_raises_on_none_forward_loop(monkeypatch): +# layerwise_calibrate tests +def test_layerwise_calib_raises_on_none_forward_loop(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) with pytest.raises(ValueError, match="forward_loop must not be None"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=None, calib_func=lambda *a, **kw: None, ) -def test_seq_calib_raises_on_unrecognized_model(): +def test_layerwise_calib_raises_on_unrecognized_model(): model = _FlatMLP() with pytest.raises(ValueError, match="Could not find transformer layers"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 16)), calib_func=lambda *a, **kw: None, ) -def test_seq_calib_empty_forward_loop_raises(monkeypatch): - """If forward_loop feeds no data, sequential_calibrate raises RuntimeError.""" +def test_layerwise_calib_empty_forward_loop_raises(monkeypatch): + """If forward_loop feeds no data, layerwise_calibrate raises RuntimeError.""" _register_test_discoverer(monkeypatch) model = _SimpleTransformerModel(n_layers=2, dim=16) with pytest.raises(RuntimeError, match="collected no inputs during forward_loop"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: None, calib_func=lambda *a, **kw: None, @@ -344,11 +347,11 @@ def forward_loop(m): try: # Layer 0 starts as capture — no output_meta yet collector.get_input_activations(model.layers[0], forward_loop) - assert model.layers[0]._seq_calib.output_meta is None + assert model.layers[0]._layerwise_calib.output_meta is None # Calibrating layer 1 puts layer 0 into run, which sets output_meta collector.get_input_activations(model.layers[1], forward_loop) - meta = model.layers[0]._seq_calib.output_meta + meta = model.layers[0]._layerwise_calib.output_meta assert meta is not None assert meta[0] == "tuple", "Tuple-returning layer should produce tuple metadata" finally: @@ -375,11 +378,11 @@ def forward_loop(m): # Before calibrating layer 2, layer 1 transitions to run. # Its cached_inputs should be populated from collected_inputs. collector._set_layer_states(2) - assert len(model.layers[1]._seq_calib.cached_inputs) == n_batches + assert len(model.layers[1]._layerwise_calib.cached_inputs) == n_batches # After the forward loop, all cached inputs should be consumed forward_loop(model) - assert len(model.layers[1]._seq_calib.cached_inputs) == 0 + assert len(model.layers[1]._layerwise_calib.cached_inputs) == 0 finally: collector._unpatch_all_layers() @@ -399,24 +402,24 @@ def test_set_layer_states_transitions(monkeypatch): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector._set_layer_states(0) assert modes() == ["capture", "original", "original", "original", "original"] - model.layers[0]._seq_calib.collected_inputs = [fake_inp] + model.layers[0]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(1) assert modes() == ["run", "capture", "original", "original", "original"] - model.layers[1]._seq_calib.collected_inputs = [fake_inp] + model.layers[1]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(2) assert modes() == ["skip", "run", "capture", "original", "original"] - model.layers[2]._seq_calib.collected_inputs = [fake_inp] + model.layers[2]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(3) assert modes() == ["skip", "skip", "run", "capture", "original"] - model.layers[3]._seq_calib.collected_inputs = [fake_inp] + model.layers[3]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(4) assert modes() == ["skip", "skip", "skip", "run", "capture"] finally: @@ -446,8 +449,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector = LayerActivationCollector(model) collector._patch_all_layers() try: - model.layers[0]._seq_calib.mode = "run" - model.layers[0]._seq_calib.cached_inputs = deque() + model.layers[0]._layerwise_calib.mode = "run" + model.layers[0]._layerwise_calib.cached_inputs = deque() with pytest.raises(AssertionError, match="no cached inputs to replay"): model(torch.randn(2, 16)) @@ -455,8 +458,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector._unpatch_all_layers() -def test_cleanup_removes_seq_calib_attr(monkeypatch): - """After unpatch, no layer should have the _seq_calib attribute.""" +def test_cleanup_removes_layerwise_calib_attr(monkeypatch): + """After unpatch, no layer should have the _layerwise_calib attribute.""" _register_test_discoverer(monkeypatch) model = _TupleUnpackingModel(n_layers=3, dim=16) data = [torch.randn(2, 16)] @@ -472,7 +475,9 @@ def forward_loop(m): collector._unpatch_all_layers() for i, layer in enumerate(model.layers): - assert not hasattr(layer, "_seq_calib"), f"Layer {i} still has _seq_calib after cleanup" + assert not hasattr(layer, "_layerwise_calib"), ( + f"Layer {i} still has _layerwise_calib after cleanup" + ) assert not hasattr(layer, "_original_forward"), ( f"Layer {i} still has _original_forward after cleanup" ) @@ -517,15 +522,17 @@ def forward_loop(m): for d in data: m(d) + originals = list(model.layers) collector = LayerActivationCollector(model) collector._patch_all_layers() try: - for layer in model.layers: + for layer in originals: collector.get_input_activations(layer, forward_loop) - # After full calibration, layers 0 and 1 have been through 'run' and have output_meta - meta_0 = model.layers[0]._seq_calib.output_meta - meta_1 = model.layers[1]._seq_calib.output_meta + # After full calibration, layers 0 and 1 have been through 'run' and have output_meta. + # Access via originals since skip-position entries are now _SkipLayer dummies. + meta_0 = originals[0]._layerwise_calib.output_meta + meta_1 = originals[1]._layerwise_calib.output_meta assert meta_0 is not None assert meta_1 is not None # SmallBlock returns 3-element tuple, BigBlock returns 1-element tuple @@ -533,3 +540,182 @@ def forward_loop(m): assert len(meta_1[1]) == 1 finally: collector._unpatch_all_layers() + + +# --------------------------------------------------------------------------- +# _SkipLayer swap / restore tests +# --------------------------------------------------------------------------- + + +def test_skip_layers_replaced_with_dummy(monkeypatch): + """After calibrating enough layers, skip-position entries must be _SkipLayer with no params.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(2)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + # Layers 0..2 should be dummies (swapped when calibrating layers 2..4) + for i in range(3): + assert isinstance(model.layers[i], _SkipLayer), f"Layer {i} should be _SkipLayer" + assert list(model.layers[i].parameters()) == [], ( + f"Layer {i} dummy should have no params" + ) + # Layers 3 (run) and 4 (original) remain real + for i in range(3, 5): + assert not isinstance(model.layers[i], _SkipLayer), f"Layer {i} should still be real" + finally: + collector._unpatch_all_layers() + + +def test_cleanup_restores_original_layers(monkeypatch): + """After _unpatch_all_layers, all ModuleList entries must be the original modules.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + originals = list(model.layers) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + for layer in originals: + collector.get_input_activations(layer, forward_loop) + collector._unpatch_all_layers() + + for i, orig in enumerate(originals): + assert model.layers[i] is orig, f"Layer {i} not restored to original after cleanup" + assert not hasattr(orig, "_layerwise_calib"), f"Layer {i} still has _layerwise_calib" + + +def _int8_layerwise_config(algorithm: dict) -> dict: + """Start from the shipped INT8 config and enable layerwise in the algorithm block. + + Using a real shipped config guarantees the same include/exclude rules + production PTQ relies on, so algorithm dispatch matches real usage. + """ + cfg = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + cfg["algorithm"] = algorithm + return cfg + + +def _awq_layerwise_config() -> dict: + """INT4 weight-only AWQ config sized for the _DecoderBlock test model.""" + cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) + # Resize AWQ block to fit dim=16 hidden. + for entry in cfg["quant_cfg"]: + if entry.get("quantizer_name") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"} + cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 0.5, "layerwise": True} + return cfg + + +def _svdquant_layerwise_config() -> dict: + """SVDQuant config sized for the _DecoderBlock test model.""" + cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) + for entry in cfg["quant_cfg"]: + if entry.get("quantizer_name") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, "type": "static"} + cfg["algorithm"] = {"method": "svdquant", "lowrank": 4, "layerwise": True} + return cfg + + +def test_mtq_quantize_layerwise_e2e_max(monkeypatch): + """End-to-end: mtq.quantize with layerwise=True produces populated amax values. + + ``max`` is the representative algorithm for the layerwise happy path because + every other algorithm seeds amax via max_calibrate first — if max works, the + shared skip/run/capture machinery is sound. Other algorithms are covered by + the dispatch-only test below to avoid hardware requirements (e.g. gptq needs + CUDA) or unnecessary duplication. + """ + _register_test_discoverer(monkeypatch) + config = _int8_layerwise_config({"method": "max", "layerwise": True}) + + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=3, dim=16) + calib_data = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + model = mtq.quantize(model, config, forward_loop=forward_loop) + + for i, layer in enumerate(model.layers): + assert not isinstance(layer, _SkipLayer), f"layer {i} left as _SkipLayer" + assert not hasattr(layer, "_layerwise_calib"), f"layer {i} leaked _layerwise_calib" + + amax_count = sum( + 1 + for layer in model.layers + for module in layer.modules() + if ( + isinstance(module, TensorQuantizer) + and module.is_enabled + and getattr(module, "_amax", None) is not None + ) + ) + assert amax_count > 0, "no TensorQuantizer in decoder layers had _amax populated" + + with torch.no_grad(): + model(calib_data[0]) + + +@pytest.mark.parametrize( + "algorithm", + ["gptq", "awq_lite", "smoothquant", "mse"], +) +def test_mtq_quantize_layerwise_dispatches_for_algorithm(monkeypatch, algorithm): + """Every layerwise-supporting algorithm must route through layerwise_calibrate. + + Stubs layerwise_calibrate to a spy so the dispatch contract is checked without + running the algorithm's full calibration — lets ``gptq`` (CUDA-only at runtime) + and other expensive algorithms participate in CPU unit tests. + """ + spy: dict = {} + + def stub(model, forward_loop, calib_func, **kwargs): + spy["calib_func"] = calib_func + spy["kwargs"] = kwargs + + monkeypatch.setattr("modelopt.torch.quantization.mode.layerwise_calibrate", stub) + + if algorithm == "awq_lite": + config = _awq_layerwise_config() + else: + config = _int8_layerwise_config({"method": algorithm, "layerwise": True}) + + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=2, dim=16) + mtq.quantize( + model, + config, + forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), + ) + + assert "calib_func" in spy, f"{algorithm} did not dispatch through layerwise_calibrate" + assert callable(spy["calib_func"]) + + +def test_mtq_quantize_layerwise_raises_for_unsupported_algorithm(): + """Modes with ``_supports_layerwise = False`` must raise a clear ValueError.""" + config = _svdquant_layerwise_config() + torch.manual_seed(0) + model = _SimpleTransformerModel(n_layers=2, dim=16) + with pytest.raises(ValueError, match="does not support layerwise=True"): + mtq.quantize( + model, + config, + forward_loop=lambda m: m(torch.randint(0, 32, (2, 8))), + ) diff --git a/tests/unit/torch/quantization/test_sequential_checkpoint.py b/tests/unit/torch/quantization/test_sequential_checkpoint.py new file mode 100644 index 0000000000..0e592a68c7 --- /dev/null +++ b/tests/unit/torch/quantization/test_sequential_checkpoint.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for layerwise calibration checkpoint save/resume.""" + +import json +import os +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector +from modelopt.torch.utils.network import get_module_device + + +class _DecoderBlock(nn.Module): + def __init__(self, dim=16): + super().__init__() + self.linear = nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return self.linear(x) + + +class _SimpleTransformerModel(nn.Module): + def __init__(self, n_layers=3, dim=16): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + self.embed = nn.Embedding(32, dim) + + def forward(self, x, **kwargs): + x = self.embed(x) + for layer in self.layers: + x = layer(x) + return x + + +def _register_test_discoverer(monkeypatch): + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def _dummy_calib_func(layer, forward_loop, **kwargs): + """Scale all weights by 0.5 to produce a visible, deterministic change.""" + forward_loop(layer) + with torch.no_grad(): + for p in layer.parameters(): + p.mul_(0.5) + + +def _make_model_and_forward(n_layers=3, dim=16, seed=42): + torch.manual_seed(seed) + model = _SimpleTransformerModel(n_layers=n_layers, dim=dim) + tokens = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def forward_loop(m): + for t in tokens: + m(t) + + return model, forward_loop + + +def test_full_run_creates_checkpoints(monkeypatch, tmp_path): + """layerwise_calibrate with checkpoint_dir creates correct layer dirs and manifest.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + ckpt_dir = str(tmp_path / "ckpt") + + layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == 2 + assert manifest["num_layers"] == 3 + + for i in range(3): + layer_dir = os.path.join(ckpt_dir, f"layer_{i:04d}") + assert os.path.isdir(layer_dir) + assert os.path.isfile(os.path.join(layer_dir, "weights.pt")) + assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt")) + assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt")) + # All layers except the last should have next_inputs + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt")) + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt")) + + +def test_resume_matches_full_run(monkeypatch, tmp_path): + """Resume from a truncated checkpoint produces the same final weights as a full run.""" + _register_test_discoverer(monkeypatch) + ckpt_dir = str(tmp_path / "ckpt") + + # Full reference run + ref_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()} + + # Simulate crash after layer 0: truncate manifest + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": 0, "num_layers": 3}, f) + + # Resume from a fresh model + resumed_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + for name, ref_param in ref_weights.items(): + resumed_param = dict(resumed_model.named_parameters())[name] + assert torch.allclose(ref_param, resumed_param, atol=1e-6), ( + f"Parameter {name} diverged after resume" + ) + + +def test_no_checkpoint_unchanged(monkeypatch): + """Without checkpoint_dir, calibration still works and modifies parameters.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + original_weights = {n: p.clone() for n, p in model.named_parameters()} + + layerwise_calibrate(model, forward_loop, _dummy_calib_func) + + changed = False + for name, param in model.named_parameters(): + if not torch.allclose(original_weights[name], param): + changed = True + break + assert changed, "Expected calibration to modify at least one parameter" + + +# --------------------------------------------------------------------------- +# get_module_device tests +# --------------------------------------------------------------------------- + + +def test_get_module_device_no_hook(): + """Falls back to parameter device when no _hf_hook is present.""" + layer = nn.Linear(4, 4) + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_with_direct_hook(): + """Returns execution_device from a direct AlignDevicesHook-style hook.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace(execution_device=torch.device("cuda:0")) + assert get_module_device(layer) == torch.device("cuda:0") + + +def test_get_module_device_with_sequential_hook(): + """Returns execution_device from an AlignDevicesHook wrapped in SequentialHook.""" + layer = nn.Linear(4, 4) + inner_hook = SimpleNamespace(execution_device=torch.device("cuda:1")) + layer._hf_hook = SimpleNamespace(hooks=[inner_hook]) + assert get_module_device(layer) == torch.device("cuda:1") + + +def test_get_module_device_hook_without_execution_device(): + """Falls back to parameters when hook has no execution_device.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace() + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_parameterless_module(): + """Returns cpu for a module with no parameters and no hook.""" + module = nn.Module() + assert get_module_device(module) == torch.device("cpu") diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index 92fe1345f9..73d3423ba5 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -20,7 +20,7 @@ convert_quantization_axis_to_reduce_axis, reduce_block_amax, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector @pytest.mark.parametrize(