diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 2936098c4..8891ba926 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -298,17 +298,20 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: return filter_results def handle_missing_values( - self, y, Z, H - ) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]: + self, y, Z, H, d + ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, float]: """ - Handle missing values in the observation data `y` + Handle missing values in the observation data ``y``. - Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns - associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent - propagating NaNs through the computation. + Adjust the design matrix ``Z``, the observation noise covariance matrix ``H``, and the observation + intercept ``d`` by zeroing the rows associated with observations that are missing at this iteration. + The missing entries of ``y`` are replaced with zeros to prevent propagating NaNs through the + computation. With ``y``, ``Z @ a``, and ``d`` all zero on the missing rows, the innovation + :math:`v = y - (Z a + d)` is exactly zero there, so missing observations contribute nothing to the + state update. - Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This - flag is used for numerical adjustments in the update method. + Return a binary flag tensor ``all_nan_flag`` indicating whether every component of the observation + is missing. This flag is used for numerical adjustments in the update method. Parameters ---------- @@ -318,21 +321,28 @@ def handle_missing_values( The design matrix. H : TensorVariable The observation noise covariance matrix. + d : TensorVariable + The observation intercept. Returns ------- y_masked : TensorVariable Observation vector with missing values replaced by zeros. - Z_masked: TensorVariable - Design matrix adjusted to exclude the missing states from the information set of observed variables in the - update step + Z_masked : TensorVariable + Design matrix with the rows corresponding to missing observations zeroed out. + + H_masked : TensorVariable + Observation noise covariance matrix with the rows *and columns* corresponding to missing + observations zeroed out, so the result remains symmetric. - H_masked: TensorVariable - Noise covariance matrix, adjusted to exclude the missing states + d_masked : TensorVariable + Observation intercept with the entries corresponding to missing observations zeroed out. Without + this masking, missing rows of the innovation become :math:`-d`, injecting a fake observation + into the state update and inflating the log-likelihood by :math:`d^2 / \\text{jitter}`. - all_nan_flag: float - 1 if the entire state vector is missing + all_nan_flag : float + 1 if every component of the observation is missing. References ---------- @@ -344,10 +354,11 @@ def handle_missing_values( W = pt.diag(pt.bitwise_not(nan_mask).astype(pytensor.config.floatX)) Z_masked = W.dot(Z) - H_masked = W.dot(H) + H_masked = W.dot(H).dot(W.mT) + d_masked = W.dot(d) y_masked = pt.set_subtensor(y[nan_mask], 0.0) - return y_masked, Z_masked, H_masked, all_nan_flag + return y_masked, Z_masked, H_masked, d_masked, all_nan_flag @staticmethod def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]: @@ -517,10 +528,12 @@ def kalman_step(self, *args) -> tuple: 2nd ed, Oxford University Press, 2012. """ y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args) - y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) + y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values( + y, Z, H, d + ) a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update( - y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag + y=y_masked, a=a, d=d_masked, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag ) P_filtered = stabilize(P_filtered, self.cov_jitter) @@ -652,10 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag): y_hat = Z.dot(a) + d v = y - y_hat - - H_chol = pytensor.ifelse( - pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan") - ) + H_chol = pt.linalg.cholesky(stabilize(H, self.cov_jitter), lower=True) # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred], @@ -787,11 +797,13 @@ def kalman_step(self, *args): y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args) nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value)) - y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) + y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values( + y, Z, H, d + ) result = pytensor.scan( self._univariate_inner_filter_step, - sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask], + sequences=[y_masked, Z_masked, d_masked, pt.diag(H_masked), nan_mask], outputs_info=[a, P, None, None, None], name="univariate_inner_scan", return_updates=False, diff --git a/tests/statespace/filters/test_kalman_filter.py b/tests/statespace/filters/test_kalman_filter.py index c8eba632a..c61633433 100644 --- a/tests/statespace/filters/test_kalman_filter.py +++ b/tests/statespace/filters/test_kalman_filter.py @@ -223,6 +223,62 @@ def test_missing_data(filter_name, p, rng): ), f"Shape of {name} does not match expected" +@pytest.mark.parametrize("filter_name", filter_names) +def test_missing_value_with_nonzero_obs_intercept(filter_name, rng): + """ + With non-zero observation intercept ``d``, masking must zero ``d`` at missing rows so the + innovation does not become ``-d`` and contaminate the log-likelihood. Verify by comparing + against the equivalent ``(y - d, 0)`` parameterization, under which the filter is invariant. + """ + p, m, r, n = 3, 5, 1, 10 + data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng, missing_data=2) + + d_nonzero = np.array([1.5, -0.7, 2.1], dtype=floatX) + + # Reference: absorb d into the data (NaN entries stay NaN under subtraction). + data_absorbed = data - d_nonzero + out_ref = get_filter_function(filter_name)( + data_absorbed, a0, P0, c, np.zeros_like(d_nonzero), T, Z, R, H, Q + ) + out_d = get_filter_function(filter_name)(data, a0, P0, c, d_nonzero, T, Z, R, H, Q) + + for idx, name in enumerate(output_names): + assert_allclose( + out_d[idx], + out_ref[idx], + atol=ATOL, + rtol=RTOL, + err_msg=f"{name} differs between (d, y) and (0, y - d) with missing observations", + ) + + +@pytest.mark.parametrize("filter_name", filter_names) +def test_missing_value_with_nondiagonal_obs_cov(filter_name, rng): + """ + With non-diagonal ``H`` and a missing observation at position ``j``, the cross-covariances + ``H[:, j]`` and ``H[j, :]`` cannot influence any observed quantity. Verify by comparing + against a run where those rows and columns have been zeroed by hand — the two must agree. + """ + p, m, r, n = 2, 5, 1, 10 + data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng) + data[:, 1] = np.nan + + H_full = np.array([[1.0, 0.4], [0.4, 1.0]], dtype=floatX) + H_zeroed = np.array([[1.0, 0.0], [0.0, 0.0]], dtype=floatX) + + out_full = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_full, Q) + out_zeroed = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_zeroed, Q) + + for idx, name in enumerate(output_names): + assert_allclose( + out_full[idx], + out_zeroed[idx], + atol=ATOL, + rtol=RTOL, + err_msg=f"{name} depends on H entries at masked positions", + ) + + @pytest.mark.parametrize("filter_name", filter_names) @pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"]) def test_last_smoother_is_last_filtered(filter_name, output_idx, rng):