Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def update_layer(
if base_layer.dilation[0] > 1:
raise ValueError("Conv2d with dilation > 1 is not supported by OFT.")

conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[1]

if r == 0 and oft_block_size != 0:
if conv_filter_dim % oft_block_size != 0 or oft_block_size > conv_filter_dim:
Expand Down Expand Up @@ -847,13 +847,13 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
oft_mat = self.get_delta_weight(active_adapter)

orig_weights = orig_weights.view(
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[1]
)
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[1]
)

base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
Expand All @@ -862,13 +862,13 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N

orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights.view(
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[1]
)
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[1]
)

base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5019,6 +5019,31 @@ def test_requires_grad_oft_same_targets(self):
"base_model.model.lin0.oft_R.adapter1.weight",
)

def test_oft_conv2d_non_square_kernel(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is placed in the wrong place. Please place it at the end of TestPeftCustomModel.

# OFT on a Conv2d with non-square kernel used to crash because the adapter
# filter dimension was computed as kernel_size[0] ** 2 instead of
# kernel_size[0] * kernel_size[1]; the forward pass then hit a shape
# mismatch and `merge_and_unload` failed to reshape the weights.
class NonSquareKernelConv2D(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(5, 10, kernel_size=(3, 5))

def forward(self, X):
X = X.float().reshape(-1, 5, 3, 5)
return self.conv2d(X)

torch.manual_seed(0)
model = NonSquareKernelConv2D().to(self.torch_device)
config = OFTConfig(r=5, oft_block_size=0, target_modules=["conv2d"], init_weights=False)
peft_model = get_peft_model(model, config)

X = torch.arange(5 * 5 * 3 * 5, dtype=torch.float, device=self.torch_device).reshape(5, 5, 3, 5)
output = peft_model(X)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment: "# Ensure that the forward pass does not raise". This is what we actually want to test. Equality of non-merged vs merged is nice to test too, but not the critical issue.

merged = peft_model.merge_and_unload()
output_merged = merged(X)
assert torch.allclose(output, output_merged, atol=1e-5, rtol=1e-5)

def test_requires_grad_hra_different_targets(self):
# test two different HRA adapters that target different modules
config0 = HRAConfig(target_modules=["lin0"])
Expand Down