55
66from pynumdiff .utils import utility
77
8- #maria spectral diff below
9-
10- def spectraldiff (x , dt , axis = 0 , params = None , options = None , high_freq_cutoff = None ,
11- even_extension = True , pad_to_zero_dxdt = True ):
8+ def spectraldiff (x , dt , params = None , options = None , high_freq_cutoff = None ,
9+ even_extension = True , pad_to_zero_dxdt = True , axis = 0 ):
1210 """Take a derivative in the Fourier domain, with high frequency attentuation.
1311
1412 :param np.array[float] x: data to differentiate
@@ -26,7 +24,7 @@ def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None
2624 :return: - **x_hat** (np.array) -- estimated (smoothed) x
2725 - **dxdt_hat** (np.array) -- estimated derivative of x
2826 """
29- if params is not None :
27+ if params is not None : # Warning to support old interface for a while. Remove these lines along with params in a future release.
3028 warn ("`params` and `options` parameters will be removed in a future version. Use `high_freq_cutoff`, " +
3129 "`even_extension`, and `pad_to_zero_dxdt` instead." , DeprecationWarning )
3230 high_freq_cutoff = params [0 ] if isinstance (params , list ) else params
@@ -38,52 +36,50 @@ def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None
3836
3937 x = np .asarray (x )
4038 x0 = np .moveaxis (x , axis , 0 ) # move time axis to the front of the array
41- # now x0 dims are (# of data points, # of signals)
39+ # Now x0 dims are (number of data points, number of signals)
4240 L = x0 .shape [0 ]
4341
44- # make derivative go to zero at ends (optional)
42+ # Make derivative go to zero at the ends (optional):
4543 if pad_to_zero_dxdt :
4644 padding = 100
45+ pre = x [0 ] * np .ones (padding )
46+ post = x [- 1 ] * np .ones (padding )
47+ x = np .hstack ((pre , x , post )) # extend the edges
4748
48- # just pad first and last values x100
49- first = x0 [0 :1 ]
50- last = x0 [- 1 :]
49+ # Pad first and last values x100
50+ first = x0 [0 :1 ]
51+ last = x0 [- 1 :]
5152 pre = np .repeat (first , padding , axis = 0 )
5253 post = np .repeat (last , padding , axis = 0 )
5354
54- xpad = np .concatenate ((pre , x0 , post ), axis = 0 ) # i think hstack won't work with the correct axis
55-
56- kernel = utility .mean_kernel (padding // 2 )
57- x_hat0 = utility .convolutional_smoother (xpad , kernel , axis = 0 )
58-
59- x_hat0 [padding :- padding ] = xpad [padding :- padding ]
60- x0 = x_hat0
55+ xpad = np .concatenate ((pre , x0 , post ), axis = 0 ) # concatenate along axis 0
6156 else :
6257 padding = 0
6358
64- # Do even extension (optional):
59+ # Do even extension (optional)
6560 if even_extension is True :
6661 x0 = np .concatenate ((x0 , x0 [::- 1 , ...]), axis = 0 )
6762
6863 # Form wavenumbers
6964 N = x0 .shape [0 ]
7065 k = np .concatenate ((np .arange (N // 2 + 1 ), np .arange (- N // 2 + 1 , 0 )))
71- if N % 2 == 0 : k [N // 2 ] = 0 # odd derivatives get the Nyquist element zeroed out
66+ if N % 2 == 0 : k [N // 2 ] = 0 # odd derivatives get the Nyquist element zeroed out
7267
7368 # Filter to zero out higher wavenumbers
7469 discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
75- filt = np .ones_like (k , dtype = float )
76- filt = np .ones (k .shape ); filt [discrete_cutoff :N - discrete_cutoff ] = 0
70+
71+ filt = np .ones (k .shape ) # start with all frequencies passing
72+ filt [discrete_cutoff :- discrete_cutoff ] = 0 # zero out high-frequency components
7773 filt = filt .reshape ((N ,) + (1 ,)* (x0 .ndim - 1 ))
7874
79- # Smoothed signal
75+ # Smoothed signal
8076 X = np .fft .fft (x0 , axis = 0 )
8177
8278 x_hat0 = np .real (np .fft .ifft (filt * X , axis = 0 ))
8379 x_hat0 = x_hat0 [padding :L + padding ]
8480
8581 # Derivative = 90 deg phase shift
86- omega = 2 * np .pi / (dt * N )
82+ omega = 2 * np .pi / (dt * N ) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
8783 k0 = k .reshape ((N ,) + (1 ,)* (x0 .ndim - 1 ))
8884 dxdt0 = np .real (np .fft .ifft (1j * k0 * omega * filt * X , axis = 0 ))
8985 dxdt0 = dxdt0 [padding :L + padding ]
0 commit comments