From 2cb966e120204e5b92061cd872e7e07e621cf819 Mon Sep 17 00:00:00 2001 From: "J.L" <997529190@qq.com> Date: Mon, 30 Mar 2026 16:46:15 +0800 Subject: [PATCH 01/10] miss update --- docs/source/conceptual_guides/adapter.md | 2 +- examples/miss_finetuning/README.md | 20 +++++++++++--------- examples/miss_finetuning/miss_finetuning.py | 2 ++ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index f9ecee5d1b..f11ec8e596 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -127,7 +127,7 @@ Bone was deprecated and removed in PEFT v0.19.0 in favor of [MiSS](https://huggi ## MiSS [MiSS](https://huggingface.co/papers/2409.15371) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.) -MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing +MiSS: Revisiting the Trade-off in LoRA with an Efficient Shard-Sharing Structure Intuitively, the shape of a single trainable matrix in MiSS is consistent with `lora_B`, so the `r` parameter in MiSS is less than the `r` in LoRA by (`in_feature * r`). diff --git a/examples/miss_finetuning/README.md b/examples/miss_finetuning/README.md index ecfbcdd4fb..72e149468f 100644 --- a/examples/miss_finetuning/README.md +++ b/examples/miss_finetuning/README.md @@ -16,7 +16,8 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") tokenizer.pad_token_id = tokenizer.eos_token_id miss_config = MissConfig( - r = 64 + r = 64, + miss_dropout = 0.01 ) #bat: In this mode, you can enable nonlinear updates across different shards. # miss_config = MissConfig( @@ -69,6 +70,7 @@ python miss_finetuning.py \ --base_model_name_or_path meta-llama/Llama-2-7b-hf \ --output_dir output/miss-llama-2-7b-metamath-10k \ --miss_r 64 \ + --miss_dropout 0.01 \ --init_weights True \ --bits bf16 \ --data_path meta-math/MetaMathQA \ @@ -93,12 +95,12 @@ python miss_finetuning.py \ # Citation ```bib -@misc{kang2025balancingloraperformanceefficiency, - title={Balancing LoRA Performance and Efficiency with Simple Shard Sharing}, - author={Jiale Kang and Qingyu Yin}, - year={2025}, - eprint={2409.15371}, - archivePrefix={arXiv}, - primaryClass={cs.CL}, - url={https://arxiv.org/abs/2409.15371}, +@misc{kang2025missrevisitingtradeofflora, + title={MiSS: Revisiting the Trade-off in LoRA with an Efficient Shard-Sharing Structure}, + author={Jiale Kang and Qingyu Yin}, + year={2025}, + eprint={2409.15371}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2409.15371}, } diff --git a/examples/miss_finetuning/miss_finetuning.py b/examples/miss_finetuning/miss_finetuning.py index 91932a3a5c..c3dfe66c6c 100644 --- a/examples/miss_finetuning/miss_finetuning.py +++ b/examples/miss_finetuning/miss_finetuning.py @@ -40,6 +40,7 @@ class ScriptArguments(SFTConfig): }, ) miss_r: int = field(default=16) + miss_dropout: float = field(default=0.0) merge_and_save: bool = field(default=False) # dataset configs data_path: str = field(default="imdb", metadata={"help": "Path to the training data."}) @@ -70,6 +71,7 @@ class ScriptArguments(SFTConfig): tokenizer.pad_token_id = tokenizer.eos_token_id miss_config = MissConfig( r=script_args.miss_r, + miss_dropout=script_args.miss_dropout, target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], bias="none", task_type="CAUSAL_LM", From 6afce906d8d5d26a284c43b2dd7c3c9f7df62610 Mon Sep 17 00:00:00 2001 From: "J.L" <997529190@qq.com> Date: Mon, 30 Mar 2026 22:10:40 +0800 Subject: [PATCH 02/10] change link --- docs/source/conceptual_guides/adapter.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index f11ec8e596..cbeed3987f 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -125,9 +125,9 @@ The higher `r`, the more trainable parameters, resulting in a larger model capac Bone was deprecated and removed in PEFT v0.19.0 in favor of [MiSS](https://huggingface.co/papers/2409.15371) (new version of paper: "MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing"). If you already have a Bone checkpoint, you can use `/scripts/convert-bone-to-miss.py` to convert it into a MiSS checkpoint and proceed with training using MiSS. ## MiSS -[MiSS](https://huggingface.co/papers/2409.15371) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.) +[MiSS](https://github.com/Joluck/MiSS) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.) -MiSS: Revisiting the Trade-off in LoRA with an Efficient Shard-Sharing Structure +MiSS: Revisiting the Trade-off in LoRA with an Efficient Shard-Sharing Structure Intuitively, the shape of a single trainable matrix in MiSS is consistent with `lora_B`, so the `r` parameter in MiSS is less than the `r` in LoRA by (`in_feature * r`). From c8d47a47770c31fd7e24b5fa8ef140b9a9186122 Mon Sep 17 00:00:00 2001 From: "J.L" <997529190@qq.com> Date: Tue, 31 Mar 2026 17:00:40 +0800 Subject: [PATCH 03/10] 1 --- docs/source/conceptual_guides/adapter.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index cbeed3987f..825df1abac 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -125,7 +125,7 @@ The higher `r`, the more trainable parameters, resulting in a larger model capac Bone was deprecated and removed in PEFT v0.19.0 in favor of [MiSS](https://huggingface.co/papers/2409.15371) (new version of paper: "MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing"). If you already have a Bone checkpoint, you can use `/scripts/convert-bone-to-miss.py` to convert it into a MiSS checkpoint and proceed with training using MiSS. ## MiSS -[MiSS](https://github.com/Joluck/MiSS) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.) +[MiSS](https://github.com/Joluck/MiSS) Matrix Shard Sharing is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.) MiSS: Revisiting the Trade-off in LoRA with an Efficient Shard-Sharing Structure From 63d941fa88cc21e0d9a7781826adea8c6166117d Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Fri, 24 Apr 2026 17:08:09 +0800 Subject: [PATCH 04/10] update --- .../MetaMathQA/default_training_params.json | 2 +- src/peft/tuners/miss/layer.py | 103 +++++++----------- 2 files changed, 41 insertions(+), 64 deletions(-) diff --git a/method_comparison/MetaMathQA/default_training_params.json b/method_comparison/MetaMathQA/default_training_params.json index a10fa49601..c45b3f05cd 100644 --- a/method_comparison/MetaMathQA/default_training_params.json +++ b/method_comparison/MetaMathQA/default_training_params.json @@ -1,5 +1,5 @@ { - "model_id": "meta-llama/Llama-3.2-3B", + "model_id": "unsloth/Llama-3.2-3B", "dtype": "bfloat16", "max_seq_length": 768, "batch_size": 4, diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index bf0c145191..86c21c4ee5 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -228,21 +228,23 @@ def unmerge(self) -> None: if active_adapter in self.miss_block.keys(): orig_weight = self.get_base_layer().weight.data.clone() if self.miss_fn == "bat": - delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight(active_adapter, orig_weight, reverse=True) elif self.miss_fn == "mini": - delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True) else: - delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True) + delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True) base_layer.weight.data = delta_weight.to(orig_dtype) - def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: + def get_delta_weight(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. + reverse (bool): + If True, reverse the merge (unmerge). If False, apply the merge (forward). """ device = self.miss_block[adapter].device dtype = self.miss_block[adapter].dtype @@ -251,44 +253,39 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - weight_miss = self.miss_block[adapter] + miss_B = self.miss_block[adapter] if cast_to_fp32: - weight_miss = weight_miss.float() - orig_weight = orig_weight.to(weight_miss.dtype) - - r = weight_miss.size(-1) - if re: - o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) - one = torch.eye(weight_miss.size(-1)).to(weight_miss.device) - # inverse must be in float32, after that the dtype can be adjusted if needed - inv_I_plus_b = torch.inverse(one + weight_miss) - inv_I_plus_b = inv_I_plus_b.to(weight_miss.dtype) - w = (o - weight_miss) @ inv_I_plus_b - output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + miss_B = miss_B.float() + orig_weight = orig_weight.to(miss_B.dtype) + + r = miss_B.size(-1) + W = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) + + if reverse: + I = torch.eye(r, device=miss_B.device, dtype=torch.float32) + inv_I_plus_miss_B = torch.inverse(I + miss_B.float()).to(miss_B.dtype) + result = (W - miss_B) @ inv_I_plus_miss_B else: - w = ( - orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) - @ weight_miss - + weight_miss - ) - output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + result = W @ miss_B + miss_B + + output_tensor = result.permute(1, 2, 0, 3).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - self.miss_block[adapter].data = weight_miss.to(dtype) + self.miss_block[adapter].data = miss_B.to(dtype) return output_tensor - def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: + def get_delta_weight_miss(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor: """ Compute the delta weight for the given adapter. Args: adapter (str): The name of the adapter for which the delta weight should be computed. + reverse (bool): + If True, reverse the merge (unmerge). If False, apply the merge (forward). """ device = self.miss_block[adapter].device dtype = self.miss_block[adapter].dtype @@ -297,55 +294,35 @@ def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch # (b)float16 because some CPUs have slow bf16/fp16 matmuls. cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - weight_miss = self.miss_block[adapter] + miss_B = self.miss_block[adapter] if cast_to_fp32: - weight_miss = weight_miss.float() + miss_B = miss_B.float() in_features = orig_weight.size(-1) out_features = orig_weight.size(0) - r = weight_miss.size(0) + r = miss_B.size(0) if self.miss_fn == "mini": - weight_miss = weight_miss.repeat(1, out_features // self.miss_mini_r[adapter]) + miss_B = miss_B.repeat(1, out_features // self.miss_mini_r[adapter]) + + sign = -1 if reverse else 1 if in_features % r != 0: - last_size = in_features % r - n_block = in_features // r - n_block_size = n_block * r - - if re: - orig_weight[:, :n_block_size] = ( - (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_miss) - .permute(2, 0, 1) - .reshape(*orig_weight[:, :n_block_size].shape) - ) - orig_weight[:, n_block_size:] = ( - orig_weight[:, n_block_size:] - (weight_miss.transpose(0, 1))[:, :last_size] - ) - else: - orig_weight[:, :n_block_size] = ( - (orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_miss) - .permute(2, 0, 1) - .reshape(*orig_weight[:, :n_block_size].shape) - ) - orig_weight[:, n_block_size:] = ( - orig_weight[:, n_block_size:] + (weight_miss.transpose(0, 1))[:, :last_size] - ) - output_tensor = orig_weight + remainder = in_features % r + n_blocks = in_features // r + aligned_size = n_blocks * r + W_aligned = orig_weight[:, :aligned_size].reshape(-1, n_blocks, r).permute(1, 2, 0) + orig_weight[:, :aligned_size] = (W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape) + orig_weight[:, aligned_size:] = orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder] + output_tensor = orig_weight else: - if re: - w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_miss - output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) - else: - w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_miss - output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape) + W_blocks = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + output_tensor = (W_blocks + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight.shape) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) - - # cast back the weights - self.miss_block[adapter].data = weight_miss.to(dtype) + self.miss_block[adapter].data = miss_B.to(dtype) return output_tensor From 53e0aa821efbb6af99905996103906277c65392f Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Mon, 27 Apr 2026 14:14:25 +0800 Subject: [PATCH 05/10] miss_to_lora --- src/peft/tuners/lora/conversion.py | 65 ++++++++++++++++++++++++++++++ src/peft/tuners/miss/layer.py | 3 +- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/conversion.py b/src/peft/tuners/lora/conversion.py index 570f9b690d..2778f35b9a 100644 --- a/src/peft/tuners/lora/conversion.py +++ b/src/peft/tuners/lora/conversion.py @@ -45,11 +45,76 @@ def _find_cutoff_index(S: torch.Tensor, threshold: float) -> int: return k + 1 +@torch.no_grad() +def _convert_miss_module_to_lora( + module, rank: int | float, adapter_name: str = "default" +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Convert a single MiSS layer to LoRA A and B matrices. + + For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r + factorization, so the exact factors are returned directly without SVD. + + For bat mode, the delta weight depends on the base weight, so SVD is used. + """ + miss_fn = module.miss_fn + miss_block = module.miss_block[adapter_name] + in_features = module.in_features + out_features = module.out_features + r_miss = module.miss_r[adapter_name] + orig_dtype = miss_block.dtype + device = miss_block.device + + if miss_fn == "bat": + base_weight = module.get_base_layer().weight.data.clone() + delta_weight = module.get_delta_weight(adapter_name, base_weight).float() + + U, S, V = torch.linalg.svd(delta_weight, full_matrices=False) + + if isinstance(rank, int): + effective_rank = rank + else: + effective_rank = _find_cutoff_index(S, threshold=rank) + + if effective_rank > U.shape[1]: + raise ValueError( + f"The chosen rank {effective_rank} is larger than the weight shape ({U.shape[1]}), please choose a " + "lower rank." + ) + + lora_B = U[:, :effective_rank] * S[:effective_rank] + lora_A = V[:effective_rank] + return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), effective_rank + + # Standard or mini: exact conversion using the native rank r + miss = miss_block.float() + r = miss.size(0) + + if miss_fn == "mini": + mini_r = module.miss_mini_r[adapter_name] + miss = miss.repeat(1, out_features // mini_r) + + # lora_A: structured summation matrix, shape (r, in_features) + # lora_A[j, i] = 1 if i % r == j + lora_A = torch.zeros(r, in_features, device=device, dtype=torch.float32) + indices = torch.arange(in_features, device=device) + lora_A[indices % r, indices] = 1.0 + + # lora_B = miss.T, shape (out_features, r) + lora_B = miss.T + + return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), r + + @torch.no_grad() def _convert_module_to_lora( module: BaseTunerLayer, rank: int | float, adapter_name: str = "default" ) -> tuple[torch.Tensor, torch.Tensor, int]: """Convert a single BaseTunerLayer's adapter weight to a LoRA weight, return A, B, and the effective rank.""" + from peft.tuners.miss.layer import MissLinear + + if isinstance(module, MissLinear): + return _convert_miss_module_to_lora(module, rank, adapter_name) + delta_weight = module.get_delta_weight(adapter_name) # Note: Explore different algorithms (truncated, randomized, ...) to see if they are more efficient diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index 86c21c4ee5..fc6f788304 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -368,8 +368,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: return result def supports_lora_conversion(self, adapter_name: str = "default") -> bool: - # only 'bat' can be converted in a straightforward way - return self.miss_fn == "bat" + return True def __repr__(self) -> str: rep = super().__repr__() From d399714424618224e632ddb9f8391ae1190a4475 Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Mon, 27 Apr 2026 15:04:37 +0800 Subject: [PATCH 06/10] origin --- method_comparison/MetaMathQA/default_training_params.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/method_comparison/MetaMathQA/default_training_params.json b/method_comparison/MetaMathQA/default_training_params.json index c45b3f05cd..a10fa49601 100644 --- a/method_comparison/MetaMathQA/default_training_params.json +++ b/method_comparison/MetaMathQA/default_training_params.json @@ -1,5 +1,5 @@ { - "model_id": "unsloth/Llama-3.2-3B", + "model_id": "meta-llama/Llama-3.2-3B", "dtype": "bfloat16", "max_seq_length": 768, "batch_size": 4, From c618c2793b384a4fbd563971c5509bdae1a4ba3b Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Tue, 28 Apr 2026 16:33:42 +0800 Subject: [PATCH 07/10] test unit --- tests/test_lora_conversion.py | 138 ++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/tests/test_lora_conversion.py b/tests/test_lora_conversion.py index 0cdbbdca98..f6437892b4 100644 --- a/tests/test_lora_conversion.py +++ b/tests/test_lora_conversion.py @@ -26,6 +26,7 @@ IA3Config, LoKrConfig, LoraConfig, + MissConfig, PeftModel, PrefixTuningConfig, convert_to_lora, @@ -553,3 +554,140 @@ def test_convert_float16_dtype(self, dtype): mse_converted = self.get_mse(output_converted, output_lokr) assert 0.0 < mse_converted < 0.1 + + +class TestMissLoraConversion: + """Test MiSS to LoRA conversion for standard, mini, and bat modes.""" + + model_id = "peft-internal-testing/tiny-random-OPTForCausalLM" + torch_device = infer_device() + base_model = None + + def get_base_model(self): + if self.base_model is None: + with hub_online_once(self.model_id): + self.base_model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.torch_device) + return copy.deepcopy(self.base_model) + + @staticmethod + def get_mse(output1, output2): + return nn.functional.mse_loss(output1.hidden_states[-1], output2.hidden_states[-1]).item() + + def _randomize_miss_blocks(self, model): + with torch.no_grad(): + for m in model.modules(): + if hasattr(m, "miss_block"): + for p in m.miss_block.values(): + p.data.normal_(0, 0.01) + + @pytest.fixture + def miss_model_standard(self): + torch.manual_seed(0) + config = MissConfig(r=4, init_weights=False, target_modules=["q_proj", "v_proj"]) + return get_peft_model(self.get_base_model(), config) + + @pytest.fixture + def miss_model_mini(self): + torch.manual_seed(0) + config = MissConfig(r=4, mini_r=2, init_weights="mini", target_modules=["q_proj", "v_proj"]) + model = get_peft_model(self.get_base_model(), config) + self._randomize_miss_blocks(model) + return model + + @pytest.fixture + def miss_model_bat(self): + torch.manual_seed(0) + config = MissConfig(r=4, init_weights="bat", target_modules=["q_proj", "v_proj"]) + model = get_peft_model(self.get_base_model(), config) + self._randomize_miss_blocks(model) + return model + + def test_miss_supports_lora_conversion(self, miss_model_standard, miss_model_mini, miss_model_bat): + assert miss_model_standard.supports_lora_conversion() + assert miss_model_mini.supports_lora_conversion() + assert miss_model_bat.supports_lora_conversion() + + def test_miss_standard_exact_conversion(self, miss_model_standard): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + output_miss = miss_model_standard(inputs, output_hidden_states=True) + + lora_config, state_dict = convert_to_lora(miss_model_standard, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert mse < 1e-5, f"Standard MiSS conversion should be exact, got mse={mse}" + + def test_miss_mini_exact_conversion(self, miss_model_mini): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + output_miss = miss_model_mini(inputs, output_hidden_states=True) + + lora_config, state_dict = convert_to_lora(miss_model_mini, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert mse < 1e-5, f"Mini MiSS conversion should be exact, got mse={mse}" + + def test_miss_bat_approximate_conversion(self, miss_model_bat): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + with torch.inference_mode(): + with miss_model_bat.disable_adapter(): + output_base = miss_model_bat(inputs, output_hidden_states=True) + output_miss = miss_model_bat(inputs, output_hidden_states=True) + + atol, rtol = 1e-4, 1e-4 + assert not torch.allclose(output_base.logits, output_miss.logits, atol=atol, rtol=rtol) + + lora_config, state_dict = convert_to_lora(miss_model_bat, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + load_result = set_peft_model_state_dict(lora_model, state_dict) + assert not load_result.unexpected_keys + + with torch.inference_mode(): + output_lora = lora_model(inputs, output_hidden_states=True) + + mse = self.get_mse(output_lora, output_miss) + assert 0.0 < mse < 0.1 + + def test_miss_targeted_modules_identical(self, miss_model_standard): + lora_config, lora_state_dict = convert_to_lora(miss_model_standard, rank=4) + miss_state_dict = miss_model_standard.state_dict() + + modules_miss = {k.rsplit(".", 2)[0] for k in miss_state_dict.keys() if ".miss_block" in k} + modules_lora = {k.rsplit(".", 2)[0] for k in lora_state_dict.keys() if ".lora" in k} + assert modules_miss == modules_lora + + def test_miss_save_as_lora(self, miss_model_standard, tmp_path): + inputs = torch.arange(10).view(1, -1).to(self.torch_device) + atol, rtol = 1e-4, 1e-4 + + lora_config, state_dict = convert_to_lora(miss_model_standard, rank=4) + base_model = self.get_base_model() + lora_model = get_peft_model(base_model, lora_config).eval() + set_peft_model_state_dict(lora_model, state_dict) + + with torch.inference_mode(): + output_before = lora_model(inputs).logits + + save_as_lora(tmp_path, miss_model_standard, rank=4) + base_model = self.get_base_model() + loaded_model = PeftModel.from_pretrained(base_model, tmp_path).to(self.torch_device) + + with torch.inference_mode(): + output_after = loaded_model(inputs).logits + + assert torch.allclose(output_before, output_after, atol=atol, rtol=rtol) From 4785f974ae40324434a9b96151e17ffdb1cfaf92 Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Wed, 29 Apr 2026 18:58:54 +0800 Subject: [PATCH 08/10] I->eye --- src/peft/tuners/miss/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index fc6f788304..d01461ffb6 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -263,8 +263,8 @@ def get_delta_weight(self, adapter, orig_weight, reverse: bool = False) -> torch W = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) if reverse: - I = torch.eye(r, device=miss_B.device, dtype=torch.float32) - inv_I_plus_miss_B = torch.inverse(I + miss_B.float()).to(miss_B.dtype) + eye = torch.eye(r, device=miss_B.device, dtype=torch.float32) + inv_I_plus_miss_B = torch.inverse(eye + miss_B.float()).to(miss_B.dtype) result = (W - miss_B) @ inv_I_plus_miss_B else: result = W @ miss_B + miss_B From 8f7171b28ef58addd4e088eb7cecd10de474167f Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Thu, 30 Apr 2026 14:38:44 +0800 Subject: [PATCH 09/10] make style --- src/peft/tuners/miss/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index d01461ffb6..6deb4b772a 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -18,8 +18,8 @@ from typing import Any, Optional import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge From 27a9f7d12a93f65fbf87ddaaad8508b9590e9f17 Mon Sep 17 00:00:00 2001 From: Joluck <997529190@qq.com> Date: Fri, 1 May 2026 22:03:58 +0800 Subject: [PATCH 10/10] fix --- src/peft/tuners/lora/conversion.py | 4 ++-- src/peft/tuners/miss/layer.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora/conversion.py b/src/peft/tuners/lora/conversion.py index 2778f35b9a..acbcfac3dc 100644 --- a/src/peft/tuners/lora/conversion.py +++ b/src/peft/tuners/lora/conversion.py @@ -51,8 +51,8 @@ def _convert_miss_module_to_lora( ) -> tuple[torch.Tensor, torch.Tensor, int]: """Convert a single MiSS layer to LoRA A and B matrices. - For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r - factorization, so the exact factors are returned directly without SVD. + For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r factorization, so the + exact factors are returned directly without SVD. For bat mode, the delta weight depends on the base weight, so SVD is used. """ diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index 6deb4b772a..14a733ccc3 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -313,8 +313,12 @@ def get_delta_weight_miss(self, adapter, orig_weight, reverse: bool = False) -> aligned_size = n_blocks * r W_aligned = orig_weight[:, :aligned_size].reshape(-1, n_blocks, r).permute(1, 2, 0) - orig_weight[:, :aligned_size] = (W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape) - orig_weight[:, aligned_size:] = orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder] + orig_weight[:, :aligned_size] = ( + (W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape) + ) + orig_weight[:, aligned_size:] = ( + orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder] + ) output_tensor = orig_weight else: W_blocks = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0)