Skip to content

Commit 654bd76

Browse files
committed
fixed pavels comments including some indentation, commenting, and redundancy issues :3
1 parent b2eae85 commit 654bd76

1 file changed

Lines changed: 19 additions & 23 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55

66
from pynumdiff.utils import utility
77

8-
#maria spectral diff below
9-
10-
def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None,
11-
even_extension=True, pad_to_zero_dxdt=True):
8+
def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None,
9+
even_extension=True, pad_to_zero_dxdt=True, axis=0):
1210
"""Take a derivative in the Fourier domain, with high frequency attentuation.
1311
1412
:param np.array[float] x: data to differentiate
@@ -26,7 +24,7 @@ def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None
2624
:return: - **x_hat** (np.array) -- estimated (smoothed) x
2725
- **dxdt_hat** (np.array) -- estimated derivative of x
2826
"""
29-
if params is not None:
27+
if params is not None: # Warning to support old interface for a while. Remove these lines along with params in a future release.
3028
warn("`params` and `options` parameters will be removed in a future version. Use `high_freq_cutoff`, " +
3129
"`even_extension`, and `pad_to_zero_dxdt` instead.", DeprecationWarning)
3230
high_freq_cutoff = params[0] if isinstance(params, list) else params
@@ -38,52 +36,50 @@ def spectraldiff(x, dt, axis=0, params=None, options=None, high_freq_cutoff=None
3836

3937
x = np.asarray(x)
4038
x0 = np.moveaxis(x, axis, 0) # move time axis to the front of the array
41-
# now x0 dims are (# of data points, # of signals)
39+
# Now x0 dims are (number of data points, number of signals)
4240
L = x0.shape[0]
4341

44-
# make derivative go to zero at ends (optional)
42+
# Make derivative go to zero at the ends (optional):
4543
if pad_to_zero_dxdt:
4644
padding = 100
45+
pre = x[0] * np.ones(padding)
46+
post = x[-1] * np.ones(padding)
47+
x = np.hstack((pre, x, post)) # extend the edges
4748

48-
# just pad first and last values x100
49-
first = x0[0:1]
50-
last = x0[-1:]
49+
# Pad first and last values x100
50+
first = x0[0:1]
51+
last = x0[-1:]
5152
pre = np.repeat(first, padding, axis=0)
5253
post = np.repeat(last, padding, axis=0)
5354

54-
xpad = np.concatenate((pre, x0, post), axis=0) # i think hstack won't work with the correct axis
55-
56-
kernel = utility.mean_kernel(padding//2)
57-
x_hat0 = utility.convolutional_smoother(xpad, kernel, axis=0)
58-
59-
x_hat0[padding:-padding] = xpad[padding:-padding]
60-
x0 = x_hat0
55+
xpad = np.concatenate((pre, x0, post), axis=0) # concatenate along axis 0
6156
else:
6257
padding = 0
6358

64-
# Do even extension (optional):
59+
# Do even extension (optional)
6560
if even_extension is True:
6661
x0 = np.concatenate((x0, x0[::-1, ...]), axis=0)
6762

6863
# Form wavenumbers
6964
N = x0.shape[0]
7065
k = np.concatenate((np.arange(N//2 + 1), np.arange(-N//2 + 1, 0)))
71-
if N % 2 == 0: k[N//2] = 0 # odd derivatives get the Nyquist element zeroed out
66+
if N % 2 == 0: k[N//2] = 0 # odd derivatives get the Nyquist element zeroed out
7267

7368
# Filter to zero out higher wavenumbers
7469
discrete_cutoff = int(high_freq_cutoff * N / 2) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
75-
filt = np.ones_like(k, dtype=float)
76-
filt = np.ones(k.shape); filt[discrete_cutoff:N-discrete_cutoff] = 0
70+
71+
filt = np.ones(k.shape) # start with all frequencies passing
72+
filt[discrete_cutoff:-discrete_cutoff] = 0 # zero out high-frequency components
7773
filt = filt.reshape((N,) + (1,)*(x0.ndim-1))
7874

79-
# Smoothed signal
75+
# Smoothed signal
8076
X = np.fft.fft(x0, axis=0)
8177

8278
x_hat0 = np.real(np.fft.ifft(filt * X, axis=0))
8379
x_hat0 = x_hat0[padding:L+padding]
8480

8581
# Derivative = 90 deg phase shift
86-
omega = 2*np.pi/(dt*N)
82+
omega = 2*np.pi/(dt*N) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
8783
k0 = k.reshape((N,) + (1,)*(x0.ndim-1))
8884
dxdt0 = np.real(np.fft.ifft(1j * k0 * omega * filt * X, axis=0))
8985
dxdt0 = dxdt0[padding:L+padding]

0 commit comments

Comments
 (0)