diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 5ad58646fc..31e09523a2 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -585,7 +585,19 @@ def add_weighted_adapter( adapters: list[str], weights: list[float], adapter_name: str, - combination_type: str = "svd", + combination_type: Literal[ + "svd", + "linear", + "cat", + "ties", + "ties_svd", + "dare_ties", + "dare_linear", + "dare_ties_svd", + "dare_linear_svd", + "magnitude_prune", + "magnitude_prune_svd", + ] = "svd", svd_rank: int | None = None, svd_clamp: int | None = None, svd_full_matrices: bool = True, @@ -612,7 +624,9 @@ def add_weighted_adapter( The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the - mixed adapter may be too big and result in OOM errors). + mixed adapter may be too big and result in OOM errors). Note that `cat` and `svd` are precise methods + and will give you good accuracy, `linear` is efficient but a very rough approximation and should be + avoided if you can afford it. svd_rank (`int`, *optional*): Rank of output adapter for svd. If None provided, will use max rank of merging adapters. svd_clamp (`float`, *optional*): @@ -738,11 +752,12 @@ def _svd_generalized_task_arithmetic_weighted_adapter( for adapter, weight in zip(adapters, weights): if adapter in target.lora_A or adapter in target.lora_embedding_A: valid_adapters.append(adapter) - valid_weights.append(weight * target.scaling[adapter]) + valid_weights.append(weight) # if no valid adapter, nothing to do if len(valid_adapters) == 0: - raise ValueError("No matching LoRAs found. Please raise an issue on Github.") + raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.") + # get_delta_weight applies the scaling, no need to handle it explicitly delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters] valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device) if combination_type == "svd": diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index c31670af61..4565c5850d 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3326,9 +3326,29 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self): dw_cancelled = module.get_delta_weight("cancelled") assert torch.allclose(dw_cancelled, torch.zeros_like(dw_cancelled)) - def test_add_weighted_adapter_negative_weight_with_different_scaling(self): - # Test negative weights with different scaling factors (lora_alpha) - # This edge case ensures negative weights work correctly with different scaling values + @pytest.mark.parametrize("weights", [[1.0, 1.0], [0.0, 1.0], [5.0, 0.01], [-1.0, -1.0], [0.5, -0.3]]) + @pytest.mark.parametrize( + "combination_type, min_corr, max_mse", + [ + # note: SVD and cat are 'precise', the others are approximation + ("svd", 0.99, 0.01), + ("cat", 0.99, 0.01), + ("linear", 0.6, 1.0), + ("ties", 0.4, 1.0), + ("ties_svd", 0.8, 1.0), + ("dare_ties", 0.1, 1.0), + ("dare_ties_svd", 0.55, 1.0), + ("dare_linear", 0.2, 1.0), + ("dare_linear_svd", 0.6, 1.0), + ("magnitude_prune", 0.55, 1.0), + ("magnitude_prune_svd", 0.9, 0.1), + ], + ) + def test_add_weighted_adapter_with_different_scaling(self, weights, combination_type, min_corr, max_mse): + # Check that the actually merged weights correspond to what their theoretical value should be. Note that each + # method is an approximation so we can never expect exact equality. We thus test for correlation and MSE as a + # proxy. The acceptance criteria are empirically determined and thus serve more as a regression test than + # actually proving that the merging method works. torch.manual_seed(42) model = MLP() @@ -3337,36 +3357,43 @@ def test_add_weighted_adapter_negative_weight_with_different_scaling(self): r=8, lora_alpha=16, # scaling = 16/8 = 2 target_modules=["lin0"], - lora_dropout=0.0, - bias="none", init_lora_weights=False, ) config2 = LoraConfig( r=8, lora_alpha=32, # scaling = 32/8 = 4 target_modules=["lin0"], - lora_dropout=0.0, - bias="none", init_lora_weights=False, ) model = get_peft_model(model, config1, adapter_name="adapter1") model.add_adapter("adapter2", config2) - - # Merge with negative weight - should handle different scalings correctly model.add_weighted_adapter( adapters=["adapter1", "adapter2"], - weights=[0.5, -0.3], - adapter_name="merged_diff_scaling", - combination_type="linear", + weights=weights, + adapter_name="merged", + combination_type=combination_type, + density=0.5, ) - - # Verify the merged adapter can run forward pass - model.set_adapter("merged_diff_scaling") + model.set_adapter("merged") dummy_input = torch.randn(2, 10) output = model(dummy_input) assert output is not None + # We cannot expect the merged weights to be approximately equal because we're dealing with rough approximations. + # Therefore, we check for correlation to verify that the direction is right and MSE to verify that the magnitude + # is right. + for module in model.modules(): + if isinstance(module, lora.LoraLayer): + dw1 = module.get_delta_weight("adapter1") + dw2 = module.get_delta_weight("adapter2") + dw_merged = module.get_delta_weight("merged") + expected = weights[0] * dw1 + weights[1] * dw2 + corr = torch.corrcoef(torch.stack((dw_merged.flatten(), expected.flatten()))) + mse = ((dw_merged - expected) ** 2).mean() + assert corr[0, 1] > min_corr + assert mse < max_mse + def test_multiple_adapters_no_needless_copy_modules_to_save(self): # See 2206 # The problem was that we keep a "global" modules_to_save on the model which contains all possible diff --git a/tests/testing_common.py b/tests/testing_common.py index caa4618694..49fc15a1f1 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1591,21 +1591,7 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis density=0.5, ) - new_adapters = [ - "single_adapter_reweighting", - "multi_adapter_svd_reweighting", - "multi_adapter_ties_svd_reweighting", - "multi_adapter_dare_linear_svd_reweighting", - "multi_adapter_dare_ties_svd_reweighting", - "multi_adapter_magnitude_prune_svd_reweighting", - "multi_adapter_cat_reweighting", - "multi_adapter_linear_reweighting", - "multi_adapter_linear_reweighting_single_enabled", - "multi_adapter_ties_reweighting", - "multi_adapter_dare_linear_reweighting", - "multi_adapter_dare_ties_reweighting", - "multi_adapter_magnitude_prune_reweighting", - ] + new_adapters = [k for k in model.peft_config.keys() if not k.startswith("adapter_")] for new_adapter in new_adapters: assert new_adapter in model.peft_config @@ -1614,11 +1600,11 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis _, target, _ = _get_submodules(model, key) if isinstance(target, LoraLayer): for adapter_name in new_adapters: + # for a single adapter, the result should be exact and we can check that; otherwise, we deal with + # approximations if "single" in adapter_name: new_delta_weight = target.get_delta_weight(adapter_name) weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0] - sign = 1 if weight_list[0] > 0 else -1 - weighted_original_delta_weights = sign * weighted_original_delta_weights assert torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4) elif "svd" in adapter_name: assert target.r[adapter_name] == 20 @@ -1673,7 +1659,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw if "gemma" in model_id.lower(): return pytest.skip("Combining Gemma adapters with SVD is currently failing") - adapter_list = ["adapter1", "adapter_2", "adapter_3"] + adapter_list = ["adapter_1", "adapter_2", "adapter_3"] weight_list = [0.5, 1.5, 1.5] negative_weight_list = [-0.5, -0.8, -1.2] # Initialize the config @@ -1690,11 +1676,22 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw model = self.transformers_class.from_pretrained(model_id) model = get_peft_model(model, config, adapter_list[0]) + # test positive weights if isinstance(config, LoraConfig): self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list) - self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list) elif isinstance(config, IA3Config): self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list) + else: + pytest.skip(f"Test not applicable for {config}") + + del model + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config, adapter_list[0]) + + # test negative weights + if isinstance(config, LoraConfig): + self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list) + elif isinstance(config, IA3Config): self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, negative_weight_list) else: pytest.skip(f"Test not applicable for {config}")