Skip to content

Commit dff1f9c

Browse files
authored
Merge pull request #197 from florisvb/rbfdiff-multidim
Add axis parameter to rbfdiff (#76)
2 parents a5247de + 6edafd5 commit dff1f9c

2 files changed

Lines changed: 30 additions & 17 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
3434
raise ValueError("`high_freq_cutoff` must be given.")
3535

3636
L = x.shape[axis]
37-
x = np.asarray(x)
3837

3938
# Make derivative go to zero at the ends (optional)
4039
if pad_to_zero_dxdt:
@@ -44,11 +43,11 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
4443
x = np.concatenate((pre, x, post), axis=axis) # extend the edges
4544
kernel = utility.mean_kernel(padding//2)
4645
x_smoothed = utility.convolutional_smoother(x, kernel, axis=axis) # smooth the padded edges in
47-
center = (slice(None),)*axis + (slice(padding, L+padding),) + (slice(None),)*(x.ndim-axis-1)
48-
x_smoothed[center] = x[center] # restore original signal in the middle
46+
m = (slice(None),)*axis + (slice(padding, L+padding),) + (slice(None),)*(x.ndim-axis-1) # middle
47+
x_smoothed[m] = x[m] # restore original signal in the middle
4948
x = x_smoothed
5049
else:
51-
padding = 0
50+
m = (slice(None),)*axis + (slice(0, L),) + (slice(None),)*(x.ndim-axis-1) # indices where signal lives
5251

5352
# Do even extension (optional)
5453
if even_extension is True:
@@ -63,21 +62,21 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
6362

6463
# Filter to zero out higher wavenumbers
6564
discrete_cutoff = int(high_freq_cutoff * N / 2) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
66-
filt = np.ones(k.shape) # start with all frequencies passing
67-
filt[discrete_cutoff:-discrete_cutoff] = 0 # zero out high-frequency components
65+
filt = np.ones(k.shape) # start with all frequencies passing
66+
filt[discrete_cutoff:-discrete_cutoff] = 0 # zero out high-frequency components
6867

6968
# Smoothed signal
7069
X = np.fft.fft(x, axis=axis)
7170
x_hat = np.real(np.fft.ifft(filt[s] * X, axis=axis))
7271

7372
# Derivative = 90 deg phase shift
7473
omega = 2*np.pi/(dt*N) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
75-
dxdt = np.real(np.fft.ifft(1j * k[s] * omega * filt[s] * X, axis=axis))
74+
dxdt_hat = np.real(np.fft.ifft(1j * k[s] * omega * filt[s] * X, axis=axis))
7675

77-
return (x_hat[center], dxdt[center]) if pad_to_zero_dxdt else (x_hat, dxdt)
76+
return x_hat[m], dxdt_hat[m]
7877

7978

80-
def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
79+
def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
8180
"""Find smoothed function and derivative estimates by fitting noisy data with radial-basis-functions. Naively,
8281
fill a matrix with basis function samples and solve a linear inverse problem against the data, but truncate tiny
8382
values to make columns sparse. Each basis function "hill" is topped with a "tower" of height :code:`lmbd` to reach
@@ -88,17 +87,24 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
8887
:math:`\\Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
8988
:param float sigma: controls width of radial basis functions
9089
:param float lmbd: controls smoothness
90+
:param int axis: data dimension along which differentiation is performed
9191
9292
:return: - **x_hat** (np.array) -- estimated (smoothed) x
9393
- **dxdt_hat** (np.array) -- estimated derivative of x
9494
"""
95+
N = x.shape[axis]
96+
x = np.moveaxis(x, axis, 0) # bring axis of differentiation to front so each N repeats comprise vector
97+
plump = x.shape
98+
x_flattened = x.reshape(N, -1) # (N, M) matrix where each column is a vector along the original axis
99+
95100
if np.isscalar(dt_or_t):
96-
t = np.arange(len(x))*dt_or_t
101+
t = np.arange(N)*dt_or_t
97102
else: # support variable step size for this function
98-
if len(x) != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
103+
if N != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
99104
t = dt_or_t
100105

101-
# The below does the approximate equivalent of this code, but sparsely in O(N sigma^2), since the rbf falls off rapidly
106+
# For each vector along the axis of differentiation, the below does the approximate equivalent of this code,
107+
# but sparsely in O(N sigma^2), since the rbf falls off rapidly
102108
# t_i, t_j = np.meshgrid(t,t)
103109
# r = t_j - t_i # radius
104110
# rbf = np.exp(-(r**2) / (2 * sigma**2)) # radial basis function kernel, O(N^2) entries
@@ -118,9 +124,12 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
118124
dv = -radius / sigma**2 * v # take derivative of radial basis function, because d/dt coef*f(t) = coef*df/dt
119125
rows.append(n); cols.append(j); vals.append(v); dvals.append(dv)
120126

121-
rbf = sparse.csr_matrix((vals, (rows, cols)), shape=(len(t), len(t))) # Build sparse kernels, O(N sigma) entries
122-
drbfdt = sparse.csr_matrix((dvals, (rows, cols)), shape=(len(t), len(t)))
123-
rbf_regularized = rbf + lmbd*sparse.eye(len(t), format="csr") # identity matrix gives a little extra height at the centers
124-
alpha = sparse.linalg.spsolve(rbf_regularized, x) # solve sparse system targeting the noisy data, O(N sigma^2)
127+
rbf = sparse.csr_matrix((vals, (rows, cols)), shape=(N, N)) # Build sparse kernels, O(N sigma) entries
128+
drbfdt = sparse.csr_matrix((dvals, (rows, cols)), shape=(N, N))
129+
rbf_regularized = rbf + lmbd*sparse.eye(N, format="csr") # identity matrix gives a little extra height at the centers
130+
alpha = sparse.linalg.spsolve(rbf_regularized, x_flattened) # solve sparse system targeting the noisy data,
131+
# can take matrix target, O(N sigma^2) for each vector
132+
x_hat_flattened = rbf @ alpha # find samples of reconstructions using the smooth bases
133+
dxdt_hat_flattened = drbfdt @ alpha
125134

126-
return rbf @ alpha, drbfdt @ alpha # find samples of reconstructions using the smooth bases
135+
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
328328
(polydiff, {'degree': 2, 'window_size': 5}),
329329
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
330330
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
331+
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
332+
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
331333
(splinediff, {'degree': 9, 's': 1e-6}),
332334
(robustdiff, {'order':2, 'log_q':7, 'log_r':2})
333335
]
@@ -343,6 +345,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
343345
polydiff: [(1, -1), (1, 0)],
344346
savgoldiff: [(0, -1), (1, 1)],
345347
rtsdiff: [(1, -1), (1, 0)],
348+
spectraldiff: [(2, 1), (3, 2)], # lot of Gibbs ringing in 2nd order derivatives along t1 with t_1^2 sin(3 pi t_2 / 2)
349+
rbfdiff: [(0, -1), (1, 0)],
346350
splinediff: [(0, -1), (1, 0)],
347351
robustdiff: [(-2, -3), (0, -1)]
348352
}

0 commit comments

Comments
 (0)