diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 7b93e0a134b8..4536e00343d6 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -34,6 +34,7 @@ Transpose, WeightConverter, WeightRenaming, + rename_source_key, ) from ..utils import ( CONFIG_NAME, @@ -47,7 +48,7 @@ logging, ) from ..utils.hub import DownloadKwargs -from ..utils.loading_report import log_state_dict_report +from ..utils.loading_report import LoadStateDictInfo, log_state_dict_report if is_torch_available(): @@ -428,6 +429,45 @@ class PeftAdapterMixin: _prepare_peft_hotswap_kwargs: dict | None = None peft_config: dict[str, PeftConfigLike] + def _resolve_adapter_state_dict( + self, adapter_state_dict: dict[str, "torch.Tensor"] | None, checkpoint_files + ) -> dict[str, torch.Tensor]: + # Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths + # that bypass `self._load_pretrained_model` (which would otherwise read the files itself). + from ..modeling_utils import load_state_dict + + all_pointer = set() + if adapter_state_dict is not None: + merged_state_dict = adapter_state_dict + elif ( + checkpoint_files is not None + and checkpoint_files[0].endswith(".safetensors") + and adapter_state_dict is None + ): + merged_state_dict = {} + for file in checkpoint_files: + file_pointer = safe_open(file, framework="pt", device="cpu") + all_pointer.add(file_pointer) + for k in file_pointer.keys(): + merged_state_dict[k] = file_pointer.get_tensor(k) + # Checkpoints are .bin + elif checkpoint_files is not None: + merged_state_dict = {} + for ckpt_file in checkpoint_files: + merged_state_dict.update(load_state_dict(ckpt_file)) + else: + raise ValueError("Neither a state dict nor checkpoint files were found.") + + return merged_state_dict + + def _set_peft_inference_mode(self) -> None: + from peft.tuners.tuners_utils import BaseTunerLayer + + self.eval() + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + module.requires_grad_(False) + def load_adapter( self, peft_model_id: str | None = None, @@ -508,7 +548,7 @@ def load_adapter( from peft import PeftType from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp - from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict + from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files if local_files_only: kwargs["local_files_only"] = True @@ -628,35 +668,54 @@ def load_adapter( break if has_tp_adapters: - all_pointer = set() - if adapter_state_dict is not None: - merged_state_dict = adapter_state_dict - elif ( - checkpoint_files is not None - and checkpoint_files[0].endswith(".safetensors") - and adapter_state_dict is None - ): - merged_state_dict = {} - for file in checkpoint_files: - file_pointer = safe_open(file, framework="pt", device="cpu") - all_pointer.add(file_pointer) - for k in file_pointer.keys(): - merged_state_dict[k] = file_pointer.get_tensor(k) - # Checkpoints are .bin - elif checkpoint_files is not None: - merged_state_dict = {} - for ckpt_file in checkpoint_files: - merged_state_dict.update(load_state_dict(ckpt_file)) - else: - raise ValueError("Neither a state dict nor checkpoint files were found.") - - adapter_state_dict = merged_state_dict + adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files) if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()): raise ValueError("Expected all values in the adapter state dict to be tensors.") _maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name) + if hotswap: + # Bypass the standard loader and use PEFT's hotswap path so that LoRA weights + # whose rank differs from the existing adapter's are copied (and zero-padded) + # in place rather than triggering a "size mismatch" reinit, and so the LoRA + # scaling is updated alongside the weights. + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + + adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files) + + # need to apply conversions manually as we don't use _load_pretrained_model + renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)] + converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)] + meta_state_dict = self.state_dict() + processed_state_dict = {} + for key, value in adapter_state_dict.items(): + renamed_key, _ = rename_source_key(key, renamings, converters, self.base_model_prefix, meta_state_dict) + processed_state_dict[renamed_key] = value + + check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=processed_state_dict, + adapter_name=adapter_name, + config=peft_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error:\n{e}") + raise + + if peft_config.inference_mode: + self._set_peft_inference_mode() + + return LoadStateDictInfo( + missing_keys=set(), + unexpected_keys=set(), + mismatched_keys=set(), + error_msgs=[], + conversion_errors={}, + ) + load_config = replace( load_config, pretrained_model_name_or_path=peft_model_id, @@ -676,12 +735,7 @@ def load_adapter( ) if peft_config.inference_mode: - from peft.tuners.tuners_utils import BaseTunerLayer - - self.eval() - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - module.requires_grad_(False) + self._set_peft_inference_mode() adapter_key_markers = {adapter_name} if peft_config is not None and getattr(peft_config, "peft_type", None) is not None: @@ -699,6 +753,16 @@ def is_adapter_key(key: str) -> bool: loading_info=loading_info, logger=logger, ) + + if self._prepare_peft_hotswap_kwargs is not None: + # Apply once, after the first adapter has been loaded but before the model is + # compiled, so the LoRA layers get padded up to target_rank and a later adapter + # with a different rank can be hot-swapped in without recompiling. + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) + self._prepare_peft_hotswap_kwargs = None + return loading_info def enable_peft_hotswap(