Skip to content

Commit 9a20e07

Browse files
FIX PVeRA forward implementation for bitsandbytes (#3189)
In the forward path of PVeRA, a sampling step is required. For the bnb layers, this was missing. The PR now adds the sampling step. This fixes failing tests in the nightly CI.
1 parent 852019c commit 9a20e07

2 files changed

Lines changed: 18 additions & 15 deletions

File tree

src/peft/tuners/pvera/bnb.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import bitsandbytes as bnb
2020
import torch
21+
import torch.nn.functional as F
2122

2223
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
2324
from peft.tuners.tuners_utils import check_adapters_to_merge
@@ -44,6 +45,7 @@ def __init__(
4445
super().__init__()
4546
PveraLayer.__init__(self, base_layer)
4647
self.fan_in_fan_out = config.fan_in_fan_out
48+
self.sample_at_inference = config.sample_at_inference
4749

4850
self._active_adapter = adapter_name
4951
self.update_layer(
@@ -223,9 +225,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
223225
sliced_B = pvera_B[: self.out_features, :].to(x.device)
224226

225227
x_temp = dropout(x.to(lambda_d.dtype))
226-
227-
adapter_output = lambda_b * torch.nn.functional.linear(
228-
lambda_d * torch.nn.functional.linear(x_temp, sliced_A), sliced_B
228+
mu, logvar = (lambda_d * F.linear(x_temp, sliced_A)).chunk(2, dim=-1)
229+
adapter_output = lambda_b * F.linear(
230+
self._reparametrize(mu, logvar, self.sample_at_inference), sliced_B
229231
)
230232

231233
if requires_conversion:
@@ -257,6 +259,7 @@ def __init__(
257259
super().__init__()
258260
PveraLayer.__init__(self, base_layer)
259261
self.fan_in_fan_out = config.fan_in_fan_out
262+
self.sample_at_inference = config.sample_at_inference
260263

261264
self._active_adapter = adapter_name
262265
self.update_layer(
@@ -392,9 +395,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
392395
sliced_B = pvera_B[: self.out_features, :].to(x.device)
393396

394397
x_temp = dropout(x.to(lambda_d.dtype))
395-
396-
adapter_output = lambda_b * torch.nn.functional.linear(
397-
lambda_d * torch.nn.functional.linear(x_temp, sliced_A), sliced_B
398+
mu, logvar = (lambda_d * F.linear(x_temp, sliced_A)).chunk(2, dim=-1)
399+
adapter_output = lambda_b * F.linear(
400+
self._reparametrize(mu, logvar, self.sample_at_inference), sliced_B
398401
)
399402

400403
if requires_conversion:

src/peft/tuners/pvera/layer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,15 @@ def reset_pvera_parameters(self, adapter_name, d_initial: float = 0.1):
143143
nn.init.zeros_(self.pvera_lambda_d[adapter_name]).fill_(d_initial)
144144
nn.init.zeros_(self.pvera_lambda_b[adapter_name])
145145

146+
def _reparametrize(self, mu, logvar, sample_at_inference):
147+
if self.training or (not self.training and sample_at_inference):
148+
std = torch.exp(0.5 * logvar)
149+
eps = torch.randn_like(std)
150+
z = mu + eps * std
151+
else:
152+
z = mu
153+
return z
154+
146155

147156
class Linear(nn.Linear, PveraLayer):
148157
# PVeRA implemented in a dense layer
@@ -259,15 +268,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
259268

260269
return output_tensor
261270

262-
def _reparametrize(self, mu, logvar, sample_at_inference):
263-
if self.training or (not self.training and sample_at_inference):
264-
std = torch.exp(0.5 * logvar)
265-
eps = torch.randn_like(std)
266-
z = mu + eps * std
267-
else:
268-
z = mu
269-
return z
270-
271271
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
272272
previous_dtype = x.dtype
273273

0 commit comments

Comments
 (0)