Skip to content

Commit 90e3199

Browse files
committed
smoothing over
1 parent 534d12d commit 90e3199

2 files changed

Lines changed: 23 additions & 17 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
4343
x = np.concatenate((pre, x, post), axis=axis) # extend the edges
4444
kernel = utility.mean_kernel(padding//2)
4545
x_smoothed = utility.convolutional_smoother(x, kernel, axis=axis) # smooth the padded edges in
46-
center = (slice(None),)*axis + (slice(padding, L+padding),) + (slice(None),)*(x.ndim-axis-1)
47-
x_smoothed[center] = x[center] # restore original signal in the middle
46+
original_signal = (slice(None),)*axis + (slice(padding, L+padding),) + (slice(None),)*(x.ndim-axis-1)
47+
x_smoothed[original_signal] = x[original_signal] # restore original signal in the middle
4848
x = x_smoothed
49+
else:
50+
padding = 0
4951

5052
# Do even extension (optional)
5153
if even_extension is True:
@@ -60,8 +62,8 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
6062

6163
# Filter to zero out higher wavenumbers
6264
discrete_cutoff = int(high_freq_cutoff * N / 2) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
63-
filt = np.ones(k.shape) # start with all frequencies passing
64-
filt[discrete_cutoff:-discrete_cutoff] = 0 # zero out high-frequency components
65+
filt = np.ones(k.shape) # start with all frequencies passing
66+
filt[discrete_cutoff:-discrete_cutoff] = 0 # zero out high-frequency components
6567

6668
# Smoothed signal
6769
X = np.fft.fft(x, axis=axis)
@@ -71,7 +73,8 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
7173
omega = 2*np.pi/(dt*N) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
7274
dxdt = np.real(np.fft.ifft(1j * k[s] * omega * filt[s] * X, axis=axis))
7375

74-
return (x_hat[center], dxdt[center]) if pad_to_zero_dxdt else (x_hat, dxdt)
76+
original_signal = (slice(None),)*axis + (slice(padding, L+padding),) + (slice(None),)*(x_hat.ndim-axis-1)
77+
return x_hat[original_signal], dxdt[original_signal]
7578

7679

7780
def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
@@ -90,19 +93,19 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
9093
:return: - **x_hat** (np.array) -- estimated (smoothed) x
9194
- **dxdt_hat** (np.array) -- estimated derivative of x
9295
"""
93-
x = np.asarray(x)
94-
x = np.moveaxis(x, axis, 0) # bring target axis to front
95-
orig_shape = x.shape
96-
N = orig_shape[0]
97-
x_2d = x.reshape(N, -1) # (N, M) — build matrix once, solve for all M columns
96+
N = x.shape[axis]
97+
x = np.moveaxis(x, axis, 0) # bring axis of differentiation to front so each N repeats comprise vector
98+
plump = x.shape
99+
x_flattened = x.reshape(N, -1) # (N, M) matrix where each column is a vector along the original axis
98100

99101
if np.isscalar(dt_or_t):
100102
t = np.arange(N)*dt_or_t
101103
else: # support variable step size for this function
102104
if N != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
103105
t = dt_or_t
104106

105-
# The below does the approximate equivalent of this code, but sparsely in O(N sigma^2), since the rbf falls off rapidly
107+
# For each vector along the axis of differentiation, the below does the approximate equivalent of this code,
108+
# but sparsely in O(N sigma^2), since the rbf falls off rapidly
106109
# t_i, t_j = np.meshgrid(t,t)
107110
# r = t_j - t_i # radius
108111
# rbf = np.exp(-(r**2) / (2 * sigma**2)) # radial basis function kernel, O(N^2) entries
@@ -125,8 +128,9 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
125128
rbf = sparse.csr_matrix((vals, (rows, cols)), shape=(N, N)) # Build sparse kernels, O(N sigma) entries
126129
drbfdt = sparse.csr_matrix((dvals, (rows, cols)), shape=(N, N))
127130
rbf_regularized = rbf + lmbd*sparse.eye(N, format="csr") # identity matrix gives a little extra height at the centers
128-
alpha = sparse.linalg.spsolve(rbf_regularized, x_2d) # solve sparse system targeting the noisy data, O(N sigma^2)
131+
alpha = sparse.linalg.spsolve(rbf_regularized, x_flattened) # solve sparse system targeting the noisy data,
132+
# can take matrix target, O(N sigma^2) for each vector
133+
x_hat_flattened = rbf @ alpha # find samples of reconstructions using the smooth bases
134+
dxdt_hat_flattened = drbfdt @ alpha
129135

130-
x_hat = np.moveaxis((rbf @ alpha).reshape(orig_shape), 0, axis) # find samples of reconstructions using the smooth bases
131-
dxdt_hat = np.moveaxis((drbfdt @ alpha).reshape(orig_shape), 0, axis)
132-
return x_hat, dxdt_hat
136+
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
328328
(polydiff, {'degree': 2, 'window_size': 5}),
329329
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
330330
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
331-
(rbfdiff, {'sigma': 0.5, 'lmbd': 0.001}),
331+
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
332+
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
332333
(splinediff, {'degree': 9, 's': 1e-6}),
333334
(robustdiff, {'order':2, 'log_q':7, 'log_r':2})
334335
]
@@ -344,7 +345,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
344345
polydiff: [(1, -1), (1, 0)],
345346
savgoldiff: [(0, -1), (1, 1)],
346347
rtsdiff: [(1, -1), (1, 0)],
347-
rbfdiff: [(2, 1), (2, 1)],
348+
spectraldiff: [(2, 1), (3, 2)], # lot of Gibbs ringing in 2nd order derivatives along t1 with t_1^2 sin(3 pi t_2 / 2)
349+
rbfdiff: [(0, -1), (1, 0)],
348350
splinediff: [(0, -1), (1, 0)],
349351
robustdiff: [(-2, -3), (0, -1)]
350352
}

0 commit comments

Comments
 (0)