Skip to content

Commit a588eac

Browse files
committed
fix ruff
1 parent 9bb3b90 commit a588eac

5 files changed

Lines changed: 63 additions & 76 deletions

File tree

src/dynaris/core/nonlinear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Callable
56
from dataclasses import dataclass
6-
from typing import Callable
77

88
import jax
99
import jax.numpy as jnp

src/dynaris/filters/ekf.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def predict(state: GaussianState, model: NonlinearSSM) -> GaussianState:
5050
R_t = F_t @ C_{t-1} @ F_t' + Q
5151
"""
5252
mean = model.f(state.mean)
53-
F_jac = jax.jacfwd(model.f)(state.mean) # (n, n)
54-
cov = F_jac @ state.cov @ F_jac.T + model.Q
53+
f_jac = jax.jacfwd(model.f)(state.mean) # (n, n)
54+
cov = f_jac @ state.cov @ f_jac.T + model.Q
5555
return GaussianState(mean=mean, cov=cov)
5656

5757

@@ -70,22 +70,22 @@ def update(
7070
"""
7171
y = observation
7272
y_pred = model.h(predicted.mean) # (m,)
73-
H_jac = jax.jacfwd(model.h)(predicted.mean) # (m, n)
73+
h_jac = jax.jacfwd(model.h)(predicted.mean) # (m, n)
7474

7575
e = y - y_pred # innovation (m,)
76-
S = H_jac @ predicted.cov @ H_jac.T + model.R # innovation covariance (m, m)
76+
s = h_jac @ predicted.cov @ h_jac.T + model.R # innovation covariance (m, m)
7777

7878
# Kalman gain: K = P @ H' @ S^{-1}
79-
K = jnp.linalg.solve(S.T, (predicted.cov @ H_jac.T).T).T # (n, m)
79+
k_gain = jnp.linalg.solve(s.T, (predicted.cov @ h_jac.T).T).T # (n, m)
8080

81-
filtered_mean = predicted.mean + K @ e
81+
filtered_mean = predicted.mean + k_gain @ e
8282
identity = jnp.eye(predicted.mean.shape[-1])
83-
filtered_cov = (identity - K @ H_jac) @ predicted.cov
83+
filtered_cov = (identity - k_gain @ h_jac) @ predicted.cov
8484

8585
# Log-likelihood: log N(e; 0, S)
8686
m = observation.shape[-1]
87-
log_det = jnp.linalg.slogdet(S)[1]
88-
mahal = e @ jnp.linalg.solve(S, e)
87+
log_det = jnp.linalg.slogdet(s)[1]
88+
mahal = e @ jnp.linalg.solve(s, e)
8989
ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal)
9090

9191
# Handle missing observations: if any element is NaN, skip update
@@ -181,9 +181,7 @@ def h(x):
181181
log_likelihood=jnp.array(0.0),
182182
)
183183

