@@ -24,9 +24,9 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True, innovat
2424 :param np.array u: optional control inputs, stacked in the direction of axis 0
2525 :param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2626 smoothing but nonstandard to compute for ordinary filtering
27- :param callable innovation_fn: optional function :code:`(y_n, pred)` returning the innovation :code:`y_n - pred`,
28- where :code:`pred = C @ xhat_`. When :code:`None` (default), standard subtraction is used. Use this to handle
29- circular quantities, e.g. :code:`lambda y, pred: (y - pred + np.pi) % (2*np.pi) - np.pi` for angular measurements in radians .
27+ :param callable innovation_fn: optional function taking measurements and predicted measurements and returning the innovation.
28+ When :code:`None`, traditional subtraction is used. This is exposed to handle cases like wrapped domains, where alternative
29+ displacement measures may be more appropriate. See e.g. the function passed by :code:`rtsdiff` with :code:`circular=True` .
3030
3131 :return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
3232 - **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -60,7 +60,7 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True, innovat
6060 P = P_ .copy ()
6161 if not np .isnan (y [n ]): # handle missing data
6262 K = P_ @ C .T @ np .linalg .inv (C @ P_ @ C .T + R )
63- innovation = innovation_fn ( y [n ], C @ xhat_ ) if innovation_fn is not None else y [n ] - C @ xhat_
63+ innovation = y [n ] - C @ xhat_ if innovation_fn is None else innovation_fn ( y [n ], C @ xhat_ )
6464 xhat += K @ innovation
6565 P -= K @ C @ P_
6666 # the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
@@ -116,7 +116,8 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward=False, axis=0, circ
116116 :param bool circular: if :code:`True`, treat the measured quantity as a circular variable in radians, wrapping the
117117 innovation to :math:`[-\\ pi, \\ pi]`. The input :code:`x` must be in radians; convert degrees with :code:`np.deg2rad`.
118118
119- :return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
119+ :return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`.
120+ When :code:`circular=True`, wrapped to :math:`[-\\ pi, \\ pi]`.
120121 - **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
121122 """
122123 N = x .shape [axis ]
@@ -153,22 +154,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward=False, axis=0, circ
153154
154155 for vec_idx in np .ndindex (x .shape [:axis ] + x .shape [axis + 1 :]): # works properly for 1D case too
155156 s = vec_idx [:axis ] + (slice (None ),) + vec_idx [axis :] # for indexing the vector we wish to differentiate
156- xhat0 = np .zeros (order + 1 ); xhat0 [0 ] = x [s ][0 ] if not np .isnan (x [s ][0 ]) else 0 # The first estimate is the first seen state. See #110
157+ xhat0 = np .zeros (order + 1 )
158+ if not np .isnan (x [s ][0 ]): xhat0 [0 ] = x [s ][0 ] # The first estimate is the first seen state. See #110
157159
158160 xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ], xhat0 , P0 , A_d , Q_d , C , R , innovation_fn = innovation_fn )
159161 xhat_smooth = rts_smooth (A_d , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
160162 x_hat [s ] = xhat_smooth [:,0 ] # first dimension is time, so slice first and second states at all times
161163 dxdt_hat [s ] = xhat_smooth [:,1 ]
162164
163165 if forwardbackward :
164- xhat0 [ 0 ] = x [s ][- 1 ] if not np . isnan ( x [s ][- 1 ]) else 0
166+ if not np . isnan ( x [s ][- 1 ]): xhat0 [ 0 ] = x [s ][- 1 ]
165167 xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ][::- 1 ], xhat0 , P0 , A_d_bwd ,
166168 Q_d if Q_d .ndim == 2 else Q_d [::- 1 ], C , R , innovation_fn = innovation_fn ) # Use same Q matrices as before, because noise should still grow in reverse time
167169 xhat_smooth = rts_smooth (A_d_bwd , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
168170
169171 x_hat [s ] = x_hat [s ] * w + xhat_smooth [:, 0 ][::- 1 ] * (1 - w )
170172 dxdt_hat [s ] = dxdt_hat [s ] * w + xhat_smooth [:, 1 ][::- 1 ] * (1 - w )
171173
174+ if circular : x_hat = (x_hat + np .pi ) % (2 * np .pi ) - np .pi # wrap output to match the input domain
172175 return x_hat , dxdt_hat
173176
174177
0 commit comments