Skip to content

Fix OFT Conv2d weight reshaping for non-square kernels#3150

Open
Chessing234 wants to merge 2 commits intohuggingface:mainfrom
Chessing234:fix/oft-conv2d-nonsquare-kernel
Open

Fix OFT Conv2d weight reshaping for non-square kernels#3150
Chessing234 wants to merge 2 commits intohuggingface:mainfrom
Chessing234:fix/oft-conv2d-nonsquare-kernel

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Summary

The OFT Conv2d layer computes the filter dimension and reshapes weights using kernel_size[0] * kernel_size[0], squaring the kernel height instead of computing height × width (kernel_size[0] * kernel_size[1]). Similarly, the 4D reshape uses kernel_size[0], kernel_size[0] instead of kernel_size[0], kernel_size[1].

This works by accident for square kernels (3×3, 5×5, etc.) but produces wrong dimensions for any asymmetric kernel (e.g., 3×1, 1×7, 3×5), causing RuntimeError: shape mismatch during forward/merge/unmerge.

Fix: Replace all 5 occurrences of the second kernel_size[0] with kernel_size[1] in oft/layer.py.

Note: The same pattern exists in boft/layer.py and hra/layer.py. Happy to extend this fix to those files if desired.

Test plan

  • Existing OFT tests with square kernels should continue to pass
  • OFT with a non-square Conv2d kernel (e.g., nn.Conv2d(3, 16, (3, 1))) should now work

🤖 Generated with Claude Code

All weight reshape operations in the OFT Conv2d layer use
kernel_size[0] * kernel_size[0], squaring the height dimension
instead of computing height * width. This gives wrong filter
dimensions for non-square kernels (e.g. 3x1, 1x7), causing shape
mismatches in forward/merge/unmerge. It works by accident for square
kernels.

Note: the same pattern exists in boft/layer.py and hra/layer.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@BenjaminBossan
Copy link
Copy Markdown
Member

@Chessing234 Thanks for the PR. Could you please extend the unit tests to cover this case? I.e. add a test case that fails with the current main branch but is fixed with your branch? It should be fine to add a one-off test for this in test_custom_models.py inside the TestPeftCustomModel class.

The same pattern exists in boft/layer.py and hra/layer.py. Happy to extend this fix to those files if desired.

This would be really appreciated, better to have that in one PR than multiple smaller ones. The unit test can be parametrized to cover these PEFT methods too.

@Chessing234
Copy link
Copy Markdown
Contributor Author

Added a standalone test test_oft_conv2d_non_square_kernel in TestPeftCustomModel: it builds a Conv2d with kernel (3, 5), wraps it in OFT, runs a forward pass, then calls merge_and_unload and checks the merged output matches. Fails on main (shape mismatch from the kernel_size[0] ** 2 miscalculation); passes with this PR.

Copy link
Copy Markdown
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test, but it's in the wrong place. Please check the comment. You can ensure that the test passes by running:

pytest tests/test_custom_models.py -k test_oft_conv2d_non_square_kernel

Also, as mentioned above, it would be much better to have the BOFT and HRA fixes in the same PR, so please add those as well.

"base_model.model.lin0.oft_R.adapter1.weight",
)

def test_oft_conv2d_non_square_kernel(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is placed in the wrong place. Please place it at the end of TestPeftCustomModel.

peft_model = get_peft_model(model, config)

X = torch.arange(5 * 5 * 3 * 5, dtype=torch.float, device=self.torch_device).reshape(5, 5, 3, 5)
output = peft_model(X)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment: "# Ensure that the forward pass does not raise". This is what we actually want to test. Equality of non-merged vs merged is nice to test too, but not the critical issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants