Skip to content

Commit 4050ef5

Browse files
authored
ENH Improve MiSS code, add LoRA conversion (#3194)
Improve readability of MiSS code. Add MiSS to LoRA conversion code, some of which is exact conversion.
1 parent dc2e5b2 commit 4050ef5

3 files changed

Lines changed: 249 additions & 66 deletions

File tree

src/peft/tuners/lora/conversion.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,76 @@ def _find_cutoff_index(S: torch.Tensor, threshold: float) -> int:
4545
return k + 1
4646

4747

48+
@torch.no_grad()
49+
def _convert_miss_module_to_lora(
50+
module, rank: int | float, adapter_name: str = "default"
51+
) -> tuple[torch.Tensor, torch.Tensor, int]:
52+
"""Convert a single MiSS layer to LoRA A and B matrices.
53+
54+
For standard and mini modes, the MiSS forward pass (reshape+sum @ miss) is already a rank-r factorization, so the
55+
exact factors are returned directly without SVD.
56+
57+
For bat mode, the delta weight depends on the base weight, so SVD is used.
58+
"""
59+
miss_fn = module.miss_fn
60+
miss_block = module.miss_block[adapter_name]
61+
in_features = module.in_features
62+
out_features = module.out_features
63+
r_miss = module.miss_r[adapter_name]
64+
orig_dtype = miss_block.dtype
65+
device = miss_block.device
66+
67+
if miss_fn == "bat":
68+
base_weight = module.get_base_layer().weight.data.clone()
69+
delta_weight = module.get_delta_weight(adapter_name, base_weight).float()
70+
71+
U, S, V = torch.linalg.svd(delta_weight, full_matrices=False)
72+
73+
if isinstance(rank, int):
74+
effective_rank = rank
75+
else:
76+
effective_rank = _find_cutoff_index(S, threshold=rank)
77+
78+
if effective_rank > U.shape[1]:
79+
raise ValueError(
80+
f"The chosen rank {effective_rank} is larger than the weight shape ({U.shape[1]}), please choose a "
81+
"lower rank."
82+
)
83+
84+
lora_B = U[:, :effective_rank] * S[:effective_rank]
85+
lora_A = V[:effective_rank]
86+
return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), effective_rank
87+
88+
# Standard or mini: exact conversion using the native rank r
89+
miss = miss_block.float()
90+
r = miss.size(0)
91+
92+
if miss_fn == "mini":
93+
mini_r = module.miss_mini_r[adapter_name]
94+
miss = miss.repeat(1, out_features // mini_r)
95+
96+
# lora_A: structured summation matrix, shape (r, in_features)
97+
# lora_A[j, i] = 1 if i % r == j
98+
lora_A = torch.zeros(r, in_features, device=device, dtype=torch.float32)
99+
indices = torch.arange(in_features, device=device)
100+
lora_A[indices % r, indices] = 1.0
101+
102+
# lora_B = miss.T, shape (out_features, r)
103+
lora_B = miss.T
104+
105+
return lora_A.to(orig_dtype).contiguous(), lora_B.to(orig_dtype).contiguous(), r
106+
107+
48108
@torch.no_grad()
49109
def _convert_module_to_lora(
50110
module: BaseTunerLayer, rank: int | float, adapter_name: str = "default"
51111
) -> tuple[torch.Tensor, torch.Tensor, int]:
52112
"""Convert a single BaseTunerLayer's adapter weight to a LoRA weight, return A, B, and the effective rank."""
113+
from peft.tuners.miss.layer import MissLinear
114+
115+
if isinstance(module, MissLinear):
116+
return _convert_miss_module_to_lora(module, rank, adapter_name)
117+
53118
delta_weight = module.get_delta_weight(adapter_name)
54119
# Note: Explore different algorithms (truncated, randomized, ...) to see if they are more efficient
55120

src/peft/tuners/miss/layer.py

Lines changed: 46 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Any, Optional
1919

2020
import torch
21-
import torch.nn as nn
2221
import torch.nn.functional as F
22+
from torch import nn
2323

2424
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
2525

@@ -228,21 +228,23 @@ def unmerge(self) -> None:
228228
if active_adapter in self.miss_block.keys():
229229
orig_weight = self.get_base_layer().weight.data.clone()
230230
if self.miss_fn == "bat":
231-
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
231+
delta_weight = self.get_delta_weight(active_adapter, orig_weight, reverse=True)
232232
elif self.miss_fn == "mini":
233-
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
233+
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True)
234234
else:
235-
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
235+
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, reverse=True)
236236

237237
base_layer.weight.data = delta_weight.to(orig_dtype)
238238

