Skip to content

Commit afe0495

Browse files
pavelkomarovclaude
andcommitted
Add axis parameter to robustdiff for multidimensional support (#76)
Uses np.moveaxis + np.ndindex pattern consistent with rtsdiff. Matrices A_d, Q_d, C, R are precomputed once outside the loop and shared across all dimensions, so the expensive matrix exponential and condition checks are not repeated per dimension. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4755e00 commit afe0495

2 files changed

Lines changed: 27 additions & 23 deletions

File tree

pynumdiff/kalman_smooth.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def constant_jerk(x, dt, params=None, options=None, r=None, q=None, forwardbackw
254254
return rtsdiff(x, dt, 3, np.log10(q/r), forwardbackward)
255255

256256

257-
def robustdiff(x, dt_or_t, order, log_q, log_r, proc_huberM=6, meas_huberM=0):
257+
def robustdiff(x, dt_or_t, order, log_q, log_r, proc_huberM=6, meas_huberM=0, axis=0):
258258
"""Perform outlier-robust differentiation by solving the Maximum A Priori optimization problem:
259259
:math:`\\text{argmin}_{\\{x_n\\}} \\sum_{n=0}^{N-1} V(R^{-1/2}(y_n - C x_n)) + \\sum_{n=1}^{N-1} J(Q_{n-1}^{-1/2}(x_n - A_{n-1} x_{n-1}))`,
260260
where :math:`A,Q,C,R` come from an assumed constant derivative model and :math:`V,J` are the :math:`\\ell_1` norm or Huber
@@ -287,38 +287,40 @@ def robustdiff(x, dt_or_t, order, log_q, log_r, proc_huberM=6, meas_huberM=0):
287287
:param float log_r: base 10 logarithm of measurement noise variance, so :code:`r = 10**log_r`
288288
:param float proc_huberM: quadratic-to-linear transition point for process loss
289289
:param float meas_huberM: quadratic-to-linear transition point for measurement loss
290+
:param int axis: data dimension along which differentiation is performed
290291
291-
:return: - **x_hat** (np.array) -- estimated (smoothed) x
292-
- **dxdt_hat** (np.array) -- estimated derivative of x
292+
:return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
293+
- **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
293294
"""
294-
equispaced = np.isscalar(dt_or_t)
295-
if not equispaced and len(x) != len(dt_or_t):
296-
raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x`.")
295+
x = np.moveaxis(np.asarray(x), axis, 0)
296+
N = x.shape[0]
297+
if not np.isscalar(dt_or_t) and N != len(dt_or_t):
298+
raise ValueError("If `dt_or_t` is given as array-like, must have same length as `x` along `axis`.")
297299

298300
A_c = np.diag(np.ones(order), 1) # continuous-time A just has 1s on the first diagonal (where 0th is main diagonal)
299301
Q_c = np.zeros(A_c.shape); Q_c[-1,-1] = 10**log_q # continuous-time uncertainty around the last derivative
300302
C = np.zeros((1, order+1)); C[0,0] = 1 # we measure only y = noisy x
301303
R = np.array([[10**log_r]]) # 1 observed state, so this is 1x1
302304
M = np.block([[A_c, Q_c], [np.zeros(A_c.shape), -A_c.T]]) # exponentiate per step
303305

304-
if equispaced:
305-
# convert to discrete time using matrix exponential
306-
eM = expm(M * dt_or_t) # Note this could handle variable dt, similar to rtsdiff
307-
A_d = eM[:order+1, :order+1]
308-
Q_d = eM[:order+1, order+1:] @ A_d.T
309-
if np.linalg.cond(Q_d) > 1e12: Q_d += np.eye(order + 1)*1e-12 # for numerical stability with convex solver. Doesn't change answers appreciably (or at all).
310-
else: # support variable step size for this function
311-
A_d = np.empty((len(x)-1, order+1, order+1)) # stack all the evolution matrices
312-
Q_d = np.empty((len(x)-1, order+1, order+1))
313-
314-
for n,dt in enumerate(np.diff(dt_or_t)): # for each variable time step
306+
if np.isscalar(dt_or_t):
307+
eM = expm(M * dt_or_t)
308+
A_d = eM[:order+1, :order+1]; Q_d = eM[:order+1, order+1:] @ A_d.T
309+
if np.linalg.cond(Q_d) > 1e12: Q_d += np.eye(order + 1)*1e-12 # for numerical stability with convex solver
310+
else:
311+
A_d = np.empty((N-1, order+1, order+1)); Q_d = np.empty_like(A_d)
312+
for n, dt in enumerate(np.diff(dt_or_t)):
315313
eM = expm(M * dt)
316-
A_d[n] = eM[:order+1, :order+1] # extract discrete time A matrix
317-
Q_d[n] = eM[:order+1, order+1:] @ A_d[n].T # extract discrete time Q matrix
314+
A_d[n] = eM[:order+1, :order+1]; Q_d[n] = eM[:order+1, order+1:] @ A_d[n].T
318315
if np.linalg.cond(Q_d[n]) > 1e12: Q_d[n] += np.eye(order + 1)*1e-12
319316

320-
x_states = convex_smooth(x, A_d, Q_d, C, R, proc_huberM=proc_huberM, meas_huberM=meas_huberM) # outsource solution of the convex optimization problem
321-
return x_states[:,0], x_states[:,1]
317+
x_hat = np.empty_like(x); dxdt_hat = np.empty_like(x)
318+
for idx in np.ndindex(x.shape[1:]):
319+
s = (slice(None),) + idx
320+
x_states = convex_smooth(x[s], A_d, Q_d, C, R, proc_huberM=proc_huberM, meas_huberM=meas_huberM)
321+
x_hat[s] = x_states[:, 0]; dxdt_hat[s] = x_states[:, 1]
322+
323+
return np.moveaxis(x_hat, 0, axis), np.moveaxis(dxdt_hat, 0, axis)
322324

323325

324326
def convex_smooth(y, A, Q, C, R, B=None, u=None, proc_huberM=6, meas_huberM=0):

pynumdiff/tests/test_diff_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
326326
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
327327
(finitediff, {}),
328328
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
329-
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True})
329+
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
330+
(robustdiff, {'order':2, 'log_q':7, 'log_r':2}),
330331
]
331332

332333
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -338,7 +339,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
338339
butterdiff: [(0, -1), (1, -1)],
339340
finitediff: [(0, -1), (1, -1)],
340341
savgoldiff: [(0, -1), (1, 1)],
341-
rtsdiff: [(1, -1), (1, 0)]
342+
rtsdiff: [(1, -1), (1, 0)],
343+
robustdiff: [(2, 1), (3, 3)],
342344
}
343345

344346
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)

0 commit comments

Comments
 (0)