@@ -169,6 +169,7 @@ def transport(
169169
170170 # Work out residuals and associate unexplained variance
171171 residuals = self .response_residual (U , Y )
172+
172173 # Due to observation error
173174 eps = self .generate_observation_noise (
174175 n ,
@@ -178,9 +179,9 @@ def transport(
178179
179180 # Update in canonical parametrization
180181 canonical_updated = self .update_canonical (
181- canonical ,
182- residual_noisy ,
183- d ,
182+ canonical = canonical ,
183+ residual_noisy = residual_noisy ,
184+ d = d ,
184185 )
185186
186187 # Bring realizations back
@@ -253,6 +254,11 @@ def Prec_residual_noisy(self) -> sparray:
253254 if self .unexplained_variance is None :
254255 raise ValueError ("`unexplained_variance` is not set." )
255256
257+ # The equation below is only valid if Prec_eps is diagonal
258+ row_idx , col_idx , _ = sp .sparse .find (self .Prec_eps )
259+ if np .any (row_idx != col_idx ):
260+ raise ValueError ("Precision matrix 'Prec_eps' must be diagonal" )
261+
256262 eps_variances = 1.0 / self .Prec_eps .diagonal ()
257263 residual_noisy_var = self .unexplained_variance + eps_variances
258264 Prec_r = diags_array (1.0 / residual_noisy_var , offsets = 0 , format = "csc" )
@@ -327,7 +333,7 @@ def update_canonical(
327333 logdet_value = 2.0 * np .sum (np .log (chol_LLT .L ().diagonal ()))
328334 log .info ("Prior precision log-determinant: %.3f" , logdet_value )
329335
330- Prec_r = self .Prec_residual_noisy ()
336+ Prec_r = self .Prec_residual_noisy () # This is a diagonal matrix
331337
332338 # posterior canonical params
333339 # this is equation (46), but transposed to update each row (realizations)
@@ -338,9 +344,10 @@ def update_canonical(
338344 # posterior precision, equation (47)
339345 self .Prec_u = self .Prec_u + self .H .T @ Prec_r @ self .H # Eqn (47)
340346
341- chol_LLT = cholesky (self .Prec_u , ordering_method = "metis" )
342- logdet_value = 2.0 * np .sum (np .log (chol_LLT .L ().diagonal ()))
343- log .info ("Posterior precision log-determinant: %.3f" , logdet_value )
347+ if log .isEnabledFor (logging .INFO ):
348+ chol_LLT = cholesky (self .Prec_u , ordering_method = "metis" )
349+ logdet_value = 2.0 * np .sum (np .log (chol_LLT .L ().diagonal ()))
350+ log .info ("Posterior precision log-determinant: %.3f" , logdet_value )
344351
345352 return updated_canonical
346353
0 commit comments