184-
def _scan_step(
185-
carry: _ScanCarry, obs: Array
186-
) -> tuple[_ScanCarry, _ScanOutput]:
184+
def _scan_step(carry: _ScanCarry, obs: Array) -> tuple[_ScanCarry, _ScanOutput]:
187185
predicted = predict(carry.filtered, model)
188186
filtered, ll = update(predicted, obs, model)
189187
new_carry = _ScanCarry(

src/dynaris/filters/ukf.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,17 @@ def sigma_points(state: GaussianState, lam: Array) -> Array:
7979
"""
8080
n = state.mean.shape[0]
8181
scaled_cov = (n + lam) * state.cov
82-
L = jnp.linalg.cholesky(scaled_cov) # (n, n)
82+
chol = jnp.linalg.cholesky(scaled_cov) # (n, n)
8383

8484
# Build sigma points: [mean, mean + L_i, mean - L_i]
85-
offsets = jnp.concatenate([
86-
jnp.zeros((1, n)),
87-
L, # rows of L as positive offsets
88-
-L, # rows of L as negative offsets
89-
], axis=0) # (2n+1, n)
85+
offsets = jnp.concatenate(
86+
[
87+
jnp.zeros((1, n)),
88+
chol, # rows of L as positive offsets
89+
-chol, # rows of L as negative offsets
90+
],
91+
axis=0,
92+
) # (2n+1, n)
9093

9194
return state.mean[None, :] + offsets
9295

@@ -162,29 +165,29 @@ def update(
162165
# Predicted observation mean
163166
y_pred = jnp.sum(weights.wm[:, None] * pts_obs, axis=0) # (m,)
164167

165-
# Innovation covariance S = sum wc * (y_diff)(y_diff)' + R
168+
# Innovation covariance s = sum wc * (y_diff)(y_diff)' + R
166169
y_diff = pts_obs - y_pred[None, :] # (2n+1, m)
167-
S = jnp.sum(weights.wc[:, None, None] * (y_diff[:, :, None] * y_diff[:, None, :]), axis=0)
168-
S = S + model.R # (m, m)
170+
s = jnp.sum(weights.wc[:, None, None] * (y_diff[:, :, None] * y_diff[:, None, :]), axis=0)
171+
s = s + model.R # (m, m)
169172

170-
# Cross-covariance P_xy = sum wc * (x_diff)(y_diff)'
173+
# Cross-covariance p_xy = sum wc * (x_diff)(y_diff)'
171174
x_diff = pts - predicted.mean[None, :] # (2n+1, n)
172-
P_xy = jnp.sum(weights.wc[:, None, None] * (x_diff[:, :, None] * y_diff[:, None, :]), axis=0)
175+
p_xy = jnp.sum(weights.wc[:, None, None] * (x_diff[:, :, None] * y_diff[:, None, :]), axis=0)
173176
# (n, m)
174177

175-
# Kalman gain K = P_xy @ S^{-1}
176-
K = jnp.linalg.solve(S.T, P_xy.T).T # (n, m)
178+
# Kalman gain k = p_xy @ s^{-1}
179+
k_gain = jnp.linalg.solve(s.T, p_xy.T).T # (n, m)
177180

178181
# Innovation
179182
e = y - y_pred # (m,)
180183

181-
filtered_mean = predicted.mean + K @ e
182-
filtered_cov = predicted.cov - K @ S @ K.T
184+
filtered_mean = predicted.mean + k_gain @ e
185+
filtered_cov = predicted.cov - k_gain @ s @ k_gain.T
183186

184-
# Log-likelihood: log N(e; 0, S)
187+
# Log-likelihood: log N(e; 0, s)
185188
m = observation.shape[-1]
186-
log_det = jnp.linalg.slogdet(S)[1]
187-
mahal = e @ jnp.linalg.solve(S, e)
189+
log_det = jnp.linalg.slogdet(s)[1]
190+
mahal = e @ jnp.linalg.solve(s, e)
188191
ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal)
189192

190193
# Handle missing observations
@@ -248,8 +251,12 @@ def scan(
248251
) -> FilterResult:
249252
"""Run full forward UKF via jax.lax.scan."""
250253
return _ukf_filter_impl(
251-
model, observations, initial_state,
252-
self.alpha, self.beta, self.kappa,
254+
model,
255+
observations,
256+
initial_state,
257+
self.alpha,
258+
self.beta,
259+
self.kappa,
253260
)
254261

255262

@@ -326,9 +333,7 @@ def _ukf_scan(
326333
log_likelihood=jnp.array(0.0),
327334
)
328335

329-
def _scan_step(
330-
carry: _ScanCarry, obs: Array
331-
) -> tuple[_ScanCarry, _ScanOutput]:
336+
def _scan_step(carry: _ScanCarry, obs: Array) -> tuple[_ScanCarry, _ScanOutput]:
332337
predicted = predict(carry.filtered, model, weights)
333338
filtered, ll = update(predicted, obs, model, weights)
334339
new_carry = _ScanCarry(

tests/test_filters/test_ekf.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@
2323
# ---------------------------------------------------------------------------
2424

2525

26-
def _linear_nonlinear_model(
27-
sigma_level: float = 1.0, sigma_obs: float = 1.0
28-
) -> NonlinearSSM:
26+
def _linear_nonlinear_model(sigma_level: float = 1.0, sigma_obs: float = 1.0) -> NonlinearSSM:
2927
"""Local-level model as a NonlinearSSM (identity transition/observation)."""
30-
Q = jnp.array([[sigma_level**2]])
31-
R = jnp.array([[sigma_obs**2]])
28+
q = jnp.array([[sigma_level**2]])
29+
r = jnp.array([[sigma_obs**2]])
3230

3331
def f(x: Array) -> Array:
3432
return x
@@ -39,8 +37,8 @@ def h(x: Array) -> Array:
3937
return NonlinearSSM(
4038
transition_fn=f,
4139
observation_fn=h,
42-
transition_cov=Q,
43-
observation_cov=R,
40+
transition_cov=q,
41+
observation_cov=r,
4442
state_dim=1,
4543
obs_dim=1,
4644
)
@@ -72,15 +70,15 @@ def test_predict_identity_transition() -> None:
7270

7371
def test_predict_nonlinear_transition() -> None:
7472
"""Test with a nonlinear transition: f(x) = x + 0.1 * sin(x)."""
75-
Q = jnp.array([[0.5]])
73+
q = jnp.array([[0.5]])
7674

7775
def f(x: Array) -> Array:
7876
return x + 0.1 * jnp.sin(x)
7977

8078
model = NonlinearSSM(
8179
transition_fn=f,
8280
observation_fn=lambda x: x,
83-
transition_cov=Q,
81+
transition_cov=q,
8482
observation_cov=jnp.array([[1.0]]),
8583
state_dim=1,
8684
obs_dim=1,
@@ -140,15 +138,11 @@ def test_ekf_matches_kalman_on_linear_model() -> None:
140138
ekf_result = ekf_filter(nl_model, observations, initial_state=init)
141139
kf_result = kalman_filter(lin_model, observations, initial_state=init)
142140

143-
np.testing.assert_allclose(
144-
ekf_result.filtered_states, kf_result.filtered_states, atol=1e-4
145-
)
141+
np.testing.assert_allclose(ekf_result.filtered_states, kf_result.filtered_states, atol=1e-4)
146142
np.testing.assert_allclose(
147143
ekf_result.filtered_covariances, kf_result.filtered_covariances, atol=1e-3
148144
)
149-
np.testing.assert_allclose(
150-
ekf_result.log_likelihood, kf_result.log_likelihood, atol=1e-2
151-
)
145+
np.testing.assert_allclose(ekf_result.log_likelihood, kf_result.log_likelihood, atol=1e-2)
152146

153147

154148
# ---------------------------------------------------------------------------
@@ -198,9 +192,7 @@ def test_ekf_filter_with_missing_obs() -> None:
198192
assert jnp.isfinite(result.log_likelihood)
199193

200194
# At NaN points, predicted == filtered
201-
np.testing.assert_allclose(
202-
result.filtered_states[10], result.predicted_states[10], atol=1e-5
203-
)
195+
np.testing.assert_allclose(result.filtered_states[10], result.predicted_states[10], atol=1e-5)
204196

205197

206198
# ---------------------------------------------------------------------------
@@ -280,7 +272,7 @@ def h(x: Array) -> Array:
280272
# Simulate observations from a known trajectory
281273
true_state = jnp.array([3.0, 4.0])
282274
obs_list = []
283-
for t in range(50):
275+
for _t in range(50):
284276
true_state = f(true_state) + jax.random.normal(key, (2,)) * 0.01
285277
key, _ = jax.random.split(key)
286278
obs = h(true_state) + jax.random.normal(key, (2,)) * 0.1
@@ -315,13 +307,13 @@ def test_grad_through_ekf() -> None:
315307
observations = NILE[:20].reshape(-1, 1)
316308

317309
def neg_ll(log_sigma_level: Array, log_sigma_obs: Array) -> Array:
318-
Q = jnp.exp(log_sigma_level) * jnp.eye(1)
319-
R = jnp.exp(log_sigma_obs) * jnp.eye(1)
310+
q_mat = jnp.exp(log_sigma_level) * jnp.eye(1)
311+
r_mat = jnp.exp(log_sigma_obs) * jnp.eye(1)
320312
model = NonlinearSSM(
321313
transition_fn=lambda x: x,
322314
observation_fn=lambda x: x,
323-
transition_cov=Q,
324-
observation_cov=R,
315+
transition_cov=q_mat,
316+
observation_cov=r_mat,
325317
state_dim=1,
326318
obs_dim=1,
327319
)

tests/test_filters/test_ukf.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from dynaris.filters.ukf import (
1717
UnscentedKalmanFilter,
1818
compute_weights,
19+
predict,
1920
sigma_points,
2021
ukf_filter,
21-
predict,
2222
update,
2323
)
2424

@@ -30,9 +30,7 @@
3030
# ---------------------------------------------------------------------------
3131

3232

33-
def _linear_nonlinear_model(
34-
sigma_level: float = 1.0, sigma_obs: float = 1.0
35-
) -> NonlinearSSM:
33+
def _linear_nonlinear_model(sigma_level: float = 1.0, sigma_obs: float = 1.0) -> NonlinearSSM:
3634
"""Local-level model as a NonlinearSSM."""
3735
return NonlinearSSM(
3836
transition_fn=lambda x: x,
@@ -96,9 +94,7 @@ def test_sigma_points_symmetric() -> None:
9694
w = compute_weights(n=1)
9795
pts = sigma_points(state, w.lam)
9896
# Points 1 and 2 should be equidistant from the mean
99-
np.testing.assert_allclose(
100-
pts[1] - state.mean, -(pts[2] - state.mean), atol=1e-6
101-
)
97+
np.testing.assert_allclose(pts[1] - state.mean, -(pts[2] - state.mean), atol=1e-6)
10298

10399

104100
def test_sigma_points_weighted_mean_recovers_mean() -> None:
@@ -202,9 +198,7 @@ def test_ukf_matches_kalman_on_linear_model() -> None:
202198
np.testing.assert_allclose(
203199
ukf_result.filtered_states[10:], kf_result.filtered_states[10:], atol=0.5
204200
)
205-
np.testing.assert_allclose(
206-
ukf_result.log_likelihood, kf_result.log_likelihood, atol=5.0
207-
)
201+
np.testing.assert_allclose(ukf_result.log_likelihood, kf_result.log_likelihood, atol=5.0)
208202

209203

210204
# ---------------------------------------------------------------------------
@@ -251,9 +245,7 @@ def test_ukf_filter_with_missing_obs() -> None:
251245
result = ukf_filter(model, observations)
252246
assert jnp.all(jnp.isfinite(result.filtered_states))
253247
assert jnp.isfinite(result.log_likelihood)
254-
np.testing.assert_allclose(
255-
result.filtered_states[10], result.predicted_states[10], atol=1e-5
256-
)
248+
np.testing.assert_allclose(result.filtered_states[10], result.predicted_states[10], atol=1e-5)
257249

258250

259251
# ---------------------------------------------------------------------------
@@ -366,13 +358,13 @@ def test_grad_through_ukf() -> None:
366358
observations = NILE[:20].reshape(-1, 1)
367359

368360
def neg_ll(log_sigma_level: Array, log_sigma_obs: Array) -> Array:
369-
Q = jnp.exp(log_sigma_level) * jnp.eye(1)
370-
R = jnp.exp(log_sigma_obs) * jnp.eye(1)
361+
q = jnp.exp(log_sigma_level) * jnp.eye(1)
362+
r = jnp.exp(log_sigma_obs) * jnp.eye(1)
371363
model = NonlinearSSM(
372364
transition_fn=lambda x: x,
373365
observation_fn=lambda x: x,
374-
transition_cov=Q,
375-
observation_cov=R,
366+
transition_cov=q,
367+
observation_cov=r,
376368
state_dim=1,
377369
obs_dim=1,
378370
)

0 commit comments

Comments
 (0)