Skip to content

Commit e19a3a9

Browse files
committed
new obs-vs.-pred plotting function, added all perf. metrics to example 1
1 parent f0c3048 commit e19a3a9

3 files changed

Lines changed: 209 additions & 100 deletions

File tree

examples/SingleTaskTest.ipynb

Lines changed: 147 additions & 98 deletions
Large diffs are not rendered by default.

fvgp/gp.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,14 +1546,15 @@ def mpiw(self, x_test, interval=0.95):
15461546
-------
15471547
MPIW : float
15481548
"""
1549-
sigma = np.sqrt(self.posterior_covariance(x_test, add_noise=True)["v(x)"])
1549+
v = self.posterior_covariance(x_test, add_noise=True)["v(x)"]
1550+
sigma = np.sqrt(np.clip(v, 0.0, None))
15501551
z = norm.ppf(1 - (1 - interval) / 2)
15511552
return np.mean(2 * z * sigma)
15521553

15531554
def interval_score(self, x_test, y_test, interval=0.95):
15541555
"""
15551556
This function calculates the Interval Score (also known as the Winkler Score).
1556-
It penalises both missed coverage and unnecessarily wide prediction intervals,
1557+
It penalizes both missed coverage and unnecessarily wide prediction intervals,
15571558
combining the concerns of :py:meth:`picp` and :py:meth:`mpiw` into a single
15581559
scalar. Lower is better.
15591560
@@ -1663,6 +1664,62 @@ def msll(self, x_test, y_test):
16631664

16641665
return nlpd_gp - nlpd_baseline
16651666

1667+
def plot_observed_vs_predicted(self, x_test, y_test, title=None, ax=None):
1668+
"""
1669+
Scatter plot of observed vs. predicted values with a reference diagonal
1670+
and 1-sigma predictive error bars (noise-inflated posterior variance).
1671+
Useful for a quick visual check of model fit on a held-out test set.
1672+
1673+
Parameters
1674+
----------
1675+
x_test : np.ndarray
1676+
Test input positions, shape (V, D).
1677+
y_test : np.ndarray
1678+
Observed test values, shape (V,) or (V, No) for multi-output.
1679+
title : str, optional
1680+
Plot title.
1681+
ax : matplotlib.axes.Axes, optional
1682+
Existing axes to draw on; if ``None``, a fresh figure + axes is created.
1683+
1684+
Returns
1685+
-------
1686+
None
1687+
If matplotlib is not installed a ``UserWarning`` is emitted; otherwise
1688+
the plot is drawn on the supplied or freshly-created axes.
1689+
"""
1690+
try:
1691+
import matplotlib.pyplot as plt
1692+
except ImportError:
1693+
warnings.warn(
1694+
"matplotlib is not installed; cannot create observed-vs-predicted plot. "
1695+
"Install with `pip install matplotlib` (or `pip install -e .[plotting]`) "
1696+
"to enable plotting."
1697+
)
1698+
return
1699+
1700+
y_pred = self.posterior_mean(x_test)["m(x)"]
1701+
y_var = self.posterior_covariance(x_test, add_noise=True)["v(x)"]
1702+
y_obs_flat = np.asarray(y_test).reshape(-1)
1703+
y_pred_flat = np.asarray(y_pred).reshape(-1)
1704+
y_sigma_flat = np.sqrt(np.clip(np.asarray(y_var).reshape(-1), 0.0, None))
1705+
1706+
if ax is None:
1707+
_, ax = plt.subplots(figsize=(6, 6))
1708+
ax.errorbar(y_obs_flat, y_pred_flat, yerr=y_sigma_flat,
1709+
fmt="o", alpha=0.6, markersize=4, capsize=2,
1710+
elinewidth=0.8, label="prediction ± 1σ")
1711+
1712+
lo = float(min(y_obs_flat.min(), (y_pred_flat - y_sigma_flat).min()))
1713+
hi = float(max(y_obs_flat.max(), (y_pred_flat + y_sigma_flat).max()))
1714+
ax.plot([lo, hi], [lo, hi], "k--", linewidth=1, label="y = x")
1715+
1716+
ax.set_xlabel("Observed")
1717+
ax.set_ylabel("Predicted")
1718+
if title is not None:
1719+
ax.set_title(title)
1720+
ax.set_aspect("equal", adjustable="box")
1721+
ax.legend(loc="best")
1722+
16661723
@staticmethod
16671724
def gaussian_1d(x, mu, sigma):
16681725
"""

fvgp/gp_posterior.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def posterior_covariance(self, x_pred, x_out=None, variance_only=False, add_nois
204204
"or double check the hyperparameter optimization bounds. This will not "
205205
"terminate the algorithm, but expect anomalies.")
206206
logger.debug("Negative variances encountered.")
207+
# Always clip tiny negatives (numerical roundoff in iterative solvers
208+
# leaves the diagonal slightly below zero); downstream sqrt would NaN.
209+
if np.any(v < 0.0):
207210
v[v < 0.0] = 0.0
208211
if not variance_only: np.fill_diagonal(S, v)
209212

0 commit comments

Comments
 (0)