@@ -9,3 +9,29 @@ def test_conv2d_fft():
99 conv2d_output = torch .nn .functional .conv2d (input , weight .flip (- 1 , - 2 ))
1010 conv2d_fft_output = conv2d_fft (input , weight )
1111 assert torch .allclose (conv2d_output , conv2d_fft_output , atol = 1e-5 )
12+
13+
14+ def test_conv2d_fft_large_kernel ():
15+ input = torch .randn (1 , 1 , 64 , 96 , dtype = torch .complex128 )
16+ weight = torch .randn (1 , 1 , 17 , 33 , dtype = torch .complex128 )
17+ conv2d_output = torch .nn .functional .conv2d (input , weight .flip (- 1 , - 2 ))
18+ conv2d_fft_output = conv2d_fft (input , weight )
19+ assert torch .allclose (conv2d_output , conv2d_fft_output , atol = 1e-8 )
20+
21+
22+ def test_conv2d_fft_with_padding_matches_conv2d ():
23+ # Ensure fft_padding does not change numerical result compared to conv2d
24+ input = torch .randn (2 , 1 , 30 , 45 , dtype = torch .complex64 )
25+ weight = torch .randn (1 , 1 , 5 , 7 , dtype = torch .complex64 )
26+ expected = torch .nn .functional .conv2d (input , weight .flip (- 1 , - 2 ))
27+
28+ # no padding
29+ out0 = conv2d_fft (input , weight , fft_padding = 0 )
30+ # small padding
31+ out8 = conv2d_fft (input , weight , fft_padding = 8 )
32+ # larger padding
33+ out32 = conv2d_fft (input , weight , fft_padding = 32 )
34+
35+ assert torch .allclose (expected , out0 , atol = 1e-5 )
36+ assert torch .allclose (expected , out8 , atol = 1e-5 )
37+ assert torch .allclose (expected , out32 , atol = 1e-5 )
0 commit comments