55
66from pynumdiff .utils import utility
77
8-
9- def spectraldiff ( x , dt , params = None , options = None , high_freq_cutoff = None , even_extension = True , pad_to_zero_dxdt = True ):
8+ def spectraldiff ( x , dt , params = None , options = None , high_freq_cutoff = None , even_extension = True ,
9+ pad_to_zero_dxdt = True , axis = 0 ):
1010 """Take a derivative in the Fourier domain, with high frequency attentuation.
1111
1212 :param np.array[float] x: data to differentiate
@@ -33,45 +33,48 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
3333 elif high_freq_cutoff is None :
3434 raise ValueError ("`high_freq_cutoff` must be given." )
3535
36- L = len (x )
36+ L = x .shape [axis ]
37+ x = np .asarray (x )
3738
38- # make derivative go to zero at ends (optional)
39+ # Make derivative go to zero at the ends (optional)
3940 if pad_to_zero_dxdt :
4041 padding = 100
41- pre = getattr ( x , 'values' , x ) [0 ]* np . ones ( padding ) # getattr to use .values if x is a pandas Series
42- post = getattr ( x , 'values' , x ) [- 1 ]* np . ones ( padding )
43- x = np .hstack ((pre , x , post )) # extend the edges
42+ pre = np . repeat ( np . take ( x , [0 ], axis = axis ), padding , axis = axis ) # take keeps dimensions, unlike x[0]
43+ post = np . repeat ( np . take ( x , [- 1 ], axis = axis ), padding , axis = axis )
44+ x = np .concatenate ((pre , x , post ), axis = axis ) # extend the edges
4445 kernel = utility .mean_kernel (padding // 2 )
45- x_hat = utility .convolutional_smoother (x , kernel ) # smooth the edges in
46- x_hat [padding :- padding ] = x [padding :- padding ] # replace middle with original signal
47- x = x_hat
46+ x_smoothed = utility .convolutional_smoother (x , kernel , axis = axis ) # smooth the padded edges in
47+ center = (slice (None ),)* axis + (slice (padding , L + padding ),) + (slice (None ),)* (x .ndim - axis - 1 )
48+ x_smoothed [center ] = x [center ] # restore original signal in the middle
49+ x = x_smoothed
4850 else :
4951 padding = 0
5052
5153 # Do even extension (optional)
5254 if even_extension is True :
53- x = np .hstack ((x , x [::- 1 ]))
55+ x = np .concatenate ((x , np .flip (x , axis = axis )), axis = axis )
56+
57+ s = [np .newaxis for dim in x .shape ]; s [axis ] = slice (None ); s = tuple (s ) # for elevating vectors to have same dimension as data
5458
5559 # Form wavenumbers
56- N = len ( x )
60+ N = x . shape [ axis ]
5761 k = np .concatenate ((np .arange (N // 2 + 1 ), np .arange (- N // 2 + 1 , 0 )))
5862 if N % 2 == 0 : k [N // 2 ] = 0 # odd derivatives get the Nyquist element zeroed out
5963
6064 # Filter to zero out higher wavenumbers
61- discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
62- filt = np .ones (k .shape ); filt [discrete_cutoff :N - discrete_cutoff ] = 0
65+ discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
66+ filt = np .ones (k .shape ) # start with all frequencies passing
67+ filt [discrete_cutoff :- discrete_cutoff ] = 0 # zero out high-frequency components
6368
6469 # Smoothed signal
65- X = np .fft .fft (x )
66- x_hat = np .real (np .fft .ifft (filt * X ))
67- x_hat = x_hat [padding :L + padding ]
70+ X = np .fft .fft (x , axis = axis )
71+ x_hat = np .real (np .fft .ifft (filt [s ] * X , axis = axis ))
6872
6973 # Derivative = 90 deg phase shift
7074 omega = 2 * np .pi / (dt * N ) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
71- dxdt_hat = np .real (np .fft .ifft (1j * k * omega * filt * X ))
72- dxdt_hat = dxdt_hat [padding :L + padding ]
75+ dxdt = np .real (np .fft .ifft (1j * k [s ] * omega * filt [s ] * X , axis = axis ))
7376
74- return x_hat , dxdt_hat
77+ return ( x_hat [ center ], dxdt [ center ]) if pad_to_zero_dxdt else ( x_hat , dxdt )
7578
7679
7780def rbfdiff (x , dt_or_t , sigma = 1 , lmbd = 0.01 ):
0 commit comments