Skip to content

Commit f73b15c

Browse files
sebcroftjessegrabowski
authored andcommitted
added missing value handling for log likelihood and statespace matrices in UnivariateFilter kalman_step
1 parent 7eb9d90 commit f73b15c

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,15 +779,10 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P
779779
return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner
780780

781781
def kalman_step(self, *args):
782-
783782
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
784783

785-
nan_mask = pt.isnan(y)
786-
787-
W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0)
788-
Z_masked = W.dot(Z)
789-
H_masked = W.dot(H)
790-
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
784+
nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
785+
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
791786

792787
result = pytensor.scan(
793788
self._univariate_inner_filter_step,
@@ -808,6 +803,10 @@ def kalman_step(self, *args):
808803
P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter)
809804
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
810805

811-
ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())
806+
ll = pt.switch(
807+
all_nan_flag,
808+
0.0,
809+
-0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()),
810+
)
812811

813812
return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll

0 commit comments

Comments
 (0)