Skip to content

Commit fafb87c

Browse files
authored
Merge pull request #192 from Ashton-Graves/robustdiff-variable-step
Robustdiff variable step
2 parents c1fe325 + d8a6c82 commit fafb87c

5 files changed

Lines changed: 106 additions & 125 deletions

File tree

notebooks/1_basic_tutorial.ipynb

Lines changed: 56 additions & 98 deletions
Large diffs are not rendered by default.

pynumdiff/kalman_smooth.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def constant_jerk(x, dt, params=None, options=None, r=None, q=None, forwardbackw
259259
return rtsdiff(x, dt, 3, np.log10(q/r), forwardbackward)
260260

261261

262-
def robustdiff(x, dt, order, log_q, log_r, proc_huberM=6, meas_huberM=0):
262+
def robustdiff(x, dt_or_t, order, log_q, log_r, proc_huberM=6, meas_huberM=0):
263263
"""Perform outlier-robust differentiation by solving the Maximum A Priori optimization problem:
264264
:math:`\\text{argmin}_{\\{x_n\\}} \\sum_{n=0}^{N-1} V(R^{-1/2}(y_n - C x_n)) + \\sum_{n=1}^{N-1} J(Q^{-1/2}(x_n - A x_{n-1}))`,
265265
where :math:`A,Q,C,R` come from an assumed constant derivative model and :math:`V,J` are the :math:`\\ell_1` norm or Huber
@@ -295,16 +295,31 @@ def robustdiff(x, dt, order, log_q, log_r, proc_huberM=6, meas_huberM=0):
295295
:return: - **x_hat** (np.array) -- estimated (smoothed) x
296296
- **dxdt_hat** (np.array) -- estimated derivative of x
297297
"""
298+
equispaced = np.isscalar(dt_or_t)
299+
if not equispaced and len(x) != len(dt_or_t):
300+
raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
301+
298302
A_c = np.diag(np.ones(order), 1) # continuous-time A just has 1s on the first diagonal (where 0th is main diagonal)
299303
Q_c = np.zeros(A_c.shape); Q_c[-1,-1] = 10**log_q # continuous-time uncertainty around the last derivative
300304
C = np.zeros((1, order+1)); C[0,0] = 1 # we measure only y = noisy x
301305
R = np.array([[10**log_r]]) # 1 observed state, so this is 1x1
306+
M = np.block([[A_c, Q_c], [np.zeros(A_c.shape), -A_c.T]]) # exponentiate per step
302307

303-
# convert to discrete time using matrix exponential
304-
eM = expm(np.block([[A_c, Q_c], [np.zeros(A_c.shape), -A_c.T]]) * dt) # Note this could handle variable dt, similar to rtsdiff
305-
A_d = eM[:order+1, :order+1]
306-
Q_d = eM[:order+1, order+1:] @ A_d.T
307-
if np.linalg.cond(Q_d) > 1e12: Q_d += np.eye(order + 1)*1e-12 # for numerical stability with convex solver. Doesn't change answers appreciably (or at all).
308+
if equispaced:
309+
# convert to discrete time using matrix exponential
310+
eM = expm(M * dt_or_t) # Note this could handle variable dt, similar to rtsdiff
311+
A_d = eM[:order+1, :order+1]
312+
Q_d = eM[:order+1, order+1:] @ A_d.T
313+
if np.linalg.cond(Q_d) > 1e12: Q_d += np.eye(order + 1)*1e-12 # for numerical stability with convex solver. Doesn't change answers appreciably (or at all).
314+
else: # support variable step size for this function
315+
A_d = np.empty((len(x)-1, order+1, order+1)) # stack all the evolution matrices
316+
Q_d = np.empty((len(x)-1, order+1, order+1))
317+
318+
for i, dt in enumerate(np.diff(dt_or_t)): # for each variable time step
319+
eM = expm(M * dt)
320+
A_d[i] = eM[:order+1, :order+1] # extract discrete time A matrix
321+
Q_d[i] = eM[:order+1, order+1:] @ A_d[i].T # extract discrete time Q matrix
322+
if np.linalg.cond(Q_d[i]) > 1e12: Q_d[i] += np.eye(order + 1)*1e-12
308323

309324
x_states = convex_smooth(x, A_d, Q_d, C, R, proc_huberM=proc_huberM, meas_huberM=meas_huberM) # outsource solution of the convex optimization problem
310325
return x_states[:,0], x_states[:,1]
@@ -327,12 +342,20 @@ def convex_smooth(y, A, Q, C, R, B=None, u=None, proc_huberM=6, meas_huberM=0):
327342
:return: (np.array) -- state estimates (state_dim x N)
328343
"""
329344
N = len(y)
330-
x_states = cvxpy.Variable((A.shape[0], N)) # each column is [position, velocity, acceleration, ...] at step n
345+
state_dim = A.shape[-1]
346+
x_states = cvxpy.Variable((state_dim, N)) # each column is [position, velocity, acceleration, ...] at step n
331347
control = isinstance(B, np.ndarray) and isinstance(u, np.ndarray) # whether there is a control input
332348

333-
# It is extremely important to run time that CVXPY expressions be in vectorized form
334-
proc_resids = np.linalg.inv(sqrtm(Q)) @ (x_states[:,1:] - A @ x_states[:,:-1] - (0 if not control else B @ u[1:].T)) # all Q^(-1/2)(x_n - (A x_{n-1} + B u_n))
335-
meas_resids = np.linalg.inv(sqrtm(R)) @ (y.reshape(C.shape[0],-1) - C @ x_states) # all R^(-1/2)(y_n - C x_n)
349+
if A.ndim == 3: # It is extremely important to runtime that CVXPY expressions be in vectorized form
350+
Ax = cvxpy.einsum('nij,jn->in', A, x_states[:, :-1]) # multipy each A matrix by the corresponding x_states at that time step
351+
Q_inv_sqrts = np.array([np.linalg.inv(sqrtm(Q[n])) for n in range(N-1)]) # precompute Q^(-1/2) for each time step
352+
proc_resids = cvxpy.einsum('nij,jn->in', Q_inv_sqrts, x_states[:,1:] - Ax - (0 if not control else B @ u[1:].T))
353+
else: # all Q^(-1/2)(x_n - (A x_{n-1} + B u_n))
354+
proc_resids = np.linalg.inv(sqrtm(Q)) @ (x_states[:,1:] - A @ x_states[:,:-1] - (0 if not control else B @ u[1:].T))
355+
356+
obs = ~np.isnan(y) # boolean mask of non-NaN observations
357+
meas_resids = np.linalg.inv(sqrtm(R)) @ (y[obs].reshape(C.shape[0],-1) - C @ x_states[:,obs]) # all R^(-1/2)(y_n - C x_n)
358+
336359
# Process terms: sum of J(proc_resids)
337360
objective = 0.5*cvxpy.sum_squares(proc_resids) if proc_huberM == float('inf') \
338361
else np.sqrt(2)*cvxpy.sum(cvxpy.abs(proc_resids)) if proc_huberM < 1e-3 \
@@ -345,8 +368,8 @@ def convex_smooth(y, A, Q, C, R, B=None, u=None, proc_huberM=6, meas_huberM=0):
345368
# function https://www.cvxpy.org/api_reference/cvxpy.atoms.elementwise.html#huber, so correct with a factor of 0.5.
346369

347370
problem = cvxpy.Problem(cvxpy.Minimize(objective))
348-
try: problem.solve(solver=cvxpy.CLARABEL)
371+
try: problem.solve(solver=cvxpy.CLARABEL, canon_backend=cvxpy.SCIPY_CANON_BACKEND)
349372
except cvxpy.error.SolverError: pass # Could try another solver here, like SCS, but slows things down
350373

351-
if x_states.value is None: return np.full((N, A.shape[0]), np.nan) # There can be solver failure, even without error
374+
if x_states.value is None: return np.full((N, state_dim), np.nan) # There can be solver failure, even without error
352375
return x_states.value.T

pynumdiff/tests/test_diff_methods.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
def iterated_second_order(*args, **kwargs): return second_order(*args, **kwargs)
1414
def iterated_fourth_order(*args, **kwargs): return fourth_order(*args, **kwargs)
1515
def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
16+
def robust_irreg_step(*args, **kwargs): return robustdiff(*args, **kwargs)
1617
def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
17-
irreg_list = [spline_irreg_step, polydiff_irreg_step, rbfdiff, rtsdiff] # methods to test with irregular time steps
18+
irreg_list = [spline_irreg_step, polydiff_irreg_step, rbfdiff, rtsdiff, robust_irreg_step] # methods to test with irregular time steps
1819

1920
dt = 0.1
2021
t = np.linspace(0, 3, 31) # sample locations, including the endpoint
@@ -55,6 +56,7 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
5556
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
5657
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
5758
(robustdiff, {'order':3, 'log_q':7, 'log_r':2}),
59+
(robust_irreg_step, {'order':3, 'log_q':7, 'log_r':2}),
5860
(velocity, {'gamma':0.5}), (velocity, [0.5]),
5961
(acceleration, {'gamma':1}), (acceleration, [1]),
6062
(jerk, {'gamma':10}), (jerk, [10]),
@@ -231,6 +233,12 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
231233
[(-7, -7), (-2, -2), (0, -1), (1, 1)],
232234
[(0, 0), (2, 2), (0, 0), (2, 2)],
233235
[(1, 1), (3, 3), (1, 1), (3, 3)]],
236+
robust_irreg_step: [[(-15, -15), (-13, -14), (0, -1), (1, 1)],
237+
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
238+
[(-14, -14), (-13, -13), (0, -1), (1, 1)],
239+
[(-8, -8), (-2, -2), (0, -1), (1, 1)],
240+
[(0, 0), (2, 2), (0, 0), (2, 2)],
241+
[(1, 1), (3, 3), (1, 1), (3, 3)]],
234242
lineardiff: [[(-3, -4), (-3, -3), (0, -1), (1, 0)],
235243
[(-1, -2), (0, 0), (0, -1), (1, 0)],
236244
[(-1, -1), (0, 0), (0, -1), (1, 1)],

pynumdiff/utils/evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import matplotlib.pyplot as plt
44
from scipy import stats
5+
from scipy.special import huber
56

67
from pynumdiff.utils import utility
78

@@ -95,7 +96,7 @@ def robust_rme(u, v, padding=0, M=6):
9596
s = slice(padding, len(u)-padding) # slice out data we want to measure
9697

9798
sigma = stats.median_abs_deviation(u[s] - v[s], scale='normal') # M is in units of this robust scatter metric
98-
return np.sqrt(2*np.mean(utility.huber(u[s] - v[s], M*sigma)))
99+
return np.sqrt(2*np.mean(huber(M*sigma, u[s] - v[s])))
99100

100101

101102
def rmse(u, v, padding=0):

pynumdiff/utils/utility.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,15 @@
22
import numpy as np
33
from scipy.integrate import cumulative_trapezoid
44
from scipy.optimize import minimize
5+
from scipy.special import huber
56
from scipy.stats import median_abs_deviation, norm
67
from scipy.ndimage import convolve1d
78

89

9-
def huber(x, M):
10-
"""Huber loss function, for outlier-robust applications,
11-
`see here <https://www.cvxpy.org/api_reference/cvxpy.atoms.elementwise.html#huber>`_
12-
13-
:param np.array[float] x: data points on which to evaluate the Huber function pointwise
14-
:param float M: where the loss turns from quadratic to linear
15-
:return: (np.array[float]) -- pointwise evaluations of the Huber function
16-
"""
17-
absx = np.abs(x)
18-
return np.where(absx <= M, 0.5*x**2, M*(absx - 0.5*M))
19-
2010
def huber_const(M):
2111
"""Scale that makes :code:`sum(huber())` interpolate :math:`\\sqrt{2}\\|\\cdot\\|_1` and :math:`\\frac{1}{2}\\|\\cdot\\|_2^2`,
22-
from https://jmlr.org/papers/volume14/aravkin13a/aravkin13a.pdf, with correction for missing sqrt
12+
from https://jmlr.org/papers/volume14/aravkin13a/aravkin13a.pdf, with correction for missing sqrt. Here :code:`huber`
13+
refers to `scipy.special.huber <https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.huber.html>`_.
2314
2415
:param float M: Huber parameter, where the function turns from quadratic to linear
2516
:return: (float) -- appropriate scale factor to normalize the Huber function
@@ -65,7 +56,7 @@ def estimate_integration_constant(x, x_hat, M=6, axis=0):
6556
elif M < 1e-3: # small M looks like l1 loss, and Huber gets too flat to work well
6657
return np.median(x - x_hat, axis=axis).reshape(s) # Solves the l1 distance minimization, argmin_c ||x_hat + c - x||_1
6758
else:
68-
return minimize(lambda c: np.sum(huber(x_hat + c.reshape(s) - x, M*sigma)), # fn to minimize in 1st argument
59+
return minimize(lambda c: np.sum(huber(M*sigma, x_hat + c.reshape(s) - x)), # fn to minimize in 1st argument
6960
np.zeros(np.prod(s)), method='SLSQP').x.reshape(s) # initial guess is zeros; vector result must be reshaped
7061

7162

0 commit comments

Comments
 (0)