Skip to content

Commit d992626

Browse files
pavelkomarovclaude
andcommitted
Add axis parameter to rbfdiff for multidimensional support (#76)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4755e00 commit d992626

2 files changed

Lines changed: 21 additions & 10 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
7474
return x_hat, dxdt_hat
7575

7676

77-
def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
77+
def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
7878
"""Find smoothed function and derivative estimates by fitting noisy data with radial-basis-functions. Naively,
7979
fill a matrix with basis function samples and solve a linear inverse problem against the data, but truncate tiny
8080
values to make columns sparse. Each basis function "hill" is topped with a "tower" of height :code:`lmbd` to reach
@@ -85,14 +85,21 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
8585
:math:`\\Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
8686
:param float sigma: controls width of radial basis functions
8787
:param float lmbd: controls smoothness
88+
:param int axis: data dimension along which differentiation is performed
8889
8990
:return: - **x_hat** (np.array) -- estimated (smoothed) x
9091
- **dxdt_hat** (np.array) -- estimated derivative of x
9192
"""
93+
x = np.asarray(x)
94+
x = np.moveaxis(x, axis, 0) # bring target axis to front
95+
orig_shape = x.shape
96+
N = orig_shape[0]
97+
x_2d = x.reshape(N, -1) # (N, M) — build matrix once, solve for all M columns
98+
9299
if np.isscalar(dt_or_t):
93-
t = np.arange(len(x))*dt_or_t
100+
t = np.arange(N)*dt_or_t
94101
else: # support variable step size for this function
95-
if len(x) != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
102+
if N != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
96103
t = dt_or_t
97104

98105
# The below does the approximate equivalent of this code, but sparsely in O(N sigma^2), since the rbf falls off rapidly
@@ -115,9 +122,11 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
115122
dv = -radius / sigma**2 * v # take derivative of radial basis function, because d/dt coef*f(t) = coef*df/dt
116123
rows.append(n); cols.append(j); vals.append(v); dvals.append(dv)
117124

118-
rbf = sparse.csr_matrix((vals, (rows, cols)), shape=(len(t), len(t))) # Build sparse kernels, O(N sigma) entries
119-
drbfdt = sparse.csr_matrix((dvals, (rows, cols)), shape=(len(t), len(t)))
120-
rbf_regularized = rbf + lmbd*sparse.eye(len(t), format="csr") # identity matrix gives a little extra height at the centers
121-
alpha = sparse.linalg.spsolve(rbf_regularized, x) # solve sparse system targeting the noisy data, O(N sigma^2)
125+
rbf = sparse.csr_matrix((vals, (rows, cols)), shape=(N, N)) # Build sparse kernels, O(N sigma) entries
126+
drbfdt = sparse.csr_matrix((dvals, (rows, cols)), shape=(N, N))
127+
rbf_regularized = rbf + lmbd*sparse.eye(N, format="csr") # identity matrix gives a little extra height at the centers
128+
alpha = sparse.linalg.spsolve(rbf_regularized, x_2d) # solve sparse system targeting the noisy data, O(N sigma^2)
122129

123-
return rbf @ alpha, drbfdt @ alpha # find samples of reconstructions using the smooth bases
130+
x_hat = np.moveaxis((rbf @ alpha).reshape(orig_shape), 0, axis) # find samples of reconstructions using the smooth bases
131+
dxdt_hat = np.moveaxis((drbfdt @ alpha).reshape(orig_shape), 0, axis)
132+
return x_hat, dxdt_hat

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
326326
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
327327
(finitediff, {}),
328328
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
329-
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True})
329+
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
330+
(rbfdiff, {'sigma': 0.5, 'lmbd': 0.001}),
330331
]
331332

332333
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -338,7 +339,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
338339
butterdiff: [(0, -1), (1, -1)],
339340
finitediff: [(0, -1), (1, -1)],
340341
savgoldiff: [(0, -1), (1, 1)],
341-
rtsdiff: [(1, -1), (1, 0)]
342+
rtsdiff: [(1, -1), (1, 0)],
343+
rbfdiff: [(2, 1), (2, 1)],
342344
}
343345

344346
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)

0 commit comments

Comments
 (0)