Skip to content

Commit 4755e00

Browse files
authored
Merge pull request #193 from florisvb/fix/191-splinediff-nan
Extend splinediff to handle missing data (NaN), add missing-data tests
2 parents 58f4918 + f574d2e commit 4755e00

2 files changed

Lines changed: 29 additions & 5 deletions

File tree

pynumdiff/polynomial_fit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
1010
"""Find smoothed data and derivative estimates by fitting a smoothing spline to the data with
1111
scipy.interpolate.UnivariateSpline. Variable step size is supported with equal ease as uniform step size.
1212
13-
:param np.array[float] x: data to differentiate
13+
:param np.array[float] x: data to differentiate. May contain NaN values (missing data); NaNs are excluded from
14+
fitting and imputed by spline interpolation.
1415
:param float or array[float] dt_or_t: This function supports variable step size. This parameter is either the constant
1516
:math:`\\Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
1617
:param list params: (**deprecated**, prefer :code:`degree`, :code:`cutoff_freq`, and :code:`num_iterations`)
@@ -38,8 +39,9 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
3839

3940
x_hat = x
4041
for _ in range(num_iterations):
41-
spline = scipy.interpolate.UnivariateSpline(t, x_hat, k=degree, s=s)
42-
x_hat = spline(t)
42+
obs = ~np.isnan(x_hat) # UnivariateSpline can't handle NaN, so fit only on observed points
43+
spline = scipy.interpolate.UnivariateSpline(t[obs], x_hat[obs], k=degree, s=s)
44+
x_hat = spline(t) # evaluate at all t, filling in NaN positions by interpolation
4345

4446
dspline = spline.derivative()
4547
dxdt_hat = dspline(t)

pynumdiff/tests/test_diff_methods.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
325325
(kerneldiff, {'kernel': 'gaussian', 'window_size': 5}),
326326
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
327327
(finitediff, {}),
328-
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3})
328+
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
329+
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True})
329330
]
330331

331332
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -336,7 +337,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
336337
kerneldiff: [(2, 1), (3, 2)],
337338
butterdiff: [(0, -1), (1, -1)],
338339
finitediff: [(0, -1), (1, -1)],
339-
savgoldiff: [(0, -1), (1, 1)]
340+
savgoldiff: [(0, -1), (1, 1)],
341+
rtsdiff: [(1, -1), (1, 0)]
340342
}
341343

342344
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)
@@ -390,3 +392,23 @@ def test_multidimensionality(multidim_method_and_params, request):
390392
ax3.plot_wireframe(T1, T2, computed_laplacian, label='computed')
391393
legend = ax3.legend(bbox_to_anchor=(0.7, 0.8)); legend.legend_handles[0].set_facecolor(pyplot.cm.viridis(0.6))
392394
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
395+
396+
# List of methods that can handle missing values
397+
nan_methods_and_params = [
398+
(splinediff, {'degree': 5, 's': 2}),
399+
(polydiff, {'degree': 2, 'window_size': 3}),
400+
(rtsdiff, {'order': 2, 'log_qr_ratio': 7, 'forwardbackward': True}),
401+
(robustdiff, {'order': 3, 'log_q': 7, 'log_r': 2}),
402+
]
403+
404+
@mark.parametrize("diff_method_and_params", nan_methods_and_params)
405+
def test_missing_data(diff_method_and_params):
406+
"""Ensure methods that support missing data return finite outputs when NaN values are present"""
407+
diff_method, params = diff_method_and_params
408+
409+
x_nan = np.sin(t)
410+
x_nan[[5, 10, 15, 20]] = np.nan # introduce missing data at several locations
411+
x_hat, dxdt_hat = diff_method(x_nan, dt, **params)
412+
413+
assert np.all(np.isfinite(x_hat))
414+
assert np.all(np.isfinite(dxdt_hat))

0 commit comments

Comments
 (0)