Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 95 additions & 31 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Transpose,
WeightConverter,
WeightRenaming,
rename_source_key,
)
from ..utils import (
CONFIG_NAME,
Expand All @@ -47,7 +48,7 @@
logging,
)
from ..utils.hub import DownloadKwargs
from ..utils.loading_report import log_state_dict_report
from ..utils.loading_report import LoadStateDictInfo, log_state_dict_report


if is_torch_available():
Expand Down Expand Up @@ -428,6 +429,45 @@ class PeftAdapterMixin:
_prepare_peft_hotswap_kwargs: dict | None = None
peft_config: dict[str, PeftConfigLike]

def _resolve_adapter_state_dict(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is not new code, it's just moved to a separate method to avoid duplication. Original code here:

all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")
adapter_state_dict = merged_state_dict
if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()):
raise ValueError("Expected all values in the adapter state dict to be tensors.")

self, adapter_state_dict: dict[str, "torch.Tensor"] | None, checkpoint_files
) -> dict[str, torch.Tensor]:
# Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths
# that bypass `self._load_pretrained_model` (which would otherwise read the files itself).
from ..modeling_utils import load_state_dict

all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

return merged_state_dict

def _set_peft_inference_mode(self) -> None:
from peft.tuners.tuners_utils import BaseTunerLayer

self.eval()
for module in self.modules():
if isinstance(module, BaseTunerLayer):
module.requires_grad_(False)

def load_adapter(
self,
peft_model_id: str | None = None,
Expand Down Expand Up @@ -508,7 +548,7 @@ def load_adapter(
from peft import PeftType
from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp

from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict
from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files

if local_files_only:
kwargs["local_files_only"] = True
Expand Down Expand Up @@ -628,35 +668,54 @@ def load_adapter(
break

if has_tp_adapters:
all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

adapter_state_dict = merged_state_dict
adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files)

if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()):
raise ValueError("Expected all values in the adapter state dict to be tensors.")

_maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name)

if hotswap:
# Bypass the standard loader and use PEFT's hotswap path so that LoRA weights
# whose rank differs from the existing adapter's are copied (and zero-padded)
# in place rather than triggering a "size mismatch" reinit, and so the LoRA
# scaling is updated alongside the weights.
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict

adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files)

# need to apply conversions manually as we don't use _load_pretrained_model
renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)]
converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)]
meta_state_dict = self.state_dict()
processed_state_dict = {}
for key, value in adapter_state_dict.items():
renamed_key, _ = rename_source_key(key, renamings, converters, self.base_model_prefix, meta_state_dict)
processed_state_dict[renamed_key] = value

check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=processed_state_dict,
adapter_name=adapter_name,
config=peft_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error:\n{e}")
raise

if peft_config.inference_mode:
self._set_peft_inference_mode()

return LoadStateDictInfo(
missing_keys=set(),
unexpected_keys=set(),
mismatched_keys=set(),
error_msgs=[],
conversion_errors={},
)

load_config = replace(
load_config,
pretrained_model_name_or_path=peft_model_id,
Expand All @@ -676,12 +735,7 @@ def load_adapter(
)

if peft_config.inference_mode:
from peft.tuners.tuners_utils import BaseTunerLayer

self.eval()
for module in self.modules():
if isinstance(module, BaseTunerLayer):
module.requires_grad_(False)
self._set_peft_inference_mode()

adapter_key_markers = {adapter_name}
if peft_config is not None and getattr(peft_config, "peft_type", None) is not None:
Expand All @@ -699,6 +753,16 @@ def is_adapter_key(key: str) -> bool:
loading_info=loading_info,
logger=logger,
)

if self._prepare_peft_hotswap_kwargs is not None:
# Apply once, after the first adapter has been loaded but before the model is
# compiled, so the LoRA layers get padded up to target_rank and a later adapter
# with a different rank can be hot-swapped in without recompiling.
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
self._prepare_peft_hotswap_kwargs = None

return loading_info

def enable_peft_hotswap(
Expand Down
Loading