Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 123 additions & 32 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Transpose,
WeightConverter,
WeightRenaming,
dot_natural_key,
rename_source_key,
)
from ..utils import (
CONFIG_NAME,
Expand All @@ -47,7 +49,7 @@
logging,
)
from ..utils.hub import DownloadKwargs
from ..utils.loading_report import log_state_dict_report
from ..utils.loading_report import LoadStateDictInfo, log_state_dict_report


if is_torch_available():
Expand Down Expand Up @@ -428,6 +430,45 @@ class PeftAdapterMixin:
_prepare_peft_hotswap_kwargs: dict | None = None
peft_config: dict[str, PeftConfigLike]

def _resolve_adapter_state_dict(

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is not new code, it's just moved to a separate method to avoid duplication. Original code here:

all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")
adapter_state_dict = merged_state_dict
if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()):
raise ValueError("Expected all values in the adapter state dict to be tensors.")

self, adapter_state_dict: dict[str, "torch.Tensor"] | None, checkpoint_files
) -> dict[str, torch.Tensor]:
# Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths
# that bypass `self._load_pretrained_model` (which would otherwise read the files itself).
from ..modeling_utils import load_state_dict

all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

return merged_state_dict

def _set_peft_inference_mode(self) -> None:
from peft.tuners.tuners_utils import BaseTunerLayer

self.eval()
for module in self.modules():
if isinstance(module, BaseTunerLayer):
module.requires_grad_(False)

def load_adapter(
self,
peft_model_id: str | None = None,
Expand Down Expand Up @@ -508,7 +549,7 @@ def load_adapter(
from peft import PeftType
from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp

from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict
from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files

if local_files_only:
kwargs["local_files_only"] = True
Expand Down Expand Up @@ -625,8 +666,6 @@ def is_adapter_key(key: str) -> bool:

device_map = getattr(self, "hf_device_map", {"": self.device})

# If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model`
# is not compatible with the way PEFT adapter should be sharded.
has_tp_adapters = False
for module in self.modules():
tp_info = getattr(module, "_tp_info", None)
Expand All @@ -635,35 +674,72 @@ def is_adapter_key(key: str) -> bool:
break

if has_tp_adapters:
all_pointer = set()
if adapter_state_dict is not None:
merged_state_dict = adapter_state_dict
elif (
checkpoint_files is not None
and checkpoint_files[0].endswith(".safetensors")
and adapter_state_dict is None
):
merged_state_dict = {}
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_tensor(k)
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
merged_state_dict.update(load_state_dict(ckpt_file))
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")

adapter_state_dict = merged_state_dict
adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files)

if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()):
raise ValueError("Expected all values in the adapter state dict to be tensors.")

_maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name)

if hotswap:
# Bypass the standard loader and use PEFT's hotswap path so that LoRA weights
# whose rank differs from the existing adapter's are copied (and zero-padded)
# in place rather than triggering a "size mismatch" reinit, and so the LoRA
# scaling is updated alongside the weights.
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict

adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files)

# Need to apply conversions manually as we don't use _load_pretrained_model. Same logic as in:
# https://github.com/huggingface/transformers/blob/a8f150d35d5863971db1e5c1dbc2a1c265f27f96/src/transformers/core_model_loading.py#L1222
renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)]
converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)]
pattern_to_converter = {p: c for c in converters for p in c.source_patterns}
meta_state_dict = self.state_dict()
conversion_mapping: dict[str, WeightConverter] = {}
processed_state_dict = {}
# Sort by `dot_natural_key` so converters such as MergeModulelist collect experts in numeric order.
for key, value in sorted(adapter_state_dict.items(), key=lambda kv: dot_natural_key(kv[0])):
renamed_key, source_pattern = rename_source_key(
key, renamings, converters, self.base_model_prefix, meta_state_dict
)
if source_pattern is not None:
# A WeightConverter matched: bucket the tensor so its operations can run over all siblings.
mapping = conversion_mapping.setdefault(
renamed_key, copy.deepcopy(pattern_to_converter[source_pattern])
)
mapping.add_tensor(renamed_key, key, source_pattern, value)
else:
processed_state_dict[renamed_key] = value

for layer_name, mapping in conversion_mapping.items():
realized = mapping.convert(layer_name, model=self, config=self.config)
for target_name, param in realized.items():
processed_state_dict[target_name] = param[0] if isinstance(param, list) else param

check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
Comment thread
ArthurZucker marked this conversation as resolved.
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=processed_state_dict,
adapter_name=adapter_name,
config=peft_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error:\n{e}")
raise

if peft_config.inference_mode:
self._set_peft_inference_mode()

return LoadStateDictInfo(
missing_keys=set(),
unexpected_keys=set(),
mismatched_keys=set(),
error_msgs=[],
conversion_errors={},
)

load_config = replace(
load_config,
pretrained_model_name_or_path=peft_model_id,
Expand All @@ -683,12 +759,14 @@ def is_adapter_key(key: str) -> bool:
)

if peft_config.inference_mode:
from peft.tuners.tuners_utils import BaseTunerLayer
self._set_peft_inference_mode()

self.eval()
for module in self.modules():
if isinstance(module, BaseTunerLayer):
module.requires_grad_(False)
adapter_key_markers = {adapter_name}
if peft_config is not None and getattr(peft_config, "peft_type", None) is not None:
adapter_key_markers.add(peft_config.peft_type.value.lower())

def is_adapter_key(key: str) -> bool:
return any(marker in key for marker in adapter_key_markers)

