diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c2d4d4bfca..90532efe38 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -15,6 +15,7 @@ import copy import glob +import hashlib import inspect import json import logging @@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod print(f"Successfully copied {len(copied_files)} custom model files to {export_path}") else: print("No custom model files found to copy") + + +def needs_checkpoint_path_update(quant_cfg: dict) -> bool: + """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" + algorithm = quant_cfg.get("algorithm") + if not isinstance(algorithm, dict): + return False + return algorithm.get("layerwise_checkpoint_dir") is not None + + +def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict: + """Append a unique ``_`` subdirectory to layerwise_checkpoint_dir. + + Allows a single recipe to be reused across models without checkpoint collisions. + Must only be called when :func:`needs_checkpoint_path_update` returns True. + """ + algorithm = quant_cfg["algorithm"] + base_dir = algorithm["layerwise_checkpoint_dir"] + + name = model_path.rstrip("/") + if "/" in name and not os.path.isabs(name): + name = name.replace("/", "--") + else: + name = Path(name).name + + config_hash = hashlib.sha256(json.dumps(quant_cfg, default=str).encode()).hexdigest()[:8] + + quant_cfg = copy.deepcopy(quant_cfg) + quant_cfg["algorithm"]["layerwise_checkpoint_dir"] = os.path.join( + base_dir, f"{name}_{config_hash}" + ) + return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 327605406c..c405de51e7 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -34,6 +34,8 @@ is_enc_dec, is_nemotron_vl, load_mtp_weights, + needs_checkpoint_path_update, + resolve_checkpoint_dir, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -91,8 +93,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: for i, entry in enumerate(quant_cfg): if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue - assert isinstance(entry.get("cfg", {}), dict) - quant_cfg[i] = {**entry, "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}} + cfg = entry.get("cfg") or {} + assert isinstance(cfg, dict) + quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} break @@ -759,7 +762,9 @@ def export_quantized( # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) # Store the MTP layer prefixes on the model for later exclusion from quantization if args.vllm_fakequant_export: - export_hf_vllm_fq_checkpoint(full_model, export_dir=export_path) + export_hf_vllm_fq_checkpoint( + full_model, export_dir=export_path, inplace_mem_efficient=True + ) else: mtp_layer_prefixes, mtp_state_dict = load_mtp_weights( full_model, args.pyt_ckpt_path @@ -1104,6 +1109,12 @@ def quantize_main( quant_cfg = copy.deepcopy(quant_cfg) _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) + if needs_checkpoint_path_update(quant_cfg): + quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) + print( + f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" + ) + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 1908354a0a..44b2d55ba8 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -24,6 +24,8 @@ from modelopt.torch.quantization.conversion import quantizer_state from modelopt.torch.quantization.nn import QuantModule, TensorQuantizer from modelopt.torch.quantization.utils import get_quantizer_state_dict +from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector from modelopt.torch.utils import get_unwrapped_name __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -38,9 +40,75 @@ def disable_rotate(quantizer: TensorQuantizer): return False +def _fakequant_module_weights( + module: nn.Module, + module_name: str, + model: nn.Module, + state_dict: dict | None, + input_quantizers_folded_pqs: set, + fakequant_weights: set, + inplace: bool, +): + """Apply fake-quant to a single QuantModule's weights. + + When ``inplace=False``, reads/writes weights from/to ``state_dict``. + When ``inplace=True``, modifies the module's weight parameters directly. + """ + if not isinstance(module, QuantModule): + return + for attr_name, quantizer in module.named_children(): + if not ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.fake_quant + and quantizer.is_enabled + ): + continue + weight_name = attr_name.removesuffix("_quantizer") + prefix = f"{module_name}." if module_name else "" + sd_key = f"{prefix}{weight_name}" + assert sd_key not in fakequant_weights, f"Weight {sd_key} has already been fakequantized" + + if inplace: + w = getattr(module, weight_name) + w_quant = quantizer(w.float()).to(w.dtype) + else: + assert state_dict is not None + if sd_key not in state_dict: + continue + w = state_dict[sd_key] + w_quant = quantizer(w.float()).to(w.dtype) + + # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) + # Only valid when input_quantizer does NOT fake-quant activations. If it does + # fake_quant(x*s), the non-linearity prevents folding s into W. + inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") + if hasattr(module, inp_attr): + inp_q = getattr(module, inp_attr) + if ( + hasattr(inp_q, "_pre_quant_scale") + and inp_q._pre_quant_scale is not None + and inp_q._disabled + ): + scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) + w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) + inp_q_key = get_unwrapped_name( + f"{module_name}.{inp_attr}" if module_name else inp_attr, model + ) + input_quantizers_folded_pqs.add(inp_q_key) + + if inplace: + w.data.copy_(w_quant) + else: + assert state_dict is not None + state_dict[sd_key] = w_quant.cpu() + fakequant_weights.add(sd_key) + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, + inplace_mem_efficient: bool = False, ): """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. @@ -53,59 +121,56 @@ def export_hf_vllm_fq_checkpoint( Args: model: In-memory quantized model. export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. + inplace_mem_efficient: When True, applies fake-quant inplace one decoder layer at + a time using ``enable_weight_access_and_writeback``, avoiding full state + dict materialization. This is destructive — model weights are permanently + modified and weight quantizers are not re-enabled after export. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) # Step 1: Build the folded HF state dict. - # model.state_dict() returns detached copies of all tensors, so model - # parameters are never modified. Apply each weight quantizer's fake-quant - # to the corresponding weight tensor in the copy. - state_dict = model.state_dict() fakequant_weights = set() - input_quantizers_folded_pqs = ( - set() - ) # keys for input_quantizers where pre_quant_scale was folded + input_quantizers_folded_pqs = set() with torch.inference_mode(): - for module_name, module in model.named_modules(): - if not isinstance(module, QuantModule): - continue - for attr_name, quantizer in module.named_children(): - if not ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.fake_quant - and quantizer.is_enabled - ): + if inplace_mem_efficient: + # Inplace path: iterate decoder layers, one offload<->onload per layer. + decoder_layers = LayerActivationCollector.get_decoder_layers(model) + assert decoder_layers is not None, ( + "inplace_mem_efficient=True requires a model with discoverable decoder layers" + ) + for name, module in model.named_modules(): + if module not in decoder_layers: continue - weight_name = attr_name.removesuffix("_quantizer") - prefix = f"{module_name}." if module_name else "" - sd_key = f"{prefix}{weight_name}" - assert sd_key not in fakequant_weights, ( - f"Weight {sd_key} has already been fakequantized" - ) - if sd_key in state_dict: - w = state_dict[sd_key] - w_quant = quantizer(w.float()).to(w.dtype).cpu() - # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) - # Only valid when input_quantizer does NOT fake-quant activations. If it does - # fake_quant(x*s), the non-linearity prevents folding s into W. - inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") - if hasattr(module, inp_attr): - inp_q = getattr(module, inp_attr) - if ( - hasattr(inp_q, "_pre_quant_scale") - and inp_q._pre_quant_scale is not None - and inp_q._disabled - ): - scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) - w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) - inp_q_key = get_unwrapped_name( - f"{module_name}.{inp_attr}" if module_name else inp_attr, model - ) - input_quantizers_folded_pqs.add(inp_q_key) - state_dict[sd_key] = w_quant - fakequant_weights.add(sd_key) + with enable_weight_access_and_writeback(module, module): + for sub_name, sub_mod in module.named_modules(): + full_name = f"{name}.{sub_name}" if sub_name else name + _fakequant_module_weights( + sub_mod, + full_name, + model, + None, + input_quantizers_folded_pqs, + fakequant_weights, + inplace=True, + ) + # Meta tensors for offloaded weights (free); offload maps now have + # fakequanted values via writeback. + state_dict = model.state_dict() + else: + # Default path: full state_dict copy, fakequant into the copy. + state_dict = model.state_dict() + for module_name, module in model.named_modules(): + with enable_weight_access_and_writeback(module, model): + _fakequant_module_weights( + module, + module_name, + model, + state_dict, + input_quantizers_folded_pqs, + fakequant_weights, + inplace=False, + ) # Filter quantizer tensors out for a clean HF checkpoint. clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} @@ -164,6 +229,7 @@ def export_hf_vllm_fq_checkpoint( # Step 3: Save HF weights using the pre-built folded state dict. model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + if not inplace_mem_efficient: + for wq, orig_rotate in wqs_to_restore: + wq.enable() + wq._rotate = orig_rotate diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 99c729efbc..3f24ac09a4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1217,16 +1217,36 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) - use_sequential: bool = ModeloptField( + layerwise: bool = ModeloptField( default=False, - title="Enable sequential layer-by-layer calibration.", + title="Enable layerwise (layer-by-layer) calibration.", description=( - "If True, the calibration algorithm is applied sequentially to each decoder block. " - "Each layer's inputs are captured via a single forward pass that reflects the " + "If True, the calibration algorithm is applied layer by layer. " + "Each layer's inputs are captured via a forward pass that reflects the " "quantization of all preceding layers, incurring O(N) forward passes for N layers." ), ) + layerwise_checkpoint_dir: str | None = ModeloptField( + default=None, + title="Checkpoint directory for layerwise calibration.", + description=( + "If set together with layerwise=True, per-layer checkpoints are saved to this " + "directory during calibration. On restart, calibration resumes from the last " + "completed layer." + ), + ) + + @model_validator(mode="after") + def validate_layerwise_checkpoint_dir(self): + """Raise if layerwise_checkpoint_dir is set but layerwise is False.""" + if self.layerwise_checkpoint_dir is not None and not self.layerwise: + raise ValueError( + "layerwise_checkpoint_dir requires layerwise=True. " + "Set layerwise=True or remove layerwise_checkpoint_dir." + ) + return self + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index c81d5c89c7..1328ef5821 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -60,10 +60,10 @@ from .model_calib import ( awq, gptq, + layerwise_calibrate, local_hessian_calibrate, max_calibrate, mse_calibrate, - sequential_calibrate, smoothquant, svdquant, ) @@ -222,7 +222,8 @@ def wrapped_calib_func( """ kwargs = config.model_dump() method = kwargs.pop("method") - sequential = kwargs.pop("use_sequential", False) + layerwise = kwargs.pop("layerwise", False) + checkpoint_dir = kwargs.pop("layerwise_checkpoint_dir", None) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -237,17 +238,17 @@ def wrapped_calib_func( module._moe_calib_experts_ratio = moe_calib_experts_ratio if func is not None: - if sequential: + if layerwise: + # All currently implemented PTQ algorithms support layerwise calibration; + # future algorithms that need full-model context must add a guard here. if forward_loop is None: raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max", "gptq"], ( - f"Sequential calibration currently only supports max and gptq calibration, got {method}" - ) - # Wrap with sequential processing - sequential_calibrate( + # Wrap with layerwise processing + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=func, + checkpoint_dir=checkpoint_dir, **kwargs, ) else: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 35a0e931c9..6db1e82cbd 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -28,7 +28,10 @@ from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _CheckpointState, +) from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method @@ -44,6 +47,7 @@ is_quantized_column_parallel_linear, is_quantized_linear, is_quantized_row_parallel_linear, + persistent_materialization, promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, @@ -53,9 +57,9 @@ __all__ = [ "awq", + "layerwise_calibrate", "local_hessian_calibrate", "max_calibrate", - "sequential_calibrate", "smoothquant", "svdquant", ] @@ -1552,21 +1556,27 @@ def postprocess(module, name): @torch.no_grad() -def sequential_calibrate( +def layerwise_calibrate( model: nn.Module, forward_loop: ForwardLoop, calib_func: Callable, **calib_kwargs, ): - """Sequential calibration - a sequential layer-by-layer calibration algorithm. + """Layerwise calibration - a layer-by-layer calibration algorithm. Runs the full model forward per layer but patches decoder layers with a skip / run / capture strategy so that inter-layer logic in parent modules (e.g. mask construction) executes naturally without model-specific hooks. + + If ``checkpoint_dir`` is passed (via ``calib_kwargs``), per-layer checkpoints + are saved after each layer completes. On restart, calibration resumes from + the last completed layer. """ + checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) + if forward_loop is None: raise ValueError( - "forward_loop must not be None for sequential calibration. " + "forward_loop must not be None for layerwise calibration. " "Please provide a valid forward_loop callable." ) @@ -1574,31 +1584,57 @@ def sequential_calibrate( if transformer_layers is None or len(transformer_layers) == 0: raise ValueError( "Could not find transformer layers in model. " - "Sequential calibration requires a model with identifiable transformer layers." + "Layerwise calibration requires a model with identifiable transformer layers." ) - print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers") + num_layers = len(transformer_layers) + print_rank_0(f"Layerwise calibration: Found {num_layers} transformer layers") + + ckpt = _CheckpointState.from_folder(checkpoint_dir, num_layers) + start_layer = ckpt.start_layer if ckpt else 0 input_getter = LayerActivationCollector(model) input_getter._patch_all_layers(decoder_layers=transformer_layers) + resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None + try: - for layer_idx, layer in enumerate(transformer_layers): - print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}") - layer_inputs = input_getter.get_input_activations(layer, forward_loop) + # Bootstrap: get first layer's inputs (or use resumed inputs). + layer_inputs = input_getter.get_first_layer_inputs( + start_layer, resumed_inputs, forward_loop + ) + + for layer_idx in range(start_layer, num_layers): + layer = transformer_layers[layer_idx] def _layer_forward_loop(m, _inputs=layer_inputs): for args, kwargs_input in _inputs: m(*args, **kwargs_input) - calib_func(layer, _layer_forward_loop, **calib_kwargs) + with persistent_materialization(layer): + calib_func(layer, _layer_forward_loop, **calib_kwargs) + + # Run one more forward to get next layer's inputs and set + # output_meta on the just-calibrated layer (via "run" mode). + is_last = layer_idx + 1 >= num_layers + if not is_last: + next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) + else: + next_inputs = None + + if ckpt: + ckpt.save(layer_idx, layer, model, transformer_layers, next_inputs) del layer_inputs torch.cuda.empty_cache() + layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure) finally: input_getter._unpatch_all_layers() - print_rank_0("Sequential calibration completed") + if ckpt: + ckpt.full_restore(transformer_layers, model) + + print_rank_0("Layerwise calibration completed") @torch.no_grad() @@ -1610,12 +1646,12 @@ def gptq( ): """GPTQ quantization. - Works in two modes depending on ``use_sequential`` in the config: + Works in two modes depending on ``layerwise`` in the config: - * **Sequential** (``use_sequential=True``): ``sequential_calibrate`` calls this + * **Layerwise** (``layerwise=True``): ``layerwise_calibrate`` calls this function once per decoder layer with updated activations, producing more accurate Hessian estimates. - * **Non-sequential** (``use_sequential=False``): called once on the full model. + * **Non-layerwise** (``layerwise=False``): called once on the full model. All layers are quantized in parallel from the original activations. Per-module steps: @@ -1628,7 +1664,7 @@ def gptq( Args: model: The module to quantize — either the full model or a single decoder - layer when invoked by ``sequential_calibrate``. + layer when invoked by ``layerwise_calibrate``. forward_loop: Callable that replays calibration inputs through *model*. perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01). block_size: Block size for GPTQ weight update. @@ -1663,8 +1699,10 @@ def gptq( handle.cleanup() print_rank_0("Updating weights using GPTQ algorithm...") + name_to_module = dict(model.named_modules()) for handle in gptq_handles.values(): - handle.update_weights(block_size, perc_damp) + with enable_weight_access_and_writeback(handle.module, model, name_to_module): + handle.update_weights(block_size, perc_damp) handle.free() del gptq_handles diff --git a/modelopt/torch/quantization/plugins/accelerate.py b/modelopt/torch/quantization/plugins/accelerate.py index 13999df0f0..f80e2478dc 100644 --- a/modelopt/torch/quantization/plugins/accelerate.py +++ b/modelopt/torch/quantization/plugins/accelerate.py @@ -31,51 +31,77 @@ __all__ = ["init_quantized_weights"] -def _get_cpu_offload_hook(hook): +def _get_offload_hook(hook): if isinstance(hook, AlignDevicesHook) and hook.offload and hook.weights_map is not None: - assert "weight" in hook.weights_map - if ( - isinstance(hook.weights_map, PrefixedDataset) - and hook.weights_map.prefix + "weight" not in hook.weights_map.dataset.state_dict - ): - raise NotImplementedError( - "This layer could be offloaded to disk. We don't support this yet." - ) + assert len(hook.weights_map) > 0 return hook elif isinstance(hook, SequentialHook): for h in hook.hooks: - align_hook = _get_cpu_offload_hook(h) + align_hook = _get_offload_hook(h) if align_hook is not None: return align_hook return None +def _writeback_params_to_weights_map(module, align_hook): + """Write all non-meta parameters and buffers back to the hook's CPU weights_map.""" + for name, tensor in module.state_dict(keep_vars=True).items(): + if tensor.device.type == "meta": + continue + if isinstance(align_hook.weights_map, PrefixedDataset): + key = align_hook.weights_map.prefix + name + w_map = align_hook.weights_map.dataset.state_dict + else: + w_map = align_hook.weights_map + key = name + if key in w_map: + w_map[key] = tensor.detach().to(w_map[key].device, dtype=w_map[key].dtype) + elif ( + isinstance(align_hook.weights_map, PrefixedDataset) + and hasattr(align_hook.weights_map.dataset, "index") + and key in align_hook.weights_map.dataset.index + ): + # Disk-offloaded weight: promote into state_dict so the next + # pre_forward picks up the modified tensor instead of the stale + # on-disk version. OffloadedWeightsLoader.__getitem__ gives + # state_dict priority over index, so this is sufficient. + w_map[key] = tensor.detach().cpu() + + @contextmanager def weight_access_and_writeback_context(module): - """Context manager for weight access and writeback for modules managed by accelerate.""" + """Context manager for weight access and writeback for modules managed by accelerate. + + Handles CPU-offloaded and disk-offloaded models. Iterates over the module and all + its descendants, materializing weights from any offload hook found and writing them + back on exit. ``pre_forward`` is skipped on modules whose weights are already + materialized (not on meta) to avoid overwriting them with stale CPU copies. + """ assert hasattr(module, "_hf_hook") - align_hook = _get_cpu_offload_hook(module._hf_hook) - if align_hook: - # Accelerate uses AlignDevicesHook to offload weights to CPU/Disk and then reload them in the forward pass - # The CPU/Disk offloaded weights are managed by PrefixDataset and OffloadedWeightsLoader - # See https://github.com/huggingface/accelerate/blame/f48d95c4939b281505a45b3d6e0bf554b65cc1ea/src/accelerate/utils/offload.py#L104-L141 - # TODO: Add support for disk-offloaded models if needed (they will be really slow, hence low priority) + materialized: list[tuple[torch.nn.Module, AlignDevicesHook, bool]] = [] + for mod in module.modules(): + if not hasattr(mod, "_hf_hook"): + continue + hook = _get_offload_hook(mod._hf_hook) + if hook is None: + continue + # Only call pre_forward if weights need materializing; already-materialized + # weights would be overwritten with stale CPU state_dict values. + needs_materialize = any(p.device.type == "meta" for p in mod.parameters()) + if needs_materialize: + hook.pre_forward(mod) + hook.offload = False + materialized.append((mod, hook, needs_materialize)) - # This will load the weights from CPU state_dict and move it to the GPU from meta device - align_hook.pre_forward(module) try: yield finally: - if align_hook: - # Update the weight in the CPU state_dict - if isinstance(align_hook.weights_map, PrefixedDataset): - key = align_hook.weights_map.prefix + "weight" - w_map = align_hook.weights_map.dataset.state_dict - else: - key, w_map = "weight", align_hook.weights_map - w_map[key] = module.weight.data.to(w_map[key].device, dtype=w_map[key].dtype) - align_hook.post_forward(module, None) + for mod, hook, was_materialized in materialized: + hook.offload = True + _writeback_params_to_weights_map(mod, hook) + if was_materialized: + hook.post_forward(mod, None) @contextmanager diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 82ab589934..59bcd215bb 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -39,7 +39,7 @@ from ..nn.modules.quant_linear import _QuantLinear from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE from ..utils import replace_function, sync_moe_expert_amax -from ..utils.activation_collector import LayerActivationCollector +from ..utils.layerwise_calib import LayerActivationCollector from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin diff --git a/modelopt/torch/quantization/utils/__init__.py b/modelopt/torch/quantization/utils/__init__.py index 2660363209..dfc23c42ee 100644 --- a/modelopt/torch/quantization/utils/__init__.py +++ b/modelopt/torch/quantization/utils/__init__.py @@ -16,8 +16,8 @@ # ruff: noqa: F405 """Quantization utilities.""" -from .activation_collector import LayerActivationCollector from .core_utils import * +from .layerwise_calib import LayerActivationCollector __all__ = [ "EXPORT_MODE", diff --git a/modelopt/torch/quantization/utils/activation_collector.py b/modelopt/torch/quantization/utils/activation_collector.py deleted file mode 100644 index 5f187fdcb2..0000000000 --- a/modelopt/torch/quantization/utils/activation_collector.py +++ /dev/null @@ -1,335 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sequential calibration layer patching and activation capture. - -This module provides :class:`LayerActivationCollector`, a stateful helper that -patches decoder layers with a skip / run / capture strategy for efficient -layer-by-layer calibration. -""" - -from collections import deque -from dataclasses import dataclass, field -from typing import Any - -import torch -import torch.nn as nn - -from modelopt.torch.opt.searcher import ForwardLoop -from modelopt.torch.utils import print_rank_0 -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method - - -class _EarlyStopForwardError(Exception): - """Raised to halt the forward pass after capturing layer inputs.""" - - -@dataclass -class _LayerCalibState: - """Mutable per-layer state used during sequential calibration. - - Attached to each decoder layer as ``_seq_calib`` and accessed by the - patched forward to decide skip / run / capture / original behaviour. - """ - - mode: str = "original" - name: str = "" - cached_inputs: deque = field(default_factory=deque) - collected_inputs: list = field(default_factory=list) - output_meta: tuple | None = None - - -class LayerActivationCollector: - """Collects layer activations for sequential (layer-by-layer) calibration. - - Each decoder layer is patched with a unified forward whose behaviour is - governed by a per-layer :class:`_LayerCalibState`: - - * **skip** — return a zero-filled dummy whose shape and type match the - layer's real output (reconstructed from lightweight metadata). No - computation is performed. The correctly shaped dummy ensures un-patched - inter-layer operations in the parent forward (e.g. LayerNorm, tuple - unpacking) do not raise shape or type errors. - * **run** — replay previously captured inputs through the original forward, - ignoring whatever the parent passes in. Only the just-calibrated layer - uses this mode, so its output reflects updated weights. - * **capture** — record ``(args, kwargs)`` and raise - ``_EarlyStopForwardError`` to halt the forward pass early. - * **original** — call the original forward unchanged. - - Because the *run* layer discards upstream values, skip-layer outputs are - never consumed for real computation. - """ - - # Global registry of (predicate, discoverer) pairs. Populated at import time - # by plugins (e.g. huggingface.py, megatron.py). Order matters: the first - # matching entry wins, so more specific predicates (e.g. Nemotron-H) must be - # registered before generic ones (e.g. homogeneous HF models). - # - # This is intentionally a mutable class variable shared across all instances: - # plugins register once at import time, and the registry is read-only after - # that. register_decoder_layer_support() guards against duplicate entries. - _decoder_layer_support: list[tuple[Any, Any]] = [] - _LAYER_ATTR = "_seq_calib" - - def __init__(self, model: nn.Module): - """Initialize the collector for the given model.""" - self.model = model - self._decoder_layers: nn.ModuleList | None = None - self._layer_to_idx: dict[nn.Module, int] = {} - self._patched = False - - @staticmethod - def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: - """Return decoder layers supported by sequential calibration.""" - for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: - if not is_supported(model): - continue - decoder_layers = discoverer(model) - if decoder_layers is not None: - return decoder_layers - return None - - @staticmethod - def is_supported(model: nn.Module) -> bool: - """Whether the model supports decoder-layer sequential calibration.""" - return LayerActivationCollector.get_decoder_layers(model) is not None - - @classmethod - def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): - """Register a (predicate, discoverer) pair for decoder-layer detection.""" - entry = (is_supported, discoverer) - if entry not in cls._decoder_layer_support: - cls._decoder_layer_support.append(entry) - - @staticmethod - def _extract_output_meta(output): - """Extract lightweight (shape, dtype, device) metadata from a layer output. - - Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). - The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a - zero-filled output with identical shape and type. - """ - if isinstance(output, torch.Tensor): - return ("tensor", output.shape, output.dtype, output.device) - if isinstance(output, tuple): - return ( - "tuple", - tuple(LayerActivationCollector._extract_output_meta(o) for o in output), - ) - if isinstance(output, list): - return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) - return ("other", output) - - @staticmethod - def _zeros_from_meta(meta): - """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" - tag = meta[0] - if tag == "tensor": - _, shape, dtype, device = meta - return torch.zeros(shape, dtype=dtype, device=device) - if tag == "tuple": - return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) - if tag == "list": - return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] - # "other" values are expected to be lightweight non-tensors (e.g. None, small scalars). - # The value is returned directly (not copied); callers must not mutate it. - # In practice this is safe because skip-mode outputs are immediately discarded by the - # downstream run-mode layer, which replays from its own cached inputs instead. - return meta[1] - - def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): - """Bind the unified forward to every decoder layer and the model. Called once. - - Args: - decoder_layers: Pre-resolved decoder layers. If *None*, layers are - discovered via :meth:`get_decoder_layers`. - """ - - def _patched_forward(self, *args, **kwargs): - """Unified forward bound to every decoder layer during sequential calibration. - - ``self`` here is the decoder layer module (bound via ``bind_forward_method``). - All per-layer state is accessed through ``self._seq_calib``. - """ - info: _LayerCalibState = self._seq_calib - - if info.mode == "skip": - if info.output_meta is None: - raise RuntimeError( - f"Layer {info.name} is in 'skip' mode but has no output_meta. " - "This indicates a state-machine bug: the layer should have run " - "in 'run' mode (which sets output_meta) before transitioning to 'skip'." - ) - return LayerActivationCollector._zeros_from_meta(info.output_meta) - - if info.mode == "run": - assert info.cached_inputs, ( - f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." - ) - real_args, real_kwargs = info.cached_inputs.popleft() - output = self._original_forward(*real_args, **real_kwargs) - info.output_meta = LayerActivationCollector._extract_output_meta(output) - return output - - if info.mode == "capture": - info.collected_inputs.append((args, kwargs)) - raise _EarlyStopForwardError() - - return self._original_forward(*args, **kwargs) - - if decoder_layers is not None: - self._decoder_layers = decoder_layers - else: - self._decoder_layers = self.get_decoder_layers(self.model) - assert self._decoder_layers is not None - - self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} - module_to_name = {m: name for name, m in self.model.named_modules()} - - try: - for layer in self._decoder_layers: - layer._seq_calib = _LayerCalibState( - name=module_to_name.get(layer, type(layer).__name__), - ) - bind_forward_method(layer, _patched_forward, "_original_forward") - - def _early_stop_forward(module_self, *args, **kwargs): - try: - return module_self._original_forward(*args, **kwargs) - except _EarlyStopForwardError: - return None - - bind_forward_method(self.model, _early_stop_forward, "_original_forward") - except Exception: - self._cleanup_layers() - raise - - self._patched = True - - def _cleanup_layers(self): - """Best-effort cleanup of any patched layers and model forward.""" - if hasattr(self.model, "_original_forward"): - unpatch_forward_method(self.model, "_original_forward") - - if self._decoder_layers is not None: - for layer in self._decoder_layers: - if hasattr(layer, "_original_forward"): - unpatch_forward_method(layer, "_original_forward") - if hasattr(layer, self._LAYER_ATTR): - delattr(layer, self._LAYER_ATTR) - - def _unpatch_all_layers(self): - """Restore original forwards and clean up state attributes. Called once.""" - if not self._patched: - return - self._cleanup_layers() - self._patched = False - - def _set_layer_states(self, layer_idx: int): - """Transition layer modes for the next calibration step. - - When calibrating layer *i*, three transitions happen: - - * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). - * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). - * Layer ``i`` → **capture** (record inputs, then early-stop). - """ - assert self._decoder_layers is not None - - if layer_idx > 1: - done = self._decoder_layers[layer_idx - 2]._seq_calib - # output_meta is intentionally kept: skip mode needs it to produce - # correctly shaped zero-filled outputs for the parent forward. - done.mode = "skip" - done.cached_inputs.clear() - - if layer_idx > 0: - prev = self._decoder_layers[layer_idx - 1]._seq_calib - if not prev.collected_inputs: - raise RuntimeError( - f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " - "Layers must be calibrated sequentially — ensure get_input_activations() " - "was called for every preceding layer in order." - ) - prev.mode = "run" - prev.cached_inputs = deque(prev.collected_inputs) - prev.collected_inputs = [] - - cur = self._decoder_layers[layer_idx]._seq_calib - cur.mode = "capture" - cur.collected_inputs = [] - - def _log_layer_summary(self, layer_idx: int): - """Log a one-line summary of layer modes for the current calibration step.""" - assert self._decoder_layers is not None - n = len(self._decoder_layers) - groups: dict[str, list[int]] = {} - for i, layer in enumerate(self._decoder_layers): - mode = layer._seq_calib.mode - if mode in ("skip", "run", "capture"): - groups.setdefault(mode, []).append(i + 1) - parts = [f"{mode}: {groups[mode]}" for mode in ("skip", "run", "capture") if mode in groups] - print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - @torch.no_grad() - def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: - """Collect input activations for *layer* by running a full model forward. - - Layers before the target are skipped or re-run (if just calibrated), the - target layer captures its inputs, and an early-stop prevents unnecessary - computation beyond the target. - - :meth:`_patch_all_layers` must be called before this method. - - Note: the model forward returns ``None`` for every batch during capture - (because ``_EarlyStopForwardError`` short-circuits the forward pass). - Callers should not rely on the model's return value within *forward_loop*. - """ - if not self._patched: - raise RuntimeError( - "get_input_activations() requires _patch_all_layers() to be called first." - ) - layer_idx = self._layer_to_idx[layer] - self._set_layer_states(layer_idx) - self._log_layer_summary(layer_idx) - - info = layer._seq_calib - try: - forward_loop(self.model) - except Exception: - # Reset the current layer so subsequent calls don't see stale state. - info.mode = "original" - info.collected_inputs = [] - raise - - if not info.collected_inputs: - info.mode = "original" - raise RuntimeError( - f"Layer {info.name!r} collected no inputs during forward_loop. " - "The forward loop did not reach this layer — check that forward_loop() " - "actually calls the model and that the layer is in the forward path." - ) - - inputs = list(info.collected_inputs) - # After capture, set to original so calib_func can call the layer's - # real forward directly. The layer will transition to run → skip - # in subsequent iterations via _set_layer_states. - info.mode = "original" - return inputs diff --git a/modelopt/torch/quantization/utils/calib_utils.py b/modelopt/torch/quantization/utils/calib_utils.py index e52a8438d5..e2d8ccf2a2 100644 --- a/modelopt/torch/quantization/utils/calib_utils.py +++ b/modelopt/torch/quantization/utils/calib_utils.py @@ -96,7 +96,7 @@ def __init__(self, module, name, offload_to_cpu=False): self.name = name in_features = module.weight.shape[-1] device = module.weight.device - if offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65: + if device.type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65): device = "cpu" self.hessian = torch.zeros(in_features, in_features, dtype=torch.float32, device=device) self.n_samples = 0 diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 273d7564c6..29661e18f5 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -423,47 +423,70 @@ def _get_enclosing_fsdp_module( return root_model +def _set_parameter(module: nn.Module, name: str, value: nn.Parameter): + """Set a parameter on a module by dotted name (e.g. ``self_attn.q_proj.weight``).""" + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent = module.get_submodule(parts[0]) + attr = parts[1] + else: + parent = module + attr = name + parent._parameters[attr] = value + + @contextmanager def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module): """Context manager for FSDP2 weight access and writeback. - Note this context will gather the weight across FSDP/HSDP shards. If TP is implemented with DTensor, - the weight will be a local tensor of the TP DTensor under this context. + Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be + read or modified. Works for both leaf modules (single ``weight``) and + composite modules like decoder layers (all ``named_parameters``). + + If TP is implemented with DTensor, the weight will be a local tensor of the + TP DTensor under this context. """ assert isinstance(root_model, torch.distributed.fsdp.FSDPModule), "We only support FSDP2" assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" - assert isinstance(module.weight, torch.distributed.tensor.DTensor) fsdp_module = _get_enclosing_fsdp_module(module, root_model) assert fsdp_module is not None, "Module is not wrapped by FSDP" fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) fsdp_dim = fsdp_device_mesh.ndim - original_placements = module.weight.placements - original_device_mesh = module.weight.device_mesh - original_weight = module.weight - # Assuming the first fsdp_dim dimensions are for FSDP/HSDP, we only collect the tensor over FSDP/HSDP dimension, - # the TP will be handled by the TP reduction. - if fsdp_dim != original_device_mesh.ndim: - assert fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim], ( - "FSDP2 mesh should be a slice of DTesnor's device mesh." + # Collect all DTensor parameters, replacing them with local replicated copies. + originals: dict[str, tuple] = {} + for name, param in module.named_parameters(): + if not isinstance(param, torch.distributed.tensor.DTensor): + continue + original_placements = param.placements + original_device_mesh = param.device_mesh + if fsdp_dim != original_device_mesh.ndim: + assert ( + fsdp_device_mesh.mesh_dim_names == original_device_mesh.mesh_dim_names[:fsdp_dim] + ), "FSDP2 mesh should be a slice of DTensor's device mesh." + collected = param.redistribute( + placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), + device_mesh=original_device_mesh, ) - - weight_collected = original_weight.redistribute( - placements=[Replicate()] * fsdp_dim + list(original_placements[fsdp_dim:]), - device_mesh=original_device_mesh, - ) - new_weight = nn.Parameter(weight_collected.to_local()) - module._parameters["weight"] = new_weight + originals[name] = (param, collected, original_placements, original_device_mesh) + _set_parameter(module, name, nn.Parameter(collected.to_local())) yield - original_weight.to_local().data.copy_( - weight_collected.redistribute( - placements=original_placements, device_mesh=original_device_mesh - ).to_local() - ) - module._parameters["weight"] = original_weight + # Write back and restore original DTensor parameters. + for name, ( + original_param, + collected, + original_placements, + original_device_mesh, + ) in originals.items(): + original_param.to_local().data.copy_( + collected.redistribute( + placements=original_placements, device_mesh=original_device_mesh + ).to_local() + ) + _set_parameter(module, name, original_param) @contextmanager @@ -471,7 +494,7 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict """Enable weight access and writeback for a module. Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or - HF accelerate CPU off-loaded models. + HF accelerate offloaded models (CPU or disk). Args: module: The module to access weights for. @@ -498,6 +521,22 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict yield +@contextmanager +def persistent_materialization(layer): + """Keep all layer weights materialized on GPU for the duration. + + Suppresses per-forward weight transfers so that N calibration batches + pay the cost of one load/unload instead of N. + + - **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then + gathers weights once via ``enable_weight_access_and_writeback``. + - **Accelerate**: materializes weights and sets ``hook.offload = False`` + so per-forward hooks skip materialization/offloading. + """ + with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer): + yield + + def get_quantizer_state_dict(model: nn.Module): """Get the state dict of the quantizers in the model.""" # We should not call model.state_dict() here. @@ -607,6 +646,24 @@ def _init_mp_dtypes(self) -> None: ) +@contextmanager +def _disable_fsdp_unshard_reshard(layer): + """Disable FSDP2 unshard/reshard if *layer* is FSDP-wrapped.""" + if isinstance(layer, FSDPModule): + _pg_cls = torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup + orig_unshard = _pg_cls.unshard + orig_reshard = _pg_cls.reshard + _pg_cls.unshard = lambda self, async_op=False: None + _pg_cls.reshard = lambda self: None + try: + yield + finally: + _pg_cls.unshard = orig_unshard + _pg_cls.reshard = orig_reshard + else: + yield + + def get_prefixed_param_names(parent_model, target_module): """Get parameter names for a target module prefixed with the parent model name. diff --git a/modelopt/torch/quantization/utils/layerwise_calib.py b/modelopt/torch/quantization/utils/layerwise_calib.py new file mode 100644 index 0000000000..aed403ad87 --- /dev/null +++ b/modelopt/torch/quantization/utils/layerwise_calib.py @@ -0,0 +1,684 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layerwise calibration layer patching, activation capture, and checkpoint save/resume. + +This module provides :class:`LayerActivationCollector`, a stateful helper that +patches decoder layers with a skip / run / capture strategy for efficient +layer-by-layer calibration, and :class:`_CheckpointState` for persisting +per-layer calibration progress to disk. +""" + +from __future__ import annotations + +import json +import os +import shutil +from collections import deque +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn + +from modelopt.torch.utils import distributed as dist +from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.network import ( + bind_forward_method, + get_module_device, + unpatch_forward_method, +) + +if TYPE_CHECKING: + from modelopt.torch.opt.searcher import ForwardLoop + + +class _EarlyStopForwardError(Exception): + """Raised to halt the forward pass after capturing layer inputs.""" + + +@dataclass +class _LayerCalibState: + """Mutable per-layer state used during layerwise calibration. + + Attached to each decoder layer as ``_layerwise_calib`` and accessed by the + patched forward to decide skip / run / capture / original behaviour. + """ + + mode: str = "original" + name: str = "" + cached_inputs: deque = field(default_factory=deque) + collected_inputs: list = field(default_factory=list) + output_meta: tuple | None = None + + +class _SkipLayer(nn.Module): + """Parameter-free stand-in for a fully calibrated decoder layer. + + Replaces the real layer in the ModuleList so that framework hooks + (accelerate, FSDP2, etc.) have no parameters to transfer. Holds a + reference to the original layer for restoration during cleanup. + """ + + def __init__(self, original: nn.Module): + super().__init__() + # Bypass nn.Module.__setattr__ to avoid registering original as a submodule. + object.__setattr__(self, "_original", original) + self._layerwise_calib = _LayerCalibState(mode="skip") + + _PROXY_BLOCKLIST = frozenset({"_hf_hook", "_old_forward"}) + + def __getattr__(self, name: str): + # Proxy non-special attribute lookups to the original layer so that + # parent-model code that accesses layer-level attributes (e.g., + # NemotronH's ``block_type``) still works when the layer is replaced + # with a _SkipLayer. Accelerate hook attrs are blocked so the + # framework does not attempt to manage this parameter-free stand-in. + try: + return super().__getattr__(name) + except AttributeError: + if name in self._PROXY_BLOCKLIST: + raise + return getattr(object.__getattribute__(self, "_original"), name) + + def forward(self, *args, **kwargs): + return LayerActivationCollector._zeros_from_meta( + self._original._layerwise_calib.output_meta + ) + + +class LayerActivationCollector: + """Collects layer activations for layerwise (layer-by-layer) calibration. + + Each decoder layer is patched with a unified forward whose behaviour is + governed by a per-layer :class:`_LayerCalibState`: + + * **skip** — return a zero-filled dummy whose shape and type match the + layer's real output (reconstructed from lightweight metadata). No + computation is performed. The correctly shaped dummy ensures un-patched + inter-layer operations in the parent forward (e.g. LayerNorm, tuple + unpacking) do not raise shape or type errors. + * **run** — replay previously captured inputs through the original forward, + ignoring whatever the parent passes in. Only the just-calibrated layer + uses this mode, so its output reflects updated weights. + * **capture** — record ``(args, kwargs)`` and raise + ``_EarlyStopForwardError`` to halt the forward pass early. + * **original** — call the original forward unchanged. + + Because the *run* layer discards upstream values, skip-layer outputs are + never consumed for real computation. + """ + + _decoder_layer_support: list[tuple[Any, Any]] = [] + _LAYER_ATTR = "_layerwise_calib" + + def __init__(self, model: nn.Module): + """Initialize the collector for the given model.""" + self.model = model + self._decoder_layers: nn.ModuleList | None = None + self._layer_to_idx: dict[nn.Module, int] = {} + self._patched = False + + def _swap_to_dummy(self, idx: int): + """Replace decoder layer *idx* with a parameter-free dummy. + + ``output_meta`` is intentionally preserved on the original layer: the + ``_SkipLayer`` reads it to produce correctly shaped zero-filled outputs + for the parent forward pass. + """ + assert self._decoder_layers is not None + layer = self._decoder_layers[idx] + layer._layerwise_calib.mode = "skip" + layer._layerwise_calib.cached_inputs.clear() + self._decoder_layers[idx] = _SkipLayer(layer) + + @staticmethod + def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None: + """Return decoder layers supported by layerwise calibration.""" + for is_supported, discoverer in LayerActivationCollector._decoder_layer_support: + if not is_supported(model): + continue + decoder_layers = discoverer(model) + if decoder_layers is not None: + return decoder_layers + return None + + @staticmethod + def is_supported(model: nn.Module) -> bool: + """Whether the model supports decoder-layer layerwise calibration.""" + return LayerActivationCollector.get_decoder_layers(model) is not None + + @classmethod + def register_decoder_layer_support(cls, is_supported: Any, discoverer: Any): + """Register a (predicate, discoverer) pair for decoder-layer detection.""" + entry = (is_supported, discoverer) + if entry not in cls._decoder_layer_support: + cls._decoder_layer_support.append(entry) + + @staticmethod + def _extract_output_meta(output): + """Extract lightweight (shape, dtype, device) metadata from a layer output. + + Recursively handles tensors, tuples, lists, and non-tensor values (e.g. None). + The returned structure can be passed to ``_zeros_from_meta`` to reconstruct a + zero-filled output with identical shape and type. + """ + if isinstance(output, torch.Tensor): + return ("tensor", output.shape, output.dtype, output.device) + if isinstance(output, tuple): + return ( + "tuple", + tuple(LayerActivationCollector._extract_output_meta(o) for o in output), + ) + if isinstance(output, list): + return ("list", [LayerActivationCollector._extract_output_meta(o) for o in output]) + return ("other", output) + + @staticmethod + def _zeros_from_meta(meta): + """Reconstruct a zero-filled output from metadata produced by ``_extract_output_meta``.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, device = meta + return torch.zeros(shape, dtype=dtype, device=device) + if tag == "tuple": + return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1]) + if tag == "list": + return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]] + # "other" values are lightweight non-tensors (e.g. None, small scalars). + # Returned directly (not copied); safe because skip-mode outputs are + # immediately discarded by the downstream run-mode layer. + return meta[1] + + def _patch_all_layers(self, decoder_layers: nn.ModuleList | None = None): + """Bind the unified forward to every decoder layer and the model. Called once. + + Args: + decoder_layers: Pre-resolved decoder layers. If *None*, layers are + discovered via :meth:`get_decoder_layers`. + """ + + def _patched_forward(self, *args, **kwargs): + info: _LayerCalibState = self._layerwise_calib + + if info.mode == "skip": + if info.output_meta is None: + raise RuntimeError( + f"Layer {info.name} is in 'skip' mode but has no output_meta. " + "This indicates a state-machine bug: the layer should have run " + "in 'run' mode (which sets output_meta) before transitioning to 'skip'." + ) + return LayerActivationCollector._zeros_from_meta(info.output_meta) + + if info.mode == "run": + assert info.cached_inputs, ( + f"Layer {info.name} is in 'run' mode but has no cached inputs to replay." + ) + real_args, real_kwargs = info.cached_inputs.popleft() + output = self._original_forward(*real_args, **real_kwargs) + info.output_meta = LayerActivationCollector._extract_output_meta(output) + return output + + if info.mode == "capture": + info.collected_inputs.append((args, kwargs)) + raise _EarlyStopForwardError() + + return self._original_forward(*args, **kwargs) + + if decoder_layers is not None: + self._decoder_layers = decoder_layers + else: + self._decoder_layers = self.get_decoder_layers(self.model) + assert self._decoder_layers is not None + + self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)} + module_to_name = {m: name for name, m in self.model.named_modules()} + + try: + for layer in self._decoder_layers: + layer._layerwise_calib = _LayerCalibState( + name=module_to_name.get(layer, type(layer).__name__), + ) + bind_forward_method(layer, _patched_forward, "_original_forward") + + def _early_stop_forward(module_self, *args, **kwargs): + try: + return module_self._original_forward(*args, **kwargs) + except _EarlyStopForwardError: + return None + + bind_forward_method(self.model, _early_stop_forward, "_original_forward") + except Exception: + self._cleanup_layers() + raise + + self._patched = True + + def _cleanup_layers(self): + """Best-effort cleanup of any patched layers and model forward.""" + if self._decoder_layers is not None: + for idx, layer in enumerate(self._decoder_layers): + if isinstance(layer, _SkipLayer): + self._decoder_layers[idx] = layer._original + + if hasattr(self.model, "_original_forward"): + unpatch_forward_method(self.model, "_original_forward") + + if self._decoder_layers is not None: + for layer in self._decoder_layers: + if hasattr(layer, "_original_forward"): + unpatch_forward_method(layer, "_original_forward") + if hasattr(layer, self._LAYER_ATTR): + delattr(layer, self._LAYER_ATTR) + + def _unpatch_all_layers(self): + """Restore original forwards and clean up state attributes. Called once.""" + if not self._patched: + return + self._cleanup_layers() + self._patched = False + + def _set_layer_states(self, layer_idx: int): + """Transition layer modes for the next calibration step. + + When calibrating layer *i*, three transitions happen: + + * Layer ``i - 2`` → **skip** (fully done, free its cached inputs). + * Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights). + * Layer ``i`` → **capture** (record inputs, then early-stop). + """ + assert self._decoder_layers is not None + + if layer_idx > 1: + idx = layer_idx - 2 + if not isinstance(self._decoder_layers[idx], _SkipLayer): + self._swap_to_dummy(idx) + + if layer_idx > 0: + prev = self._decoder_layers[layer_idx - 1]._layerwise_calib + if not prev.collected_inputs: + raise RuntimeError( + f"Layer {layer_idx - 1} ({prev.name!r}) has no collected inputs to replay. " + "Layers must be calibrated sequentially — ensure get_input_activations() " + "was called for every preceding layer in order." + ) + prev.mode = "run" + prev.cached_inputs = deque(prev.collected_inputs) + prev.collected_inputs = [] + + cur = self._decoder_layers[layer_idx]._layerwise_calib + cur.mode = "capture" + cur.collected_inputs = [] + + def _log_layer_summary(self, layer_idx: int): + """Log a one-line summary of layer modes for the current calibration step.""" + assert self._decoder_layers is not None + n = len(self._decoder_layers) + groups: dict[str, list[int]] = {} + for i, layer in enumerate(self._decoder_layers): + mode = layer._layerwise_calib.mode + if mode in ("skip", "run", "capture"): + groups.setdefault(mode, []).append(i + 1) + + parts = [] + for mode in ("skip", "run", "capture"): + if mode not in groups: + continue + ids = groups[mode] + parts.append(f"{mode}: {len(ids)}" if mode == "skip" else f"{mode}: {ids}") + print_rank_0(f"Calibrating layer {layer_idx + 1}/{n} | {' | '.join(parts)}") + + @torch.no_grad() + def get_input_activations(self, layer: torch.nn.Module, forward_loop: ForwardLoop) -> list: + """Collect input activations for *layer* by running a full model forward. + + Layers before the target are skipped or re-run (if just calibrated), the + target layer captures its inputs, and an early-stop prevents unnecessary + computation beyond the target. + + :meth:`_patch_all_layers` must be called before this method. + + Note: the model forward returns ``None`` for every batch during capture + (because ``_EarlyStopForwardError`` short-circuits the forward pass). + Callers should not rely on the model's return value within *forward_loop*. + """ + if not self._patched: + raise RuntimeError( + "get_input_activations() requires _patch_all_layers() to be called first." + ) + layer_idx = self._layer_to_idx[layer] + self._set_layer_states(layer_idx) + self._log_layer_summary(layer_idx) + + info = layer._layerwise_calib + try: + forward_loop(self.model) + except Exception: + # Reset the current layer so subsequent calls don't see stale state. + info.mode = "original" + info.collected_inputs = [] + raise + + if not info.collected_inputs: + info.mode = "original" + raise RuntimeError( + f"Layer {info.name!r} collected no inputs during forward_loop. " + "The forward loop did not reach this layer — check that forward_loop() " + "actually calls the model and that the layer is in the forward path." + ) + + inputs = list(info.collected_inputs) + # Reset to original so calib_func can call the layer's real forward + # directly. The layer will transition to run → skip in subsequent + # iterations via _set_layer_states. + info.mode = "original" + return inputs + + def get_first_layer_inputs( + self, + start_layer: int, + resumed_inputs: list | None, + forward_loop: ForwardLoop, + ) -> list: + """Get inputs for the first layer to calibrate, handling resume. + + If *resumed_inputs* is provided, sets skip mode on layers ``0..start_layer-1`` + and seeds the start layer's ``collected_inputs`` for subsequent + ``cache_outputs_for_next_layer_calib`` calls. Otherwise, captures inputs + via a normal forward pass. + """ + assert self._decoder_layers is not None + + if resumed_inputs is not None: + print_rank_0(f"Calibrating layer {start_layer + 1} (resumed)") + for i in range(start_layer): + self._swap_to_dummy(i) + layer = self._decoder_layers[start_layer] + layer._layerwise_calib.collected_inputs = resumed_inputs + layer._layerwise_calib.mode = "original" + return resumed_inputs + + return self.get_input_activations(self._decoder_layers[start_layer], forward_loop) + + @torch.no_grad() + def cache_outputs_for_next_layer_calib( + self, layer: torch.nn.Module, forward_loop: ForwardLoop + ) -> list: + """Run a forward pass after calibrating *layer* to capture the next layer's inputs. + + This puts *layer* into "run" mode (setting its ``output_meta``) and the + next layer into "capture" mode, then runs *forward_loop*. Returns the + captured inputs for the next layer. + + Must be called only when a next layer exists (i.e. *layer* is not the + last decoder layer). + """ + assert self._decoder_layers is not None + layer_idx = self._layer_to_idx[layer] + next_idx = layer_idx + 1 + assert next_idx < len(self._decoder_layers), "No next layer to capture inputs for." + from .core_utils import persistent_materialization + + next_layer = self._decoder_layers[next_idx] + with persistent_materialization(layer): + return self.get_input_activations(next_layer, forward_loop) + + +def _move_to_device(obj: Any, device: torch.device) -> Any: + """Recursively move tensors to *device*. Non-tensors are returned as-is.""" + if isinstance(obj, torch.Tensor): + return obj.to(device) + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + moved = [_move_to_device(v, device) for v in obj] + return type(obj)(moved) + return obj + + +def _remap_output_metadata_device(meta: tuple, device: torch.device) -> tuple: + """Patch the device field inside output_meta tuples so _zeros_from_meta uses *device*.""" + tag = meta[0] + if tag == "tensor": + _, shape, dtype, _old_device = meta + return ("tensor", shape, dtype, device) + if tag == "tuple": + return ("tuple", tuple(_remap_output_metadata_device(m, device) for m in meta[1])) + if tag == "list": + return ("list", [_remap_output_metadata_device(m, device) for m in meta[1]]) + return meta + + +def _read_manifest(checkpoint_dir: str) -> dict | None: + """Read manifest.json from *checkpoint_dir*. Returns None if missing or corrupt.""" + path = os.path.join(checkpoint_dir, "manifest.json") + if not os.path.isfile(path): + return None + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + +def _write_manifest(checkpoint_dir: str, last_completed_layer: int, num_layers: int) -> None: + """Atomically write manifest.json.""" + path = os.path.join(checkpoint_dir, "manifest.json") + tmp = path + ".tmp" + with open(tmp, "w") as f: + json.dump( + {"last_completed_layer": last_completed_layer, "num_layers": num_layers}, + f, + ) + os.replace(tmp, path) + + +def _layer_dir(checkpoint_dir: str, idx: int) -> str: + return os.path.join(checkpoint_dir, f"layer_{idx:04d}") + + +def _save_layer( + checkpoint_dir: str, + idx: int, + weights: dict, + qstate: dict, + output_meta: tuple, + next_inputs: list | None, + num_layers: int, +) -> None: + """Save a single layer checkpoint and update the manifest atomically.""" + d = _layer_dir(checkpoint_dir, idx) + if os.path.isdir(d): + shutil.rmtree(d) + os.makedirs(d) + torch.save(weights, os.path.join(d, "weights.pt")) + torch.save(qstate, os.path.join(d, "quantizer_state.pt")) + torch.save(output_meta, os.path.join(d, "output_meta.pt")) + if next_inputs is not None: + torch.save(next_inputs, os.path.join(d, "next_inputs.pt")) + _write_manifest(checkpoint_dir, idx, num_layers) + + +def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None: + """Detect where to resume from an existing checkpoint directory. + + Returns ``(start_layer, manifest)`` if there is work to resume, + or ``None`` if the directory is empty, corrupt, or calibration was already complete. + """ + manifest = _read_manifest(checkpoint_dir) + if manifest is None: + return None + last = manifest.get("last_completed_layer") + total = manifest.get("num_layers") + if last is None or total is None: + return None + if last + 1 >= total: + return None + return (last + 1, manifest) + + +class _CheckpointState: + """Manages checkpoint save and restore for layerwise calibration. + + Handles both saving per-layer checkpoints during calibration and + restoring from a previous partial run. + + .. todo:: + Support distributed checkpoint save/restore for FSDP2: + use ``torch.distributed.checkpoint`` (or save only from rank 0 + barrier) + and broadcast restored state to all ranks during resume. + """ + + def __init__(self, checkpoint_dir: str, num_layers: int, start_layer: int = 0): + if dist.is_initialized() and dist.size() > 1: + raise RuntimeError( + "Layerwise calibration checkpointing is not supported in " + "multi-process distributed jobs (e.g. FSDP2). " + "Use single-process calibration or disable checkpointing." + ) + + self.checkpoint_dir = checkpoint_dir + self.num_layers = num_layers + self.start_layer = start_layer + + @classmethod + def from_folder(cls, checkpoint_dir: str | None, num_layers: int) -> _CheckpointState | None: + """Create from folder. Detects resume point. Returns None if no checkpoint_dir.""" + if not checkpoint_dir: + return None + os.makedirs(checkpoint_dir, exist_ok=True) + info = detect_resume_point(checkpoint_dir) + if info is not None: + manifest_num_layers = info[1].get("num_layers") + if manifest_num_layers is not None and manifest_num_layers != num_layers: + raise ValueError( + f"Checkpoint num_layers mismatch: manifest has {manifest_num_layers} " + f"but model has {num_layers}. Use a fresh checkpoint directory." + ) + start = info[0] if info else 0 + if start > 0: + print_rank_0( + f"Checkpoint: resuming layerwise calibration from layer {start}/{num_layers}" + ) + return cls(checkpoint_dir, num_layers, start_layer=start) + + def setup_resume(self, layers: nn.ModuleList) -> list | None: + """Load output_meta for skip layers 0..K-1, return next_inputs for layer K. + + Sets ``output_meta`` on each already-calibrated layer so that + skip mode can produce correctly shaped dummy outputs. + """ + if self.start_layer == 0: + return None + + last_ckpt = self.start_layer - 1 + + for i in range(self.start_layer): + d = _layer_dir(self.checkpoint_dir, i) + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + meta = torch.load( + os.path.join(d, "output_meta.pt"), map_location="cpu", weights_only=False + ) + layer_device = get_module_device(layers[i]) + meta = _remap_output_metadata_device(meta, layer_device) + layers[i]._layerwise_calib.output_meta = meta + + d = _layer_dir(self.checkpoint_dir, last_ckpt) + next_inputs_path = os.path.join(d, "next_inputs.pt") + if not os.path.isfile(next_inputs_path): + raise FileNotFoundError(f"Cannot resume: next_inputs.pt missing for layer {last_ckpt}") + # weights_only=False is safe: file is internally generated by _save_layer, not user-supplied + next_inputs = torch.load(next_inputs_path, map_location="cpu", weights_only=False) + resume_device = get_module_device(layers[self.start_layer]) + next_inputs = _move_to_device(next_inputs, resume_device) + return next_inputs + + def full_restore(self, layers: nn.ModuleList, model: nn.Module) -> None: + """Restore weights and quantizer state for layers 0..K-1 after the calibration loop.""" + from modelopt.torch.quantization.config import QuantizeConfig + from modelopt.torch.quantization.conversion import restore_quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + if self.start_layer == 0: + return + + dummy_config = QuantizeConfig() + name_to_module = dict(model.named_modules()) + for i in range(self.start_layer): + layer = layers[i] + d = _layer_dir(self.checkpoint_dir, i) + + # Resolve layer_device and load inside the context so params are + # materialized — otherwise get_module_device can return meta. + with enable_weight_access_and_writeback(layer, model, name_to_module): + layer_device = get_module_device(layer) + # weights_only=False is safe: files are internally generated by _save_layer + qstate = torch.load( + os.path.join(d, "quantizer_state.pt"), + map_location=layer_device, + weights_only=False, + ) + weights = torch.load( + os.path.join(d, "weights.pt"), + map_location=layer_device, + weights_only=False, + ) + restore_quantizer_state(layer, dummy_config, {"quantizer_state": qstate}) + layer.load_state_dict(weights, strict=False, assign=True) + + print_rank_0(f"Checkpoint: restored {self.start_layer} previously calibrated layers") + + def save( + self, + layer_idx: int, + layer: nn.Module, + model: nn.Module, + layers: nn.ModuleList, + next_layer_inputs: list | None = None, + ) -> None: + """Snapshot layer state and write checkpoint to disk in one step. + + Args: + layer_idx: Index of the layer just calibrated. + layer: The layer module (weights may be on GPU or managed by accelerate/FSDP2). + model: The full model (needed for ``enable_weight_access_and_writeback``). + layers: The decoder layer list (to read ``output_meta``). + next_layer_inputs: Inputs for the next layer (``None`` for the final layer). + """ + from modelopt.torch.quantization.conversion import quantizer_state + from modelopt.torch.quantization.utils.core_utils import enable_weight_access_and_writeback + + _cpu = torch.device("cpu") + with enable_weight_access_and_writeback(layer, model): + weights = _move_to_device(layer.state_dict(), _cpu) + qstate = _move_to_device(quantizer_state(layer), _cpu) + + output_meta = getattr(layer._layerwise_calib, "output_meta", None) + if output_meta is None: + # Placeholder for the last layer: output_meta is never used for skip mode + # since there is no subsequent layer that needs a correctly shaped dummy output. + output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1)) + + _save_layer( + self.checkpoint_dir, + layer_idx, + weights, + qstate, + _move_to_device(output_meta, _cpu), + _move_to_device(next_layer_inputs, _cpu) if next_layer_inputs is not None else None, + self.num_layers, + ) + suffix = " (final)" if next_layer_inputs is None else "" + print_rank_0(f"Checkpoint: saved layer {layer_idx}{suffix}") diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 1e9a7fbbbd..01cb3abe88 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -601,16 +601,28 @@ def _forward_loop( dataloader: DataLoader containing the batched input data allowed_non_tensor_keys: Set of key names whose values may be non-tensor types """ - with torch.no_grad(): - is_enc_dec = model_type_is_enc_dec(model) - infer_method = model.generate if is_enc_dec else model.forward - max_working_batch_size = None # Initialize max working batch size as None + # Disable KV caching during calibration — it is unnecessary overhead and causes + # correctness issues with hybrid Mamba/attention models whose cache state is mutated + # in-place (e.g., NemotronH). + config = getattr(model, "config", None) + prev_use_cache = getattr(config, "use_cache", None) + if config is not None and prev_use_cache is not None: + config.use_cache = False - for _, data in enumerate(tqdm(dataloader)): - # Process batch and update max working batch size - max_working_batch_size = _process_batch( - data, infer_method, max_working_batch_size, allowed_non_tensor_keys - ) + try: + with torch.no_grad(): + is_enc_dec = model_type_is_enc_dec(model) + infer_method = model.generate if is_enc_dec else model.forward + max_working_batch_size = None # Initialize max working batch size as None + + for _, data in enumerate(tqdm(dataloader)): + # Process batch and update max working batch size + max_working_batch_size = _process_batch( + data, infer_method, max_working_batch_size, allowed_non_tensor_keys + ) + finally: + if config is not None and prev_use_cache is not None: + config.use_cache = prev_use_cache def create_forward_loop( diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index b54332375b..440ca522d1 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -90,12 +90,43 @@ def is_parallel(model: nn.Module) -> bool: return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) +def _get_execution_device_from_hook(module: nn.Module) -> torch.device | None: + """Extract the execution device from an accelerate ``_hf_hook``, if present. + + Handles both ``AlignDevicesHook`` (direct) and ``SequentialHook`` (which + may wrap one or more ``AlignDevicesHook`` instances). Returns ``None`` + when no hook is found or the hook carries no ``execution_device``. + """ + hook = getattr(module, "_hf_hook", None) + if hook is None: + return None + + dev = getattr(hook, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + for h in getattr(hook, "hooks", ()): + dev = getattr(h, "execution_device", None) + if dev is not None: + return torch.device("cuda", dev) if isinstance(dev, int) else torch.device(dev) + + return None + + def get_module_device(module: nn.Module) -> torch.device: - """Get the device of a PyTorch module.""" + """Get the device of a PyTorch module. + + For modules managed by accelerate (``_hf_hook``), returns the hook's + ``execution_device`` which is the authoritative device even when + parameters are offloaded to CPU/meta between forward calls. + """ + hook_device = _get_execution_device_from_hook(module) + if hook_device is not None: + return hook_device + try: return next(module.parameters()).device except StopIteration: - # For modules without parameters return torch.device("cpu") @@ -590,21 +621,29 @@ def get_unwrapped_name(name: str, model: nn.Module | None = None) -> str: @contextmanager def temporarily_remove_accelerate_hook(module): - """Context manager to temporarily remove accelerate hook from a module.""" - accelerate_hook = None - if hasattr(module, "_hf_hook"): - # A module with forward method patched by accelerate - from accelerate.hooks import add_hook_to_module, remove_hook_from_module + """Context manager to temporarily bypass the accelerate hook on a module. + + Swaps ``module.forward`` with the pre-hook forward (``_old_forward``) so + that code inside the context sees the un-hooked forward. On exit the + hook-wrapped forward is restored and ``_old_forward`` is updated to + reflect any changes made inside the context. - accelerate_hook = module._hf_hook - remove_hook_from_module(module) + This avoids ``remove_hook_from_module`` / ``add_hook_to_module`` entirely, + sidestepping ``init_hook`` which would call ``set_module_tensor_to_device`` + and fail when newly-added quantizer modules have weights on the meta device. + """ + hooked_forward = None + cached_old_forward = None + if hasattr(module, "_hf_hook"): + hooked_forward = module.forward + cached_old_forward = module._old_forward + module.forward = cached_old_forward try: yield finally: - if accelerate_hook is not None: - from accelerate.hooks import add_hook_to_module - - add_hook_to_module(module, accelerate_hook) + if hooked_forward is not None: + module._old_forward = module.forward + module.forward = hooked_forward def bind_forward_method( diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml index 6fe4a8c3d1..862929ef34 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yaml @@ -15,7 +15,7 @@ metadata: recipe_type: ptq - description: NVFP4 MLP/MoE weight only (W4A16), FP8 KV cache, max calibration. + description: NVFP4 W4A4, FP8 KV cache, max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml index a62051b659..99098c9d6d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml @@ -15,11 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 weight and activation (W4A4), gptq sequential calibration. + description: NVFP4 weight and activation (W4A4), gptq layerwise calibration. quantize: algorithm: method: gptq - use_sequential: true + layerwise: true + layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - quantizer_name: '*' enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index cc332733a0..4274e40b62 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -15,9 +15,12 @@ metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration. + description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. quantize: - algorithm: max + algorithm: + method: max + layerwise: true + layerwise_checkpoint_dir: output/layerwise_ckpts/ quant_cfg: - quantizer_name: '*' enable: false diff --git a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py index 8ed1039e59..49e74e5851 100644 --- a/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py +++ b/tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import json +import os +import shutil + import pytest import torch -from _test_utils.torch.quantization.quantize_common import INT4_AWQ_CLIP_CFG from _test_utils.torch.transformers_models import create_tiny_llama_dir from accelerate import init_empty_weights, load_checkpoint_and_dispatch from transformers import AutoConfig, AutoModelForCausalLM @@ -25,19 +29,11 @@ enable_weight_access_and_writeback, is_quantized_linear, ) +from modelopt.torch.quantization.utils.layerwise_calib import _layer_dir -@pytest.mark.parametrize( - "quant_cfg", - [ - mtq.INT4_AWQ_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - INT4_AWQ_CLIP_CFG, - mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - mtq.INT8_DEFAULT_CFG, - ], -) -def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): +def test_cpu_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) config = AutoConfig.from_pretrained(tiny_llama_dir) @@ -73,3 +69,575 @@ def test_cpu_offloaded_tinyllama(tmp_path, quant_cfg): assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) assert torch.allclose(output_ref.logits, output_test.logits) + + +def _make_cpu_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to CPU via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +def _make_layerwise_cfg(base_cfg): + """Add layerwise=True to a quant config's algorithm field.""" + cfg = copy.deepcopy(base_cfg) + algo = cfg.get("algorithm", "max") + if isinstance(algo, str): + cfg["algorithm"] = {"method": algo, "layerwise": True} + else: + algo["layerwise"] = True + return cfg + + +def _make_layerwise_checkpoint_cfg(base_cfg, checkpoint_dir): + """Add layerwise=True and layerwise_checkpoint_dir to a quant config's algorithm field.""" + cfg = _make_layerwise_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_cpu_offloaded(tmp_path, use_checkpoint): + """Layerwise calibration on CPU-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model = load_checkpoint_and_dispatch(model, tiny_llama_dir, device_map=device_map) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + if use_checkpoint: + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == num_layers - 1 + assert manifest["num_layers"] == num_layers + + +def test_sequential_checkpoint_resume_cpu_offloaded(tmp_path): + """Resume from a partial checkpoint on a CPU-offloaded model matches a full run.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 by truncating the manifest and removing later layers + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from a fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint should produce identical output to full run" + ) + + +def test_sequential_checkpoint_resume_multi_offload(tmp_path): + """Resume with multiple layers offloaded exercises per-layer device resolution.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_ckpt_cfg = _make_layerwise_checkpoint_cfg(mtq.INT4_AWQ_CFG, ckpt_dir) + + def _make_multi_offload_model(): + with init_empty_weights(): + m = AutoModelForCausalLM.from_config(config) + dmap = { + n: 0 + for n, mod in m.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + dmap["model.layers.0"] = "cpu" + dmap["model.layers.1"] = "cpu" + return load_checkpoint_and_dispatch(m, tiny_llama_dir, device_map=dmap) + + # Full reference run + model_ref = _make_multi_offload_model() + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh model with same offload layout + model_resumed = _make_multi_offload_model() + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "Resumed checkpoint with multi-offload should match full run" + ) + + +def _make_gptq_sequential_cfg(base_cfg): + """Create a sequential GPTQ config from a base quantization config.""" + cfg = copy.deepcopy(base_cfg) + cfg["algorithm"] = {"method": "gptq", "layerwise": True} + return cfg + + +def _make_gptq_sequential_checkpoint_cfg(base_cfg, checkpoint_dir): + """Create a sequential GPTQ config with checkpoint dir.""" + cfg = _make_gptq_sequential_cfg(base_cfg) + cfg["algorithm"]["layerwise_checkpoint_dir"] = checkpoint_dir + return cfg + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_sequential_gptq_cpu_offloaded(tmp_path, use_checkpoint): + """Sequential GPTQ (weight-modifying) on CPU-offloaded model matches GPU-only reference.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) + else: + seq_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) + + # Reference: GPU-only model + ref_cfg = _make_gptq_sequential_cfg(mtq.NVFP4_AWQ_LITE_CFG) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: CPU-offloaded model + model, _, _, _ = _make_cpu_offloaded_model(tmp_path / "offloaded", num_hidden_layers=num_layers) + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_sequential_gptq_checkpoint_resume_cpu_offloaded(tmp_path): + """GPTQ checkpoint resume with CPU offloading restores modified weights correctly.""" + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + ckpt_dir = str(tmp_path / "gptq_ckpt") + seq_ckpt_cfg = _make_gptq_sequential_checkpoint_cfg(mtq.NVFP4_AWQ_LITE_CFG, ckpt_dir) + + # Full reference run with checkpointing + with init_empty_weights(): + model_ref = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_ref.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_ref = load_checkpoint_and_dispatch(model_ref, tiny_llama_dir, device_map=device_map) + mtq.quantize(model_ref, seq_ckpt_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Simulate crash after layer 0 + last_completed_layer = 0 + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": last_completed_layer, "num_layers": num_layers}, f) + for i in range(last_completed_layer + 1, num_layers): + d = _layer_dir(ckpt_dir, i) + if os.path.isdir(d): + shutil.rmtree(d) + + # Resume from fresh CPU-offloaded model + with init_empty_weights(): + model_resumed = AutoModelForCausalLM.from_config(config) + device_map = { + n: 0 + for n, m in model_resumed.named_modules() + if "layers" not in n or n.split("layers.")[-1].isdigit() + } + device_map["model.layers.0"] = "cpu" + model_resumed = load_checkpoint_and_dispatch( + model_resumed, tiny_llama_dir, device_map=device_map + ) + mtq.quantize(model_resumed, seq_ckpt_cfg, lambda model: model(inputs)) + output_resumed = model_resumed(inputs) + + assert torch.allclose(output_ref.logits, output_resumed.logits), ( + "GPTQ resumed checkpoint should produce identical output to full run" + ) + + +class _TupleReturningBlock(torch.nn.Module): + """Decoder layer that returns a tuple, mimicking HuggingFace decoder layers.""" + + def __init__(self, dim=16): + super().__init__() + self.linear = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return (self.linear(x), None) + + +class _TupleUnpackingModel(torch.nn.Module): + """Parent model that unpacks layer outputs as tuples.""" + + def __init__(self, n_layers=4, dim=16): + super().__init__() + self.layers = torch.nn.ModuleList([_TupleReturningBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x, _ = layer(x) + return x + + +def test_skip_dummy_has_no_hf_hook(monkeypatch): + """Dummies must not carry _hf_hook from the original layer.""" + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + from modelopt.torch.quantization.utils.layerwise_calib import ( + LayerActivationCollector, + _SkipLayer, + ) + + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + model = _TupleUnpackingModel(n_layers=4, dim=16) + data = [torch.randn(2, 16)] + + for layer in model.layers: + hook = AlignDevicesHook(execution_device=torch.device("cpu")) + add_hook_to_module(layer, hook) + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + for i in range(2): + dummy = model.layers[i] + assert isinstance(dummy, _SkipLayer) + assert not hasattr(dummy, "_hf_hook"), f"Dummy at {i} should not have _hf_hook" + finally: + collector._unpatch_all_layers() + + +def test_persistent_materialization_cpu_offloaded(tmp_path): + """persistent_materialization keeps CPU-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_cpu_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) + + +def _make_disk_offload_device_map(model): + """Build a device_map with layer 0 on disk, everything else on GPU 0. + + Ancestor modules (``""`` and ``"model"``) are excluded so that + ``dispatch_model`` does not attach a ``place_submodules=True`` hook that + would try to move disk-offloaded meta tensors to GPU (which fails because + no ``value`` is available — unlike CPU offload where weights are on CPU and + can be moved directly). + """ + device_map = { + n: 0 + for n, m in model.named_modules() + if n not in ("", "model") and ("layers" not in n or n.split("layers.")[-1].isdigit()) + } + device_map["model.layers.0"] = "disk" + return device_map + + +def _make_disk_offloaded_model(tmp_path, num_hidden_layers=3): + """Create a tiny LLaMA model with layer 0 offloaded to disk via accelerate.""" + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_hidden_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + return model, config, tiny_llama_dir, inputs + + +def test_disk_offloaded_tinyllama(tmp_path): + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=2) + + config = AutoConfig.from_pretrained(tiny_llama_dir) + + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + inputs = torch.randint(0, model_ref.config.vocab_size, (1, 4)).cuda() + + mtq.quantize(model_ref, quant_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + assert all(p.device == torch.device("meta") for p in model.model.layers[0].parameters()) + + mtq.quantize(model, quant_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight) + + assert torch.allclose(output_ref.logits, output_test.logits) + + +def test_persistent_materialization_disk_offloaded(tmp_path): + """persistent_materialization keeps disk-offloaded weights on GPU and writes back modifications.""" + import torch.nn as nn + from accelerate.hooks import AlignDevicesHook + + from modelopt.torch.quantization.utils import persistent_materialization + + model, config, _, inputs = _make_disk_offloaded_model(tmp_path) + offloaded_layer = model.model.layers[0] + + # Verify offloaded (meta device) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Save reference weight + linear = None + with enable_weight_access_and_writeback(offloaded_layer, model): + linear = next(m for m in offloaded_layer.modules() if isinstance(m, nn.Linear)) + ref_weight = linear.weight.clone() + + with persistent_materialization(offloaded_layer): + # Params materialized on GPU + assert all( + p.device.type == "cuda" for p in offloaded_layer.parameters() if p.device.type != "meta" + ) + + # Run multiple forward passes (hooks don't re-offload) + for _ in range(3): + model(inputs) + + # Modify a weight + linear.weight.data.add_(1.0) + + # Verify hooks have offload=False during context + for mod in offloaded_layer.modules(): + if hasattr(mod, "_hf_hook"): + hook = mod._hf_hook + if isinstance(hook, AlignDevicesHook): + assert not hook.offload + + # After context: back to meta device (offloaded) + assert all(p.device.type == "meta" for p in offloaded_layer.parameters()) + + # Verify weight modification persisted through writeback + with enable_weight_access_and_writeback(offloaded_layer, model): + assert torch.allclose(linear.weight, ref_weight + 1.0) + + +@pytest.mark.parametrize("use_checkpoint", [False, True], ids=["no_ckpt", "ckpt"]) +def test_layerwise_calibrate_disk_offloaded(tmp_path, use_checkpoint): + """Layerwise calibration on disk-offloaded model matches GPU-only reference.""" + quant_cfg = mtq.NVFP4_AWQ_LITE_CFG + num_layers = 3 + tiny_llama_dir = create_tiny_llama_dir(tmp_path, num_hidden_layers=num_layers) + config = AutoConfig.from_pretrained(tiny_llama_dir) + inputs = torch.randint(0, config.vocab_size, (1, 4)).cuda() + + if use_checkpoint: + ckpt_dir = str(tmp_path / "seq_ckpt") + seq_cfg = _make_layerwise_checkpoint_cfg(quant_cfg, ckpt_dir) + else: + seq_cfg = _make_layerwise_cfg(quant_cfg) + + # Reference: GPU-only model with layerwise calibration + ref_cfg = _make_layerwise_cfg(quant_cfg) + model_ref = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, torch_dtype=config.torch_dtype + ).cuda() + mtq.quantize(model_ref, ref_cfg, lambda model: model(inputs)) + output_ref = model_ref(inputs) + + # Test: disk-offloaded model + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + device_map = _make_disk_offload_device_map(model) + offload_dir = str(tmp_path / "offload") + model = load_checkpoint_and_dispatch( + model, tiny_llama_dir, device_map=device_map, offload_folder=offload_dir + ) + + mtq.quantize(model, seq_cfg, lambda model: model(inputs)) + output_test = model(inputs) + + for name, module in model.named_modules(): + if is_quantized_linear(module): + with enable_weight_access_and_writeback(module, model): + assert torch.allclose(module.weight, model_ref.get_submodule(name).weight), ( + f"Weight mismatch at {name}" + ) + + assert torch.allclose(output_ref.logits, output_test.logits) diff --git a/tests/gpu/torch/quantization/test_fsdp2.py b/tests/gpu/torch/quantization/test_fsdp2.py index 4889b6dc8c..c5584ece5c 100644 --- a/tests/gpu/torch/quantization/test_fsdp2.py +++ b/tests/gpu/torch/quantization/test_fsdp2.py @@ -128,3 +128,136 @@ def test_fsdp_simple_linear(dist_workers): ) def test_nested_fsdp2_backward(quant_cfg, dist_workers): dist_workers.run(partial(_test_nested_fsdp2_backward, quant_cfg=quant_cfg)) + + +class _DecoderBlock(nn.Module): + """Minimal decoder block for FSDP2 sequential tests.""" + + def __init__(self, dim=32): + super().__init__() + self.attn = nn.Linear(dim, dim, bias=False) + self.ffn = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.ReLU(), nn.Linear(dim, dim, bias=False) + ) + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + x = x + self.attn(self.norm(x)) + x = x + self.ffn(x) + return x + + +class _SimpleTransformerModel(nn.Module): + """Model with ``model.layers`` for layerwise calibration discovery.""" + + def __init__(self, n_layers=3, dim=32): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def _test_layerwise_calibrate_fsdp2(rank, size): + """Layerwise calibration on FSDP2-wrapped model matches non-FSDP reference.""" + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector + + dim = 32 + torch.manual_seed(1) + model = _SimpleTransformerModel(n_layers=3, dim=dim).cuda() + inputs = torch.randn(2, 2, dim).cuda() + synchronize_state_dict(model) + + # Register discoverer for our simple model + old_support = LayerActivationCollector._decoder_layer_support[:] + LayerActivationCollector._decoder_layer_support = [ + ( + lambda m: hasattr(m, "layers") and isinstance(m.layers, nn.ModuleList), + lambda m: m.layers, + ), + *old_support, + ] + + try: + # Reference: non-FSDP layerwise calibration + ref_model = copy.deepcopy(model) + seq_cfg = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + seq_cfg["algorithm"] = {"method": "max", "layerwise": True} + mtq.quantize(ref_model, seq_cfg, lambda m: m(inputs)) + output_ref = ref_model(inputs) + + # Test: FSDP2-wrapped layerwise calibration + for layer in model.layers: + fully_shard(layer) + model = fully_shard(model) + mtq.quantize(model, seq_cfg, lambda m: m(inputs)) + output_test = model(inputs) + + assert torch.allclose(output_ref, output_test) + finally: + LayerActivationCollector._decoder_layer_support = old_support + + +def test_layerwise_calibrate_fsdp2(dist_workers): + dist_workers.run(_test_layerwise_calibrate_fsdp2) + + +def _test_persistent_materialization(rank, size): + """persistent_materialization keeps weights accessible and writes back modifications.""" + from torch.distributed.tensor import DTensor + + from modelopt.torch.quantization.utils import ( + enable_weight_access_and_writeback, + persistent_materialization, + ) + + dim = 32 + torch.manual_seed(1) + model = nn.Sequential( + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)), + ).cuda(rank) + synchronize_state_dict(model) + + fully_shard(model[0]) + fully_shard(model[1]) + model = fully_shard(model) + + layer = model[0] + inputs = torch.randn(2, dim).cuda(rank) + + # Warmup forward to trigger FSDP2's lazy_init (mirrors real usage where + # layerwise_calibrate always runs get_first_layer_inputs first). + model(inputs) + + # Save reference weight (gathered) + with enable_weight_access_and_writeback(layer[0], model): + ref_weight = layer[0].weight.clone() + + # Verify sharded before context + assert isinstance(next(iter(layer.parameters())), DTensor) + + with persistent_materialization(layer): + # Params are local tensors (not DTensors) + assert not isinstance(layer[0].weight, DTensor) + assert layer[0].weight.device.type == "cuda" + + # Run multiple forward passes (FSDP hooks fire, unshard/reshard are no-ops) + for _ in range(3): + layer(inputs) + + # Modify a weight + layer[0].weight.data.add_(1.0) + + # After context: params restored to DTensors (sharded) + assert isinstance(next(iter(layer.parameters())), DTensor) + + # Verify modification persisted + with enable_weight_access_and_writeback(layer[0], model): + assert torch.allclose(layer[0].weight, ref_weight + 1.0) + + +def test_persistent_materialization(dist_workers): + dist_workers.run(_test_persistent_materialization) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py index d183855abb..2d5f9d6d70 100644 --- a/tests/gpu/torch/quantization/test_gptq.py +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -219,7 +219,7 @@ def test_gptq_e2e_flow(quant_cfg): model.eval() quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["algorithm"] = {"method": "gptq", "use_sequential": True} + quant_cfg["algorithm"] = {"method": "gptq", "layerwise": True} calib_dataloader = get_dataset_dataloader( dataset_name="cnn_dailymail", tokenizer=tokenizer, diff --git a/tests/gpu/torch/quantization/test_sequential_calibrate.py b/tests/gpu/torch/quantization/test_layerwise_calibrate.py similarity index 90% rename from tests/gpu/torch/quantization/test_sequential_calibrate.py rename to tests/gpu/torch/quantization/test_layerwise_calibrate.py index ba71e896c7..d38b82f46f 100644 --- a/tests/gpu/torch/quantization/test_sequential_calibrate.py +++ b/tests/gpu/torch/quantization/test_layerwise_calibrate.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for sequential_calibrate and LayerActivationCollector.""" +"""Integration tests for layerwise_calibrate and LayerActivationCollector.""" import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _DecoderBlock(nn.Module): @@ -101,7 +101,7 @@ def _register_test_discoverer(monkeypatch): ) -def test_seq_calib_func_called_per_layer(monkeypatch): +def test_layerwise_calib_func_called_per_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=4) call_count = [0] @@ -109,7 +109,7 @@ def test_seq_calib_func_called_per_layer(monkeypatch): def counting_calib(layer, forward_loop, **kwargs): call_count[0] += 1 - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=counting_calib, @@ -118,7 +118,7 @@ def counting_calib(layer, forward_loop, **kwargs): assert call_count[0] == 4 -def test_seq_calib_func_receives_correct_layer(monkeypatch): +def test_layerwise_calib_func_receives_correct_layer(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) called_layers = [] @@ -126,7 +126,7 @@ def test_seq_calib_func_receives_correct_layer(monkeypatch): def track_layers(layer, forward_loop, **kwargs): called_layers.append(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=track_layers, @@ -136,7 +136,7 @@ def track_layers(layer, forward_loop, **kwargs): assert called_layers[i] is layer -def test_seq_calib_kwargs_forwarded(monkeypatch): +def test_layerwise_calib_kwargs_forwarded(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) received_kwargs = [] @@ -144,7 +144,7 @@ def test_seq_calib_kwargs_forwarded(monkeypatch): def capture_kwargs(layer, forward_loop, **kwargs): received_kwargs.append(kwargs) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=capture_kwargs, @@ -158,7 +158,7 @@ def capture_kwargs(layer, forward_loop, **kwargs): assert kw["method"] == "max" -def test_seq_calib_layer_forward_loop_runs_all_batches(monkeypatch): +def test_layerwise_calib_layer_forward_loop_runs_all_batches(monkeypatch): """The per-layer forward loop passed to calib_func should replay all batches.""" _register_test_discoverer(monkeypatch) n_batches = 5 @@ -178,7 +178,7 @@ def counting_forward(*args, **kw): layer.forward = orig_forward batch_counts.append(counter["n"]) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=count_batches, @@ -188,13 +188,13 @@ def counting_forward(*args, **kw): assert count == n_batches -def test_seq_calib_does_not_alter_weights(monkeypatch): - """sequential_calibrate itself should not modify model weights.""" +def test_layerwise_calib_does_not_alter_weights(monkeypatch): + """layerwise_calibrate itself should not modify model weights.""" _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=3) weights_before = {n: p.clone() for n, p in model.named_parameters()} - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: _run_forward(m, data), calib_func=lambda layer, forward_loop, **kw: None, @@ -204,7 +204,7 @@ def test_seq_calib_does_not_alter_weights(monkeypatch): assert torch.equal(p, weights_before[n]), f"Weight {n} was modified" -def test_seq_calib_activations_update_across_layers(monkeypatch): +def test_layerwise_calib_activations_update_across_layers(monkeypatch): """Subsequent layers should see activations transformed by prior layers.""" _register_test_discoverer(monkeypatch) torch.manual_seed(0) @@ -228,7 +228,7 @@ def capture_forward(*args, **kw): layer_idx = list(model.layers).index(layer) layer_inputs_record[layer_idx] = activations - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: [m(t) for t in tokens], calib_func=record_inputs, @@ -240,7 +240,7 @@ def capture_forward(*args, **kw): def test_mode_transitions_across_calibration_steps(monkeypatch): - """Verify layer modes after each sequential calibration step. + """Verify layer modes after each layerwise calibration step. After get_input_activations(layers[i]) returns, the current layer is reset to 'original'. Layers further back are left in 'run' (just calibrated) or @@ -259,7 +259,7 @@ def forward_loop(m): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector.get_input_activations(model.layers[0], forward_loop) assert modes() == ["original", "original", "original", "original", "original"] @@ -316,7 +316,7 @@ def weight_doubling_calib(layer, layer_forward_loop, **kwargs): layer.weight.mul_(2.0) layer_forward_loop(layer) - sequential_calibrate( + layerwise_calibrate( model, forward_loop=forward_loop, calib_func=weight_doubling_calib, diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 692ab07d4a..ae638c42ee 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -35,7 +35,7 @@ get_homogeneous_hf_decoder_layers, is_homogeneous_hf_model, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector pytest.importorskip("transformers") diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index b3c372eb33..d2e6fdd03e 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -27,7 +27,7 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, - sequential_calibrate, + layerwise_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -379,7 +379,7 @@ def test_svdquant_lora_weights(): assert lora_residual.shape == module.weight.shape -def test_sequential_calibrate_support_gate(): +def test_layerwise_calibrate_support_gate(): class _UnsupportedModel(nn.Module): def __init__(self): super().__init__() @@ -392,17 +392,17 @@ def forward(self, x): with ( torch.no_grad(), - pytest.raises(ValueError, match="Sequential calibration requires a model"), + pytest.raises(ValueError, match="Layerwise calibration requires a model"), ): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 4)), calib_func=lambda layer, loop: loop(layer), ) -def test_sequential_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +def test_layerwise_calibrate_propagates_inputs_without_replaying_full_model(monkeypatch): + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float, bias: float): @@ -463,7 +463,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert forward_loop_calls == len(model.layers) assert len(observed_layer_inputs) == len(model.layers) @@ -482,9 +482,9 @@ def _pre_hook(_module, args): assert torch.allclose(observed, expected) -def test_sequential_calibrate_handles_inter_layer_logic(monkeypatch): +def test_layerwise_calibrate_handles_inter_layer_logic(monkeypatch): """Verify that parent-level inter-layer logic (e.g. mask selection) works correctly.""" - from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector + from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector class _ToyLayer(nn.Module): def __init__(self, scale: float): @@ -537,7 +537,7 @@ def _pre_hook(_module, args): handle.remove() observed_layer_inputs.append(captured) - sequential_calibrate(model, _forward_loop, _calib_func) + layerwise_calibrate(model, _forward_loop, _calib_func) assert len(observed_layer_inputs) == 3 # Layer 0 gets raw batch diff --git a/tests/unit/torch/quantization/test_sequential_calibrate.py b/tests/unit/torch/quantization/test_layerwise_calibrate.py similarity index 78% rename from tests/unit/torch/quantization/test_sequential_calibrate.py rename to tests/unit/torch/quantization/test_layerwise_calibrate.py index 14c1903de2..6596c1b4b1 100644 --- a/tests/unit/torch/quantization/test_sequential_calibrate.py +++ b/tests/unit/torch/quantization/test_layerwise_calibrate.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for sequential_calibrate and LayerActivationCollector.""" +"""Unit tests for layerwise_calibrate and LayerActivationCollector.""" from collections import deque @@ -21,8 +21,8 @@ import torch import torch.nn as nn -from modelopt.torch.quantization.model_calib import sequential_calibrate -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector, _SkipLayer class _DecoderBlock(nn.Module): @@ -60,7 +60,7 @@ def forward(self, x, **kwargs): class _FlatMLP(nn.Module): - """No decoder-layer structure -- should be rejected by sequential_calibrate.""" + """No decoder-layer structure -- should be rejected by layerwise_calibrate.""" def __init__(self, dim=16): super().__init__() @@ -180,7 +180,7 @@ def forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") assert not hasattr(model.layers[0], "_original_forward") @@ -201,38 +201,38 @@ def bad_forward_loop(m): collector._unpatch_all_layers() assert not hasattr(model, "_original_forward") - assert not hasattr(model.layers[0], "_seq_calib") + assert not hasattr(model.layers[0], "_layerwise_calib") -# sequential_calibrate tests -def test_seq_calib_raises_on_none_forward_loop(monkeypatch): +# layerwise_calibrate tests +def test_layerwise_calib_raises_on_none_forward_loop(monkeypatch): _register_test_discoverer(monkeypatch) model, data = _make_model_and_data(n_layers=2) with pytest.raises(ValueError, match="forward_loop must not be None"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=None, calib_func=lambda *a, **kw: None, ) -def test_seq_calib_raises_on_unrecognized_model(): +def test_layerwise_calib_raises_on_unrecognized_model(): model = _FlatMLP() with pytest.raises(ValueError, match="Could not find transformer layers"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: m(torch.randn(2, 16)), calib_func=lambda *a, **kw: None, ) -def test_seq_calib_empty_forward_loop_raises(monkeypatch): - """If forward_loop feeds no data, sequential_calibrate raises RuntimeError.""" +def test_layerwise_calib_empty_forward_loop_raises(monkeypatch): + """If forward_loop feeds no data, layerwise_calibrate raises RuntimeError.""" _register_test_discoverer(monkeypatch) model = _SimpleTransformerModel(n_layers=2, dim=16) with pytest.raises(RuntimeError, match="collected no inputs during forward_loop"): - sequential_calibrate( + layerwise_calibrate( model, forward_loop=lambda m: None, calib_func=lambda *a, **kw: None, @@ -344,11 +344,11 @@ def forward_loop(m): try: # Layer 0 starts as capture — no output_meta yet collector.get_input_activations(model.layers[0], forward_loop) - assert model.layers[0]._seq_calib.output_meta is None + assert model.layers[0]._layerwise_calib.output_meta is None # Calibrating layer 1 puts layer 0 into run, which sets output_meta collector.get_input_activations(model.layers[1], forward_loop) - meta = model.layers[0]._seq_calib.output_meta + meta = model.layers[0]._layerwise_calib.output_meta assert meta is not None assert meta[0] == "tuple", "Tuple-returning layer should produce tuple metadata" finally: @@ -375,11 +375,11 @@ def forward_loop(m): # Before calibrating layer 2, layer 1 transitions to run. # Its cached_inputs should be populated from collected_inputs. collector._set_layer_states(2) - assert len(model.layers[1]._seq_calib.cached_inputs) == n_batches + assert len(model.layers[1]._layerwise_calib.cached_inputs) == n_batches # After the forward loop, all cached inputs should be consumed forward_loop(model) - assert len(model.layers[1]._seq_calib.cached_inputs) == 0 + assert len(model.layers[1]._layerwise_calib.cached_inputs) == 0 finally: collector._unpatch_all_layers() @@ -399,24 +399,24 @@ def test_set_layer_states_transitions(monkeypatch): try: def modes(): - return [model.layers[i]._seq_calib.mode for i in range(5)] + return [model.layers[i]._layerwise_calib.mode for i in range(5)] collector._set_layer_states(0) assert modes() == ["capture", "original", "original", "original", "original"] - model.layers[0]._seq_calib.collected_inputs = [fake_inp] + model.layers[0]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(1) assert modes() == ["run", "capture", "original", "original", "original"] - model.layers[1]._seq_calib.collected_inputs = [fake_inp] + model.layers[1]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(2) assert modes() == ["skip", "run", "capture", "original", "original"] - model.layers[2]._seq_calib.collected_inputs = [fake_inp] + model.layers[2]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(3) assert modes() == ["skip", "skip", "run", "capture", "original"] - model.layers[3]._seq_calib.collected_inputs = [fake_inp] + model.layers[3]._layerwise_calib.collected_inputs = [fake_inp] collector._set_layer_states(4) assert modes() == ["skip", "skip", "skip", "run", "capture"] finally: @@ -446,8 +446,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector = LayerActivationCollector(model) collector._patch_all_layers() try: - model.layers[0]._seq_calib.mode = "run" - model.layers[0]._seq_calib.cached_inputs = deque() + model.layers[0]._layerwise_calib.mode = "run" + model.layers[0]._layerwise_calib.cached_inputs = deque() with pytest.raises(AssertionError, match="no cached inputs to replay"): model(torch.randn(2, 16)) @@ -455,8 +455,8 @@ def test_run_asserts_on_empty_cached_inputs(monkeypatch): collector._unpatch_all_layers() -def test_cleanup_removes_seq_calib_attr(monkeypatch): - """After unpatch, no layer should have the _seq_calib attribute.""" +def test_cleanup_removes_layerwise_calib_attr(monkeypatch): + """After unpatch, no layer should have the _layerwise_calib attribute.""" _register_test_discoverer(monkeypatch) model = _TupleUnpackingModel(n_layers=3, dim=16) data = [torch.randn(2, 16)] @@ -472,7 +472,9 @@ def forward_loop(m): collector._unpatch_all_layers() for i, layer in enumerate(model.layers): - assert not hasattr(layer, "_seq_calib"), f"Layer {i} still has _seq_calib after cleanup" + assert not hasattr(layer, "_layerwise_calib"), ( + f"Layer {i} still has _layerwise_calib after cleanup" + ) assert not hasattr(layer, "_original_forward"), ( f"Layer {i} still has _original_forward after cleanup" ) @@ -517,15 +519,17 @@ def forward_loop(m): for d in data: m(d) + originals = list(model.layers) collector = LayerActivationCollector(model) collector._patch_all_layers() try: - for layer in model.layers: + for layer in originals: collector.get_input_activations(layer, forward_loop) - # After full calibration, layers 0 and 1 have been through 'run' and have output_meta - meta_0 = model.layers[0]._seq_calib.output_meta - meta_1 = model.layers[1]._seq_calib.output_meta + # After full calibration, layers 0 and 1 have been through 'run' and have output_meta. + # Access via originals since skip-position entries are now _SkipLayer dummies. + meta_0 = originals[0]._layerwise_calib.output_meta + meta_1 = originals[1]._layerwise_calib.output_meta assert meta_0 is not None assert meta_1 is not None # SmallBlock returns 3-element tuple, BigBlock returns 1-element tuple @@ -533,3 +537,59 @@ def forward_loop(m): assert len(meta_1[1]) == 1 finally: collector._unpatch_all_layers() + + +# --------------------------------------------------------------------------- +# _SkipLayer swap / restore tests +# --------------------------------------------------------------------------- + + +def test_skip_layers_replaced_with_dummy(monkeypatch): + """After calibrating enough layers, skip-position entries must be _SkipLayer with no params.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + data = [torch.randn(2, 16) for _ in range(2)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + try: + for layer in list(model.layers): + collector.get_input_activations(layer, forward_loop) + + # Layers 0..2 should be dummies (swapped when calibrating layers 2..4) + for i in range(3): + assert isinstance(model.layers[i], _SkipLayer), f"Layer {i} should be _SkipLayer" + assert list(model.layers[i].parameters()) == [], ( + f"Layer {i} dummy should have no params" + ) + # Layers 3 (run) and 4 (original) remain real + for i in range(3, 5): + assert not isinstance(model.layers[i], _SkipLayer), f"Layer {i} should still be real" + finally: + collector._unpatch_all_layers() + + +def test_cleanup_restores_original_layers(monkeypatch): + """After _unpatch_all_layers, all ModuleList entries must be the original modules.""" + _register_test_discoverer(monkeypatch) + model = _TupleUnpackingModel(n_layers=5, dim=16) + originals = list(model.layers) + data = [torch.randn(2, 16)] + + def forward_loop(m): + for d in data: + m(d) + + collector = LayerActivationCollector(model) + collector._patch_all_layers() + for layer in originals: + collector.get_input_activations(layer, forward_loop) + collector._unpatch_all_layers() + + for i, orig in enumerate(originals): + assert model.layers[i] is orig, f"Layer {i} not restored to original after cleanup" + assert not hasattr(orig, "_layerwise_calib"), f"Layer {i} still has _layerwise_calib" diff --git a/tests/unit/torch/quantization/test_sequential_checkpoint.py b/tests/unit/torch/quantization/test_sequential_checkpoint.py new file mode 100644 index 0000000000..0e592a68c7 --- /dev/null +++ b/tests/unit/torch/quantization/test_sequential_checkpoint.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for layerwise calibration checkpoint save/resume.""" + +import json +import os +from types import SimpleNamespace + +import torch +import torch.nn as nn + +from modelopt.torch.quantization.model_calib import layerwise_calibrate +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector +from modelopt.torch.utils.network import get_module_device + + +class _DecoderBlock(nn.Module): + def __init__(self, dim=16): + super().__init__() + self.linear = nn.Linear(dim, dim, bias=False) + + def forward(self, x, **kwargs): + return self.linear(x) + + +class _SimpleTransformerModel(nn.Module): + def __init__(self, n_layers=3, dim=16): + super().__init__() + self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)]) + self.embed = nn.Embedding(32, dim) + + def forward(self, x, **kwargs): + x = self.embed(x) + for layer in self.layers: + x = layer(x) + return x + + +def _register_test_discoverer(monkeypatch): + monkeypatch.setattr( + LayerActivationCollector, + "_decoder_layer_support", + [(lambda m: hasattr(m, "layers"), lambda m: m.layers)], + ) + + +def _dummy_calib_func(layer, forward_loop, **kwargs): + """Scale all weights by 0.5 to produce a visible, deterministic change.""" + forward_loop(layer) + with torch.no_grad(): + for p in layer.parameters(): + p.mul_(0.5) + + +def _make_model_and_forward(n_layers=3, dim=16, seed=42): + torch.manual_seed(seed) + model = _SimpleTransformerModel(n_layers=n_layers, dim=dim) + tokens = [torch.randint(0, 32, (2, 8)) for _ in range(2)] + + def forward_loop(m): + for t in tokens: + m(t) + + return model, forward_loop + + +def test_full_run_creates_checkpoints(monkeypatch, tmp_path): + """layerwise_calibrate with checkpoint_dir creates correct layer dirs and manifest.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + ckpt_dir = str(tmp_path / "ckpt") + + layerwise_calibrate(model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + manifest_path = os.path.join(ckpt_dir, "manifest.json") + assert os.path.isfile(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert manifest["last_completed_layer"] == 2 + assert manifest["num_layers"] == 3 + + for i in range(3): + layer_dir = os.path.join(ckpt_dir, f"layer_{i:04d}") + assert os.path.isdir(layer_dir) + assert os.path.isfile(os.path.join(layer_dir, "weights.pt")) + assert os.path.isfile(os.path.join(layer_dir, "quantizer_state.pt")) + assert os.path.isfile(os.path.join(layer_dir, "output_meta.pt")) + # All layers except the last should have next_inputs + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0000", "next_inputs.pt")) + assert os.path.isfile(os.path.join(ckpt_dir, "layer_0001", "next_inputs.pt")) + assert not os.path.isfile(os.path.join(ckpt_dir, "layer_0002", "next_inputs.pt")) + + +def test_resume_matches_full_run(monkeypatch, tmp_path): + """Resume from a truncated checkpoint produces the same final weights as a full run.""" + _register_test_discoverer(monkeypatch) + ckpt_dir = str(tmp_path / "ckpt") + + # Full reference run + ref_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(ref_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + ref_weights = {n: p.clone() for n, p in ref_model.named_parameters()} + + # Simulate crash after layer 0: truncate manifest + manifest_path = os.path.join(ckpt_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump({"last_completed_layer": 0, "num_layers": 3}, f) + + # Resume from a fresh model + resumed_model, forward_loop = _make_model_and_forward(n_layers=3) + layerwise_calibrate(resumed_model, forward_loop, _dummy_calib_func, checkpoint_dir=ckpt_dir) + + for name, ref_param in ref_weights.items(): + resumed_param = dict(resumed_model.named_parameters())[name] + assert torch.allclose(ref_param, resumed_param, atol=1e-6), ( + f"Parameter {name} diverged after resume" + ) + + +def test_no_checkpoint_unchanged(monkeypatch): + """Without checkpoint_dir, calibration still works and modifies parameters.""" + _register_test_discoverer(monkeypatch) + model, forward_loop = _make_model_and_forward(n_layers=3) + original_weights = {n: p.clone() for n, p in model.named_parameters()} + + layerwise_calibrate(model, forward_loop, _dummy_calib_func) + + changed = False + for name, param in model.named_parameters(): + if not torch.allclose(original_weights[name], param): + changed = True + break + assert changed, "Expected calibration to modify at least one parameter" + + +# --------------------------------------------------------------------------- +# get_module_device tests +# --------------------------------------------------------------------------- + + +def test_get_module_device_no_hook(): + """Falls back to parameter device when no _hf_hook is present.""" + layer = nn.Linear(4, 4) + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_with_direct_hook(): + """Returns execution_device from a direct AlignDevicesHook-style hook.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace(execution_device=torch.device("cuda:0")) + assert get_module_device(layer) == torch.device("cuda:0") + + +def test_get_module_device_with_sequential_hook(): + """Returns execution_device from an AlignDevicesHook wrapped in SequentialHook.""" + layer = nn.Linear(4, 4) + inner_hook = SimpleNamespace(execution_device=torch.device("cuda:1")) + layer._hf_hook = SimpleNamespace(hooks=[inner_hook]) + assert get_module_device(layer) == torch.device("cuda:1") + + +def test_get_module_device_hook_without_execution_device(): + """Falls back to parameters when hook has no execution_device.""" + layer = nn.Linear(4, 4) + layer._hf_hook = SimpleNamespace() + assert get_module_device(layer) == torch.device("cpu") + + +def test_get_module_device_parameterless_module(): + """Returns cpu for a module with no parameters and no hook.""" + module = nn.Module() + assert get_module_device(module) == torch.device("cpu") diff --git a/tests/unit/torch/quantization/test_utils.py b/tests/unit/torch/quantization/test_utils.py index 92fe1345f9..73d3423ba5 100644 --- a/tests/unit/torch/quantization/test_utils.py +++ b/tests/unit/torch/quantization/test_utils.py @@ -20,7 +20,7 @@ convert_quantization_axis_to_reduce_axis, reduce_block_amax, ) -from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector +from modelopt.torch.quantization.utils.layerwise_calib import LayerActivationCollector @pytest.mark.parametrize(