11import numpy as np
22
33
4- def calc_truth_fft (sky_dist ):
5- truth_fft = np .fft .fftshift (
6- np .fft .fft2 (np .fft .fftshift (sky_dist , axes = (2 , 3 )), axes = (2 , 3 )),
7- axes = (2 , 3 ),
4+ def calc_truth_fft (image ):
5+ """Calculates the Fourier transform (image space -> uv-space)
6+ for a single image or a batch of images.
7+
8+ This is shape independent as long as the last two axes are
9+ height and width, i.e. ``(..., H, W)``.
10+
11+ Parameters
12+ ----------
13+ image : array_like, shape (..., H, W)
14+ (True) sky distribution.
15+
16+ Returns
17+ -------
18+ fft_image : array_like, shape (..., H, W)
19+ Complex type array of the fft of the input image.
20+ """
21+ fft_image = np .fft .fftshift (
22+ np .fft .fft2 (np .fft .fftshift (image , axes = (- 2 , - 1 )), axes = (- 2 , - 1 )),
23+ axes = (- 2 , - 1 ),
824 )
9- return truth_fft
25+ return fft_image
1026
1127
1228def convert_amp_phase (data , sky_sim = False ):
1329 if sky_sim :
1430 amp = np .abs (data )
1531 phase = np .angle (data )
16- data = np .concatenate ((amp , phase ), axis = 1 )
32+ data = np .concatenate ((amp , phase ), axis = - 3 )
1733 else :
1834 test = data [:, 0 ] + 1j * data [:, 1 ]
1935 amp = np .abs (test )
2036 phase = np .angle (test )
21- data = np .stack ((amp , phase ), axis = 1 )
37+ data = np .stack ((amp , phase ), axis = - 3 )
2238
2339 return data
2440
@@ -28,11 +44,11 @@ def convert_real_imag(data, sky_sim=False):
2844 real = data .real
2945 imag = data .imag
3046
31- data = np .concatenate ((real , imag ), axis = 1 )
47+ data = np .concatenate ((real , imag ), axis = - 3 )
3248 else :
3349 real = data [:, 0 ]
3450 imag = data [:, 1 ]
3551
36- data = np .stack ((real , imag ), axis = 1 )
52+ data = np .stack ((real , imag ), axis = - 3 )
3753
3854 return data
0 commit comments