From 11013225e5c18a7565e740222f19e20c683c46a9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 28 Apr 2026 13:14:47 +0200 Subject: [PATCH 1/4] FIX Restore LoRA hotswapping functionality LoRA hotswapping was added in #41297. Due to changes in #43261, it stopped working. This PR restores the functionality. The tests already cover this and are failing, but probably no one noticed because they're slow tests. On main, they fail with mismatched sizes, which is expected as the padding of the LoRA weights is not being applied. With this PR, I can confirm that the tests pass locally. Since the two PRs were released in together in v5, there was never a Transformers release with working hotswapping functionality. Notes: The hotswap path does not use _load_pretrained_model, which means that loading the state_dict if not present is required. I hoisted that functionality from the TP path, which was already there, to re-use the same logic. I also apply weight renamings for that reason. Moreover, I moved the inference model logic to a local function, again to avoid duplicating the logic. --- src/transformers/integrations/peft.py | 108 +++++++++++++++++++------- 1 file changed, 81 insertions(+), 27 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 7b93e0a134b8..cad07bc2d3fc 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -34,6 +34,7 @@ Transpose, WeightConverter, WeightRenaming, + rename_source_key, ) from ..utils import ( CONFIG_NAME, @@ -47,7 +48,7 @@ logging, ) from ..utils.hub import DownloadKwargs -from ..utils.loading_report import log_state_dict_report +from ..utils.loading_report import LoadStateDictInfo, log_state_dict_report if is_torch_available(): @@ -506,6 +507,7 @@ def load_adapter( `find_adapter_config_file` method. """ from peft import PeftType + from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict @@ -618,45 +620,92 @@ def load_adapter( device_map = getattr(self, "hf_device_map", {"": self.device}) - # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` - # is not compatible with the way PEFT adapter should be sharded. - has_tp_adapters = False - for module in self.modules(): - tp_info = getattr(module, "_tp_info", None) - if tp_info is not None: - has_tp_adapters = True - break - - if has_tp_adapters: + def _resolve_adapter_state_dict(): + # Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths + # that bypass `self._load_pretrained_model` (which would otherwise read the files itself). all_pointer = set() if adapter_state_dict is not None: - merged_state_dict = adapter_state_dict - elif ( - checkpoint_files is not None - and checkpoint_files[0].endswith(".safetensors") - and adapter_state_dict is None - ): + return adapter_state_dict + if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"): merged_state_dict = {} for file in checkpoint_files: file_pointer = safe_open(file, framework="pt", device="cpu") all_pointer.add(file_pointer) for k in file_pointer.keys(): merged_state_dict[k] = file_pointer.get_tensor(k) + return merged_state_dict # Checkpoints are .bin - elif checkpoint_files is not None: + if checkpoint_files is not None: merged_state_dict = {} for ckpt_file in checkpoint_files: merged_state_dict.update(load_state_dict(ckpt_file)) - else: - raise ValueError("Neither a state dict nor checkpoint files were found.") + return merged_state_dict + raise ValueError("Neither a state dict nor checkpoint files were found.") - adapter_state_dict = merged_state_dict + def set_inference_mode(model): + model.eval() + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.requires_grad_(False) + + # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` + # is not compatible with the way PEFT adapter should be sharded. + has_tp_adapters = False + for module in self.modules(): + tp_info = getattr(module, "_tp_info", None) + if tp_info is not None: + has_tp_adapters = True + break + + if has_tp_adapters: + adapter_state_dict = _resolve_adapter_state_dict() if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()): raise ValueError("Expected all values in the adapter state dict to be tensors.") _maybe_shard_state_dict_for_tp(self, adapter_state_dict, adapter_name) + if hotswap: + # Bypass the standard loader and use PEFT's hotswap path so that LoRA weights + # whose rank differs from the existing adapter's are copied (and zero-padded) + # in place rather than triggering a "size mismatch" reinit, and so the LoRA + # scaling is updated alongside the weights. + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + + adapter_state_dict = _resolve_adapter_state_dict() + + # need to apply conversions manually as we don't use _load_pretrained_model + renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)] + converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)] + meta_state_dict = self.state_dict() + processed_state_dict = {} + for key, value in adapter_state_dict.items(): + renamed_key, _ = rename_source_key(key, renamings, converters, self.base_model_prefix, meta_state_dict) + processed_state_dict[renamed_key] = value + + check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=processed_state_dict, + adapter_name=adapter_name, + config=peft_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error:\n{e}") + raise + + if peft_config.inference_mode: + set_inference_mode(self) + + return LoadStateDictInfo( + missing_keys=set(), + unexpected_keys=set(), + mismatched_keys=set(), + error_msgs=[], + conversion_errors={}, + ) + load_config = replace( load_config, pretrained_model_name_or_path=peft_model_id, @@ -676,12 +725,7 @@ def load_adapter( ) if peft_config.inference_mode: - from peft.tuners.tuners_utils import BaseTunerLayer - - self.eval() - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - module.requires_grad_(False) + set_inference_mode(self) adapter_key_markers = {adapter_name} if peft_config is not None and getattr(peft_config, "peft_type", None) is not None: @@ -699,6 +743,16 @@ def is_adapter_key(key: str) -> bool: loading_info=loading_info, logger=logger, ) + + if self._prepare_peft_hotswap_kwargs is not None: + # Apply once, after the first adapter has been loaded but before the model is + # compiled, so the LoRA layers get padded up to target_rank and a later adapter + # with a different rank can be hot-swapped in without recompiling. + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) + self._prepare_peft_hotswap_kwargs = None + return loading_info def enable_peft_hotswap( From 6a15499a4943b03c85e0a01a2ed5541915242257 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 11 May 2026 12:06:05 +0200 Subject: [PATCH 2/4] Reviewer feedback No local functions --- src/transformers/integrations/peft.py | 78 +++++++++++++++------------ 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index cad07bc2d3fc..4536e00343d6 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -429,6 +429,45 @@ class PeftAdapterMixin: _prepare_peft_hotswap_kwargs: dict | None = None peft_config: dict[str, PeftConfigLike] + def _resolve_adapter_state_dict( + self, adapter_state_dict: dict[str, "torch.Tensor"] | None, checkpoint_files + ) -> dict[str, torch.Tensor]: + # Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths + # that bypass `self._load_pretrained_model` (which would otherwise read the files itself). + from ..modeling_utils import load_state_dict + + all_pointer = set() + if adapter_state_dict is not None: + merged_state_dict = adapter_state_dict + elif ( + checkpoint_files is not None + and checkpoint_files[0].endswith(".safetensors") + and adapter_state_dict is None + ): + merged_state_dict = {} + for file in checkpoint_files: + file_pointer = safe_open(file, framework="pt", device="cpu") + all_pointer.add(file_pointer) + for k in file_pointer.keys(): + merged_state_dict[k] = file_pointer.get_tensor(k) + # Checkpoints are .bin + elif checkpoint_files is not None: + merged_state_dict = {} + for ckpt_file in checkpoint_files: + merged_state_dict.update(load_state_dict(ckpt_file)) + else: + raise ValueError("Neither a state dict nor checkpoint files were found.") + + return merged_state_dict + + def _set_peft_inference_mode(self) -> None: + from peft.tuners.tuners_utils import BaseTunerLayer + + self.eval() + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + module.requires_grad_(False) + def load_adapter( self, peft_model_id: str | None = None, @@ -507,10 +546,9 @@ def load_adapter( `find_adapter_config_file` method. """ from peft import PeftType - from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils.save_and_load import _maybe_shard_state_dict_for_tp - from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files, load_state_dict + from ..modeling_utils import LoadStateDictConfig, _get_resolved_checkpoint_files if local_files_only: kwargs["local_files_only"] = True @@ -620,34 +658,6 @@ def load_adapter( device_map = getattr(self, "hf_device_map", {"": self.device}) - def _resolve_adapter_state_dict(): - # Materialize the adapter state dict from `adapter_state_dict` or `checkpoint_files`. Used by paths - # that bypass `self._load_pretrained_model` (which would otherwise read the files itself). - all_pointer = set() - if adapter_state_dict is not None: - return adapter_state_dict - if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"): - merged_state_dict = {} - for file in checkpoint_files: - file_pointer = safe_open(file, framework="pt", device="cpu") - all_pointer.add(file_pointer) - for k in file_pointer.keys(): - merged_state_dict[k] = file_pointer.get_tensor(k) - return merged_state_dict - # Checkpoints are .bin - if checkpoint_files is not None: - merged_state_dict = {} - for ckpt_file in checkpoint_files: - merged_state_dict.update(load_state_dict(ckpt_file)) - return merged_state_dict - raise ValueError("Neither a state dict nor checkpoint files were found.") - - def set_inference_mode(model): - model.eval() - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - module.requires_grad_(False) - # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` # is not compatible with the way PEFT adapter should be sharded. has_tp_adapters = False @@ -658,7 +668,7 @@ def set_inference_mode(model): break if has_tp_adapters: - adapter_state_dict = _resolve_adapter_state_dict() + adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files) if any(not isinstance(v, torch.Tensor) for v in adapter_state_dict.values()): raise ValueError("Expected all values in the adapter state dict to be tensors.") @@ -672,7 +682,7 @@ def set_inference_mode(model): # scaling is updated alongside the weights. from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict - adapter_state_dict = _resolve_adapter_state_dict() + adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files) # need to apply conversions manually as we don't use _load_pretrained_model renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)] @@ -696,7 +706,7 @@ def set_inference_mode(model): raise if peft_config.inference_mode: - set_inference_mode(self) + self._set_peft_inference_mode() return LoadStateDictInfo( missing_keys=set(), @@ -725,7 +735,7 @@ def set_inference_mode(model): ) if peft_config.inference_mode: - set_inference_mode(self) + self._set_peft_inference_mode() adapter_key_markers = {adapter_name} if peft_config is not None and getattr(peft_config, "peft_type", None) is not None: From 87356d88242c1224f9131913a0d6810969052f41 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 18 May 2026 11:04:28 +0200 Subject: [PATCH 3/4] Address reviewer comments - remove incorrect comment - testing mixtral Also fixed a small bug regarding inference mode setting. For Mixtral, I had to adjust tolerances, but I visually inspected the logits and they're pretty much identical. --- src/transformers/integrations/peft.py | 8 ++-- .../peft_integration/test_peft_integration.py | 46 ++++++++++++++----- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 4536e00343d6..da40126c5d07 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -658,8 +658,6 @@ def load_adapter( device_map = getattr(self, "hf_device_map", {"": self.device}) - # If the model is tensor parallel, we handle the sharding of the state dict here since the logic in `self._load_pretrained_model` - # is not compatible with the way PEFT adapter should be sharded. has_tp_adapters = False for module in self.modules(): tp_info = getattr(module, "_tp_info", None) @@ -734,9 +732,6 @@ def load_adapter( expected_keys=[n for n, _ in self.named_parameters()], ) - if peft_config.inference_mode: - self._set_peft_inference_mode() - adapter_key_markers = {adapter_name} if peft_config is not None and getattr(peft_config, "peft_type", None) is not None: adapter_key_markers.add(peft_config.peft_type.value.lower()) @@ -763,6 +758,9 @@ def is_adapter_key(key: str) -> bool: prepare_model_for_compiled_hotswap(self, config=peft_config, **self._prepare_peft_hotswap_kwargs) self._prepare_peft_hotswap_kwargs = None + if peft_config.inference_mode: + self._set_peft_inference_mode() + return loading_info def enable_peft_hotswap( diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 33880c88135d..ffb65dae9417 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -1046,33 +1046,41 @@ def tearDown(self): torch.compiler.reset() gc.collect() - def _check_model_hotswap(self, *, rank1, rank2, do_compile): + def _check_model_hotswap( + self, *, rank1, rank2, do_compile, model_id="hf-internal-testing/tiny-random-OPTForCausalLM" + ): # utility method that checks that we can successfully hotswap adapters, with the model outputs corresponding to # the respective adapters from peft import LoraConfig torch.manual_seed(0) - model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) input = torch.randint(0, 100, (1, 10)).to(torch_device) with torch.inference_mode(): base_output = model(input).logits # create 2 adapters - model.add_adapter(LoraConfig(r=rank1, init_lora_weights=False), adapter_name="adapter_1") + model.add_adapter( + LoraConfig(r=rank1, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_1" + ) with torch.inference_mode(): lora_1_output = model(input).logits # second adapter may have a different rank - model.add_adapter(LoraConfig(r=rank2, init_lora_weights=False), adapter_name="adapter_2") + model.add_adapter( + LoraConfig(r=rank2, init_lora_weights=False, target_modules=["q_proj", "v_proj"]), adapter_name="adapter_2" + ) model.set_adapter("adapter_2") with torch.inference_mode(): lora_2_output = model(input).logits # sanity checks - self.assertFalse(torch.allclose(base_output, lora_1_output, atol=1e-6, rtol=1e-6)) - self.assertFalse(torch.allclose(base_output, lora_2_output, atol=1e-6, rtol=1e-6)) - self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=1e-6, rtol=1e-6)) + atol = 2e-3 + rtol = 1e-6 + self.assertFalse(torch.allclose(base_output, lora_1_output, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(base_output, lora_2_output, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(lora_1_output, lora_2_output, atol=atol, rtol=rtol)) with tempfile.TemporaryDirectory() as tmpdirname: path_1 = os.path.join(tmpdirname, "adapter_1") @@ -1090,26 +1098,26 @@ def _check_model_hotswap(self, *, rank1, rank2, do_compile): model.enable_peft_hotswap(target_rank=max(rank1, rank2)) # load the first adapter without hotswap (hotswap requires an existing adapter) - model.load_adapter(path_1, adapter_name="adapter_1") + model.load_adapter(path_1, adapter_name="adapter_1", is_trainable=False) if do_compile: # compile the model after loading the first adapter model = torch.compile(model, mode="reduce-overhead") with torch.inference_mode(): lora_1_output_loaded = model(input).logits - self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.allclose(lora_1_output, lora_1_output_loaded, atol=atol, rtol=rtol)) # hotswap in adapter_2 again, output should be same as lora_2_output if enable_hotswap: # after calling enable_peft_hotswap, hotswap will automatically be enabled - model.load_adapter(path_2, adapter_name="adapter_1") + model.load_adapter(path_2, adapter_name="adapter_1", is_trainable=False) else: # enable_peft_hotswap was not called, need to explicitly pass hotswap=True - model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True) + model.load_adapter(path_2, adapter_name="adapter_1", hotswap=True, is_trainable=False) with torch.inference_mode(): lora_2_output_loaded = model(input).logits - self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=1e-6, rtol=1e-6)) + self.assertTrue(torch.allclose(lora_2_output, lora_2_output_loaded, atol=atol, rtol=rtol)) def test_hotswap_wrong_peft_type_raises(self): # only LoRA is supported for now @@ -1260,3 +1268,17 @@ def test_maybe_load_adapters_path_not_overwritten_for_complete_model(self): # Load from the saved path and make sure it actually loads despite # the invalid adapter config path AutoModel.from_pretrained(tmp_dir) + + def test_mixtral_hotswap_without_compile_works(self): + # test a model that usees weight conversion + model_id = "hf-internal-testing/Mixtral-tiny" + self._check_model_hotswap(rank1=7, rank2=13, do_compile=False, model_id=model_id) + + def test_mixtral_hotswap_with_compile_works(self): + # test a model that usees weight conversion + model_id = "hf-internal-testing/Mixtral-tiny" + with ( + torch._dynamo.config.patch(error_on_recompile=True), + torch._inductor.utils.fresh_inductor_cache(), + ): + self._check_model_hotswap(rank1=8, rank2=8, do_compile=True, model_id=model_id) From a91877f89214157e06c6d582cae858f56c86ab1a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 3 Jun 2026 13:06:31 +0200 Subject: [PATCH 4/4] Address reviewer feedback: - properly re-implement weight conversion logic - fix mixtral test --- src/transformers/integrations/peft.py | 27 ++++++++++++++++--- .../peft_integration/test_peft_integration.py | 8 +++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index fccb66e13376..0105d1c34697 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -34,6 +34,7 @@ Transpose, WeightConverter, WeightRenaming, + dot_natural_key, rename_source_key, ) from ..utils import ( @@ -689,14 +690,32 @@ def is_adapter_key(key: str) -> bool: adapter_state_dict = self._resolve_adapter_state_dict(adapter_state_dict, checkpoint_files) - # need to apply conversions manually as we don't use _load_pretrained_model + # Need to apply conversions manually as we don't use _load_pretrained_model. Same logic as in: + # https://github.com/huggingface/transformers/blob/a8f150d35d5863971db1e5c1dbc2a1c265f27f96/src/transformers/core_model_loading.py#L1222 renamings = [r for r in peft_weight_conversions if isinstance(r, WeightRenaming)] converters = [c for c in peft_weight_conversions if isinstance(c, WeightConverter)] + pattern_to_converter = {p: c for c in converters for p in c.source_patterns} meta_state_dict = self.state_dict() + conversion_mapping: dict[str, WeightConverter] = {} processed_state_dict = {} - for key, value in adapter_state_dict.items(): - renamed_key, _ = rename_source_key(key, renamings, converters, self.base_model_prefix, meta_state_dict) - processed_state_dict[renamed_key] = value + # Sort by `dot_natural_key` so converters such as MergeModulelist collect experts in numeric order. + for key, value in sorted(adapter_state_dict.items(), key=lambda kv: dot_natural_key(kv[0])): + renamed_key, source_pattern = rename_source_key( + key, renamings, converters, self.base_model_prefix, meta_state_dict + ) + if source_pattern is not None: + # A WeightConverter matched: bucket the tensor so its operations can run over all siblings. + mapping = conversion_mapping.setdefault( + renamed_key, copy.deepcopy(pattern_to_converter[source_pattern]) + ) + mapping.add_tensor(renamed_key, key, source_pattern, value) + else: + processed_state_dict[renamed_key] = value + + for layer_name, mapping in conversion_mapping.items(): + realized = mapping.convert(layer_name, model=self, config=self.config) + for target_name, param in realized.items(): + processed_state_dict[target_name] = param[0] if isinstance(param, list) else param check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config) try: diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index f682f8276f5d..f0a325817aa5 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -1154,7 +1154,13 @@ def _check_model_hotswap( model.load_adapter(path_1, adapter_name="adapter_1", is_trainable=False) if do_compile: # compile the model after loading the first adapter - model = torch.compile(model, mode="reduce-overhead") + if "mixtral" not in model_id.lower(): + model = torch.compile(model, mode="reduce-overhead") + else: + # The tiny mixtral model is incompatible with 'reduce-overhead', resulting in: + # > torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture + # For the purpose of this test, 'reduce-overhead' is not material, so we drop it here. + model = torch.compile(model) with torch.inference_mode(): lora_1_output_loaded = model(input).logits