Skip to content

Commit 87f41c4

Browse files
authored
Fix ordering of time-varying statespace matices for UnivariateFilter's kalman step (#653)
* Added unpack_args to UnivariateFilter kalman_step function so ordering of time-varying statespace matrices is correct * added missing value handling for log likelihood and statespace matrices in UnivariateFilter kalman_step
1 parent cca9207 commit 87f41c4

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,11 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P
778778

779779
return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner
780780

781-
def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
782-
nan_mask = pt.isnan(y)
781+
def kalman_step(self, *args):
782+
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
783783

784-
W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0)
785-
Z_masked = W.dot(Z)
786-
H_masked = W.dot(H)
787-
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
784+
nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
785+
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
788786

789787
result = pytensor.scan(
790788
self._univariate_inner_filter_step,
@@ -805,6 +803,10 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
805803
P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter)
806804
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
807805

808-
ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
806+
ll = pt.switch(
807+
all_nan_flag,
808+
0.0,
809+
-0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()),
810+
)
809811

810812
return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll

0 commit comments

Comments
 (0)