@@ -79,14 +79,17 @@ def sigma_points(state: GaussianState, lam: Array) -> Array:
7979 """
8080 n = state .mean .shape [0 ]
8181 scaled_cov = (n + lam ) * state .cov
82- L = jnp .linalg .cholesky (scaled_cov ) # (n, n)
82+ chol = jnp .linalg .cholesky (scaled_cov ) # (n, n)
8383
8484 # Build sigma points: [mean, mean + L_i, mean - L_i]
85- offsets = jnp .concatenate ([
86- jnp .zeros ((1 , n )),
87- L , # rows of L as positive offsets
88- - L , # rows of L as negative offsets
89- ], axis = 0 ) # (2n+1, n)
85+ offsets = jnp .concatenate (
86+ [
87+ jnp .zeros ((1 , n )),
88+ chol , # rows of L as positive offsets
89+ - chol , # rows of L as negative offsets
90+ ],
91+ axis = 0 ,
92+ ) # (2n+1, n)
9093
9194 return state .mean [None , :] + offsets
9295
@@ -162,29 +165,29 @@ def update(
162165 # Predicted observation mean
163166 y_pred = jnp .sum (weights .wm [:, None ] * pts_obs , axis = 0 ) # (m,)
164167
165- # Innovation covariance S = sum wc * (y_diff)(y_diff)' + R
168+ # Innovation covariance s = sum wc * (y_diff)(y_diff)' + R
166169 y_diff = pts_obs - y_pred [None , :] # (2n+1, m)
167- S = jnp .sum (weights .wc [:, None , None ] * (y_diff [:, :, None ] * y_diff [:, None , :]), axis = 0 )
168- S = S + model .R # (m, m)
170+ s = jnp .sum (weights .wc [:, None , None ] * (y_diff [:, :, None ] * y_diff [:, None , :]), axis = 0 )
171+ s = s + model .R # (m, m)
169172
170- # Cross-covariance P_xy = sum wc * (x_diff)(y_diff)'
173+ # Cross-covariance p_xy = sum wc * (x_diff)(y_diff)'
171174 x_diff = pts - predicted .mean [None , :] # (2n+1, n)
172- P_xy = jnp .sum (weights .wc [:, None , None ] * (x_diff [:, :, None ] * y_diff [:, None , :]), axis = 0 )
175+ p_xy = jnp .sum (weights .wc [:, None , None ] * (x_diff [:, :, None ] * y_diff [:, None , :]), axis = 0 )
173176 # (n, m)
174177
175- # Kalman gain K = P_xy @ S ^{-1}
176- K = jnp .linalg .solve (S .T , P_xy .T ).T # (n, m)
178+ # Kalman gain k = p_xy @ s ^{-1}
179+ k_gain = jnp .linalg .solve (s .T , p_xy .T ).T # (n, m)
177180
178181 # Innovation
179182 e = y - y_pred # (m,)
180183
181- filtered_mean = predicted .mean + K @ e
182- filtered_cov = predicted .cov - K @ S @ K .T
184+ filtered_mean = predicted .mean + k_gain @ e
185+ filtered_cov = predicted .cov - k_gain @ s @ k_gain .T
183186
184- # Log-likelihood: log N(e; 0, S )
187+ # Log-likelihood: log N(e; 0, s )
185188 m = observation .shape [- 1 ]
186- log_det = jnp .linalg .slogdet (S )[1 ]
187- mahal = e @ jnp .linalg .solve (S , e )
189+ log_det = jnp .linalg .slogdet (s )[1 ]
190+ mahal = e @ jnp .linalg .solve (s , e )
188191 ll = - 0.5 * (m * jnp .log (2.0 * jnp .pi ) + log_det + mahal )
189192
190193 # Handle missing observations
@@ -248,8 +251,12 @@ def scan(
248251 ) -> FilterResult :
249252 """Run full forward UKF via jax.lax.scan."""
250253 return _ukf_filter_impl (
251- model , observations , initial_state ,
252- self .alpha , self .beta , self .kappa ,
254+ model ,
255+ observations ,
256+ initial_state ,
257+ self .alpha ,
258+ self .beta ,
259+ self .kappa ,
253260 )
254261
255262
@@ -326,9 +333,7 @@ def _ukf_scan(
326333 log_likelihood = jnp .array (0.0 ),
327334 )
328335
329- def _scan_step (
330- carry : _ScanCarry , obs : Array
331- ) -> tuple [_ScanCarry , _ScanOutput ]:
336+ def _scan_step (carry : _ScanCarry , obs : Array ) -> tuple [_ScanCarry , _ScanOutput ]:
332337 predicted = predict (carry .filtered , model , weights )
333338 filtered , ll = update (predicted , obs , model , weights )
334339 new_carry = _ScanCarry (
0 commit comments