Skip to content

Commit 67e6877

Browse files
committed
cleaning more
1 parent 71a082d commit 67e6877

4 files changed

Lines changed: 37 additions & 25 deletions

File tree

notebooks/7_circular_domain.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@
8080
],
8181
"source": [
8282
"x_hat, dxdt_hat = pynumdiff.kalman_smooth.rtsdiff(x, dt, order=1, log_qr_ratio=5, axis=0, circular=False)\n",
83-
"x_hat_wrapped = (x_hat + np.pi) % (2*np.pi) - np.pi\n",
8483
"\n",
85-
"evaluate.plot(x, dt, x_hat_wrapped, dxdt_hat, x_truth, dxdt_truth);"
84+
"evaluate.plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth);"
8685
]
8786
},
8887
{
@@ -112,9 +111,8 @@
112111
],
113112
"source": [
114113
"x_hat, dxdt_hat = pynumdiff.kalman_smooth.rtsdiff(x, dt, order=1, log_qr_ratio=3, circular=True)\n",
115-
"x_hat_wrapped = (x_hat + np.pi) % (2*np.pi) - np.pi\n",
116114
"\n",
117-
"evaluate.plot(x, dt, x_hat_wrapped, dxdt_hat, x_truth, dxdt_truth);"
115+
"evaluate.plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth);"
118116
]
119117
}
120118
],

