Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/peft/tuners/lora/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,76 @@ def _find_cutoff_index(S: torch.Tensor, threshold: float) -> int:
return k + 1


@torch.no_grad()
def _convert_miss_module_to_lora(
module, rank: int | float, adapter_name: str = "default"
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a single MiSS layer to LoRA A and B matrices.

For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r factorization, so the
exact factors are returned directly without SVD.

For bat mode, the delta weight depends on the base weight, so SVD is used.
"""
miss_fn = module.miss_fn
miss_block = module.miss_block[adapter_name]
in_features = module.in_features
out_features = module.out_features
r_miss = module.miss_r[adapter_name]
orig_dtype = miss_block.dtype
device = miss_block.device

if miss_fn == "bat":
base_weight = module.get_base_layer().weight.data.clone()
delta_weight = module.get_delta_weight(adapter_name, base_weight).float()

U, S, V = torch.linalg.svd(delta_weight, full_matrices=False)

if isinstance(rank, int):
effective_rank = rank
else:
effective_rank = _find_cutoff_index(S, threshold=rank)

if effective_rank > U.shape[1]:
raise ValueError(
f"The chosen rank {effective_rank} is larger than the weight shape ({U.shape[1]}), please choose a "
"lower rank."
)

lora_B = U[:, :effective_rank] * S[:effective_rank]
lora_A = V[:effective_rank]
return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), effective_rank

# Standard or mini: exact conversion using the native rank r
miss = miss_block.float()
r = miss.size(0)

if miss_fn == "mini":
mini_r = module.miss_mini_r[adapter_name]
miss = miss.repeat(1, out_features // mini_r)

# lora_A: structured summation matrix, shape (r, in_features)
# lora_A[j, i] = 1 if i % r == j
lora_A = torch.zeros(r, in_features, device=device, dtype=torch.float32)
indices = torch.arange(in_features, device=device)
lora_A[indices % r, indices] = 1.0

# lora_B = miss.T, shape (out_features, r)
lora_B = miss.T

return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), r


@torch.no_grad()
def _convert_module_to_lora(
module: BaseTunerLayer, rank: int | float, adapter_name: str = "default"
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a single BaseTunerLayer's adapter weight to a LoRA weight, return A, B, and the effective rank."""
from peft.tuners.miss.layer import MissLinear

if isinstance(module, MissLinear):
return _convert_miss_module_to_lora(module, rank, adapter_name)

delta_weight = module.get_delta_weight(adapter_name)
# Note: Explore different algorithms (truncated, randomized, ...) to see if they are more efficient

Expand Down
112 changes: 46 additions & 66 deletions src/peft/tuners/miss/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge

Expand Down Expand Up @@ -228,21 +228,23 @@ def unmerge(self) -> None:
if active_adapter in self.miss_block.keys():
orig_weight = self.get_base_layer().weight.data.clone()
if self.miss_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
delta_weight = self.get_delta_weight(active_adapter, orig_weight, reverse=True)
elif self.miss_fn == "mini":
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True)
else:
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True)

base_layer.weight.data = delta_weight.to(orig_dtype)

def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
def get_delta_weight(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
reverse (bool):
If True, reverse the merge (unmerge). If False, apply the merge (forward).
"""
device = self.miss_block[adapter].device
dtype = self.miss_block[adapter].dtype
Expand All @@ -251,44 +253,39 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

weight_miss = self.miss_block[adapter]
miss_B = self.miss_block[adapter]

if cast_to_fp32:
weight_miss = weight_miss.float()
orig_weight = orig_weight.to(weight_miss.dtype)

r = weight_miss.size(-1)
if re:
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
one = torch.eye(weight_miss.size(-1)).to(weight_miss.device)
# inverse must be in float32, after that the dtype can be adjusted if needed
inv_I_plus_b = torch.inverse(one + weight_miss)
inv_I_plus_b = inv_I_plus_b.to(weight_miss.dtype)
w = (o - weight_miss) @ inv_I_plus_b
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
miss_B = miss_B.float()
orig_weight = orig_weight.to(miss_B.dtype)

r = miss_B.size(-1)
W = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)

if reverse:
eye = torch.eye(r, device=miss_B.device, dtype=torch.float32)
inv_I_plus_miss_B = torch.inverse(eye + miss_B.float()).to(miss_B.dtype)
result = (W - miss_B) @ inv_I_plus_miss_B
else:
w = (
orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
@ weight_miss
+ weight_miss
)
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
result = W @ miss_B + miss_B

output_tensor = result.permute(1, 2, 0, 3).reshape(*orig_weight.shape)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.miss_block[adapter].data = weight_miss.to(dtype)
self.miss_block[adapter].data = miss_B.to(dtype)

return output_tensor

def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
def get_delta_weight_miss(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
reverse (bool):
If True, reverse the merge (unmerge). If False, apply the merge (forward).
"""
device = self.miss_block[adapter].device
dtype = self.miss_block[adapter].dtype
Expand All @@ -297,55 +294,39 @@ def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

weight_miss = self.miss_block[adapter]
miss_B = self.miss_block[adapter]

if cast_to_fp32:
weight_miss = weight_miss.float()
miss_B = miss_B.float()

in_features = orig_weight.size(-1)
out_features = orig_weight.size(0)
r = weight_miss.size(0)
r = miss_B.size(0)
if self.miss_fn == "mini":
weight_miss = weight_miss.repeat(1, out_features // self.miss_mini_r[adapter])
miss_B = miss_B.repeat(1, out_features // self.miss_mini_r[adapter])

sign = -1 if reverse else 1

if in_features % r != 0:
last_size = in_features % r
n_block = in_features // r
n_block_size = n_block * r

if re:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_miss)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] - (weight_miss.transpose(0, 1))[:, :last_size]
)
else:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_miss)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] + (weight_miss.transpose(0, 1))[:, :last_size]
)
output_tensor = orig_weight
remainder = in_features % r
n_blocks = in_features // r
aligned_size = n_blocks * r

W_aligned = orig_weight[:, :aligned_size].reshape(-1, n_blocks, r).permute(1, 2, 0)
orig_weight[:, :aligned_size] = (
(W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape)
)
orig_weight[:, aligned_size:] = (
orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder]
)
output_tensor = orig_weight
else:
if re:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_miss
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
else:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_miss
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
W_blocks = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0)
output_tensor = (W_blocks + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight.shape)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.miss_block[adapter].data = weight_miss.to(dtype)
self.miss_block[adapter].data = miss_B.to(dtype)

return output_tensor

Expand Down Expand Up @@ -391,8 +372,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return result

def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
# only 'bat' can be converted in a straightforward way
return self.miss_fn == "bat"
return True

def __repr__(self) -> str:
rep = super().__repr__()
Expand Down
Loading
Loading