diff --git a/src/peft/tuners/lora/conversion.py b/src/peft/tuners/lora/conversion.py index 570f9b690d..acbcfac3dc 100644 --- a/src/peft/tuners/lora/conversion.py +++ b/src/peft/tuners/lora/conversion.py @@ -45,11 +45,76 @@ def _find_cutoff_index(S: torch.Tensor, threshold: float) -> int: return k + 1 +@torch.no_grad() +def _convert_miss_module_to_lora( + module, rank: int | float, adapter_name: str = "default" +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Convert a single MiSS layer to LoRA A and B matrices. + + For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r factorization, so the + exact factors are returned directly without SVD. + + For bat mode, the delta weight depends on the base weight, so SVD is used. + """ + miss_fn = module.miss_fn + miss_block = module.miss_block[adapter_name] + in_features = module.in_features + out_features = module.out_features + r_miss = module.miss_r[adapter_name] + orig_dtype = miss_block.dtype + device = miss_block.device + + if miss_fn == "bat": + base_weight = module.get_base_layer().weight.data.clone() + delta_weight = module.get_delta_weight(adapter_name, base_weight).float() + + U, S, V = torch.linalg.svd(delta_weight, full_matrices=False) + + if isinstance(rank, int): + effective_rank = rank + else: + effective_rank = _find_cutoff_index(S, threshold=rank) + + if effective_rank > U.shape[1]: + raise ValueError( + f"The chosen rank {effective_rank} is larger than the weight shape ({U.shape[1]}), please choose a " + "lower rank." + ) + + lora_B = U[:, :effective_rank] * S[:effective_rank] + lora_A = V[:effective_rank] + return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), effective_rank + + # Standard or mini: exact conversion using the native rank r + miss = miss_block.float() + r = miss.size(0) + + if miss_fn == "mini": + mini_r = module.miss_mini_r[adapter_name] + miss = miss.repeat(1, out_features // mini_r) + + # lora_A: structured summation matrix, shape (r, in_features) + # lora_A[j, i] = 1 if i % r == j + lora_A = torch.zeros(r, in_features, device=device, dtype=torch.float32) + indices = torch.arange(in_features, device=device) + lora_A[indices % r, indices] = 1.0 + + # lora_B = miss.T, shape (out_features, r) + lora_B = miss.T + + return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), r + + @torch.no_grad() def _convert_module_to_lora( module: BaseTunerLayer, rank: int | float, adapter_name: str = "default" ) -> tuple[torch.Tensor, torch.Tensor, int]: """Convert a single BaseTunerLayer's adapter weight to a LoRA weight, return A, B, and the effective rank.""" + from peft.tuners.miss.layer import MissLinear + + if isinstance(module, MissLinear): + return _convert_miss_module_to_lora(module, rank, adapter_name) + delta_weight = module.get_delta_weight(adapter_name) # Note: Explore different algorithms (truncated, randomized, ...) to see if they are more efficient diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index bf0c145191..14a733ccc3 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -18,8 +18,8 @@ from typing import Any, Optional import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge @@ -228,21 +228,23 @@ def unmerge(self) -> None: if active_adapter in self.miss_block.keys(): orig_weight = self.get_base_layer().weight.data.clone() if self.miss_fn == "bat": - delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight(active_adapter, orig_weight, reverse=True) elif self.miss_fn == "mini": - delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True) else: - delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True) base_layer.weight.data = delta_weight.to(orig_dtype) - def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: + def get_delta_weight(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. + reverse (bool): + If True, reverse the merge (unmerge). If False, apply the merge (forward). """ device = self.miss_block[adapter].device dtype = self.miss_block[adapter].dtype @@ -251,44 +253,39 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - weight_miss = self.miss_block[adapter] + miss_B = self.miss_block[adapter] if cast_to_fp32: - weight_miss = weight_miss.float() - orig_weight = orig_weight.to(weight_miss.dtype) - - r = weight_miss.size(-1) - if re: - o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) - one = torch.eye(weight_miss.size(-1)).to(weight_miss.device) - # inverse must be in float32, after that the dtype can be adjusted if needed - inv_I_plus_b = torch.inverse(one + weight_miss) - inv_I_plus_b = inv_I_plus_b.to(weight_miss.dtype) - w = (o - weight_miss) @ inv_I_plus_b - output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + miss_B = miss_B.float() + orig_weight = orig_weight.to(miss_B.dtype) + + r = miss_B.size(-1) + W = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) + + if reverse: + eye = torch.eye(r, device=miss_B.device, dtype=torch.float32) + inv_I_plus_miss_B = torch.inverse(eye + miss_B.float()).to(miss_B.dtype) + result = (W - miss_B) @ inv_I_plus_miss_B else: - w = ( - orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) - @ weight_miss - + weight_miss - ) - output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + result = W @ miss_B + miss_B + + output_tensor = result.permute(1, 2, 0, 3).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - self.miss_block[adapter].data = weight_miss.to(dtype) + self.miss_block[adapter].data = miss_B.to(dtype) return output_tensor - def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: + def get_delta_weight_miss(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. + reverse (bool): + If True, reverse the merge (unmerge). If False, apply the merge (forward). """ device = self.miss_block[adapter].device dtype = self.miss_block[adapter].dtype @@ -297,55 +294,39 @@ def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - weight_miss = self.miss_block[adapter] + miss_B = self.miss_block[adapter] if cast_to_fp32: - weight_miss = weight_miss.float() + miss_B = miss_B.float() in_features = orig_weight.size(-1) out_features = orig_weight.size(0) - r = weight_miss.size(0) + r = miss_B.size(0) if self.miss_fn == "mini": - weight_miss = weight_miss.repeat(1, out_features // self.miss_mini_r[adapter]) + miss_B = miss_B.repeat(1, out_features // self.miss_mini_r[adapter]) + + sign = -1 if reverse else 1 if in_features % r != 0: - last_size = in_features % r - n_block = in_features // r - n_block_size = n_block * r - - if re: - orig_weight[:, :n_block_size] = ( - (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_miss) - .permute(2, 0, 1) - .reshape(*orig_weight[:, :n_block_size].shape) - ) - orig_weight[:, n_block_size:] = ( - orig_weight[:, n_block_size:] - (weight_miss.transpose(0, 1))[:, :last_size] - ) - else: - orig_weight[:, :n_block_size] = ( - (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_miss) - .permute(2, 0, 1) - .reshape(*orig_weight[:, :n_block_size].shape) - ) - orig_weight[:, n_block_size:] = ( - orig_weight[:, n_block_size:] + (weight_miss.transpose(0, 1))[:, :last_size] - ) - output_tensor = orig_weight + remainder = in_features % r + n_blocks = in_features // r + aligned_size = n_blocks * r + W_aligned = orig_weight[:, :aligned_size].reshape(-1, n_blocks, r).permute(1, 2, 0) + orig_weight[:, :aligned_size] = ( + (W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape) + ) + orig_weight[:, aligned_size:] = ( + orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder] + ) + output_tensor = orig_weight else: - if re: - w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_miss - output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) - else: - w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_miss - output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) + W_blocks = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + output_tensor = (W_blocks + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - self.miss_block[adapter].data = weight_miss.to(dtype) + self.miss_block[adapter].data = miss_B.to(dtype) return output_tensor @@ -391,8 +372,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: return result def supports_lora_conversion(self, adapter_name: str = "default") -> bool: - # only 'bat' can be converted in a straightforward way - return self.miss_fn == "bat" + return True def __repr__(self) -> str: rep = super().__repr__() diff --git a/tests/test_lora_conversion.py b/tests/test_lora_conversion.py index 0cdbbdca98..f6437892b4 100644 --- a/tests/test_lora_conversion.py +++ b/tests/test_lora_conversion.py @@ -26,6 +26,7 @@ IA3Config, LoKrConfig, LoraConfig, + MissConfig, PeftModel, PrefixTuningConfig, convert_to_lora, @@ -553,3 +554,140 @@ def test_convert_float16_dtype(self, dtype): mse_converted = self.get_mse(output_converted, output_lokr) assert 0.0 < mse_converted < 0.1 + + +class TestMissLoraConversion: + """Test MiSS to LoRA conversion for standard, mini, and bat modes.""" + + model_id = "peft-internal-testing/tiny-random-OPTForCausalLM" + torch_device = infer_device() + base_model = None + + def get_base_model(self): + if self.base_model is None: + with hub_online_once(self.model_id): + self.base_model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.torch_device) + return copy.deepcopy(self.base_model) + + @staticmethod + def get_mse(output1, output2): + return nn.functional.mse_loss(output1.hidden_states[-1], output2.hidden_states[-1]).item() + + def _randomize_miss_blocks(self, model): + with torch.no_grad(): + for m in model.modules(): + if hasattr(m, "miss_block"): + for p in m.miss_block.values(): + p.data.normal_(0, 0.01) + + @pytest.fixture + def miss_model_standard(self): + torch.manual_seed(0) + config = MissConfig(r=4, init_weights=False, target_modules=["q_proj", "v_proj"]) + return get_peft_model(self.get_base_model(), config) + + @pytest.fixture + def miss_model_mini(self): + torch.manual_seed(0) + config = MissConfig(r=4, mini_r=2, init_weights="mini", target_modules=["q_proj", "v_proj"]) + model = get_peft_model(self.get_base_model(), config) + self._randomize_miss_blocks(model) + return model + + @pytest.fixture + def miss_model_bat(self): + torch.manual_seed(0) + config = MissConfig(r=4, init_weights="bat", target_modules=["q_proj", "v_proj"]) + model = get_peft_model(self.get_base_model(), config) + self._randomize_miss_blocks(model) + return model + + def test_miss_supports_lora_conversion(self, miss_model_standard, miss_model_mini, miss_model_bat): + assert miss_model_standard.supports_lora_conversion() + assert miss_model_mini.supports_lora_conversion() + assert miss_model_bat.supports_lora_conversion() + + def test_miss_standard_exact_conversion(self, miss_model_standard): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + output_miss = miss_model_standard(inputs, output_hidden_states=True) + + lora_config, state_dict = convert_to_lora(miss_model_standard, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert mse < 1e-5, f"Standard MiSS conversion should be exact, got mse={mse}" + + def test_miss_mini_exact_conversion(self, miss_model_mini): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + output_miss = miss_model_mini(inputs, output_hidden_states=True) + + lora_config, state_dict = convert_to_lora(miss_model_mini, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert mse < 1e-5, f"Mini MiSS conversion should be exact, got mse={mse}" + + def test_miss_bat_approximate_conversion(self, miss_model_bat): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + with miss_model_bat.disable_adapter(): + output_base = miss_model_bat(inputs, output_hidden_states=True) + output_miss = miss_model_bat(inputs, output_hidden_states=True) + + atol, rtol = 1e-4, 1e-4 + assert not torch.allclose(output_base.logits, output_miss.logits, atol=atol, rtol=rtol) + + lora_config, state_dict = convert_to_lora(miss_model_bat, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert 0.0 < mse < 0.1 + + def test_miss_targeted_modules_identical(self, miss_model_standard): + lora_config, lora_state_dict = convert_to_lora(miss_model_standard, rank=4) + miss_state_dict = miss_model_standard.state_dict() + + modules_miss = {k.rsplit(".", 2)[0] for k in miss_state_dict.keys() if ".miss_block" in k} + modules_lora = {k.rsplit(".", 2)[0] for k in lora_state_dict.keys() if ".lora" in k} + assert modules_miss == modules_lora + + def test_miss_save_as_lora(self, miss_model_standard, tmp_path): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + atol, rtol = 1e-4, 1e-4 + + lora_config, state_dict = convert_to_lora(miss_model_standard, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + set_peft_model_state_dict(lora_model, state_dict) + + with torch.inference_mode(): + output_before = lora_model(inputs).logits + + save_as_lora(tmp_path, miss_model_standard, rank=4) + base_model = self.get_base_model() + loaded_model = PeftModel.from_pretrained(base_model, tmp_path).to(self.torch_device) + + with torch.inference_mode(): + output_after = loaded_model(inputs).logits + + assert torch.allclose(output_before, output_after, atol=atol, rtol=rtol)