66from pynumdiff .utils import utility
77
88
9- def splinediff (x , dt_or_t , params = None , options = None , degree = 3 , s = None , num_iterations = 1 ):
9+ def splinediff (x , dt_or_t , params = None , options = None , degree = 3 , s = None , num_iterations = 1 , axis = 0 ):
1010 """Find smoothed data and derivative estimates by fitting a smoothing spline to the data with
1111 scipy.interpolate.UnivariateSpline. Variable step size is supported with equal ease as uniform step size.
1212
@@ -20,6 +20,7 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
2020 :param float s: positive smoothing factor used to choose the number of knots. Number of knots will be increased
2121 until the smoothing condition is satisfied: :math:`\\ sum_t (x[t] - \\ text{spline}[t])^2 \\ leq s`
2222 :param int num_iterations: how many times to apply smoothing
23+ :param int axis: data dimension along which differentiation is performed
2324
2425 :return: - **x_hat** (np.array) -- estimated (smoothed) x
2526 - **dxdt_hat** (np.array) -- estimated derivative of x
@@ -31,22 +32,30 @@ def splinediff(x, dt_or_t, params=None, options=None, degree=3, s=None, num_iter
3132 if options is not None :
3233 if 'iterate' in options and options ['iterate' ]: num_iterations = params [2 ]
3334
35+ x = np .moveaxis (np .asarray (x ), axis , 0 )
36+ n = x .shape [0 ]
37+
3438 if np .isscalar (dt_or_t ):
35- t = np .arange (len ( x )) * dt_or_t
39+ t = np .arange (n ) * dt_or_t
3640 else : # support variable step size for this function
37- if len ( x ) != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
41+ if n != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
3842 t = dt_or_t
3943
40- x_hat = x
41- for _ in range (num_iterations ):
42- obs = ~ np .isnan (x_hat ) # UnivariateSpline can't handle NaN, so fit only on observed points
43- spline = scipy .interpolate .UnivariateSpline (t [obs ], x_hat [obs ], k = degree , s = s )
44- x_hat = spline (t ) # evaluate at all t, filling in NaN positions by interpolation
45-
46- dspline = spline .derivative ()
47- dxdt_hat = dspline (t )
48-
49- return x_hat , dxdt_hat
44+ x_hat = np .empty_like (x )
45+ dxdt_hat = np .empty_like (x )
46+
47+ for idx in np .ndindex (x .shape [1 :]):
48+ sl = (slice (None ),) + idx
49+ xi = x [sl ]
50+ for _ in range (num_iterations ):
51+ obs = ~ np .isnan (xi ) # UnivariateSpline can't handle NaN, so fit only on observed points
52+ spline = scipy .interpolate .UnivariateSpline (t [obs ], xi [obs ], k = degree , s = s )
53+ xi = spline (t ) # evaluate at all t, filling in NaN positions by interpolation
54+ dspline = spline .derivative ()
55+ x_hat [sl ] = xi
56+ dxdt_hat [sl ] = dspline (t )
57+
58+ return np .moveaxis (x_hat , 0 , axis ), np .moveaxis (dxdt_hat , 0 , axis )
5059
5160
5261def polydiff (x , dt_or_t , params = None , options = None , degree = None , window_size = None , step_size = 1 ,
0 commit comments