Skip to content

Commit c247de4

Browse files
Correctly mask obs_intercept and obs_cov when data is missing (#682)
* Correctly mask obs_intercept and obs_cov when data is missing * Stabilize H before factorization in SquareRootFilter
1 parent e4e39c7 commit c247de4

2 files changed

Lines changed: 93 additions & 25 deletions

File tree

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -298,17 +298,20 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
298298
return filter_results
299299

300300
def handle_missing_values(
301-
self, y, Z, H
302-
) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
301+
self, y, Z, H, d
302+
) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, float]:
303303
"""
304-
Handle missing values in the observation data `y`
304+
Handle missing values in the observation data ``y``.
305305
306-
Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
307-
associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
308-
propagating NaNs through the computation.
306+
Adjust the design matrix ``Z``, the observation noise covariance matrix ``H``, and the observation
307+
intercept ``d`` by zeroing the rows associated with observations that are missing at this iteration.
308+
The missing entries of ``y`` are replaced with zeros to prevent propagating NaNs through the
309+
computation. With ``y``, ``Z @ a``, and ``d`` all zero on the missing rows, the innovation
310+
:math:`v = y - (Z a + d)` is exactly zero there, so missing observations contribute nothing to the
311+
state update.
309312
310-
Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
311-
flag is used for numerical adjustments in the update method.
313+
Return a binary flag tensor ``all_nan_flag`` indicating whether every component of the observation
314+
is missing. This flag is used for numerical adjustments in the update method.
312315
313316
Parameters
314317
----------
@@ -318,21 +321,28 @@ def handle_missing_values(
318321
The design matrix.
319322
H : TensorVariable
320323
The observation noise covariance matrix.
324+
d : TensorVariable
325+
The observation intercept.
321326
322327
Returns
323328
-------
324329
y_masked : TensorVariable
325330
Observation vector with missing values replaced by zeros.
326331
327-
Z_masked: TensorVariable
328-
Design matrix adjusted to exclude the missing states from the information set of observed variables in the
329-
update step
332+
Z_masked : TensorVariable
333+
Design matrix with the rows corresponding to missing observations zeroed out.
334+
335+
H_masked : TensorVariable
336+
Observation noise covariance matrix with the rows *and columns* corresponding to missing
337+
observations zeroed out, so the result remains symmetric.
330338
331-
H_masked: TensorVariable
332-
Noise covariance matrix, adjusted to exclude the missing states
339+
d_masked : TensorVariable
340+
Observation intercept with the entries corresponding to missing observations zeroed out. Without
341+
this masking, missing rows of the innovation become :math:`-d`, injecting a fake observation
342+
into the state update and inflating the log-likelihood by :math:`d^2 / \\text{jitter}`.
333343
334-
all_nan_flag: float
335-
1 if the entire state vector is missing
344+
all_nan_flag : float
345+
1 if every component of the observation is missing.
336346
337347
References
338348
----------
@@ -344,10 +354,11 @@ def handle_missing_values(
344354
W = pt.diag(pt.bitwise_not(nan_mask).astype(pytensor.config.floatX))
345355

346356
Z_masked = W.dot(Z)
347-
H_masked = W.dot(H)
357+
H_masked = W.dot(H).dot(W.mT)
358+
d_masked = W.dot(d)
348359
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
349360

350-
return y_masked, Z_masked, H_masked, all_nan_flag
361+
return y_masked, Z_masked, H_masked, d_masked, all_nan_flag
351362

352363
@staticmethod
353364
def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
@@ -517,10 +528,12 @@ def kalman_step(self, *args) -> tuple:
517528
2nd ed, Oxford University Press, 2012.
518529
"""
519530
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
520-
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
531+
y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values(
532+
y, Z, H, d
533+
)
521534

522535
a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update(
523-
y=y_masked, a=a, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
536+
y=y_masked, a=a, d=d_masked, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag
524537
)
525538

526539
P_filtered = stabilize(P_filtered, self.cov_jitter)
@@ -652,10 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
652665

