Skip to content

Commit df59646

Browse files
pavelkomarovclaude
andcommitted
Add axis parameter to splinediff for multidimensional support (#76)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4755e00 commit df59646

2 files changed

Lines changed: 26 additions & 15 deletions

File tree

pynumdiff/polynomial_fit.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pynumdiff.utils import utility
77

88

9-
def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iterations=1):
9+
def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iterations=1, axis=0):
1010
"""Find smoothed data and derivative estimates by fitting a smoothing spline to the data with
1111
scipy.interpolate.UnivariateSpline. Variable step size is supported with equal ease as uniform step size.
1212
@@ -20,6 +20,7 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
2020
:param float s: positive smoothing factor used to choose the number of knots. Number of knots will be increased
2121
until the smoothing condition is satisfied: :math:`\\sum_t (x[t] - \\text{spline}[t])^2 \\leq s`
2222
:param int num_iterations: how many times to apply smoothing
23+
:param int axis: data dimension along which differentiation is performed
2324
2425
:return: - **x_hat** (np.array) -- estimated (smoothed) x
2526
- **dxdt_hat** (np.array) -- estimated derivative of x
@@ -31,22 +32,30 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
3132
if options is not None:
3233
if 'iterate' in options and options['iterate']: num_iterations = params[2]
3334

35+
x = np.moveaxis(np.asarray(x), axis, 0)
36+
n = x.shape[0]
37+
3438
if np.isscalar(dt_or_t):
35-
t = np.arange(len(x))*dt_or_t
39+
t = np.arange(n) * dt_or_t
3640
else: # support variable step size for this function
37-
if len(x) != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
41+
if n != len(dt_or_t): raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
3842
t = dt_or_t
3943

40-
x_hat = x
41-
for _ in range(num_iterations):
42-
obs = ~np.isnan(x_hat) # UnivariateSpline can't handle NaN, so fit only on observed points
43-
spline = scipy.interpolate.UnivariateSpline(t[obs], x_hat[obs], k=degree, s=s)
44-
x_hat = spline(t) # evaluate at all t, filling in NaN positions by interpolation
45-
46-
dspline = spline.derivative()
47-
dxdt_hat = dspline(t)
48-
49-
return x_hat, dxdt_hat
44+
x_hat = np.empty_like(x)
45+
dxdt_hat = np.empty_like(x)
46+
47+
for idx in np.ndindex(x.shape[1:]):
48+
sl = (slice(None),) + idx
49+
xi = x[sl]
50+
for _ in range(num_iterations):
51+
obs = ~np.isnan(xi) # UnivariateSpline can't handle NaN, so fit only on observed points
52+
spline = scipy.interpolate.UnivariateSpline(t[obs], xi[obs], k=degree, s=s)
53+
xi = spline(t) # evaluate at all t, filling in NaN positions by interpolation
54+
dspline = spline.derivative()
55+
x_hat[sl] = xi
56+
dxdt_hat[sl] = dspline(t)
57+
58+
return np.moveaxis(x_hat, 0, axis), np.moveaxis(dxdt_hat, 0, axis)
5059

5160

5261
def polydiff(x, dt_or_t, params=None, options=None, degree=None, window_size=None, step_size=1,

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
326326
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
327327
(finitediff, {}),
328328
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
329-
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True})
329+
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
330+
(splinediff, {'degree': 5, 's': 2}),
330331
]
331332

332333
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -338,7 +339,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
338339
butterdiff: [(0, -1), (1, -1)],
339340
finitediff: [(0, -1), (1, -1)],
340341
savgoldiff: [(0, -1), (1, 1)],
341-
rtsdiff: [(1, -1), (1, 0)]
342+
rtsdiff: [(1, -1), (1, 0)],
343+
splinediff: [(3, 1), (3, 2)],
342344
}
343345

344346
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)

0 commit comments

Comments
 (0)