@@ -63,6 +63,8 @@ def __init__(
6363 rslora : bool = False ,
6464 lora_plus_scale : float = 1.0 ,
6565 pissa : bool = False ,
66+ nola : bool = False ,
67+ nola_basis_num : int = 1 ,
6668 lora_use_mixer : bool = False ,
6769 mixer_num : int = 1 ,
6870 use_mora : bool = False ,
@@ -85,6 +87,8 @@ def __init__(
8587 # Mark the weight as unmerged
8688 self .merged = False
8789 self .pissa = pissa
90+ self .nola = nola
91+ self .nola_basis_num = nola_basis_num
8892 self .lora_use_mixer = lora_use_mixer
8993 self .mixer_num = mixer_num
9094 self .lorapro = lorapro
@@ -144,6 +148,32 @@ def __init__(
144148 ),
145149 )
146150 self .apply_pissa = False
151+ if nola :
152+ # Initialize placeholders for NOLA parameters
153+ self .nola_basis_A = self .create_parameter (
154+ shape = [nola_basis_num , in_features , r ],
155+ dtype = self ._dtype ,
156+ is_bias = False ,
157+ )
158+ self .nola_basis_A .stop_gradient = True
159+ self .nola_basis_B = self .create_parameter (
160+ shape = [nola_basis_num , r , out_features ],
161+ dtype = self ._dtype ,
162+ is_bias = False ,
163+ )
164+ self .nola_basis_B .stop_gradient = True
165+ self .nola_alpha = self .create_parameter (
166+ shape = [nola_basis_num ],
167+ dtype = self ._dtype ,
168+ is_bias = False ,
169+ default_initializer = nn .initializer .Constant (value = 0.0 ),
170+ )
171+ self .nola_beta = self .create_parameter (
172+ shape = [nola_basis_num ],
173+ dtype = self ._dtype ,
174+ is_bias = False ,
175+ default_initializer = nn .initializer .Constant (value = 0.0 ),
176+ )
147177 if use_mora or pissa :
148178 self .scaling = 1.0
149179 elif not rslora :
@@ -179,6 +209,16 @@ def pissa_init(self, rank):
179209 weight = res .astype (dtype )
180210 self .weight .set_value (weight )
181211
212+ def get_nola_lora_matrices (self ):
213+ """Compute LoRA matrices A and B from NOLA basis and coefficients."""
214+ if not self .nola :
215+ return self .lora_A , self .lora_B
216+ # Compute A = sum(alpha_i * A_i)
217+ lora_A = paddle .einsum ("k,kir->ir" , self .nola_alpha , self .nola_basis_A ) # [in_features, r]
218+ # Compute B = sum(beta_j * B_j)
219+ lora_B = paddle .einsum ("k,kro->ro" , self .nola_beta , self .nola_basis_B ) # [r, out_features]
220+ return lora_A , lora_B
221+
182222 def rope_init (self ):
183223 if self .cos is None or self .sin is None :
184224 inv_freq = 1.0 / (10000 ** (paddle .arange (0 , self .r , 2 , dtype = paddle .float32 ) / self .r ))
@@ -257,6 +297,9 @@ def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
257297 w = w [: self .out_features ]
258298 final_weight = w
259299 delta_weight = final_weight .T
300+ elif self .nola :
301+ lora_A , lora_B = self .get_nola_lora_matrices ()
302+ delta_weight = lora_A @ lora_B * self .scaling
260303 else :
261304 lora_A = lora_A if lora_A is not None else self .lora_A
262305 lora_B = lora_B if lora_B is not None else self .lora_B
@@ -299,6 +342,11 @@ def forward(self, input: paddle.Tensor, *args, **kwargs):
299342 input = self .lora_dropout (input )
300343 mora_out = self ._apply_mora (input )
301344 result += mora_out
345+ elif self .nola :
346+ result = F .linear (x = input , weight = self .weight , bias = self .bias , name = self .name )
347+ input = self .lora_dropout (input )
348+ lora_A , lora_B = self .get_nola_lora_matrices ()
349+ result += (self .lora_dropout (input ) @ lora_A @ lora_B ) * self .scaling
302350 else :
303351 result = F .linear (x = input , weight = self .weight , bias = self .bias , name = self .name )
304352 if self .lora_use_mixer :
@@ -327,14 +375,16 @@ def __init__(
327375 use_quick_lora : bool = False ,
328376 pissa : bool = False ,
329377 use_mora : bool = False ,
378+ nola : bool = False ,
379+ nola_basis_num : int = 1 ,
330380 ** kwargs
331381 ):
332382 RowParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
333383 if not isinstance (r , int ) or r <= 0 :
334384 raise ValueError ("Lora rank r should be a positive integer" )
335385
336- if pissa or use_mora :
337- raise ValueError ("Pissa or Mora is not supported in model parallel by now" )
386+ if pissa or use_mora or nola :
387+ raise ValueError ("Pissa, Mora or NoLA is not supported in model parallel by now" )
338388
339389 self .r = r
340390 self .lora_alpha = lora_alpha
@@ -593,14 +643,16 @@ def __init__(
593643 use_quick_lora : bool = False ,
594644 pissa : bool = False ,
595645 use_mora : bool = False ,
646+ nola : bool = False ,
647+ nola_basis_num : int = 1 ,
596648 ** kwargs
597649 ):
598650 ColumnParallelLinear .__init__ (self , in_features , out_features , ** kwargs )
599651 if not isinstance (r , int ) or r <= 0 :
600652 raise ValueError ("Lora rank r should be a positive integer" )
601653
602- if pissa or use_mora :
603- raise ValueError ("Pissa or Mora is not supported in model parallel by now" )
654+ if pissa or use_mora or nola :
655+ raise ValueError ("Pissa, Mora or NoLA is not supported in model parallel by now" )
604656
605657 self .r = r
606658 self .lora_alpha = lora_alpha
0 commit comments