653666
y_hat = Z.dot(a) + d
654667
v = y - y_hat
655-
656-
H_chol = pytensor.ifelse(
657-
pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
658-
)
668+
H_chol = pt.linalg.cholesky(stabilize(H, self.cov_jitter), lower=True)
659669

660670
# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
661671
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
@@ -787,11 +797,13 @@ def kalman_step(self, *args):
787797
y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
788798

789799
nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value))
790-
y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H)
800+
y_masked, Z_masked, H_masked, d_masked, all_nan_flag = self.handle_missing_values(
801+
y, Z, H, d
802+
)
791803

792804
result = pytensor.scan(
793805
self._univariate_inner_filter_step,
794-
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
806+
sequences=[y_masked, Z_masked, d_masked, pt.diag(H_masked), nan_mask],
795807
outputs_info=[a, P, None, None, None],
796808
name="univariate_inner_scan",
797809
return_updates=False,

tests/statespace/filters/test_kalman_filter.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,62 @@ def test_missing_data(filter_name, p, rng):
223223
), f"Shape of {name} does not match expected"
224224

225225

226+
@pytest.mark.parametrize("filter_name", filter_names)
227+
def test_missing_value_with_nonzero_obs_intercept(filter_name, rng):
228+
"""
229+
With non-zero observation intercept ``d``, masking must zero ``d`` at missing rows so the
230+
innovation does not become ``-d`` and contaminate the log-likelihood. Verify by comparing
231+
against the equivalent ``(y - d, 0)`` parameterization, under which the filter is invariant.
232+
"""
233+
p, m, r, n = 3, 5, 1, 10
234+
data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng, missing_data=2)
235+
236+
d_nonzero = np.array([1.5, -0.7, 2.1], dtype=floatX)
237+
238+
# Reference: absorb d into the data (NaN entries stay NaN under subtraction).
239+
data_absorbed = data - d_nonzero
240+
out_ref = get_filter_function(filter_name)(
241+
data_absorbed, a0, P0, c, np.zeros_like(d_nonzero), T, Z, R, H, Q
242+
)
243+
out_d = get_filter_function(filter_name)(data, a0, P0, c, d_nonzero, T, Z, R, H, Q)
244+
245+
for idx, name in enumerate(output_names):
246+
assert_allclose(
247+
out_d[idx],
248+
out_ref[idx],
249+
atol=ATOL,
250+
rtol=RTOL,
251+
err_msg=f"{name} differs between (d, y) and (0, y - d) with missing observations",
252+
)
253+
254+
255+
@pytest.mark.parametrize("filter_name", filter_names)
256+
def test_missing_value_with_nondiagonal_obs_cov(filter_name, rng):
257+
"""
258+
With non-diagonal ``H`` and a missing observation at position ``j``, the cross-covariances
259+
``H[:, j]`` and ``H[j, :]`` cannot influence any observed quantity. Verify by comparing
260+
against a run where those rows and columns have been zeroed by hand — the two must agree.
261+
"""
262+
p, m, r, n = 2, 5, 1, 10
263+
data, a0, P0, c, d, T, Z, R, H, Q = make_test_inputs(p, m, r, n, rng)
264+
data[:, 1] = np.nan
265+
266+
H_full = np.array([[1.0, 0.4], [0.4, 1.0]], dtype=floatX)
267+
H_zeroed = np.array([[1.0, 0.0], [0.0, 0.0]], dtype=floatX)
268+
269+
out_full = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_full, Q)
270+
out_zeroed = get_filter_function(filter_name)(data, a0, P0, c, d, T, Z, R, H_zeroed, Q)
271+
272+
for idx, name in enumerate(output_names):
273+
assert_allclose(
274+
out_full[idx],
275+
out_zeroed[idx],
276+
atol=ATOL,
277+
rtol=RTOL,
278+
err_msg=f"{name} depends on H entries at masked positions",
279+
)
280+
281+
226282
@pytest.mark.parametrize("filter_name", filter_names)
227283
@pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"])
228284
def test_last_smoother_is_last_filtered(filter_name, output_idx, rng):

0 commit comments

Comments
 (0)