Skip to content

Commit db13800

Browse files
authored
Misc comments, enforce diagonal, etc (#165)
* Re-use indexed column * Comment that Prec_r is always diagonal * Fail if Prec_eps is not diagonal - in that case math does not check out * Only log posterior precision log det if logging.INFO
1 parent f931175 commit db13800

2 files changed

Lines changed: 19 additions & 15 deletions

File tree

graphite_maps/enif.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def transport(
169169

170170
# Work out residuals and associate unexplained variance
171171
residuals = self.response_residual(U, Y)
172+
172173
# Due to observation error
173174
eps = self.generate_observation_noise(
174175
n,
@@ -178,9 +179,9 @@ def transport(
178179

179180
# Update in canonical parametrization
180181
canonical_updated = self.update_canonical(
181-
canonical,
182-
residual_noisy,
183-
d,
182+
canonical=canonical,
183+
residual_noisy=residual_noisy,
184+
d=d,
184185
)
185186

186187
# Bring realizations back
@@ -253,6 +254,11 @@ def Prec_residual_noisy(self) -> sparray:
253254
if self.unexplained_variance is None:
254255
raise ValueError("`unexplained_variance` is not set.")
255256

257+
# The equation below is only valid if Prec_eps is diagonal
258+
row_idx, col_idx, _ = sp.sparse.find(self.Prec_eps)
259+
if np.any(row_idx != col_idx):
260+
raise ValueError("Precision matrix 'Prec_eps' must be diagonal")
261+
256262
eps_variances = 1.0 / self.Prec_eps.diagonal()
257263
residual_noisy_var = self.unexplained_variance + eps_variances
258264
Prec_r = diags_array(1.0 / residual_noisy_var, offsets=0, format="csc")
@@ -327,7 +333,7 @@ def update_canonical(
327333
logdet_value = 2.0 * np.sum(np.log(chol_LLT.L().diagonal()))
328334
log.info("Prior precision log-determinant: %.3f", logdet_value)
329335

330-
Prec_r = self.Prec_residual_noisy()
336+
Prec_r = self.Prec_residual_noisy() # This is a diagonal matrix
331337

332338
# posterior canonical params
333339
# this is equation (46), but transposed to update each row (realizations)
@@ -338,9 +344,10 @@ def update_canonical(
338344
# posterior precision, equation (47)
339345
self.Prec_u = self.Prec_u + self.H.T @ Prec_r @ self.H # Eqn (47)
340346

341-
chol_LLT = cholesky(self.Prec_u, ordering_method="metis")
342-
logdet_value = 2.0 * np.sum(np.log(chol_LLT.L().diagonal()))
343-
log.info("Posterior precision log-determinant: %.3f", logdet_value)
347+
if log.isEnabledFor(logging.INFO):
348+
chol_LLT = cholesky(self.Prec_u, ordering_method="metis")
349+
logdet_value = 2.0 * np.sum(np.log(chol_LLT.L().diagonal()))
350+
log.info("Posterior precision log-determinant: %.3f", logdet_value)
344351

345352
return updated_canonical
346353

graphite_maps/linear_regression.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def boost_linear_regression(
193193
"""
194194
n_samples, n_features = X.shape
195195
coefficients = np.zeros(n_features)
196-
residuals = y.copy()
196+
residuals = y.copy() # residuals = y - X @ coef = y - X @ 0 = y
197197
residuals_loo = y.copy()
198198

199199
# A stricter criterion is the loo-adjustment: mse(residuals_loo)-mse
@@ -220,15 +220,12 @@ def boost_linear_regression(
220220
# Inlined influence: psi / M
221221
# where influence = -residuals * x and
222222
# M = -mean(x^2) = -1 since x standardized
223-
influence = (residuals - beta_estimate * X[:, best_feature]) * X[
224-
:, best_feature
225-
]
223+
X_best = X[:, best_feature]
224+
influence = (residuals - beta_estimate * X_best) * X_best
226225
beta_estimate_loo = beta_estimate - influence / n_samples
227226

228227
# residuals_full = residuals - beta_estimate * X[:, best_feature]
229-
residuals_full_loo = (
230-
residuals_loo - learning_rate * beta_estimate_loo * X[:, best_feature]
231-
)
228+
residuals_full_loo = residuals_loo - learning_rate * beta_estimate_loo * X_best
232229

233230
if mse(residuals_loo) < mse(residuals_full_loo):
234231
break
@@ -246,7 +243,7 @@ def boost_linear_regression(
246243
break
247244
else:
248245
# Update
249-
residuals -= coef_change * X[:, best_feature]
246+
residuals -= coef_change * X_best
250247
coefficients[best_feature] += coef_change
251248

252249
# loo update

0 commit comments

Comments
 (0)