22from warnings import warn
33import numpy as np
44from scipy .linalg import expm , sqrtm
5+ from collections .abc import Iterable
56try : import cvxpy
67except ImportError : pass
78
8- from pynumdiff .utils .utility import huber_const
9+ from pynumdiff .utils .utility import huber_const , wrap_angle , ensure_iterable
910
1011
11- def kalman_filter (y , xhat0 , P0 , A , Q , C , R , B = None , u = None , save_P = True ):
12+ def kalman_filter (y , xhat0 , P0 , A , Q , C , R , B = None , u = None , save_P = True ,
13+ circular_vars = None , circular_units = 'rad' ):
1214 """Run the forward pass of a Kalman filter. Expects discrete-time matrices; use :func:`scipy.linalg.expm`
1315 in the caller to convert from continuous time if needed.
1416
@@ -24,6 +26,10 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
2426 :param np.array u: optional control inputs, stacked in the direction of axis 0
2527 :param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2628 smoothing but nonstandard to compute for ordinary filtering
29+ :param bool or None circular_vars: bool indicating whether the measurement y is a circular (angular) variable
30+ that is wrapped. This will use a circular innovation calculation for the Kalman filter. The smoothed result
31+ will be returned in an unwrapped form.
32+ :param string circular_units: 'rad' or 'deg' to specify whether wrapping is in degrees or radians.
2733
2834 :return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
2935 - **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -57,7 +63,10 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
5763 P = P_ .copy ()
5864 if not np .isnan (y [n ]): # handle missing data
5965 K = P_ @ C .T @ np .linalg .inv (C @ P_ @ C .T + R )
60- xhat += K @ (y [n ] - C @ xhat_ )
66+ innovation = y [n ] - C @ xhat_
67+ if circular_vars is not None and circular_vars is not False :
68+ innovation [0 ] = wrap_angle (innovation [0 ], circular_units )
69+ xhat += K @ innovation
6170 P -= K @ C @ P_
6271 # the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
6372 xhat_post [n ] = xhat
@@ -94,7 +103,8 @@ def rts_smooth(A, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=True):
94103 return xhat_smooth if not compute_P_smooth else (xhat_smooth , P_smooth )
95104
96105
97- def rtsdiff (x , dt_or_t , order , log_qr_ratio , forwardbackward , axis = 0 ):
106+ def rtsdiff (x , dt_or_t , order , log_qr_ratio , forwardbackward , axis = 0 ,
107+ circular_vars = None , circular_units = 'rad' ):
98108 """Perform Rauch-Tung-Striebel smoothing with a naive constant derivative model. Makes use of :code:`kalman_filter`
99109 and :code:`rts_smooth`, which are made public. :code:`constant_X` methods in this module call this function.
100110
@@ -109,13 +119,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
109119 :param bool forwardbackward: indicates whether to run smoother forwards and backwards
110120 (usually achieves better estimate at end points)
111121 :param int axis: data dimension along which differentiation is performed
122+ :param list[bool] circular_vars: set list element to bool for any axes of x that are a circular (angular) variable
123+ that is wrapped. This will use a circular innovation calculation for the Kalman filter. The smoothed result
124+ will be returned in an unwrapped form.
125+ :param string circular_units: 'rad' or 'deg' to specify whether wrapping is in degrees or radians.
112126
113127 :return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
114128 - **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
115129 """
116130 N = x .shape [axis ]
117131 if not np .isscalar (dt_or_t ) and N != len (dt_or_t ):
118132 raise ValueError ("If `dt_or_t` is given as array-like, must have same length as x along `axis`." )
133+
134+ # turn circular_vars into something with the same shape as the number of differentiated axes in x
135+ if len (x .shape ) > 1 :
136+ n = int (np .prod (x .shape [:axis ] + x .shape [axis + 1 :]))
137+ circular_vars = ensure_iterable (circular_vars , n )
138+ else :
139+ circular_vars = ensure_iterable (circular_vars , 1 )
119140
120141 q = 10 ** int (log_qr_ratio / 2 ) # even-ish split of the powers across 0
121142 r = q / (10 ** log_qr_ratio )
@@ -143,19 +164,21 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
143164 x_hat = np .empty_like (x ); dxdt_hat = np .empty_like (x )
144165 if forwardbackward : w = np .linspace (0 , 1 , N ) # weights used to combine forward and backward results
145166
146- for vec_idx in np .ndindex (x .shape [:axis ] + x .shape [axis + 1 :]): # works properly for 1D case too
167+ for i , vec_idx in enumerate ( np .ndindex (x .shape [:axis ] + x .shape [axis + 1 :]) ): # works properly for 1D case too
147168 s = vec_idx [:axis ] + (slice (None ),) + vec_idx [axis :] # for indexing the vector we wish to differentiate
148169 xhat0 = np .zeros (order + 1 ); xhat0 [0 ] = x [s ][0 ] if not np .isnan (x [s ][0 ]) else 0 # The first estimate is the first seen state. See #110
149170
150- xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ], xhat0 , P0 , A_d , Q_d , C , R )
171+ xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ], xhat0 , P0 , A_d , Q_d , C , R ,
172+ circular_vars = circular_vars [i ], circular_units = circular_units )
151173 xhat_smooth = rts_smooth (A_d , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
152174 x_hat [s ] = xhat_smooth [:,0 ] # first dimension is time, so slice first and second states at all times
153175 dxdt_hat [s ] = xhat_smooth [:,1 ]
154176
155177 if forwardbackward :
156178 xhat0 [0 ] = x [s ][- 1 ] if not np .isnan (x [s ][- 1 ]) else 0
157179 xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ][::- 1 ], xhat0 , P0 , A_d_bwd ,
158- Q_d if Q_d .ndim == 2 else Q_d [::- 1 ], C , R ) # Use same Q matrices as before, because noise should still grow in reverse time
180+ Q_d if Q_d .ndim == 2 else Q_d [::- 1 ], C , R ,
181+ circular_vars = circular_vars [i ], circular_units = circular_units ) # Use same Q matrices as before, because noise should still grow in reverse time
159182 xhat_smooth = rts_smooth (A_d_bwd , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
160183
161184 x_hat [s ] = x_hat [s ] * w + xhat_smooth [:, 0 ][::- 1 ] * (1 - w )
0 commit comments