loading_info.missing_keys = {k for k in loading_info.missing_keys if is_adapter_key(k)}

Expand All @@ -699,6 +777,19 @@ def is_adapter_key(key: str) -> bool:
loading_info=loading_info,
logger=logger,
)

if self._prepare_peft_hotswap_kwargs is not None:
# Apply once, after the first adapter has been loaded but before the model is
# compiled, so the LoRA layers get padded up to target_rank and a later adapter
# with a different rank can be hot-swapped in without recompiling.
from peft.utils.hotswap import prepare_model_for_compiled_hotswap

prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs)
self._prepare_peft_hotswap_kwargs = None

if peft_config.inference_mode:
self._set_peft_inference_mode()

return loading_info

def enable_peft_hotswap(
Expand Down
54 changes: 41 additions & 13 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,33 +1099,41 @@ def tearDown(self):
torch.compiler.reset()
gc.collect()

def _check_model_hotswap(self, *, rank1, rank2, do_compile):
def _check_model_hotswap(
self, *, rank1, rank2, do_compile, model_id="hf-internal-testing/tiny-random-OPTForCausalLM"
):
# utility method that checks that we can successfully hotswap adapters, with the model outputs corresponding to
# the respective adapters
from peft import LoraConfig

torch.manual_seed(0)
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
input = torch.randint(0, 100, (1, 10)).to(torch_device)
with torch.inference_mode():
base_output = model(input).logits

# create 2 adapters
model.add_adapter(LoraConfig(r=rank1, init_lora_weights=False), adapter_name="adapter_1")
model.add_adapter(
LoraConfig(r=rank1, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_1"
)
with torch.inference_mode():
lora_1_output = model(input).logits

# second adapter may have a different rank
model.add_adapter(LoraConfig(r=rank2, init_lora_weights=False), adapter_name="adapter_2")
model.add_adapter(
LoraConfig(r=rank2, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_2"
)
model.set_adapter("adapter_2")
with torch.inference_mode():
lora_2_output = model(input).logits

# sanity checks
self.assertFalse(torch.allclose(base_output, lora_1_output, atol=1e-6, rtol=1e-6))
self.assertFalse(torch.allclose(base_output, lora_2_output, atol=1e-6, rtol=1e-6))
self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=1e-6, rtol=1e-6))
atol = 2e-3
rtol = 1e-6
self.assertFalse(torch.allclose(base_output, lora_1_output, atol=atol, rtol=rtol))
self.assertFalse(torch.allclose(base_output, lora_2_output, atol=atol, rtol=rtol))
self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=atol, rtol=rtol))

with tempfile.TemporaryDirectory() as tmpdirname:
path_1 = os.path.join(tmpdirname, "adapter_1")
Expand All @@ -1143,26 +1151,32 @@ def _check_model_hotswap(self, *, rank1, rank2, do_compile):
model.enable_peft_hotswap(target_rank=max(rank1, rank2))

# load the first adapter without hotswap (hotswap requires an existing adapter)
model.load_adapter(path_1, adapter_name="adapter_1")
model.load_adapter(path_1, adapter_name="adapter_1", is_trainable=False)
if do_compile:
# compile the model after loading the first adapter
model = torch.compile(model, mode="reduce-overhead")
if "mixtral" not in model_id.lower():
model = torch.compile(model, mode="reduce-overhead")
else:
# The tiny mixtral model is incompatible with 'reduce-overhead', resulting in:
# > torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture
# For the purpose of this test, 'reduce-overhead' is not material, so we drop it here.
model = torch.compile(model)

with torch.inference_mode():
lora_1_output_loaded = model(input).logits
self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=atol, rtol=rtol))

# hotswap in adapter_2 again, output should be same as lora_2_output
if enable_hotswap:
# after calling enable_peft_hotswap, hotswap will automatically be enabled
model.load_adapter(path_2, adapter_name="adapter_1")
model.load_adapter(path_2, adapter_name="adapter_1", is_trainable=False)
else:
# enable_peft_hotswap was not called, need to explicitly pass hotswap=True
model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True)
model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True, is_trainable=False)

with torch.inference_mode():
lora_2_output_loaded = model(input).logits
self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=1e-6, rtol=1e-6))
self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=atol, rtol=rtol))

def test_hotswap_wrong_peft_type_raises(self):
# only LoRA is supported for now
Expand Down Expand Up @@ -1313,3 +1327,17 @@ def test_maybe_load_adapters_path_not_overwritten_for_complete_model(self):
# Load from the saved path and make sure it actually loads despite
# the invalid adapter config path
AutoModel.from_pretrained(tmp_dir)

def test_mixtral_hotswap_without_compile_works(self):
# test a model that usees weight conversion
model_id = "hf-internal-testing/Mixtral-tiny"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the experts properly targeted ? (the only that actually do convert?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the experts are not covered by this test. As mentioned, it wouldn't work for lack of support in PEFT itself. I have an item on my TODO list to see if expert (or rather, nn.Parameter) targeting can be supported with hotswapping. When/if that lands, I can update the Transformers test to check for that.

self._check_model_hotswap(rank1=7, rank2=13, do_compile=False, model_id=model_id)

def test_mixtral_hotswap_with_compile_works(self):
# test a model that usees weight conversion
model_id = "hf-internal-testing/Mixtral-tiny"
with (
torch._dynamo.config.patch(error_on_recompile=True),
torch._inductor.utils.fresh_inductor_cache(),
):
self._check_model_hotswap(rank1=8, rank2=8, do_compile=True, model_id=model_id)
Loading