Skip to content

Commit f2472b1

Browse files
Improve DIM performance using padding in FFT (#92)
1 parent 52d9eef commit f2472b1

3 files changed

Lines changed: 36 additions & 5 deletions

File tree

src/torchoptics/functional/functional.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def calculate_std(intensity: Tensor, meshgrid: tuple[Tensor, Tensor]) -> Tensor:
4242
)
4343

4444

45-
def conv2d_fft(input: Tensor, weight: Tensor) -> Tensor:
45+
def conv2d_fft(input: Tensor, weight: Tensor, fft_padding: int = 0) -> Tensor:
4646
"""Perform a 2D convolution using Fast Fourier Transforms (FFT).
4747
4848
Unlike the :func:`torch.nn.functional.conv2d` function, which performs cross-correlation,
@@ -51,14 +51,18 @@ def conv2d_fft(input: Tensor, weight: Tensor) -> Tensor:
5151
Args:
5252
input (torch.Tensor): Input tensor to be convolved of shape :math:`(..., iH, iW)`.
5353
weight (torch.Tensor): Filters of shape :math:`(..., kH, kW)`.
54+
fft_padding (int): Number of extra zeros appended to the input in each spatial dimension
55+
before the FFT. Does not affect the output size, but can improve FFT performance
56+
when the padded size has favorable prime factors. Default: ``0``.
5457
5558
Returns:
5659
torch.Tensor: Convolved output tensor of shape :math:`(..., oH, oW)`.
5760
5861
"""
59-
input_fr = fft2(input)
60-
output_size = (input_fr.size(-2) - weight.size(-2) + 1, input_fr.size(-1) - weight.size(-1) + 1)
61-
weight_fr = fft2(weight.flip(-1, -2).conj(), s=(input_fr.size(-2), input_fr.size(-1)))
62+
output_size = (input.size(-2) - weight.size(-2) + 1, input.size(-1) - weight.size(-1) + 1)
63+
fft_size = (input.size(-2) + fft_padding, input.size(-1) + fft_padding)
64+
input_fr = fft2(input, s=fft_size)
65+
weight_fr = fft2(weight.flip(-1, -2).conj(), s=fft_size)
6266
output_fr = input_fr * weight_fr.conj()
6367
return ifft2(output_fr)[..., : output_size[0], : output_size[1]]
6468

src/torchoptics/propagation/direct_integration_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def dim_propagation(field: Field, propagation_plane: PlanarGrid, propagation_met
2828
"""
2929
x, y = calculate_meshgrid(field, propagation_plane)
3030
impulse_response = calculate_impulse_response(field, propagation_plane, x, y, propagation_method)
31-
propagated_data = conv2d_fft(impulse_response, field.data)
31+
# padding=1: impulse response is (N+M-1), so FFT runs at (N+M) which can improve performance
32+
propagated_data = conv2d_fft(impulse_response, field.data, fft_padding=1)
3233
return field.copy(data=propagated_data, z=propagation_plane.z, offset=propagation_plane.offset)
3334

3435

tests/functional/test_conv2d.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,29 @@ def test_conv2d_fft():
99
conv2d_output = torch.nn.functional.conv2d(input, weight.flip(-1, -2))
1010
conv2d_fft_output = conv2d_fft(input, weight)
1111
assert torch.allclose(conv2d_output, conv2d_fft_output, atol=1e-5)
12+
13+
14+
def test_conv2d_fft_large_kernel():
15+
input = torch.randn(1, 1, 64, 96, dtype=torch.complex128)
16+
weight = torch.randn(1, 1, 17, 33, dtype=torch.complex128)
17+
conv2d_output = torch.nn.functional.conv2d(input, weight.flip(-1, -2))
18+
conv2d_fft_output = conv2d_fft(input, weight)
19+
assert torch.allclose(conv2d_output, conv2d_fft_output, atol=1e-8)
20+
21+
22+
def test_conv2d_fft_with_padding_matches_conv2d():
23+
# Ensure fft_padding does not change numerical result compared to conv2d
24+
input = torch.randn(2, 1, 30, 45, dtype=torch.complex64)
25+
weight = torch.randn(1, 1, 5, 7, dtype=torch.complex64)
26+
expected = torch.nn.functional.conv2d(input, weight.flip(-1, -2))
27+
28+
# no padding
29+
out0 = conv2d_fft(input, weight, fft_padding=0)
30+
# small padding
31+
out8 = conv2d_fft(input, weight, fft_padding=8)
32+
# larger padding
33+
out32 = conv2d_fft(input, weight, fft_padding=32)
34+
35+
assert torch.allclose(expected, out0, atol=1e-5)
36+
assert torch.allclose(expected, out8, atol=1e-5)
37+
assert torch.allclose(expected, out32, atol=1e-5)

0 commit comments

Comments
 (0)