pynumdiff/kalman_smooth.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True, innovat
2424
:param np.array u: optional control inputs, stacked in the direction of axis 0
2525
:param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2626
smoothing but nonstandard to compute for ordinary filtering
27-
:param callable innovation_fn: optional function :code:`(y_n, pred)` returning the innovation :code:`y_n - pred`,
28-
where :code:`pred = C @ xhat_`. When :code:`None` (default), standard subtraction is used. Use this to handle
29-
circular quantities, e.g. :code:`lambda y, pred: (y - pred + np.pi) % (2*np.pi) - np.pi` for angular measurements in radians.
27+
:param callable innovation_fn: optional function taking measurements and predicted measurements and returning the innovation.
28+
When :code:`None`, traditional subtraction is used. This is exposed to handle cases like wrapped domains, where alternative
29+
displacement measures may be more appropriate. See e.g. the function passed by :code:`rtsdiff` with :code:`circular=True`.
3030
3131
:return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
3232
- **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -60,7 +60,7 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True, innovat
6060
P = P_.copy()
6161
if not np.isnan(y[n]): # handle missing data
6262
K = P_ @ C.T @ np.linalg.inv(C @ P_ @ C.T + R)
63-
innovation = innovation_fn(y[n], C @ xhat_) if innovation_fn is not None else y[n] - C @ xhat_
63+
innovation = y[n] - C @ xhat_ if innovation_fn is None else innovation_fn(y[n], C @ xhat_)
6464
xhat += K @ innovation
6565
P -= K @ C @ P_
6666
# the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
@@ -116,7 +116,8 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward=False, axis=0, circ
116116
:param bool circular: if :code:`True`, treat the measured quantity as a circular variable in radians, wrapping the
117117
innovation to :math:`[-\\pi, \\pi]`. The input :code:`x` must be in radians; convert degrees with :code:`np.deg2rad`.
118118
119-
:return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
119+
:return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`.
120+
When :code:`circular=True`, wrapped to :math:`[-\\pi, \\pi]`.
120121
- **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
121122
"""
122123
N = x.shape[axis]
@@ -153,22 +154,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward=False, axis=0, circ
153154

154155
for vec_idx in np.ndindex(x.shape[:axis] + x.shape[axis+1:]): # works properly for 1D case too
155156
s = vec_idx[:axis] + (slice(None),) + vec_idx[axis:] # for indexing the vector we wish to differentiate
156-
xhat0 = np.zeros(order+1); xhat0[0] = x[s][0] if not np.isnan(x[s][0]) else 0 # The first estimate is the first seen state. See #110
157+
xhat0 = np.zeros(order+1)
158+
if not np.isnan(x[s][0]): xhat0[0] = x[s][0] # The first estimate is the first seen state. See #110
157159

158160
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s], xhat0, P0, A_d, Q_d, C, R, innovation_fn=innovation_fn)
159161
xhat_smooth = rts_smooth(A_d, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
160162
x_hat[s] = xhat_smooth[:,0] # first dimension is time, so slice first and second states at all times
161163
dxdt_hat[s] = xhat_smooth[:,1]
162164

163165
if forwardbackward:
164-
xhat0[0] = x[s][-1] if not np.isnan(x[s][-1]) else 0
166+
if not np.isnan(x[s][-1]): xhat0[0] = x[s][-1]
165167
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s][::-1], xhat0, P0, A_d_bwd,
166168
Q_d if Q_d.ndim == 2 else Q_d[::-1], C, R, innovation_fn=innovation_fn) # Use same Q matrices as before, because noise should still grow in reverse time
167169
xhat_smooth = rts_smooth(A_d_bwd, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
168170

169171
x_hat[s] = x_hat[s] * w + xhat_smooth[:, 0][::-1] * (1-w)
170172
dxdt_hat[s] = dxdt_hat[s] * w + xhat_smooth[:, 1][::-1] * (1-w)
171173

174+
if circular: x_hat = (x_hat + np.pi) % (2*np.pi) - np.pi # wrap output to match the input domain
172175
return x_hat, dxdt_hat
173176

174177

pynumdiff/tests/test_diff_methods.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -405,22 +405,34 @@ def test_multidimensionality(multidim_method_and_params, request):
405405
legend = ax3.legend(bbox_to_anchor=(0.7, 0.8)); legend.legend_handles[0].set_facecolor(pyplot.cm.viridis(0.6))
406406
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
407407

408-
def test_circular_rtsdiff():
408+
def test_circular_rtsdiff(request):
409409
"""Ensure rtsdiff with circular=True correctly differentiates a wrapping angle signal in radians"""
410-
np.random.seed(42)
411-
N = 200
412-
dt_circ = 0.05
413-
t_circ = np.arange(N) * dt_circ
414-
true_dtheta = 2.0 # constant angular velocity in rad/s
415-
theta_true = true_dtheta * t_circ # linearly increasing angle, crosses 2*pi boundaries
416-
theta_noisy = np.angle(np.exp(1j * (theta_true + 0.1 * np.random.randn(N)))) # wrap to [-pi, pi] and add noise
410+
dtheta = 5 # constant angular velocity in rad/s
411+
theta = dtheta * t # linearly increasing angle, crosses 2*pi boundaries
412+
theta_noisy = np.angle(np.exp(1j * (theta + noise))) # add noise and wrap to [-pi, pi]
417413

418-
_, dxdt_hat = rtsdiff(theta_noisy, dt_circ, order=1, log_qr_ratio=1, forwardbackward=True, circular=True)
414+
theta_hat_naive, dxdt_hat_naive = rtsdiff(theta_noisy, dt, order=1, log_qr_ratio=1, circular=False)
415+
theta_hat, dxdt_hat = rtsdiff(theta_noisy, dt, order=1, log_qr_ratio=1, circular=True)
416+
417+
naive_rmse = np.sqrt(np.mean((dxdt_hat_naive - dtheta)**2))
418+
wrapped_rmse = np.sqrt(np.mean((dxdt_hat - dtheta)**2))
419+
assert wrapped_rmse < naive_rmse
419420

420-
# The interior of the signal (away from endpoints) should recover the true angular velocity well
421-
interior = slice(10, N-10)
422-
rmse = np.sqrt(np.mean((dxdt_hat[interior] - true_dtheta)**2))
423-
assert rmse < 0.5, f"RMSE of angular velocity estimate too large: {rmse:.3f} rad/s"
421+
if request.config.getoption("--plot"):
422+
from matplotlib import pyplot
423+
fig, (ax1, ax2) = pyplot.subplots(2, 1, figsize=(10, 6), sharex=True)
424+
ax1.plot(t, theta_noisy, 'k+', label=r'$\theta$ noisy (wrapped)')
425+
ax1.plot(t, theta_hat_naive, 'C1--', label=r'$\hat{\theta}$ with circular=False')
426+
ax1.plot(t, theta_hat, 'C0', label=r'$\hat{\theta}$ with circular=True')
427+
ax1.set_ylabel(r'$\theta$ (rad)')
428+
ax1.legend()
429+
ax2.axhline(dtheta, color='C2', xmin=0.045, xmax=0.955, label=r'true $\dot{\theta}$')
430+
ax2.plot(t, dxdt_hat_naive, 'C1--', label=r'$\hat{\dot{\theta}}$ circular=False')
431+
ax2.plot(t, dxdt_hat, 'C0', label=r'$\hat{\dot{\theta}}$ circular=True')
432+
ax2.set_ylabel(r'$\dot{\theta}$ (rad/time)')
433+
ax2.set_xlabel('t')
434+
ax2.legend()
435+
fig.suptitle('rtsdiff with circular domain', fontsize=16)
424436

425437

426438
# List of methods that can handle missing values

pynumdiff/utils/utility.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from scipy.ndimage import convolve1d
88

99

10-
1110
def huber_const(M):
1211
"""Scale that makes :code:`sum(huber())` interpolate :math:`\\sqrt{2}\\|\\cdot\\|_1` and :math:`\\frac{1}{2}\\|\\cdot\\|_2^2`,
1312
from https://jmlr.org/papers/volume14/aravkin13a/aravkin13a.pdf, with correction for missing sqrt. Here :code:`huber`

0 commit comments

Comments
 (0)