|
11 | 11 | except: |
12 | 12 | _HAS_NUMBA = False |
13 | 13 | if _HAS_NUMBA: |
14 | | - from pipedream_solver.nutils import interpolate_sample, _kalman_semi_implicit |
| 14 | + from pipedream_solver.nutils import interpolate_sample, _kalman_semi_implicit, _square_root_kalman_semi_implicit |
15 | 15 | else: |
16 | | - from pipedream_solver.utils import interpolate_sample, _kalman_semi_implicit |
| 16 | + from pipedream_solver.utils import interpolate_sample, _kalman_semi_implicit, _square_root_kalman_semi_implicit |
17 | 17 |
|
18 | 18 | eps = np.finfo(float).eps |
19 | 19 |
|
@@ -86,7 +86,7 @@ class Simulation(): |
86 | 86 | def __init__(self, model, Q_in=None, H_bc=None, Q_Ik=None, t_start=None, |
87 | 87 | t_end=None, dt=None, max_iter=None, min_dt=1, max_dt=200, |
88 | 88 | tol=0.01, min_rel_change=1e-10, max_rel_change=1e10, safety_factor=0.9, |
89 | | - Qcov=None, Rcov=None, C=None, H=None, interpolation_method='linear'): |
| 89 | + Pxx = None, Qcov=None, Rcov=None, C=None, H=None, interpolation_method='linear'): |
90 | 90 | self.model = model |
91 | 91 | if Q_in is not None: |
92 | 92 | self.Q_in = Q_in.copy(deep=True) |
@@ -204,7 +204,12 @@ def __init__(self, model, Q_in=None, H_bc=None, Q_Ik=None, t_start=None, |
204 | 204 | else: |
205 | 205 | assert isinstance(H, np.ndarray) |
206 | 206 | self.H = H |
207 | | - self.P_x_k_k = self.C @ self.Qcov @ self.C.T |
| 207 | + if Pxx is None: |
| 208 | + self.P_x_k_k = self.C @ self.Qcov @ self.C.T |
| 209 | + else: |
| 210 | + self.P_x_k_k = Pxx.copy() |
| 211 | + self.A_1 = None |
| 212 | + self.P_zz = None |
208 | 213 | # Progress bar checkpoints |
209 | 214 | if np.isfinite(self.t_end): |
210 | 215 | self._checkpoints = np.linspace(self.t_start, self.t_end) |
@@ -447,7 +452,7 @@ def filter_step_size(self, tol=0.5, dts=None, errs=None, coeffs=[0.5, 0.5, 0, 0. |
447 | 452 | return dt_np1 |
448 | 453 |
|
449 | 454 | def kalman_filter(self, Z, H=None, C=None, Qcov=None, Rcov=None, P_x_k_k=None, |
450 | | - dt=None, **kwargs): |
| 455 | + dt=None, SR=False, **kwargs): |
451 | 456 | """ |
452 | 457 | Apply Kalman Filter to fuse observed data into model. |
453 | 458 |
|
@@ -481,9 +486,15 @@ def kalman_filter(self, Z, H=None, C=None, Qcov=None, Rcov=None, P_x_k_k=None, |
481 | 486 | if Rcov is None: |
482 | 487 | Rcov = self.Rcov |
483 | 488 | A_1, A_2, b = self.model._semi_implicit_system(_dt=dt) |
484 | | - b_hat, P_x_k_k = _kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C, |
| 489 | + if SR == False: |
| 490 | + b_hat, P_x_k_k, P_zz = _kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C, |
| 491 | + Qcov, Rcov) |
| 492 | + else: |
| 493 | + b_hat, P_x_k_k, P_zz = _square_root_kalman_semi_implicit(Z, P_x_k_k, A_1, A_2, b, H, C, |
485 | 494 | Qcov, Rcov) |
486 | 495 | self.P_x_k_k = P_x_k_k |
| 496 | + self.P_zz = P_zz |
| 497 | + self.A_1 = A_1 |
487 | 498 | self.model.b = b_hat |
488 | 499 | self.model.iter_count -= 1 |
489 | 500 | self.model.t -= dt |
|
0 commit comments