239-
def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
239+
def get_delta_weight(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor:
240240
"""
241241
Compute the delta weight for the given adapter.
242242
243243
Args:
244244
adapter (str):
245245
The name of the adapter for which the delta weight should be computed.
246+
reverse (bool):
247+
If True, reverse the merge (unmerge). If False, apply the merge (forward).
246248
"""
247249
device = self.miss_block[adapter].device
248250
dtype = self.miss_block[adapter].dtype
@@ -251,44 +253,39 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens
251253
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
252254
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
253255

254-
weight_miss = self.miss_block[adapter]
256+
miss_B = self.miss_block[adapter]
255257

256258
if cast_to_fp32:
257-
weight_miss = weight_miss.float()
258-
orig_weight = orig_weight.to(weight_miss.dtype)
259-
260-
r = weight_miss.size(-1)
261-
if re:
262-
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
263-
one = torch.eye(weight_miss.size(-1)).to(weight_miss.device)
264-
# inverse must be in float32, after that the dtype can be adjusted if needed
265-
inv_I_plus_b = torch.inverse(one + weight_miss)
266-
inv_I_plus_b = inv_I_plus_b.to(weight_miss.dtype)
267-
w = (o - weight_miss) @ inv_I_plus_b
268-
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
259+
miss_B = miss_B.float()
260+
orig_weight = orig_weight.to(miss_B.dtype)
261+
262+
r = miss_B.size(-1)
263+
W = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
264+
265+
if reverse:
266+
eye = torch.eye(r, device=miss_B.device, dtype=torch.float32)
267+
inv_I_plus_miss_B = torch.inverse(eye + miss_B.float()).to(miss_B.dtype)
268+
result = (W - miss_B) @ inv_I_plus_miss_B
269269
else:
270-
w = (
271-
orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
272-
@ weight_miss
273-
+ weight_miss
274-
)
275-
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
270+
result = W @ miss_B + miss_B
271+
272+
output_tensor = result.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
276273

277274
if cast_to_fp32:
278275
output_tensor = output_tensor.to(dtype=dtype)
279-
280-
# cast back the weights
281-
self.miss_block[adapter].data = weight_miss.to(dtype)
276+
self.miss_block[adapter].data = miss_B.to(dtype)
282277

283278
return output_tensor
284279

285-
def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
280+
def get_delta_weight_miss(self, adapter, orig_weight, reverse: bool = False) -> torch.Tensor:
286281
"""
287282
Compute the delta weight for the given adapter.
288283
289284
Args:
290285
adapter (str):
291286
The name of the adapter for which the delta weight should be computed.
287+
reverse (bool):
288+
If True, reverse the merge (unmerge). If False, apply the merge (forward).
292289
"""
293290
device = self.miss_block[adapter].device
294291
dtype = self.miss_block[adapter].dtype
@@ -297,55 +294,39 @@ def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch
297294
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
298295
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
299296

300-
weight_miss = self.miss_block[adapter]
297+
miss_B = self.miss_block[adapter]
301298

302299
if cast_to_fp32:
303-
weight_miss = weight_miss.float()
300+
miss_B = miss_B.float()
304301

305302
in_features = orig_weight.size(-1)
306303
out_features = orig_weight.size(0)
307-
r = weight_miss.size(0)
304+
r = miss_B.size(0)
308305
if self.miss_fn == "mini":
309-
weight_miss = weight_miss.repeat(1, out_features // self.miss_mini_r[adapter])
306+
miss_B = miss_B.repeat(1, out_features // self.miss_mini_r[adapter])
307+
308+
sign = -1 if reverse else 1
310309

311310
if in_features % r != 0:
312-
last_size = in_features % r
313-
n_block = in_features // r
314-
n_block_size = n_block * r
315-
316-
if re:
317-
orig_weight[:, :n_block_size] = (
318-
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_miss)
319-
.permute(2, 0, 1)
320-
.reshape(*orig_weight[:, :n_block_size].shape)
321-
)
322-
orig_weight[:, n_block_size:] = (
323-
orig_weight[:, n_block_size:] - (weight_miss.transpose(0, 1))[:, :last_size]
324-
)
325-
else:
326-
orig_weight[:, :n_block_size] = (
327-
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_miss)
328-
.permute(2, 0, 1)
329-
.reshape(*orig_weight[:, :n_block_size].shape)
330-
)
331-
orig_weight[:, n_block_size:] = (
332-
orig_weight[:, n_block_size:] + (weight_miss.transpose(0, 1))[:, :last_size]
333-
)
334-
output_tensor = orig_weight
311+
remainder = in_features % r
312+
n_blocks = in_features // r
313+
aligned_size = n_blocks * r
335314

315+
W_aligned = orig_weight[:, :aligned_size].reshape(-1, n_blocks, r).permute(1, 2, 0)
316+
orig_weight[:, :aligned_size] = (
317+
(W_aligned + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight[:, :aligned_size].shape)
318+
)
319+
orig_weight[:, aligned_size:] = (
320+
orig_weight[:, aligned_size:] + sign * miss_B.transpose(0, 1)[:, :remainder]
321+
)
322+
output_tensor = orig_weight
336323
else:
337-
if re:
338-
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_miss
339-
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
340-
else:
341-
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_miss
342-
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
324+
W_blocks = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0)
325+
output_tensor = (W_blocks + sign * miss_B).permute(2, 0, 1).reshape(*orig_weight.shape)
343326

344327
if cast_to_fp32:
345328
output_tensor = output_tensor.to(dtype=dtype)
346-
347-
# cast back the weights
348-
self.miss_block[adapter].data = weight_miss.to(dtype)
329+
self.miss_block[adapter].data = miss_B.to(dtype)
349330

350331
return output_tensor
351332

@@ -391,8 +372,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
391372
return result
392373

393374
def supports_lora_conversion(self, adapter_name: str = "default") -> bool:
394-
# only 'bat' can be converted in a straightforward way
395-
return self.miss_fn == "bat"
375+
return True
396376

397377
def __repr__(self) -> str:
398378
rep = super().__repr__()

0 commit comments

Comments
 (0)