Skip to content

Commit 7d66fb8

Browse files
More stable KF update (#684)
* More stable KF update * cleanup * use cho_solve
1 parent c247de4 commit 7d66fb8

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,16 +600,18 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
600600
PZT = P.dot(Z.mT)
601601
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
602602

603-
K = pt.linalg.solve(F.mT, PZT.mT, assume_a="pos", check_finite=False).mT
603+
F_chol = pt.linalg.cholesky(F, lower=True)
604+
605+
K = pt.linalg.cho_solve((F_chol, True), PZT.mT).mT
604606
I_KZ = pt.eye(self.n_states) - K.dot(Z)
605607

606608
a_filtered = a + K @ v
607609
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
608610

609-
F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
611+
F_inv_v = pt.linalg.cho_solve((F_chol, True), v)
610612
inner_term = v.T @ F_inv_v
611613

612-
F_logdet = pt.log(pt.linalg.det(F))
614+
F_logdet = 2 * pt.log(pt.diag(F_chol)).sum()
613615

614616
ll = pt.switch(
615617
all_nan_flag,

0 commit comments

Comments
 (0)