1818from typing import Any , Optional
1919
2020import torch
21- import torch .nn as nn
2221import torch .nn .functional as F
22+ from torch import nn
2323
2424from peft .tuners .tuners_utils import BaseTunerLayer , check_adapters_to_merge
2525
@@ -228,21 +228,23 @@ def unmerge(self) -> None:
228228 if active_adapter in self .miss_block .keys ():
229229 orig_weight = self .get_base_layer ().weight .data .clone ()
230230 if self .miss_fn == "bat" :
231- delta_weight = self .get_delta_weight (active_adapter , orig_weight , re = True )
231+ delta_weight = self .get_delta_weight (active_adapter , orig_weight , reverse = True )
232232 elif self .miss_fn == "mini" :
233- delta_weight = self .get_delta_weight_miss (active_adapter , orig_weight , re = True )
233+ delta_weight = self .get_delta_weight_miss (active_adapter , orig_weight , reverse = True )
234234 else :
235- delta_weight = self .get_delta_weight_miss (active_adapter , orig_weight , re = True )
235+ delta_weight = self .get_delta_weight_miss (active_adapter , orig_weight , reverse = True )
236236
237237 base_layer .weight .data = delta_weight .to (orig_dtype )
238238
239- def get_delta_weight (self , adapter , orig_weight , re : bool = False ) -> torch .Tensor :
239+ def get_delta_weight (self , adapter , orig_weight , reverse : bool = False ) -> torch .Tensor :
240240 """
241241 Compute the delta weight for the given adapter.
242242
243243 Args:
244244 adapter (str):
245245 The name of the adapter for which the delta weight should be computed.
246+ reverse (bool):
247+ If True, reverse the merge (unmerge). If False, apply the merge (forward).
246248 """
247249 device = self .miss_block [adapter ].device
248250 dtype = self .miss_block [adapter ].dtype
@@ -251,44 +253,39 @@ def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tens
251253 # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
252254 cast_to_fp32 = device .type == "cpu" and (dtype == torch .float16 or dtype == torch .bfloat16 )
253255
254- weight_miss = self .miss_block [adapter ]
256+ miss_B = self .miss_block [adapter ]
255257
256258 if cast_to_fp32 :
257- weight_miss = weight_miss .float ()
258- orig_weight = orig_weight .to (weight_miss .dtype )
259-
260- r = weight_miss .size (- 1 )
261- if re :
262- o = orig_weight .reshape (orig_weight .size (0 ) // r , r , orig_weight .size (1 ) // r , r ).permute (2 , 0 , 1 , 3 )
263- one = torch .eye (weight_miss .size (- 1 )).to (weight_miss .device )
264- # inverse must be in float32, after that the dtype can be adjusted if needed
265- inv_I_plus_b = torch .inverse (one + weight_miss )
266- inv_I_plus_b = inv_I_plus_b .to (weight_miss .dtype )
267- w = (o - weight_miss ) @ inv_I_plus_b
268- output_tensor = w .permute (1 , 2 , 0 , 3 ).reshape (* orig_weight .shape )
259+ miss_B = miss_B .float ()
260+ orig_weight = orig_weight .to (miss_B .dtype )
261+
262+ r = miss_B .size (- 1 )
263+ W = orig_weight .reshape (orig_weight .size (0 ) // r , r , orig_weight .size (1 ) // r , r ).permute (2 , 0 , 1 , 3 )
264+
265+ if reverse :
266+ eye = torch .eye (r , device = miss_B .device , dtype = torch .float32 )
267+ inv_I_plus_miss_B = torch .inverse (eye + miss_B .float ()).to (miss_B .dtype )
268+ result = (W - miss_B ) @ inv_I_plus_miss_B
269269 else :
270- w = (
271- orig_weight .reshape (orig_weight .size (0 ) // r , r , orig_weight .size (1 ) // r , r ).permute (2 , 0 , 1 , 3 )
272- @ weight_miss
273- + weight_miss
274- )
275- output_tensor = w .permute (1 , 2 , 0 , 3 ).reshape (* orig_weight .shape )
270+ result = W @ miss_B + miss_B
271+
272+ output_tensor = result .permute (1 , 2 , 0 , 3 ).reshape (* orig_weight .shape )
276273
277274 if cast_to_fp32 :
278275 output_tensor = output_tensor .to (dtype = dtype )
279-
280- # cast back the weights
281- self .miss_block [adapter ].data = weight_miss .to (dtype )
276+ self .miss_block [adapter ].data = miss_B .to (dtype )
282277
283278 return output_tensor
284279
285- def get_delta_weight_miss (self , adapter , orig_weight , re : bool = False ) -> torch .Tensor :
280+ def get_delta_weight_miss (self , adapter , orig_weight , reverse : bool = False ) -> torch .Tensor :
286281 """
287282 Compute the delta weight for the given adapter.
288283
289284 Args:
290285 adapter (str):
291286 The name of the adapter for which the delta weight should be computed.
287+ reverse (bool):
288+ If True, reverse the merge (unmerge). If False, apply the merge (forward).
292289 """
293290 device = self .miss_block [adapter ].device
294291 dtype = self .miss_block [adapter ].dtype
@@ -297,55 +294,39 @@ def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch
297294 # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
298295 cast_to_fp32 = device .type == "cpu" and (dtype == torch .float16 or dtype == torch .bfloat16 )
299296
300- weight_miss = self .miss_block [adapter ]
297+ miss_B = self .miss_block [adapter ]
301298
302299 if cast_to_fp32 :
303- weight_miss = weight_miss .float ()
300+ miss_B = miss_B .float ()
304301
305302 in_features = orig_weight .size (- 1 )
306303 out_features = orig_weight .size (0 )
307- r = weight_miss .size (0 )
304+ r = miss_B .size (0 )
308305 if self .miss_fn == "mini" :
309- weight_miss = weight_miss .repeat (1 , out_features // self .miss_mini_r [adapter ])
306+ miss_B = miss_B .repeat (1 , out_features // self .miss_mini_r [adapter ])
307+
308+ sign = - 1 if reverse else 1
310309
311310 if in_features % r != 0 :
312- last_size = in_features % r
313- n_block = in_features // r
314- n_block_size = n_block * r
315-
316- if re :
317- orig_weight [:, :n_block_size ] = (
318- (orig_weight [:, :n_block_size ].reshape (- 1 , n_block , r ).permute (1 , 2 , 0 ) - weight_miss )
319- .permute (2 , 0 , 1 )
320- .reshape (* orig_weight [:, :n_block_size ].shape )
321- )
322- orig_weight [:, n_block_size :] = (
323- orig_weight [:, n_block_size :] - (weight_miss .transpose (0 , 1 ))[:, :last_size ]
324- )
325- else :
326- orig_weight [:, :n_block_size ] = (
327- (orig_weight [:, :n_block_size ].reshape (- 1 , n_block , r ).permute (1 , 2 , 0 ) + weight_miss )
328- .permute (2 , 0 , 1 )
329- .reshape (* orig_weight [:, :n_block_size ].shape )
330- )
331- orig_weight [:, n_block_size :] = (
332- orig_weight [:, n_block_size :] + (weight_miss .transpose (0 , 1 ))[:, :last_size ]
333- )
334- output_tensor = orig_weight
311+ remainder = in_features % r
312+ n_blocks = in_features // r
313+ aligned_size = n_blocks * r
335314
315+ W_aligned = orig_weight [:, :aligned_size ].reshape (- 1 , n_blocks , r ).permute (1 , 2 , 0 )
316+ orig_weight [:, :aligned_size ] = (
317+ (W_aligned + sign * miss_B ).permute (2 , 0 , 1 ).reshape (* orig_weight [:, :aligned_size ].shape )
318+ )
319+ orig_weight [:, aligned_size :] = (
320+ orig_weight [:, aligned_size :] + sign * miss_B .transpose (0 , 1 )[:, :remainder ]
321+ )
322+ output_tensor = orig_weight
336323 else :
337- if re :
338- w = orig_weight .reshape (- 1 , orig_weight .size (1 ) // r , r ).permute (1 , 2 , 0 ) - weight_miss
339- output_tensor = w .permute (2 , 0 , 1 ).reshape (* orig_weight .shape )
340- else :
341- w = orig_weight .reshape (- 1 , orig_weight .size (1 ) // r , r ).permute (1 , 2 , 0 ) + weight_miss
342- output_tensor = w .permute (2 , 0 , 1 ).reshape (* orig_weight .shape )
324+ W_blocks = orig_weight .reshape (- 1 , orig_weight .size (1 ) // r , r ).permute (1 , 2 , 0 )
325+ output_tensor = (W_blocks + sign * miss_B ).permute (2 , 0 , 1 ).reshape (* orig_weight .shape )
343326
344327 if cast_to_fp32 :
345328 output_tensor = output_tensor .to (dtype = dtype )
346-
347- # cast back the weights
348- self .miss_block [adapter ].data = weight_miss .to (dtype )
329+ self .miss_block [adapter ].data = miss_B .to (dtype )
349330
350331 return output_tensor
351332
@@ -391,8 +372,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
391372 return result
392373
393374 def supports_lora_conversion (self , adapter_name : str = "default" ) -> bool :
394- # only 'bat' can be converted in a straightforward way
395- return self .miss_fn == "bat"
375+ return True
396376
397377 def __repr__ (self ) -> str :
398378 rep = super ().__repr__ ()
0 commit comments