@@ -779,15 +779,10 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P
779779 return a_filtered , P_filtered , pt .atleast_1d (y_hat ), pt .atleast_2d (F ), ll_inner
780780
781781 def kalman_step (self , * args ):
782-
783782 y , a , P , c , d , T , Z , R , H , Q = self .unpack_args (args )
784783
785- nan_mask = pt .isnan (y )
786-
787- W = pt .set_subtensor (pt .eye (y .shape [0 ])[nan_mask , nan_mask ], 0.0 )
788- Z_masked = W .dot (Z )
789- H_masked = W .dot (H )
790- y_masked = pt .set_subtensor (y [nan_mask ], 0.0 )
784+ nan_mask = pt .or_ (pt .isnan (y ), pt .eq (y , self .missing_fill_value ))
785+ y_masked , Z_masked , H_masked , all_nan_flag = self .handle_missing_values (y , Z , H )
791786
792787 result = pytensor .scan (
793788 self ._univariate_inner_filter_step ,
@@ -808,6 +803,10 @@ def kalman_step(self, *args):
808803 P_filtered = stabilize (0.5 * (P_filtered + P_filtered .mT ), self .cov_jitter )
809804 a_hat , P_hat = self .predict (a = a_filtered , P = P_filtered , c = c , T = T , R = R , Q = Q )
810805
811- ll = - 0.5 * ((pt .neq (ll_inner , 0 ).sum ()) * MVN_CONST + ll_inner .sum ())
806+ ll = pt .switch (
807+ all_nan_flag ,
808+ 0.0 ,
809+ - 0.5 * ((pt .neq (ll_inner , 0 ).sum ()) * MVN_CONST + ll_inner .sum ()),
810+ )
812811
813812 return a_filtered , a_hat , obs_mu , P_filtered , P_hat , obs_cov , ll
0 commit comments