Skip to content

Commit 129195a

Browse files
Floris van BreugelFloris van Breugel
authored andcommitted
derivatives of circular (wrapped) variables implemented with kalman approach (rtsdiff)
1 parent 37b39af commit 129195a

3 files changed

Lines changed: 445 additions & 7 deletions

File tree

notebooks/7_circular_variables.ipynb

Lines changed: 365 additions & 0 deletions
Large diffs are not rendered by default.

pynumdiff/kalman_smooth.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from warnings import warn
33
import numpy as np
44
from scipy.linalg import expm, sqrtm
5+
from collections.abc import Iterable
56
try: import cvxpy
67
except ImportError: pass
78

8-
from pynumdiff.utils.utility import huber_const
9+
from pynumdiff.utils.utility import huber_const, wrap_angle, ensure_iterable
910

1011

11-
def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
12+
def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True,
13+
circular_vars=None, circular_units='rad'):
1214
"""Run the forward pass of a Kalman filter. Expects discrete-time matrices; use :func:`scipy.linalg.expm`
1315
in the caller to convert from continuous time if needed.
1416
@@ -24,6 +26,10 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
2426
:param np.array u: optional control inputs, stacked in the direction of axis 0
2527
:param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2628
smoothing but nonstandard to compute for ordinary filtering
29+
:param bool or None circular_vars: bool indicating whether the measurement y is a circular (angular) variable
30+
that is wrapped. This will use a circular innovation calculation for the Kalman filter. The smoothed result
31+
will be returned in an unwrapped form.
32+
:param string circular_units: 'rad' or 'deg' to specify whether wrapping is in degrees or radians.
2733
2834
:return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
2935
- **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -57,7 +63,10 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
5763
P = P_.copy()
5864
if not np.isnan(y[n]): # handle missing data
5965
K = P_ @ C.T @ np.linalg.inv(C @ P_ @ C.T + R)
60-
xhat += K @ (y[n] - C @ xhat_)
66+
innovation = y[n] - C @ xhat_
67+
if circular_vars is not None and circular_vars is not False:
68+
innovation[0] = wrap_angle(innovation[0], circular_units)
69+
xhat += K @ innovation
6170
P -= K @ C @ P_
6271
# the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
6372
xhat_post[n] = xhat
@@ -94,7 +103,8 @@ def rts_smooth(A, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=True):
94103
return xhat_smooth if not compute_P_smooth else (xhat_smooth, P_smooth)
95104

96105

97-
def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
106+
def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0,
107+
circular_vars=None, circular_units='rad'):
98108
"""Perform Rauch-Tung-Striebel smoothing with a naive constant derivative model. Makes use of :code:`kalman_filter`
99109
and :code:`rts_smooth`, which are made public. :code:`constant_X` methods in this module call this function.
100110
@@ -109,13 +119,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
109119
:param bool forwardbackward: indicates whether to run smoother forwards and backwards
110120
(usually achieves better estimate at end points)
111121
:param int axis: data dimension along which differentiation is performed
122+
:param list[bool] circular_vars: set list element to bool for any axes of x that are a circular (angular) variable
123+
that is wrapped. This will use a circular innovation calculation for the Kalman filter. The smoothed result
124+
will be returned in an unwrapped form.
125+
:param string circular_units: 'rad' or 'deg' to specify whether wrapping is in degrees or radians.
112126
113127
:return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
114128
- **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
115129
"""
116130
N = x.shape[axis]
117131
if not np.isscalar(dt_or_t) and N != len(dt_or_t):
118132
raise ValueError("If `dt_or_t` is given as array-like, must have same length as x along `axis`.")
133+
134+
# turn circular_vars into something with the same shape as the number of differentiated axes in x
135+
if len(x.shape) > 1:
136+
n = int(np.prod(x.shape[:axis] + x.shape[axis+1:]))
137+
circular_vars = ensure_iterable(circular_vars, n)
138+
else:
139+
circular_vars = ensure_iterable(circular_vars, 1)
119140

120141
q = 10**int(log_qr_ratio/2) # even-ish split of the powers across 0
121142
r = q/(10**log_qr_ratio)
@@ -143,19 +164,21 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
143164
x_hat = np.empty_like(x); dxdt_hat = np.empty_like(x)
144165
if forwardbackward: w = np.linspace(0, 1, N) # weights used to combine forward and backward results
145166

146-
for vec_idx in np.ndindex(x.shape[:axis] + x.shape[axis+1:]): # works properly for 1D case too
167+
for i, vec_idx in enumerate(np.ndindex(x.shape[:axis] + x.shape[axis+1:])): # works properly for 1D case too
147168
s = vec_idx[:axis] + (slice(None),) + vec_idx[axis:] # for indexing the vector we wish to differentiate
148169
xhat0 = np.zeros(order+1); xhat0[0] = x[s][0] if not np.isnan(x[s][0]) else 0 # The first estimate is the first seen state. See #110
149170

150-
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s], xhat0, P0, A_d, Q_d, C, R)
171+
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s], xhat0, P0, A_d, Q_d, C, R,
172+
circular_vars=circular_vars[i], circular_units=circular_units)
151173
xhat_smooth = rts_smooth(A_d, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
152174
x_hat[s] = xhat_smooth[:,0] # first dimension is time, so slice first and second states at all times
153175
dxdt_hat[s] = xhat_smooth[:,1]
154176

155177
if forwardbackward:
156178
xhat0[0] = x[s][-1] if not np.isnan(x[s][-1]) else 0
157179
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s][::-1], xhat0, P0, A_d_bwd,
158-
Q_d if Q_d.ndim == 2 else Q_d[::-1], C, R) # Use same Q matrices as before, because noise should still grow in reverse time
180+
Q_d if Q_d.ndim == 2 else Q_d[::-1], C, R,
181+
circular_vars=circular_vars[i], circular_units=circular_units) # Use same Q matrices as before, because noise should still grow in reverse time
159182
xhat_smooth = rts_smooth(A_d_bwd, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
160183

161184
x_hat[s] = x_hat[s] * w + xhat_smooth[:, 0][::-1] * (1-w)

pynumdiff/utils/utility.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,53 @@ def peakdet(x, delta, t=None):
200200
lookformax = True # now searching for a max
201201

202202
return np.array(maxtab), np.array(mintab)
203+
204+
205+
def wrap_angle(angle, units='rad', range='symmetric'):
206+
"""Wrap an angle to a specified range.
207+
208+
:param np.array angle: angular values
209+
:param string units: either 'rad' or 'deg'
210+
:param string range: either 'symmetric' for [-pi, pi] / [-180, 180],
211+
or 'positive' for [0, 2pi] / [0, 360]
212+
213+
:return: - **angle** -- the angular values wrapped as requested
214+
"""
215+
if units == 'rad':
216+
period = 2 * np.pi
217+
if range == 'symmetric':
218+
return (angle + np.pi) % period - np.pi
219+
elif range == 'positive':
220+
return angle % period
221+
else:
222+
raise ValueError(f"Invalid range '{range}'. Expected 'symmetric' or 'positive'.")
223+
224+
elif units == 'deg':
225+
period = 360.
226+
if range == 'symmetric':
227+
return (angle + 180.) % period - 180.
228+
elif range == 'positive':
229+
return angle % period
230+
else:
231+
raise ValueError(f"Invalid range '{range}'. Expected 'symmetric' or 'positive'.")
232+
233+
else:
234+
raise ValueError(f"Invalid units '{units}'. Expected 'rad' or 'deg'.")
235+
236+
237+
def ensure_iterable(v, length):
238+
"""Ensure v is a list of the specified length.
239+
240+
If v is not iterable (e.g. a scalar), it is broadcast
241+
into a list by repeating it `length` times. If it is already iterable,
242+
it is returned as-is.
243+
244+
:param v: a scalar or iterable
245+
:param int length: desired length of the output list when broadcasting a scalar
246+
:return: v repeated `length` times if scalar, otherwise unchanged
247+
248+
:return: - **v** -- list or iterable
249+
"""
250+
if not isinstance(v, Iterable):
251+
return [v] * length
252+
return v

0 commit comments

Comments
 (0)