@@ -298,17 +298,20 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
298298 return filter_results
299299
300300 def handle_missing_values (
301- self , y , Z , H
302- ) -> tuple [TensorVariable , TensorVariable , TensorVariable , float ]:
301+ self , y , Z , H , d
302+ ) -> tuple [TensorVariable , TensorVariable , TensorVariable , TensorVariable , float ]:
303303 """
304- Handle missing values in the observation data `y`
304+ Handle missing values in the observation data ``y``.
305305
306- Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
307- associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
308- propagating NaNs through the computation.
306+ Adjust the design matrix ``Z``, the observation noise covariance matrix ``H``, and the observation
307+ intercept ``d`` by zeroing the rows associated with observations that are missing at this iteration.
308+ The missing entries of ``y`` are replaced with zeros to prevent propagating NaNs through the
309+ computation. With ``y``, ``Z @ a``, and ``d`` all zero on the missing rows, the innovation
310+ :math:`v = y - (Z a + d)` is exactly zero there, so missing observations contribute nothing to the
311+ state update.
309312
310- Return a binary flag tensor `all_nan_flag`, indicating if all values in the observation data are missing. This
311- flag is used for numerical adjustments in the update method.
313+ Return a binary flag tensor `` all_nan_flag`` indicating whether every component of the observation
314+ is missing. This flag is used for numerical adjustments in the update method.
312315
313316 Parameters
314317 ----------
@@ -318,21 +321,28 @@ def handle_missing_values(
318321 The design matrix.
319322 H : TensorVariable
320323 The observation noise covariance matrix.
324+ d : TensorVariable
325+ The observation intercept.
321326
322327 Returns
323328 -------
324329 y_masked : TensorVariable
325330 Observation vector with missing values replaced by zeros.
326331
327- Z_masked: TensorVariable
328- Design matrix adjusted to exclude the missing states from the information set of observed variables in the
329- update step
332+ Z_masked : TensorVariable
333+ Design matrix with the rows corresponding to missing observations zeroed out.
334+
335+ H_masked : TensorVariable
336+ Observation noise covariance matrix with the rows *and columns* corresponding to missing
337+ observations zeroed out, so the result remains symmetric.
330338
331- H_masked: TensorVariable
332- Noise covariance matrix, adjusted to exclude the missing states
339+ d_masked : TensorVariable
340+ Observation intercept with the entries corresponding to missing observations zeroed out. Without
341+ this masking, missing rows of the innovation become :math:`-d`, injecting a fake observation
342+ into the state update and inflating the log-likelihood by :math:`d^2 / \\ text{jitter}`.
333343
334- all_nan_flag: float
335- 1 if the entire state vector is missing
344+ all_nan_flag : float
345+ 1 if every component of the observation is missing.
336346
337347 References
338348 ----------
@@ -344,10 +354,11 @@ def handle_missing_values(
344354 W = pt .diag (pt .bitwise_not (nan_mask ).astype (pytensor .config .floatX ))
345355
346356 Z_masked = W .dot (Z )
347- H_masked = W .dot (H )
357+ H_masked = W .dot (H ).dot (W .mT )
358+ d_masked = W .dot (d )
348359 y_masked = pt .set_subtensor (y [nan_mask ], 0.0 )
349360
350- return y_masked , Z_masked , H_masked , all_nan_flag
361+ return y_masked , Z_masked , H_masked , d_masked , all_nan_flag
351362
352363 @staticmethod
353364 def predict (a , P , c , T , R , Q ) -> tuple [TensorVariable , TensorVariable ]:
@@ -517,10 +528,12 @@ def kalman_step(self, *args) -> tuple:
517528 2nd ed, Oxford University Press, 2012.
518529 """
519530 y , a , P , c , d , T , Z , R , H , Q = self .unpack_args (args )
520- y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
531+ y_masked , Z_masked , H_masked , d_masked , all_nan_flag = self .handle_missing_values (
532+ y , Z , H , d
533+ )
521534
522535 a_filtered , P_filtered , obs_mu , obs_cov , ll = self .update (
523- y = y_masked , a = a , d = d , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
536+ y = y_masked , a = a , d = d_masked , P = P , Z = Z_masked , H = H_masked , all_nan_flag = all_nan_flag
524537 )
525538
526539 P_filtered = stabilize (P_filtered , self .cov_jitter )
@@ -652,10 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
652665
653666 y_hat = Z .dot (a ) + d
654667 v = y - y_hat
655-
656- H_chol = pytensor .ifelse (
657- pt .all (pt .eq (H , 0.0 )), H , pt .linalg .cholesky (H , lower = True , on_error = "nan" )
658- )
668+ H_chol = pt .linalg .cholesky (stabilize (H , self .cov_jitter ), lower = True )
659669
660670 # The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
661671 # Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
@@ -787,11 +797,13 @@ def kalman_step(self, *args):
787797 y , a , P , c , d , T , Z , R , H , Q = self .unpack_args (args )
788798
789799 nan_mask = pt .or_ (pt .isnan (y ), pt .eq (y , self .missing_fill_value ))
790- y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
800+ y_masked , Z_masked , H_masked , d_masked , all_nan_flag = self .handle_missing_values (
801+ y , Z , H , d
802+ )
791803
792804 result = pytensor .scan (
793805 self ._univariate_inner_filter_step ,
794- sequences = [y_masked , Z_masked , d , pt .diag (H_masked ), nan_mask ],
806+ sequences = [y_masked , Z_masked , d_masked , pt .diag (H_masked ), nan_mask ],
795807 outputs_info = [a , P , None , None , None ],
796808 name = "univariate_inner_scan" ,
797809 return_updates = False ,
0 commit comments