From 727b28865647c2d141064c8b1debbbece68761f6 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Feb 2026 13:30:27 +0100 Subject: [PATCH 1/5] Add HSGP Implementation --- gpjax/kernels/__init__.py | 5 +- gpjax/kernels/approximations/__init__.py | 3 +- gpjax/kernels/approximations/hsgp.py | 181 +++++++++++ gpjax/kernels/computations/__init__.py | 2 + gpjax/kernels/computations/hsgp.py | 49 +++ gpjax/kernels/stationary/base.py | 9 +- gpjax/kernels/stationary/matern12.py | 11 +- gpjax/kernels/stationary/matern32.py | 12 +- gpjax/kernels/stationary/matern52.py | 17 +- gpjax/kernels/stationary/rbf.py | 14 +- gpjax/kernels/stationary/utils.py | 51 ++++ tests/test_kernels/test_hsgp.py | 317 ++++++++++++++++++++ tests/test_kernels/test_spectral_density.py | 121 ++++++++ 13 files changed, 774 insertions(+), 18 deletions(-) create mode 100644 gpjax/kernels/approximations/hsgp.py create mode 100644 gpjax/kernels/computations/hsgp.py create mode 100644 tests/test_kernels/test_hsgp.py create mode 100644 tests/test_kernels/test_spectral_density.py diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 5ad342b86..cf8b06042 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -19,7 +19,7 @@ from gpjax.kernels.additive import ( OrthogonalAdditiveKernel, ) -from gpjax.kernels.approximations import RFF +from gpjax.kernels.approximations import HSGP, RFF from gpjax.kernels.base import ( AbstractKernel, Constant, @@ -32,6 +32,7 @@ DenseKernelComputation, DiagonalKernelComputation, EigenKernelComputation, + HSGPComputation, ) from gpjax.kernels.multioutput import ( ICMKernel, @@ -57,6 +58,7 @@ ) __all__ = [ + "HSGP", "RBF", "RFF", "AbstractKernel", @@ -68,6 +70,7 @@ "DiagonalKernelComputation", "EigenKernelComputation", "GraphKernel", + "HSGPComputation", "ICMKernel", "LCMKernel", "Linear", diff --git a/gpjax/kernels/approximations/__init__.py b/gpjax/kernels/approximations/__init__.py index 9b08299e5..1f6a804e9 100644 --- a/gpjax/kernels/approximations/__init__.py +++ b/gpjax/kernels/approximations/__init__.py @@ -1,3 +1,4 @@ +from gpjax.kernels.approximations.hsgp import HSGP from gpjax.kernels.approximations.rff import RFF -__all__ = ["RFF"] +__all__ = ["HSGP", "RFF"] diff --git a/gpjax/kernels/approximations/hsgp.py b/gpjax/kernels/approximations/hsgp.py new file mode 100644 index 000000000..e6cacc8b2 --- /dev/null +++ b/gpjax/kernels/approximations/hsgp.py @@ -0,0 +1,181 @@ +"""Hilbert Space Gaussian Process (HSGP) kernel approximation. + +Reference: + Solin & Sarkka (2019). "Hilbert Space Methods for Reduced-Rank + Gaussian Process Regression." Statistics and Computing. +""" + +import beartype.typing as tp +import jax.numpy as jnp +from jaxtyping import Float + +from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.computations.hsgp import HSGPComputation +from gpjax.kernels.stationary.base import StationaryKernel +from gpjax.typing import Array + + +class HSGP(AbstractKernel): + r"""Hilbert Space Gaussian Process approximation (1D). + + Approximates a stationary kernel by projecting onto the eigenbasis of the + Laplacian operator with Dirichlet boundary conditions on :math:`[-L, L]`. + The approximate covariance function is: + + .. math:: + \tilde{k}(x, x') = \sum_{j=1}^{m} S(\sqrt{\lambda_j})\, + \phi_j(x)\,\phi_j(x') + + where :math:`\lambda_j = (j\pi / 2L)^2` are the eigenvalues, + :math:`\phi_j(x) = L^{-1/2}\sin(j\pi(x + L) / 2L)` are the + eigenfunctions, and :math:`S` is the spectral density of the base + kernel. + + The linearised form decomposes the GP function as: + + .. math:: + f(x) \approx \Phi(x)\,\mathrm{diag}(\sqrt{S})\,\beta, + \quad \beta \sim \mathcal{N}(0, I_m) + + which is a Bayesian linear model with :math:`m` basis functions. + + Args: + base_kernel: The stationary kernel to approximate. Must have a + ``spectral_density`` property that returns a + :class:`~gpjax.kernels.stationary.utils.SpectralDensity`. + num_basis_fns: Number of basis functions :math:`m`. + domain_half_width: Half-width of the approximation domain :math:`L`. + Inputs should lie well inside :math:`[-L, L]` (after centering). + center: Center of the data domain. If ``None``, it is set + automatically on the first call to :meth:`eigenfunctions` as the + midpoint of the observed input range. + compute_engine: Computation engine (default + :class:`~gpjax.kernels.computations.hsgp.HSGPComputation`). + + Example: + >>> import gpjax as gpx + >>> base = gpx.kernels.Matern52(n_dims=1) + >>> hsgp = gpx.kernels.HSGP(base, num_basis_fns=20, domain_half_width=5.0) + >>> K = hsgp.gram(X) # approximate Gram matrix + """ + + compute_engine: HSGPComputation + + def __init__( + self, + base_kernel: StationaryKernel, + num_basis_fns: int, + domain_half_width: float, + center: tp.Union[float, None] = None, + compute_engine: HSGPComputation = HSGPComputation(), + ): + if not isinstance(base_kernel, StationaryKernel): + raise TypeError( + "HSGP can only approximate stationary kernels. " + f"Got {type(base_kernel).__name__}." + ) + # Verify that the base kernel has a usable spectral density. + _ = base_kernel.spectral_density + + super().__init__( + active_dims=base_kernel.active_dims, + n_dims=base_kernel.n_dims, + compute_engine=compute_engine, + ) + self.base_kernel = base_kernel + self.num_basis_fns = num_basis_fns + self.domain_half_width = domain_half_width + self._center = center + self.name = f"{self.base_kernel.name} (HSGP)" + + # ------------------------------------------------------------------ + # Basis functions + # ------------------------------------------------------------------ + + def eigenvalues(self) -> Float[Array, " m"]: + r"""Square roots of the Laplacian eigenvalues. + + .. math:: + \sqrt{\lambda_j} = \frac{j\pi}{2L}, \quad j = 1, \ldots, m + + Returns: + Array of shape ``(m,)``. + """ + j = jnp.arange(1, self.num_basis_fns + 1) + return j * jnp.pi / (2.0 * self.domain_half_width) + + def eigenfunctions(self, x: Float[Array, "N 1"]) -> Float[Array, "N m"]: + r"""Laplacian eigenfunctions evaluated at *x*. + + .. math:: + \phi_j(x) = \frac{1}{\sqrt{L}} + \sin\!\Bigl(\frac{j\pi\,(x + L)}{2L}\Bigr) + + where *x* is first shifted by the stored center. + + Args: + x: Input locations of shape ``(N, 1)``. + + Returns: + Basis matrix :math:`\Phi` of shape ``(N, m)``. + """ + if self._center is None: + self._center = float((x.max() + x.min()) / 2.0) + + L = self.domain_half_width + x_centered = x - self._center # [N, 1] + sqrt_eigenvalues = self.eigenvalues() # [m] + return jnp.sin((x_centered + L) * sqrt_eigenvalues) / jnp.sqrt(L) + + def spectral_weights(self) -> Float[Array, " m"]: + r"""Spectral density evaluated at the eigenvalue square roots. + + .. math:: + S(\sqrt{\lambda_j}) + + Returns: + Array of shape ``(m,)`` with the spectral weights. + """ + omega = self.eigenvalues() + sd = self.base_kernel.spectral_density + return sd( + omega, + self.base_kernel.variance[...], + self.base_kernel.lengthscale[...], + ) + + def compute_basis( + self, x: Float[Array, "N 1"] + ) -> tuple[Float[Array, "N m"], Float[Array, " m"]]: + r"""Linearised HSGP decomposition. + + Returns :math:`(\Phi, \sqrt{S})` such that + + .. math:: + f(x) \approx \Phi\,\mathrm{diag}(\sqrt{S})\,\beta, + \quad \beta \sim \mathcal{N}(0, I_m). + + Args: + x: Input locations of shape ``(N, 1)``. + + Returns: + Tuple ``(phi, sqrt_psd)`` where ``phi`` has shape ``(N, m)`` + and ``sqrt_psd`` has shape ``(m,)``. + """ + phi = self.eigenfunctions(x) + spd = self.spectral_weights() + return phi, jnp.sqrt(spd) + + # ------------------------------------------------------------------ + # AbstractKernel interface + # ------------------------------------------------------------------ + + def __call__( + self, + x: Float[Array, " D"], + y: Float[Array, " D"], + ) -> None: + raise RuntimeError( + "HSGP does not support pointwise kernel evaluation. " + "Use gram(), cross_covariance(), or compute_basis() instead." + ) diff --git a/gpjax/kernels/computations/__init__.py b/gpjax/kernels/computations/__init__.py index 60eb5a1b3..1de8b5f5f 100644 --- a/gpjax/kernels/computations/__init__.py +++ b/gpjax/kernels/computations/__init__.py @@ -21,6 +21,7 @@ from gpjax.kernels.computations.dense import DenseKernelComputation from gpjax.kernels.computations.diagonal import DiagonalKernelComputation from gpjax.kernels.computations.eigen import EigenKernelComputation +from gpjax.kernels.computations.hsgp import HSGPComputation __all__ = [ "AbstractKernelComputation", @@ -29,4 +30,5 @@ "DenseKernelComputation", "DiagonalKernelComputation", "EigenKernelComputation", + "HSGPComputation", ] diff --git a/gpjax/kernels/computations/hsgp.py b/gpjax/kernels/computations/hsgp.py new file mode 100644 index 000000000..3e0877d32 --- /dev/null +++ b/gpjax/kernels/computations/hsgp.py @@ -0,0 +1,49 @@ +"""Compute engine for the Hilbert Space GP approximation.""" + +import typing as tp + +import jax.numpy as jnp +from jaxtyping import Float + +import gpjax +from gpjax.kernels.computations.base import AbstractKernelComputation +from gpjax.linalg import Dense, Diagonal +from gpjax.linalg.utils import psd +from gpjax.typing import Array + +K = tp.TypeVar("K", bound="gpjax.kernels.approximations.HSGP") + + +class HSGPComputation(AbstractKernelComputation): + r"""Compute engine for the HSGP kernel approximation. + + Computes the Gram matrix via the low-rank decomposition: + + .. math:: + \tilde{K} = \Phi \Lambda \Phi^\top + + where :math:`\Phi` is the matrix of Laplacian eigenfunctions and + :math:`\Lambda = \mathrm{diag}(S(\sqrt{\lambda_1}), \ldots, S(\sqrt{\lambda_m}))` + contains the spectral density evaluated at the eigenvalue square roots. + """ + + def gram(self, kernel: K, x: Float[Array, "N D"]) -> Dense: + Kxx = self._gram(kernel, x) + return psd(Dense(Kxx)) + + def _gram(self, kernel: K, x: Float[Array, "N D"]) -> Float[Array, "N N"]: + phi, sqrt_psd = kernel.compute_basis(x) + weighted = phi * sqrt_psd[None, :] + return weighted @ weighted.T + + def _cross_covariance( + self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + phi_x, sqrt_psd = kernel.compute_basis(x) + phi_y, _ = kernel.compute_basis(y) + return (phi_x * sqrt_psd[None, :]) @ (phi_y * sqrt_psd[None, :]).T + + def diagonal(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal: + phi, sqrt_psd = kernel.compute_basis(x) + weighted = phi * sqrt_psd[None, :] + return psd(Diagonal(jnp.sum(weighted**2, axis=1))) diff --git a/gpjax/kernels/stationary/base.py b/gpjax/kernels/stationary/base.py index 5de7572e2..3cd476236 100644 --- a/gpjax/kernels/stationary/base.py +++ b/gpjax/kernels/stationary/base.py @@ -18,13 +18,13 @@ from flax import nnx import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.base import AbstractKernel from gpjax.kernels.computations import ( AbstractKernelComputation, DenseKernelComputation, ) +from gpjax.kernels.stationary.utils import SpectralDensity from gpjax.parameters import ( NonNegativeReal, PositiveReal, @@ -95,11 +95,12 @@ def __init__( self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance) @property - def spectral_density(self) -> npd.Normal | npd.StudentT: + def spectral_density(self) -> SpectralDensity: r"""The spectral density of the kernel. - Returns: - Callable[[Float[Array, "D"]], Float[Array, "D"]]: The spectral density function. + Returns a :class:`~gpjax.kernels.stationary.utils.SpectralDensity` + object that supports both ``sample()`` (for RFF) and + ``__call__(omega, variance, lengthscale)`` (for HSGP). """ raise NotImplementedError( f"Kernel {self.name} does not have a spectral density." diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 66f992ae1..e4ab87565 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -48,5 +48,10 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=1) + def spectral_density(self) -> SpectralDensity: + def _matern12_spectral_density(omega, variance, lengthscale): + return variance * (2.0 / lengthscale) / (1.0 / lengthscale**2 + omega**2) + + return SpectralDensity( + build_student_t_distribution(nu=1), _matern12_spectral_density + ) diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index e47e031d9..eead4aa92 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -54,5 +54,11 @@ def __call__( return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=3) + def spectral_density(self) -> SpectralDensity: + def _matern32_spectral_density(omega, variance, lengthscale): + alpha = jnp.sqrt(3.0) / lengthscale + return variance * 4.0 * alpha**3 / (3.0 / lengthscale**2 + omega**2) ** 2 + + return SpectralDensity( + build_student_t_distribution(nu=3), _matern32_spectral_density + ) diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 84ca61069..c4431ae74 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -53,5 +53,16 @@ def __call__( return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=5) + def spectral_density(self) -> SpectralDensity: + def _matern52_spectral_density(omega, variance, lengthscale): + alpha = jnp.sqrt(5.0) / lengthscale + return ( + variance + * (16.0 / 3.0) + * alpha**5 + / (5.0 / lengthscale**2 + omega**2) ** 3 + ) + + return SpectralDensity( + build_student_t_distribution(nu=5), _matern52_spectral_density + ) diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 44ea74d0e..ff71fcc3e 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -18,7 +18,7 @@ import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel -from gpjax.kernels.stationary.utils import squared_distance +from gpjax.kernels.stationary.utils import SpectralDensity, squared_distance from gpjax.typing import ( Array, ScalarFloat, @@ -44,5 +44,13 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: return K.squeeze() @property - def spectral_density(self) -> npd.Normal: - return npd.Normal(0.0, 1.0) + def spectral_density(self) -> SpectralDensity: + def _rbf_spectral_density(omega, variance, lengthscale): + return ( + variance + * jnp.sqrt(2.0 * jnp.pi) + * lengthscale + * jnp.exp(-0.5 * lengthscale**2 * omega**2) + ) + + return SpectralDensity(npd.Normal(0.0, 1.0), _rbf_spectral_density) diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index bbe7b0a7d..8697123d7 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import beartype.typing as tp import jax.numpy as jnp from jaxtyping import Float import numpyro.distributions as npd @@ -39,6 +40,56 @@ def build_student_t_distribution(nu: int) -> npd.StudentT: return dist +class SpectralDensity: + """Spectral density of a stationary kernel. + + Wraps a NumPyro distribution (for sampling, used by RFF) and adds + evaluation of the spectral density function S(omega) at arbitrary + frequencies (used by HSGP). + + Args: + distribution: A NumPyro distribution to delegate ``sample()`` to. + evaluate_fn: A callable ``(omega, variance, lengthscale) -> S(omega)`` + that computes the un-normalized spectral density at the given + frequencies incorporating kernel variance and lengthscale. + """ + + def __init__( + self, + distribution: npd.Distribution, + evaluate_fn: tp.Callable, + ): + self._distribution = distribution + self._evaluate_fn = evaluate_fn + + def sample(self, key, sample_shape): + """Draw samples from the spectral density distribution. + + This delegates to the wrapped NumPyro distribution and is used by + Random Fourier Features (RFF). + """ + return self._distribution.sample(key=key, sample_shape=sample_shape) + + def __call__(self, omega, variance, lengthscale): + """Evaluate S(omega) incorporating kernel variance and lengthscale. + + Parameters + ---------- + omega : Array + Frequencies at which to evaluate the spectral density. + variance : ScalarFloat + Kernel variance parameter (sigma^2). + lengthscale : ScalarFloat + Kernel lengthscale parameter (ell). + + Returns + ------- + Array + Spectral density values S(omega). + """ + return self._evaluate_fn(omega, variance, lengthscale) + + def squared_distance(x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: r"""Compute the squared distance between a pair of inputs. diff --git a/tests/test_kernels/test_hsgp.py b/tests/test_kernels/test_hsgp.py new file mode 100644 index 000000000..23028946a --- /dev/null +++ b/tests/test_kernels/test_hsgp.py @@ -0,0 +1,317 @@ +"""Tests for the Hilbert Space Gaussian Process (HSGP) kernel approximation.""" + +from gpjax.kernels.approximations.hsgp import HSGP +from gpjax.kernels.stationary import RBF, Matern12, Matern32, Matern52 +from gpjax.linalg.operators import Dense +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + + +class TestEigenvalues: + def test_eigenvalue_count(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), num_basis_fns=10, domain_half_width=3.0 + ) + evals = kernel.eigenvalues() + assert evals.shape == (10,) + + def test_eigenvalue_formula(self): + """sqrt(lambda_j) = j * pi / (2 * L).""" + L = 5.0 + m = 4 + kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=m, domain_half_width=L) + evals = kernel.eigenvalues() + expected = jnp.arange(1, m + 1) * jnp.pi / (2.0 * L) + npt.assert_allclose(evals, expected) + + def test_eigenvalues_increase(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), num_basis_fns=20, domain_half_width=3.0 + ) + evals = kernel.eigenvalues() + assert jnp.all(jnp.diff(evals) > 0) + + +class TestEigenfunctions: + def test_shape(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=10, + domain_half_width=3.0, + center=0.0, + ) + x = jnp.linspace(-2, 2, 50)[:, None] + phi = kernel.eigenfunctions(x) + assert phi.shape == (50, 10) + + def test_orthonormality(self): + """Eigenfunctions should be approximately orthonormal over [-L, L].""" + L = 3.0 + m = 5 + kernel = HSGP( + base_kernel=RBF(n_dims=1), num_basis_fns=m, domain_half_width=L, center=0.0 + ) + # Dense grid for numerical integration + n_quad = 10000 + x = jnp.linspace(-L, L, n_quad)[:, None] + phi = kernel.eigenfunctions(x) + dx = 2.0 * L / n_quad + gram = phi.T @ phi * dx # Approximate integral + npt.assert_allclose(gram, jnp.eye(m), atol=1e-2) + + def test_zero_at_boundaries(self): + """Eigenfunctions should be zero at x = -L and x = L.""" + L = 3.0 + kernel = HSGP( + base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=L, center=0.0 + ) + boundaries = jnp.array([[-L], [L]]) + phi = kernel.eigenfunctions(boundaries) + npt.assert_allclose(phi, 0.0, atol=1e-12) + + +class TestCentering: + def test_explicit_center(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=5, + domain_half_width=3.0, + center=1.0, + ) + assert kernel._center == 1.0 + + def test_auto_center(self): + kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=3.0) + x = jnp.linspace(2.0, 8.0, 50)[:, None] + _ = kernel.eigenfunctions(x) + assert kernel._center == pytest.approx(5.0) + + def test_auto_center_persists(self): + """Once set, auto-center should not change on subsequent calls.""" + kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=3.0) + x1 = jnp.linspace(0, 10, 50)[:, None] + x2 = jnp.linspace(-5, 5, 50)[:, None] + _ = kernel.eigenfunctions(x1) + center_after_first = kernel._center + _ = kernel.eigenfunctions(x2) + assert kernel._center == center_after_first + + +class TestComputeBasis: + def test_returns_tuple(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=10, + domain_half_width=3.0, + center=0.0, + ) + x = jnp.linspace(-2, 2, 50)[:, None] + phi, sqrt_psd = kernel.compute_basis(x) + assert phi.shape == (50, 10) + assert sqrt_psd.shape == (10,) + + def test_sqrt_psd_positive(self): + kernel = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=10, + domain_half_width=3.0, + center=0.0, + ) + x = jnp.linspace(-2, 2, 50)[:, None] + _, sqrt_psd = kernel.compute_basis(x) + assert jnp.all(sqrt_psd > 0) + + +class TestGram: + @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) + def test_gram_shape_and_psd(self, KernelClass): + base = KernelClass(n_dims=1) + hsgp = HSGP( + base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + x = jnp.linspace(-3, 3, 30)[:, None] + linop = hsgp.gram(x) + assert isinstance(linop, Dense) + K = linop.to_dense() + assert K.shape == (30, 30) + # PSD check + evals, _ = jnp.linalg.eigh(K + 1e-6 * jnp.eye(30)) + assert jnp.all(evals > 0) + + @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) + def test_cross_covariance_shape(self, KernelClass): + base = KernelClass(n_dims=1) + hsgp = HSGP( + base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + x1 = jnp.linspace(-3, 3, 30)[:, None] + x2 = jnp.linspace(-2, 2, 20)[:, None] + Kxy = hsgp.cross_covariance(x1, x2) + assert Kxy.shape == (30, 20) + + def test_gram_symmetric(self): + base = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + x = jnp.linspace(-3, 3, 30)[:, None] + K = hsgp.gram(x).to_dense() + npt.assert_allclose(K, K.T, atol=1e-12) + + @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) + def test_diagonal(self, KernelClass): + base = KernelClass(n_dims=1) + hsgp = HSGP( + base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + x = jnp.linspace(-3, 3, 30)[:, None] + diag = hsgp.diagonal(x) + K = hsgp.gram(x).to_dense() + npt.assert_allclose(jnp.diag(diag.to_dense()), jnp.diag(K), atol=1e-10) + + +class TestConvergence: + @pytest.mark.parametrize("KernelClass", [RBF, Matern32, Matern52]) + def test_gram_converges_to_exact(self, KernelClass): + """With large m and appropriate L, HSGP Gram should converge to exact.""" + base = KernelClass(n_dims=1) + x = jnp.linspace(-1, 1, 30)[:, None] + exact = base.gram(x).to_dense() + + hsgp_coarse = HSGP( + base_kernel=base, num_basis_fns=10, domain_half_width=5.0, center=0.0 + ) + hsgp_fine = HSGP( + base_kernel=base, num_basis_fns=80, domain_half_width=5.0, center=0.0 + ) + + err_coarse = jnp.linalg.norm(exact - hsgp_coarse.gram(x).to_dense(), ord="fro") + err_fine = jnp.linalg.norm(exact - hsgp_fine.gram(x).to_dense(), ord="fro") + + # Finer approximation should have smaller error + assert err_fine < err_coarse + + def test_rbf_close_to_exact(self): + """RBF with large m should be very close to exact.""" + base = RBF(n_dims=1) + x = jnp.linspace(-1, 1, 20)[:, None] + exact = base.gram(x).to_dense() + + hsgp = HSGP( + base_kernel=base, num_basis_fns=100, domain_half_width=5.0, center=0.0 + ) + approx = hsgp.gram(x).to_dense() + max_err = jnp.max(jnp.abs(exact - approx)) + assert max_err < 0.01 + + +class TestValidation: + def test_nonstationary_kernel_rejected(self): + from gpjax.kernels.nonstationary import Linear + + with pytest.raises(TypeError): + HSGP(base_kernel=Linear(1), num_basis_fns=10, domain_half_width=3.0) + + def test_pointwise_call_raises(self): + hsgp = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=10, domain_half_width=3.0) + with pytest.raises(RuntimeError): + hsgp(jnp.array([1.0]), jnp.array([2.0])) + + +class TestIntegration: + def test_prior_posterior_pipeline(self): + """HSGP should work as a drop-in kernel in the Prior/Posterior pipeline.""" + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + # MLL should return a finite scalar + mll = conjugate_mll(posterior, D) + assert jnp.isfinite(mll) + + def test_predict(self): + """HSGP posterior should produce finite mean and variance.""" + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 30)[:, None] + pred = posterior.predict(x_test, D) + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.covariance())) + + def test_mll_differentiable(self): + """conjugate_mll with HSGP must be differentiable w.r.t. kernel params.""" + from flax import nnx + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + import jax + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + # Split into graphdef and state + graphdef, state = nnx.split(posterior) + + def loss(state): + model = nnx.merge(graphdef, state) + return -conjugate_mll(model, D) + + grad_fn = jax.grad(loss) + grads = grad_fn(state) + + # Gradients should be finite + flat_grads = jax.tree.leaves(grads) + for g in flat_grads: + assert jnp.all(jnp.isfinite(g)), f"Non-finite gradient: {g}" diff --git a/tests/test_kernels/test_spectral_density.py b/tests/test_kernels/test_spectral_density.py new file mode 100644 index 000000000..0e8c883b0 --- /dev/null +++ b/tests/test_kernels/test_spectral_density.py @@ -0,0 +1,121 @@ +"""Tests for the SpectralDensity class and kernel spectral density evaluation.""" + +from gpjax.kernels.stationary import RBF, Matern12, Matern32, Matern52 +from gpjax.kernels.stationary.utils import SpectralDensity +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + + +def test_spectral_density_has_sample(): + """SpectralDensity must expose sample() for RFF compatibility.""" + kernel = RBF(n_dims=1) + sd = kernel.spectral_density + assert isinstance(sd, SpectralDensity) + samples = sd.sample(key=jr.key(0), sample_shape=(10, 1)) + assert samples.shape == (10, 1) + + +def test_spectral_density_callable(): + """SpectralDensity must be callable with (omega, variance, lengthscale).""" + kernel = RBF(n_dims=1) + sd = kernel.spectral_density + omega = jnp.array([0.5, 1.0, 2.0]) + result = sd(omega, jnp.array(1.0), jnp.array(1.0)) + assert result.shape == (3,) + assert jnp.all(result > 0) + + +def test_rbf_spectral_density_formula(): + """Verify the RBF spectral density against the known closed form. + + S(w) = variance * sqrt(2*pi) * lengthscale * exp(-0.5 * lengthscale^2 * w^2) + """ + kernel = RBF(n_dims=1, variance=2.0, lengthscale=0.5) + sd = kernel.spectral_density + + omega = jnp.array([0.0, 1.0, 3.0]) + variance = jnp.array(2.0) + lengthscale = jnp.array(0.5) + + result = sd(omega, variance, lengthscale) + expected = ( + variance + * jnp.sqrt(2 * jnp.pi) + * lengthscale + * jnp.exp(-0.5 * lengthscale**2 * omega**2) + ) + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_rbf_spectral_density_peak_at_zero(): + """RBF spectral density peaks at omega=0 and decays monotonically.""" + kernel = RBF(n_dims=1) + sd = kernel.spectral_density + omega = jnp.linspace(0, 10, 100) + values = sd(omega, jnp.array(1.0), jnp.array(1.0)) + # Peak at omega=0 + assert values[0] == jnp.max(values) + # Monotonically decreasing + assert jnp.all(jnp.diff(values) <= 0) + + +def test_matern12_spectral_density_formula(): + """Verify Matern12 (nu=1/2) spectral density. + + S(w) = variance * 2/ell * 1/(1/ell^2 + w^2) + """ + kernel = Matern12(n_dims=1, variance=1.5, lengthscale=0.8) + sd = kernel.spectral_density + omega = jnp.array([0.0, 1.0, 5.0]) + v, ell = jnp.array(1.5), jnp.array(0.8) + + result = sd(omega, v, ell) + expected = v * (2.0 / ell) / (1.0 / ell**2 + omega**2) + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_matern32_spectral_density_formula(): + """Verify Matern32 (nu=3/2) spectral density. + + S(w) = variance * 4*(sqrt(3)/ell)^3 / (3/ell^2 + w^2)^2 + """ + kernel = Matern32(n_dims=1, variance=2.0, lengthscale=1.5) + sd = kernel.spectral_density + omega = jnp.array([0.0, 1.0, 5.0]) + v, ell = jnp.array(2.0), jnp.array(1.5) + + result = sd(omega, v, ell) + alpha = jnp.sqrt(3.0) / ell + expected = v * 4.0 * alpha**3 / (3.0 / ell**2 + omega**2) ** 2 + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_matern52_spectral_density_formula(): + """Verify Matern52 (nu=5/2) spectral density. + + S(w) = variance * (16/3)*(sqrt(5)/ell)^5 / (5/ell^2 + w^2)^3 + """ + kernel = Matern52(n_dims=1, variance=3.0, lengthscale=2.0) + sd = kernel.spectral_density + omega = jnp.array([0.0, 1.0, 5.0]) + v, ell = jnp.array(3.0), jnp.array(2.0) + + result = sd(omega, v, ell) + alpha = jnp.sqrt(5.0) / ell + expected = v * (16.0 / 3.0) * alpha**5 / (5.0 / ell**2 + omega**2) ** 3 + npt.assert_allclose(result, expected, atol=1e-12) + + +@pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) +def test_spectral_density_positive(KernelClass): + """All spectral densities must be positive for all omega.""" + kernel = KernelClass(n_dims=1) + sd = kernel.spectral_density + omega = jnp.linspace(0, 20, 200) + values = sd(omega, jnp.array(1.0), jnp.array(1.0)) + assert jnp.all(values > 0) From 446260b37d865a2983deb788ed7a4a6a278607e6 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Feb 2026 19:37:23 +0100 Subject: [PATCH 2/5] Add Woodbury identity --- examples/hsgp.py | 357 +++++++++++ gpjax/gps.py | 159 ++++- gpjax/kernels/approximations/hsgp.py | 40 +- gpjax/kernels/computations/basis_functions.py | 54 +- gpjax/kernels/computations/hsgp.py | 45 +- gpjax/kernels/stationary/base.py | 7 +- gpjax/kernels/stationary/matern12.py | 12 +- gpjax/kernels/stationary/matern32.py | 14 +- gpjax/kernels/stationary/matern52.py | 14 +- gpjax/kernels/stationary/rbf.py | 11 +- gpjax/kernels/stationary/utils.py | 104 +-- gpjax/linalg/__init__.py | 2 + gpjax/linalg/operations.py | 6 + gpjax/linalg/operators.py | 47 ++ gpjax/linalg/woodbury.py | 121 ++++ gpjax/objectives.py | 25 +- pyproject.toml | 1 + tests/test_kernels/test_approximations.py | 4 +- tests/test_kernels/test_computations.py | 3 +- tests/test_kernels/test_hsgp.py | 600 +++++++++--------- tests/test_kernels/test_spectral_density.py | 146 +++-- tests/test_linalg/__init__.py | 0 tests/test_linalg/test_lowrank.py | 81 +++ .../test_operators.py} | 0 tests/test_linalg/test_woodbury.py | 379 +++++++++++ uv.lock | 33 +- 26 files changed, 1713 insertions(+), 552 deletions(-) create mode 100644 examples/hsgp.py create mode 100644 gpjax/linalg/woodbury.py create mode 100644 tests/test_linalg/__init__.py create mode 100644 tests/test_linalg/test_lowrank.py rename tests/{test_linalg.py => test_linalg/test_operators.py} (100%) create mode 100644 tests/test_linalg/test_woodbury.py diff --git a/examples/hsgp.py b/examples/hsgp.py new file mode 100644 index 000000000..d14a0c2c3 --- /dev/null +++ b/examples/hsgp.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Hilbert Space Gaussian Processes +# +# Standard GP inference requires inverting an $n \times n$ Gram matrix at +# $\mathcal{O}(n^3)$ cost, which limits practical use to a few thousand points. +# The Hilbert Space Gaussian Process (HSGP) approximation +# ([Solin and Sarkka, 2020](https://link.springer.com/article/10.1007/s11222-019-09886-w)) +# projects the GP onto a truncated basis of Laplacian eigenfunctions, reducing +# the cost to $\mathcal{O}(nm^2)$ for setup and $\mathcal{O}(m^3)$ per +# optimisation step, where $m \ll n$ is the number of basis functions. See +# also the PyMC tutorial by +# [Orduz and Capretto](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html) +# and the standalone introduction by +# [Orduz (2022)](https://juanitorduz.github.io/hsgp_intro/). + +# %% +from jax import config + +config.update("jax_enable_x64", True) + +from examples.utils import use_mpl_style +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import install_import_hook +import matplotlib as mpl +import matplotlib.pyplot as plt + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + +key = jr.key(42) +use_mpl_style() +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] +# ## The HSGP approximation +# +# For a stationary kernel $k$ with spectral density $S$, the HSGP approximates +# the covariance on a bounded domain $[-L, L]$: +# +# $$k(x, x') \approx \sum_{j=1}^{m} S\!\left(\sqrt{\lambda_j}\right) \phi_j(x)\,\phi_j(x').$$ +# +# The eigenfunctions $\phi_j(x) = L^{-1/2}\sin\!\bigl(j\pi(x + L) / 2L\bigr)$ +# are sinusoids of increasing frequency, and the eigenvalues +# $\lambda_j = (j\pi / 2L)^2$ grow quadratically. The spectral density $S$ +# weights each eigenfunction's contribution. In matrix form the approximate +# Gram matrix is $\tilde{K} = \Phi\,\Lambda\,\Phi^\top$ where +# $\Lambda = \mathrm{diag}\!\bigl(S(\sqrt{\lambda_1}), \ldots, S(\sqrt{\lambda_m})\bigr)$. +# +# Two parameters govern approximation quality: +# +# - **$m$** (number of basis functions): more terms capture shorter +# lengthscales, at higher computational cost. +# - **$L$** (domain half-width): must enclose all inputs after centering, +# but setting it much larger than necessary wastes basis functions. + +# %% [markdown] +# ## Eigenfunctions and spectral weights +# +# The HSGP has two ingredients: Laplacian eigenfunctions (which depend only on +# the domain, not the kernel) and spectral weights (which encode the kernel's +# properties). We visualise both below. + +# %% +x_grid = jnp.linspace(-3.0, 3.0, 300)[:, None] + +base_rbf = gpx.kernels.RBF(n_dims=1) +hsgp_rbf = gpx.kernels.HSGP( + base_kernel=base_rbf, num_basis_fns=20, domain_half_width=4.0, center=0.0 +) + +phi = hsgp_rbf.eigenfunctions(x_grid) + +fig, ax = plt.subplots(figsize=(7.5, 2.5)) +for j in range(6): + ax.plot(x_grid, phi[:, j], label=f"$\\phi_{{{j + 1}}}$", color=cols[j % len(cols)]) +ax.set_xlabel("$x$") +ax.set_ylabel("$\\phi_j(x)$") +ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5), fontsize=8) +ax.set_title("Laplacian eigenfunctions") + +# %% [markdown] +# The eigenfunctions are sinusoids of increasing frequency, forming an +# orthonormal basis on $[-L, L]$. Because they are independent of the kernel +# hyperparameters, they need only be evaluated once for a given domain. +# +# The spectral weights $S(\sqrt{\lambda_j})$ determine each eigenfunction's +# contribution. We compare weights for the RBF kernel (Gaussian spectral +# density, rapid decay) and the Matern-5/2 (heavier-tailed, slower decay). + +# %% +base_m52 = gpx.kernels.Matern52(n_dims=1) +hsgp_m52 = gpx.kernels.HSGP( + base_kernel=base_m52, num_basis_fns=20, domain_half_width=4.0, center=0.0 +) + +weights_rbf = hsgp_rbf.spectral_weights() +weights_m52 = hsgp_m52.spectral_weights() + +fig, axes = plt.subplots(1, 2, figsize=(7.5, 2.5), sharey=True) +basis_indices = jnp.arange(1, 21) + +axes[0].stem(basis_indices, weights_rbf, linefmt=cols[0], markerfmt="o", basefmt=" ") +axes[0].set_title("RBF") +axes[0].set_xlabel("Basis index $j$") +axes[0].set_ylabel("$S(\\sqrt{\\lambda_j})$") + +axes[1].stem(basis_indices, weights_m52, linefmt=cols[1], markerfmt="o", basefmt=" ") +axes[1].set_title("Matern-5/2") +axes[1].set_xlabel("Basis index $j$") + +# %% [markdown] +# The RBF weights decay rapidly, so few basis functions suffice. The +# Matern-5/2 retains appreciable weight at higher frequencies and needs more +# terms for the same approximation quality. + +# %% [markdown] +# ## Approximation quality +# +# How faithfully does the HSGP Gram matrix reproduce the exact kernel? We +# compare exact and approximate Gram matrices at $m = 10$, $25$, and $50$ for +# both the RBF and Matern-5/2 kernels. + +# %% +x_small = jnp.linspace(-3.0, 3.0, 80)[:, None] +m_values = [10, 25, 50] + + +def gram_comparison(base_kernel, x, m_values): + """Return exact Gram and HSGP Gram matrices for several values of m.""" + gram_exact = base_kernel.gram(x).to_dense() + gram_approximations = [] + for num_basis in m_values: + hsgp = gpx.kernels.HSGP( + base_kernel=base_kernel, + num_basis_fns=num_basis, + domain_half_width=4.0, + center=0.0, + ) + gram_approximations.append(hsgp.gram(x).to_dense()) + return gram_exact, gram_approximations + + +gram_exact_rbf, gram_hsgps_rbf = gram_comparison(base_rbf, x_small, m_values) + +fig, axes = plt.subplots(1, 4, figsize=(10, 2.5)) +vmin, vmax = float(gram_exact_rbf.min()), float(gram_exact_rbf.max()) +titles = ["Exact"] + [f"HSGP ($m = {m}$)" for m in m_values] +matrices = [gram_exact_rbf, *gram_hsgps_rbf] + +for ax, matrix, title in zip(axes, matrices, titles, strict=True): + ax.imshow(matrix, vmin=vmin, vmax=vmax, cmap="inferno") + ax.set_title(title, fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + +fig.suptitle("RBF kernel", fontsize=10, y=1.02) + +# %% [markdown] +# For the RBF kernel, the approximation is near-indistinguishable from exact +# by $m = 25$, reflecting the rapid spectral weight decay. + +# %% +gram_exact_m52, gram_hsgps_m52 = gram_comparison(base_m52, x_small, m_values) + +fig, axes = plt.subplots(1, 4, figsize=(10, 2.5)) +vmin, vmax = float(gram_exact_m52.min()), float(gram_exact_m52.max()) +titles = ["Exact"] + [f"HSGP ($m = {m}$)" for m in m_values] +matrices = [gram_exact_m52, *gram_hsgps_m52] + +for ax, matrix, title in zip(axes, matrices, titles, strict=True): + ax.imshow(matrix, vmin=vmin, vmax=vmax, cmap="inferno") + ax.set_title(title, fontsize=9) + ax.set_xticks([]) + ax.set_yticks([]) + +fig.suptitle("Matern-5/2 kernel", fontsize=10, y=1.02) + +# %% [markdown] +# The Matern-5/2 converges more slowly, with visible discrepancies at +# $m = 10$. Rougher kernels retain more high-frequency content. In practice, +# $m$ between 20 and 50 suffices for most one-dimensional problems. + +# %% [markdown] +# ## Regression with real data +# +# We apply the HSGP to daily nitrogen dioxide (NO$_2$) measurements from +# Marylebone Road in central London, one of the UK's most heavily trafficked +# roadside sites. The data come from the +# [Automatic Urban and Rural Network (AURN)](https://uk-air.defra.gov.uk/networks/site-info?site_id=MY1). +# We aggregate to daily means over approximately ten years (2016--2025), +# giving several thousand data points where exact GP inference is impractical. + +# %% +from pathlib import Path +import tempfile +from urllib.request import urlretrieve +import warnings + +import pandas as pd +import rdata + +SITE = "MY1" +YEARS = range(2016, 2026) + +frames = [] +for year in YEARS: + url = f"https://uk-air.defra.gov.uk/openair/R_data/{SITE}_{year}.RData" + tmp = Path(tempfile.gettempdir()) / f"{SITE}_{year}.RData" + urlretrieve(url, tmp) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + parsed = rdata.read_rda(str(tmp)) + frames.append(parsed[next(iter(parsed.keys()))]) + +hourly = pd.concat(frames, ignore_index=True) +hourly["date"] = pd.to_datetime(hourly["date"], unit="s") +daily = hourly[["date", "NO2"]].set_index("date").resample("D").mean().dropna() + +# %% +fig, ax = plt.subplots(figsize=(7.5, 2.5)) +ax.plot(daily.index, daily["NO2"], linewidth=0.5, color=cols[0]) +ax.set_xlabel("Date") +ax.set_ylabel("NO$_2$ ($\\mu$g m$^{-3}$)") +ax.set_title("Daily mean NO$_2$ at Marylebone Road, London") + +# %% [markdown] +# Clear features include an annual cycle from heating-season emissions, an +# abrupt dip during the 2020 COVID-19 lockdowns, and a gradual downward trend +# from tightening vehicle emission standards. + +# %% +# Convert dates to a numeric index and standardise for numerical stability. +t = (daily.index - daily.index[0]).days.values.astype(float) +y_obs = daily["NO2"].values + +t_mean, t_std = t.mean(), t.std() +y_mean, y_std = y_obs.mean(), y_obs.std() + +t_norm = ((t - t_mean) / t_std)[:, None] +y_norm = ((y_obs - y_mean) / y_std)[:, None] + +D = gpx.Dataset(X=t_norm, y=y_norm) + +# %% [markdown] +# We build a GP prior with an HSGP kernel using a Matern-3/2 base kernel. We +# use $m = 30$ basis functions and set $L$ to 1.2 times the half-range of +# the normalised inputs, providing a modest buffer beyond the data boundary. + +# %% +data_half_range = float((t_norm.max() - t_norm.min()) / 2.0) +L = 1.2 * data_half_range + +base_kernel = gpx.kernels.Matern32(n_dims=1) +hsgp_kernel = gpx.kernels.HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=L, + center=float(t_norm.mean()), +) + +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=hsgp_kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +posterior = prior * likelihood + +# %% [markdown] +# We optimise the negative conjugate marginal log-likelihood using L-BFGS-B, +# learning the kernel lengthscale, variance, and observation noise. + +# %% +opt_posterior, history = gpx.fit_scipy( + model=posterior, + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), + train_data=D, + trainable=gpx.parameters.Parameter, +) + +print( + f"Learned lengthscale: {opt_posterior.prior.kernel.base_kernel.lengthscale[...]:.4f}" +) +print( + f"Learned variance: {opt_posterior.prior.kernel.base_kernel.variance[...]:.4f}" +) +print(f"Learned obs. noise: {opt_posterior.likelihood.obs_stddev[...] ** 2:.4f}") + +# %% [markdown] +# With optimised hyperparameters, we compute the posterior predictive on a +# fine grid spanning the training domain. + +# %% +t_test = jnp.linspace(t_norm.min() - 0.05, t_norm.max() + 0.05, 500)[:, None] + +latent_dist = opt_posterior.predict( + t_test, train_data=D, return_covariance_type="diagonal" +) +predictive_dist = opt_posterior.likelihood(latent_dist) + +pred_mean = predictive_dist.mean * y_std + y_mean +pred_std = jnp.sqrt(predictive_dist.variance) * y_std + +# Convert test grid back to dates for plotting. +t_test_days = (t_test.squeeze() * t_std + t_mean).astype(int) +dates_test = daily.index[0] + pd.to_timedelta(t_test_days, unit="D") + +# %% +fig, ax = plt.subplots(figsize=(7.5, 3.0)) +ax.scatter(daily.index, y_obs, s=1, alpha=0.3, color=cols[0], label="Observations") +ax.plot(dates_test, pred_mean, color=cols[1], linewidth=1.5, label="Predictive mean") +ax.fill_between( + dates_test, + pred_mean - 2 * pred_std, + pred_mean + 2 * pred_std, + alpha=0.2, + color=cols[1], + label="Two sigma", +) +ax.set_xlabel("Date") +ax.set_ylabel("NO$_2$ ($\\mu$g m$^{-3}$)") +ax.legend(loc="upper right", fontsize=8) +ax.set_title("HSGP posterior for daily NO$_2$ at Marylebone Road") + +# %% [markdown] +# The HSGP posterior captures seasonal structure and the long-term downward +# trend, with credible intervals that widen where data are sparser. The +# computation completed in seconds on a single CPU; an exact GP over the same +# dataset would require forming and decomposing a dense Gram matrix with over +# ten million entries. + +# %% [markdown] +# ## References +# +# - Solin, A. and Sarkka, S. (2020). +# [Hilbert Space Methods for Reduced-Rank Gaussian Process Regression](https://link.springer.com/article/10.1007/s11222-019-09886-w). +# *Statistics and Computing*. +# - Orduz, J. and Martin, O. A. +# [PyMC HSGP tutorial](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html). +# - Orduz, J. (2022). +# [Introduction to the HSGP](https://juanitorduz.github.io/hsgp_intro/). +# - UK [Automatic Urban and Rural Network](https://uk-air.defra.gov.uk/networks/site-info?site_id=MY1) +# air quality data via the openair data service. diff --git a/gpjax/gps.py b/gpjax/gps.py index ae25c8777..bd813aee5 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -46,7 +46,9 @@ from gpjax.linalg.operations import ( lower_cholesky, ) +from gpjax.linalg.operators import LowRank from gpjax.linalg.utils import add_jitter +from gpjax.linalg.woodbury import woodbury_solve from gpjax.mean_functions import AbstractMeanFunction from gpjax.parameters import ( Parameter, @@ -380,7 +382,9 @@ def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: ####################### # GP Posteriors -#######################from gpjax.linalg.operators import LinearOperator +####################### + + class AbstractPosterior(nnx.Module, tp.Generic[P, L]): r"""Abstract Gaussian process posterior. @@ -587,60 +591,151 @@ def predict( kernel = self.prior.kernel x, y = train_data.X, train_data.y - P = self.likelihood.num_outputs + num_outputs = self.likelihood.num_outputs - # Prepare targets via likelihood protocol (identity for single-output, - # output-major reshape for multi-output) mx = self.prior.mean_function(x) y_flat, mx_flat = self.likelihood.prepare_targets(y, mx) noise = self.likelihood.noise_vector(train_data.n) Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) - Sigma_dense = Kxx_dense + jnp.diag(noise) - Sigma = psd(Dense(Sigma_dense)) - L_sigma = lower_cholesky(Sigma) - Kxt = kernel.cross_covariance(x, test_inputs) - L_inv_Kxt = solve(L_sigma, Kxt) - L_inv_y_diff = solve(L_sigma, y_flat - mx_flat) - mean_t_raw = self.prior.mean_function(test_inputs) - mean_t = jnp.tile(mean_t_raw, (P, 1)) if P > 1 else mean_t_raw - mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) + # Posterior test-point mean: m(x*) + Kxt^T Sigma^{-1} (y - m(x)) + mean_test_raw = self.prior.mean_function(test_inputs) + mean_test = ( + jnp.tile(mean_test_raw, (num_outputs, 1)) + if num_outputs > 1 + else mean_test_raw + ) - # Diagonal covariance not yet supported for multi-output - if return_covariance_type == "diagonal" and P > 1: + if return_covariance_type == "diagonal" and num_outputs > 1: warnings.warn( - "Diagonal covariance is not yet supported for multi-output GPs. " - "Returning full covariance.", + "Diagonal covariance is not yet supported for " + "multi-output GPs. Returning full covariance.", stacklevel=2, ) return_covariance_type = "dense" - def _return_full_covariance(L_inv_Kxt, t): - Ktt = kernel.gram(t) - covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) - covariance = add_jitter(covariance, self.prior.jitter) - covariance = psd(Dense(covariance)) - return covariance + if isinstance(Kxx, LowRank): + mean, cov = self._predict_lowrank( + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ) + else: + mean, cov = self._predict_dense( + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ) - def _return_diagonal_covariance(L_inv_Kxt, t): - Ktt = kernel.diagonal(t).diagonal - covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) + return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov) + + def _predict_lowrank( + self, + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ): + r"""Posterior prediction using the Woodbury identity. + + When K_xx = W W^T (low-rank), the marginal covariance is + Sigma = W W^T + D with D = diag(noise + jitter). The Woodbury + identity gives Sigma^{-1} in O(N m^2) instead of O(N^3). + """ + W = Kxx.factor + noise_with_jitter = noise + self.jitter + + # alpha = Sigma^{-1} (y - m(x)) + alpha = woodbury_solve(W, noise_with_jitter, (y_flat - mx_flat).squeeze()) + mean = mean_test + jnp.matmul(Kxt.T, alpha[:, None]) + + # Sigma^{-1} Kxt for the covariance update + Sigma_inv_Kxt = woodbury_solve(W, noise_with_jitter, Kxt) + + def _full_covariance(Sigma_inv_Kxt, test_inputs): + Ktt = kernel.gram(test_inputs).to_dense() + covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) + return psd(Dense(add_jitter(covariance, self.prior.jitter))) + + def _diagonal_covariance(Sigma_inv_Kxt, test_inputs): + Ktt_diag = kernel.diagonal(test_inputs).diagonal + covariance = Ktt_diag - jnp.einsum("ij,ij->j", Kxt, Sigma_inv_Kxt) covariance += self.prior.jitter - covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) - return covariance + return psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) cov = jax.lax.cond( return_covariance_type == "dense", - _return_full_covariance, - _return_diagonal_covariance, + _full_covariance, + _diagonal_covariance, + Sigma_inv_Kxt, + test_inputs, + ) + + return mean, cov + + def _predict_dense( + self, + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ): + r"""Posterior prediction via the standard Cholesky decomposition. + + Sigma = K_xx + D, L L^T = Sigma, then solves use L^{-1}. + """ + Sigma_dense = add_jitter(Kxx.to_dense(), self.jitter) + jnp.diag(noise) + L_sigma = lower_cholesky(psd(Dense(Sigma_dense))) + + L_inv_Kxt = solve(L_sigma, Kxt) + L_inv_diff = solve(L_sigma, y_flat - mx_flat) + + mean = mean_test + jnp.matmul(L_inv_Kxt.T, L_inv_diff) + + def _full_covariance(L_inv_Kxt, test_inputs): + Ktt = kernel.gram(test_inputs).to_dense() + covariance = Ktt - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) + return psd(Dense(add_jitter(covariance, self.prior.jitter))) + + def _diagonal_covariance(L_inv_Kxt, test_inputs): + Ktt_diag = kernel.diagonal(test_inputs).diagonal + covariance = Ktt_diag - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) + covariance += self.prior.jitter + return psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) + + cov = jax.lax.cond( + return_covariance_type == "dense", + _full_covariance, + _diagonal_covariance, L_inv_Kxt, test_inputs, ) - return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov) + return mean, cov def sample_approx( self, diff --git a/gpjax/kernels/approximations/hsgp.py b/gpjax/kernels/approximations/hsgp.py index e6cacc8b2..2c5d6c487 100644 --- a/gpjax/kernels/approximations/hsgp.py +++ b/gpjax/kernels/approximations/hsgp.py @@ -40,11 +40,11 @@ class HSGP(AbstractKernel): which is a Bayesian linear model with :math:`m` basis functions. Args: - base_kernel: The stationary kernel to approximate. Must have a - ``spectral_density`` property that returns a + base_kernel: Stationary kernel to approximate. Must provide a + ``spectral_density`` property returning a :class:`~gpjax.kernels.stationary.utils.SpectralDensity`. num_basis_fns: Number of basis functions :math:`m`. - domain_half_width: Half-width of the approximation domain :math:`L`. + domain_half_width: Half-width :math:`L` of the approximation domain. Inputs should lie well inside :math:`[-L, L]` (after centering). center: Center of the data domain. If ``None``, it is set automatically on the first call to :meth:`eigenfunctions` as the @@ -74,7 +74,6 @@ def __init__( "HSGP can only approximate stationary kernels. " f"Got {type(base_kernel).__name__}." ) - # Verify that the base kernel has a usable spectral density. _ = base_kernel.spectral_density super().__init__( @@ -88,10 +87,6 @@ def __init__( self._center = center self.name = f"{self.base_kernel.name} (HSGP)" - # ------------------------------------------------------------------ - # Basis functions - # ------------------------------------------------------------------ - def eigenvalues(self) -> Float[Array, " m"]: r"""Square roots of the Laplacian eigenvalues. @@ -101,8 +96,8 @@ def eigenvalues(self) -> Float[Array, " m"]: Returns: Array of shape ``(m,)``. """ - j = jnp.arange(1, self.num_basis_fns + 1) - return j * jnp.pi / (2.0 * self.domain_half_width) + indices = jnp.arange(1, self.num_basis_fns + 1) + return indices * jnp.pi / (2.0 * self.domain_half_width) def eigenfunctions(self, x: Float[Array, "N 1"]) -> Float[Array, "N m"]: r"""Laplacian eigenfunctions evaluated at *x*. @@ -122,10 +117,12 @@ def eigenfunctions(self, x: Float[Array, "N 1"]) -> Float[Array, "N m"]: if self._center is None: self._center = float((x.max() + x.min()) / 2.0) - L = self.domain_half_width - x_centered = x - self._center # [N, 1] - sqrt_eigenvalues = self.eigenvalues() # [m] - return jnp.sin((x_centered + L) * sqrt_eigenvalues) / jnp.sqrt(L) + half_width = self.domain_half_width + x_centered = x - self._center + sqrt_eigenvalues = self.eigenvalues() + return jnp.sin((x_centered + half_width) * sqrt_eigenvalues) / jnp.sqrt( + half_width + ) def spectral_weights(self) -> Float[Array, " m"]: r"""Spectral density evaluated at the eigenvalue square roots. @@ -137,8 +134,7 @@ def spectral_weights(self) -> Float[Array, " m"]: Array of shape ``(m,)`` with the spectral weights. """ omega = self.eigenvalues() - sd = self.base_kernel.spectral_density - return sd( + return self.base_kernel.spectral_density( omega, self.base_kernel.variance[...], self.base_kernel.lengthscale[...], @@ -159,16 +155,12 @@ def compute_basis( x: Input locations of shape ``(N, 1)``. Returns: - Tuple ``(phi, sqrt_psd)`` where ``phi`` has shape ``(N, m)`` - and ``sqrt_psd`` has shape ``(m,)``. + Tuple ``(phi, sqrt_spectral_weights)`` where ``phi`` has shape + ``(N, m)`` and ``sqrt_spectral_weights`` has shape ``(m,)``. """ phi = self.eigenfunctions(x) - spd = self.spectral_weights() - return phi, jnp.sqrt(spd) - - # ------------------------------------------------------------------ - # AbstractKernel interface - # ------------------------------------------------------------------ + sqrt_spectral_weights = jnp.sqrt(self.spectral_weights()) + return phi, sqrt_spectral_weights def __call__( self, diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index d4a06c2a2..07c8818d8 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -7,67 +7,69 @@ from gpjax.kernels.computations.base import AbstractKernelComputation from gpjax.linalg import ( Diagonal, + LowRank, ) +from gpjax.linalg.utils import psd from gpjax.typing import Array K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") -# TODO: Use low rank linear operator! - class BasisFunctionComputation(AbstractKernelComputation): - r"""Compute engine class for finite basis function approximations to a kernel.""" + r"""Compute engine for finite basis function approximations (RFF).""" + + def gram(self, kernel: K, x: Float[Array, "N D"]) -> LowRank: + features = self.compute_features(kernel, x) + weighted_features = features * jnp.sqrt(self.scaling(kernel)) + return psd(LowRank(weighted_features)) def _cross_covariance( self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - z1 = self.compute_features(kernel, x) - z2 = self.compute_features(kernel, y) - return self.scaling(kernel) * jnp.matmul(z1, z2.T) + features_x = self.compute_features(kernel, x) + features_y = self.compute_features(kernel, y) + return self.scaling(kernel) * jnp.matmul(features_x, features_y.T) def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]: - z1 = self.compute_features(kernel, inputs) - return self.scaling(kernel) * jnp.matmul(z1, z1.T) + features = self.compute_features(kernel, inputs) + return self.scaling(kernel) * jnp.matmul(features, features.T) def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal: - r"""For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. + r"""Diagonal of the approximate Gram matrix. Args: - kernel (AbstractKernel): the kernel function. - inputs (Float[Array, "N D"]): The input matrix. + kernel: The RFF kernel. + inputs: Input matrix of shape ``(N, D)``. - Returns - ------- - Diagonal: The computed diagonal variance entries. + Returns: + Diagonal variance entries. """ return super().diagonal(kernel.base_kernel, inputs) def compute_features( self, kernel: K, x: Float[Array, "N D"] ) -> Float[Array, "N L"]: - r"""Compute the features for the inputs. + r"""Compute random Fourier features :math:`[\cos(z), \sin(z)]`. Args: - kernel: the kernel function. - x: the inputs to the kernel function of shape `(N, D)`. + kernel: The RFF kernel. + x: Inputs of shape ``(N, D)``. Returns: - A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$. + Feature matrix of shape ``(N, 2M)`` where ``M = num_basis_fns``. """ frequencies = kernel.frequencies - scaling_factor = kernel.base_kernel.lengthscale[...] - z = jnp.matmul(x, (frequencies / scaling_factor).T) - z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) - return z + lengthscale = kernel.base_kernel.lengthscale[...] + projected = jnp.matmul(x, (frequencies / lengthscale).T) + return jnp.concatenate([jnp.cos(projected), jnp.sin(projected)], axis=-1) def scaling(self, kernel: K) -> Float[Array, ""]: - r"""Compute the scaling factor for the covariance matrix. + r"""Variance scaling factor: :math:`\sigma^2 / M`. Args: - kernel: the kernel function. + kernel: The RFF kernel. Returns: - A scalar array representing the scaling factor. + Scalar scaling factor. """ return kernel.base_kernel.variance[...] / kernel.num_basis_fns diff --git a/gpjax/kernels/computations/hsgp.py b/gpjax/kernels/computations/hsgp.py index 3e0877d32..8baadb90f 100644 --- a/gpjax/kernels/computations/hsgp.py +++ b/gpjax/kernels/computations/hsgp.py @@ -7,7 +7,7 @@ import gpjax from gpjax.kernels.computations.base import AbstractKernelComputation -from gpjax.linalg import Dense, Diagonal +from gpjax.linalg import Diagonal, LowRank from gpjax.linalg.utils import psd from gpjax.typing import Array @@ -17,33 +17,36 @@ class HSGPComputation(AbstractKernelComputation): r"""Compute engine for the HSGP kernel approximation. - Computes the Gram matrix via the low-rank decomposition: - - .. math:: - \tilde{K} = \Phi \Lambda \Phi^\top - - where :math:`\Phi` is the matrix of Laplacian eigenfunctions and - :math:`\Lambda = \mathrm{diag}(S(\sqrt{\lambda_1}), \ldots, S(\sqrt{\lambda_m}))` - contains the spectral density evaluated at the eigenvalue square roots. + Builds the Gram matrix from the low-rank decomposition + :math:`\tilde{K} = \Phi \Lambda \Phi^\top` where + :math:`\Phi` contains Laplacian eigenfunctions and + :math:`\Lambda = \mathrm{diag}(S(\sqrt{\lambda_1}), \ldots, + S(\sqrt{\lambda_m}))`. """ - def gram(self, kernel: K, x: Float[Array, "N D"]) -> Dense: - Kxx = self._gram(kernel, x) - return psd(Dense(Kxx)) + def _weighted_basis( + self, kernel: K, x: Float[Array, "N D"] + ) -> tuple[Float[Array, "N m"], Float[Array, "N m"]]: + r"""Return ``(phi, weighted_phi)`` where ``weighted_phi = phi * sqrt(S)``.""" + phi, sqrt_spectral_weights = kernel.compute_basis(x) + weighted_phi = phi * sqrt_spectral_weights[None, :] + return phi, weighted_phi + + def gram(self, kernel: K, x: Float[Array, "N D"]) -> LowRank: + _, weighted_phi = self._weighted_basis(kernel, x) + return psd(LowRank(weighted_phi)) def _gram(self, kernel: K, x: Float[Array, "N D"]) -> Float[Array, "N N"]: - phi, sqrt_psd = kernel.compute_basis(x) - weighted = phi * sqrt_psd[None, :] - return weighted @ weighted.T + _, weighted_phi = self._weighted_basis(kernel, x) + return weighted_phi @ weighted_phi.T def _cross_covariance( self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - phi_x, sqrt_psd = kernel.compute_basis(x) - phi_y, _ = kernel.compute_basis(y) - return (phi_x * sqrt_psd[None, :]) @ (phi_y * sqrt_psd[None, :]).T + _, weighted_phi_x = self._weighted_basis(kernel, x) + _, weighted_phi_y = self._weighted_basis(kernel, y) + return weighted_phi_x @ weighted_phi_y.T def diagonal(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal: - phi, sqrt_psd = kernel.compute_basis(x) - weighted = phi * sqrt_psd[None, :] - return psd(Diagonal(jnp.sum(weighted**2, axis=1))) + _, weighted_phi = self._weighted_basis(kernel, x) + return psd(Diagonal(jnp.sum(weighted_phi**2, axis=1))) diff --git a/gpjax/kernels/stationary/base.py b/gpjax/kernels/stationary/base.py index 3cd476236..9108abc56 100644 --- a/gpjax/kernels/stationary/base.py +++ b/gpjax/kernels/stationary/base.py @@ -96,10 +96,11 @@ def __init__( @property def spectral_density(self) -> SpectralDensity: - r"""The spectral density of the kernel. + r"""Spectral density :math:`S(\omega)` of this kernel. - Returns a :class:`~gpjax.kernels.stationary.utils.SpectralDensity` - object that supports both ``sample()`` (for RFF) and + Subclasses override this to return a + :class:`~gpjax.kernels.stationary.utils.SpectralDensity` that + supports ``sample()`` (for RFF) and ``__call__(omega, variance, lengthscale)`` (for HSGP). """ raise NotImplementedError( diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index e4ab87565..e452f8144 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -49,9 +49,13 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: @property def spectral_density(self) -> SpectralDensity: - def _matern12_spectral_density(omega, variance, lengthscale): + r"""Matern-1/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \,\frac{2/\ell}{1/\ell^2 + \omega^2} + """ + + def _evaluate(omega, variance, lengthscale): return variance * (2.0 / lengthscale) / (1.0 / lengthscale**2 + omega**2) - return SpectralDensity( - build_student_t_distribution(nu=1), _matern12_spectral_density - ) + return SpectralDensity(build_student_t_distribution(nu=1), _evaluate) diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index eead4aa92..a288f3ac4 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -55,10 +55,16 @@ def __call__( @property def spectral_density(self) -> SpectralDensity: - def _matern32_spectral_density(omega, variance, lengthscale): + r"""Matern-3/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \, + \frac{4\,\alpha^3}{(3/\ell^2 + \omega^2)^2}, + \quad \alpha = \sqrt{3}/\ell + """ + + def _evaluate(omega, variance, lengthscale): alpha = jnp.sqrt(3.0) / lengthscale return variance * 4.0 * alpha**3 / (3.0 / lengthscale**2 + omega**2) ** 2 - return SpectralDensity( - build_student_t_distribution(nu=3), _matern32_spectral_density - ) + return SpectralDensity(build_student_t_distribution(nu=3), _evaluate) diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index c4431ae74..1c1e7c943 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -54,7 +54,15 @@ def __call__( @property def spectral_density(self) -> SpectralDensity: - def _matern52_spectral_density(omega, variance, lengthscale): + r"""Matern-5/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \, + \frac{(16/3)\,\alpha^5}{(5/\ell^2 + \omega^2)^3}, + \quad \alpha = \sqrt{5}/\ell + """ + + def _evaluate(omega, variance, lengthscale): alpha = jnp.sqrt(5.0) / lengthscale return ( variance @@ -63,6 +71,4 @@ def _matern52_spectral_density(omega, variance, lengthscale): / (5.0 / lengthscale**2 + omega**2) ** 3 ) - return SpectralDensity( - build_student_t_distribution(nu=5), _matern52_spectral_density - ) + return SpectralDensity(build_student_t_distribution(nu=5), _evaluate) diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index ff71fcc3e..e1da294dc 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -45,7 +45,14 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: @property def spectral_density(self) -> SpectralDensity: - def _rbf_spectral_density(omega, variance, lengthscale): + r"""RBF spectral density. + + .. math:: + S(\omega) = \sigma^2 \sqrt{2\pi}\,\ell\, + \exp\!\bigl(-\tfrac{1}{2}\ell^2 \omega^2\bigr) + """ + + def _evaluate(omega, variance, lengthscale): return ( variance * jnp.sqrt(2.0 * jnp.pi) @@ -53,4 +60,4 @@ def _rbf_spectral_density(omega, variance, lengthscale): * jnp.exp(-0.5 * lengthscale**2 * omega**2) ) - return SpectralDensity(npd.Normal(0.0, 1.0), _rbf_spectral_density) + return SpectralDensity(npd.Normal(0.0, 1.0), _evaluate) diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index 8697123d7..e0e3261b1 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -24,95 +24,97 @@ def build_student_t_distribution(nu: int) -> npd.StudentT: - r"""Build a Student's t distribution with a fixed smoothness parameter. - - For a fixed half-integer smoothness parameter, compute the spectral density of a - Matérn kernel; a Student's t distribution. + r"""Student's t distribution for Matern spectral densities. Args: - nu (int): The smoothness parameter of the Matérn kernel. + nu: Degrees of freedom (equals the Matern smoothness parameter + :math:`\nu` mapped to the nearest integer: 1, 3, or 5). - Returns - ------- - tfp.Distribution: A Student's t distribution with the same smoothness parameter. + Returns: + A standard Student's t distribution with ``df=nu``. """ - dist = npd.StudentT(df=nu, loc=0.0, scale=1.0) - return dist + return npd.StudentT(df=nu, loc=0.0, scale=1.0) class SpectralDensity: - """Spectral density of a stationary kernel. + r"""Spectral density :math:`S(\omega)` of a stationary kernel. + + This class serves two roles: - Wraps a NumPyro distribution (for sampling, used by RFF) and adds - evaluation of the spectral density function S(omega) at arbitrary - frequencies (used by HSGP). + 1. **Sampling** (for RFF): delegates to a wrapped NumPyro distribution + via :meth:`sample`, drawing frequency samples :math:`\omega`. + 2. **Evaluation** (for HSGP): computes :math:`S(\omega)` at arbitrary + frequencies via :meth:`__call__`, incorporating kernel variance and + lengthscale. Args: - distribution: A NumPyro distribution to delegate ``sample()`` to. - evaluate_fn: A callable ``(omega, variance, lengthscale) -> S(omega)`` - that computes the un-normalized spectral density at the given - frequencies incorporating kernel variance and lengthscale. + distribution: NumPyro distribution to sample from (used by RFF). + evaluate_fn: Callable ``(omega, variance, lengthscale) -> S(omega)`` + that evaluates the spectral density at given frequencies. """ def __init__( self, distribution: npd.Distribution, - evaluate_fn: tp.Callable, + evaluate_fn: tp.Callable[ + [Float[Array, " M"], ScalarFloat, ScalarFloat], Float[Array, " M"] + ], ): self._distribution = distribution self._evaluate_fn = evaluate_fn - def sample(self, key, sample_shape): - """Draw samples from the spectral density distribution. + def sample(self, key: Array, sample_shape: tuple[int, ...]) -> Float[Array, "..."]: + """Draw frequency samples from the spectral distribution (used by RFF). - This delegates to the wrapped NumPyro distribution and is used by - Random Fourier Features (RFF). + Args: + key: JAX PRNG key. + sample_shape: Shape of the sample array. + + Returns: + Sampled frequencies. """ return self._distribution.sample(key=key, sample_shape=sample_shape) - def __call__(self, omega, variance, lengthscale): - """Evaluate S(omega) incorporating kernel variance and lengthscale. - - Parameters - ---------- - omega : Array - Frequencies at which to evaluate the spectral density. - variance : ScalarFloat - Kernel variance parameter (sigma^2). - lengthscale : ScalarFloat - Kernel lengthscale parameter (ell). - - Returns - ------- - Array - Spectral density values S(omega). + def __call__( + self, + omega: Float[Array, " M"], + variance: ScalarFloat, + lengthscale: ScalarFloat, + ) -> Float[Array, " M"]: + r"""Evaluate :math:`S(\omega)` at the given frequencies (used by HSGP). + + Args: + omega: Frequencies at which to evaluate the spectral density. + variance: Kernel variance :math:`\sigma^2`. + lengthscale: Kernel lengthscale :math:`\ell`. + + Returns: + Spectral density values :math:`S(\omega)`. """ return self._evaluate_fn(omega, variance, lengthscale) def squared_distance(x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: - r"""Compute the squared distance between a pair of inputs. + r"""Squared Euclidean distance :math:`\lVert x - y \rVert^2`. Args: - x (Float[Array, " D"]): First input. - y (Float[Array, " D"]): Second input. + x: First input vector. + y: Second input vector. - Returns - ------- - ScalarFloat: The squared distance between the inputs. + Returns: + The squared distance between the inputs. """ return jnp.sum((x - y) ** 2) def euclidean_distance(x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: - r"""Compute the euclidean distance between a pair of inputs. + r"""Euclidean distance :math:`\lVert x - y \rVert`, clamped for stability. Args: - x (Float[Array, " D"]): First input. - y (Float[Array, " D"]): Second input. + x: First input vector. + y: Second input vector. - Returns - ------- - ScalarFloat: The euclidean distance between the inputs. + Returns: + The Euclidean distance between the inputs. """ return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) diff --git a/gpjax/linalg/__init__.py b/gpjax/linalg/__init__.py index 489ac2cbd..366caebac 100644 --- a/gpjax/linalg/__init__.py +++ b/gpjax/linalg/__init__.py @@ -13,6 +13,7 @@ Identity, Kronecker, LinearOperator, + LowRank, Triangular, ) from gpjax.linalg.utils import ( @@ -28,6 +29,7 @@ "Identity", "Kronecker", "LinearOperator", + "LowRank", "Triangular", "diag", "logdet", diff --git a/gpjax/linalg/operations.py b/gpjax/linalg/operations.py index d287d49a9..1ccb69fd2 100644 --- a/gpjax/linalg/operations.py +++ b/gpjax/linalg/operations.py @@ -12,6 +12,7 @@ Identity, Kronecker, LinearOperator, + LowRank, Triangular, ) from gpjax.typing import ScalarFloat @@ -219,6 +220,10 @@ def _handle_blockdiag(A): def _handle_dense(A): return jnp.diag(A.array) + def _handle_lowrank(A): + # diag(W W^T)_i = sum_j W_{ij}^2 + return jnp.sum(A.factor**2, axis=1) + def _handle_default(A): return jnp.diag(A.to_dense()) @@ -229,6 +234,7 @@ def _handle_default(A): Kronecker: _handle_kronecker, BlockDiag: _handle_blockdiag, Dense: _handle_dense, + LowRank: _handle_lowrank, } handler = dispatch_table.get(type(A), _handle_default) diff --git a/gpjax/linalg/operators.py b/gpjax/linalg/operators.py index db49a9e43..98f14dabe 100644 --- a/gpjax/linalg/operators.py +++ b/gpjax/linalg/operators.py @@ -407,6 +407,52 @@ def _kronecker_tree_unflatten(aux_data, children): jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten) + +class LowRank(LinearOperator): + r"""Low-rank matrix K = W W^T where W has shape (N, m) with m << N. + + Basis-function kernel approximations (HSGP, RFF) produce a kernel + matrix that factors as K = W W^T. Storing only the (N, m) factor W + rather than the full (N, N) matrix enables O(N m^2) inference via the + Woodbury identity instead of O(N^3) Cholesky-based inference. + """ + + def __init__(self, factor: Float[Array, "N m"]): + super().__init__() + self.factor = factor + + @property + def shape(self) -> tuple[int, int]: + num_data = self.factor.shape[0] + return (num_data, num_data) + + @property + def rank(self) -> int: + return self.factor.shape[1] + + @property + def dtype(self) -> jnp.dtype: + return self.factor.dtype + + def to_dense(self) -> Float[Array, "N N"]: + return self.factor @ self.factor.T + + @property + def T(self) -> "LowRank": + # W W^T is symmetric, so the transpose is itself. + return self + + +def _lowrank_tree_flatten(lowrank): + return (lowrank.factor,), None + + +def _lowrank_tree_unflatten(aux_data, children): + return LowRank(children[0]) + + +jtu.register_pytree_node(LowRank, _lowrank_tree_flatten, _lowrank_tree_unflatten) + __all__ = [ "BlockDiag", "Dense", @@ -414,5 +460,6 @@ def _kronecker_tree_unflatten(aux_data, children): "Identity", "Kronecker", "LinearOperator", + "LowRank", "Triangular", ] diff --git a/gpjax/linalg/woodbury.py b/gpjax/linalg/woodbury.py new file mode 100644 index 000000000..c96d74d36 --- /dev/null +++ b/gpjax/linalg/woodbury.py @@ -0,0 +1,121 @@ +r"""Woodbury identity helpers for efficient low-rank + diagonal solves. + +When a kernel is approximated by a low-rank factorisation K \approx W W^T +(e.g. via HSGP or RFF), the marginal covariance becomes + + \Sigma = W W^T + D, where D = diag(noise), W is (N, m), m << N. + +Naively inverting \Sigma costs O(N^3). The *Woodbury matrix identity* reduces +this to O(N m^2 + m^3) by working with the small (m x m) *capacitance matrix* + + A = I_m + W^T D^{-1} W. + +The three operations exposed here --- solve, log-determinant, and quadratic +form --- are the building blocks for the GP marginal likelihood and posterior +prediction when the kernel has low-rank structure. +""" + +import jax.numpy as jnp +import jax.scipy as jsp +from jaxtyping import Float + +from gpjax.typing import Array, ScalarFloat + + +def _capacitance_cholesky( + W: Float[Array, "N m"], + noise_inv: Float[Array, " N"], +) -> Float[Array, "m m"]: + r"""Cholesky factor of the capacitance matrix A = I_m + W^T D^{-1} W. + + This is the shared computation underlying all Woodbury operations. + + Args: + W: Factor matrix, shape (N, m). + noise_inv: Element-wise reciprocal of the diagonal noise, shape (N,). + + Returns: + Lower-triangular Cholesky factor L_A such that A = L_A L_A^T. + """ + m = W.shape[1] + Dinv_W = noise_inv[:, None] * W + A = jnp.eye(m) + W.T @ Dinv_W + return jnp.linalg.cholesky(A) + + +def woodbury_solve( + W: Float[Array, "N m"], + noise: Float[Array, " N"], + b: Float[Array, "N ..."], +) -> Float[Array, "N ..."]: + r"""Solve (W W^T + D) x = b via the Woodbury identity. + + .. math:: + + x = D^{-1} b - D^{-1} W \, A^{-1} \, W^T D^{-1} b + + where D = diag(noise) and A = I_m + W^T D^{-1} W. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + b: Right-hand side, shape (N,) or (N, K). + + Returns: + Solution x with same shape as b. + """ + noise_inv = 1.0 / noise + Dinv_W = noise_inv[:, None] * W + Dinv_b = noise_inv * b if b.ndim == 1 else noise_inv[:, None] * b + + L_A = _capacitance_cholesky(W, noise_inv) + + WtDinv_b = Dinv_W.T @ b + forward = jsp.linalg.solve_triangular(L_A, WtDinv_b, lower=True) + Ainv_WtDinv_b = jsp.linalg.solve_triangular(L_A.T, forward, lower=False) + + return Dinv_b - Dinv_W @ Ainv_WtDinv_b + + +def woodbury_logdet( + W: Float[Array, "N m"], + noise: Float[Array, " N"], +) -> ScalarFloat: + r"""Log-determinant of W W^T + D via the matrix determinant lemma. + + .. math:: + + \log|\Sigma| = \log|A| + \sum_i \log(\text{noise}_i) + + where A = I_m + W^T D^{-1} W. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + + Returns: + Log-determinant as a scalar. + """ + noise_inv = 1.0 / noise + L_A = _capacitance_cholesky(W, noise_inv) + logdet_A = 2.0 * jnp.sum(jnp.log(jnp.diag(L_A))) + return logdet_A + jnp.sum(jnp.log(noise)) + + +def woodbury_quad( + W: Float[Array, "N m"], + noise: Float[Array, " N"], + diff: Float[Array, " N"], +) -> ScalarFloat: + r"""Quadratic form diff^T \Sigma^{-1} diff where \Sigma = W W^T + D. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + diff: Vector, shape (N,). + + Returns: + Quadratic form as a scalar. + """ + solved = woodbury_solve(W, noise, diff) + return diff @ solved diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 236383b93..7c4d36004 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -23,7 +23,9 @@ psd, solve, ) +from gpjax.linalg.operators import LowRank from gpjax.linalg.utils import add_jitter +from gpjax.linalg.woodbury import woodbury_logdet, woodbury_quad from gpjax.typing import ( Array, ScalarFloat, @@ -120,12 +122,25 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat: noise = posterior.likelihood.noise_vector(data.n) Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter) - Sigma_dense = Kxx_dense + jnp.diag(noise) - Sigma = psd(Dense(Sigma_dense)) + diff = jnp.atleast_1d((y_flat - mx_flat).squeeze()) + + if isinstance(Kxx, LowRank): + # Low-rank path: Sigma = W W^T + D. + # Use the Woodbury identity for O(N m^2) instead of O(N^3) Cholesky. + W = Kxx.factor + noise_with_jitter = noise + posterior.prior.jitter + num_datapoints = x.shape[0] * posterior.likelihood.num_outputs + + log_det = woodbury_logdet(W, noise_with_jitter) + quadratic = woodbury_quad(W, noise_with_jitter, diff) - mll = GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) - return mll.log_prob(jnp.atleast_1d(y_flat.squeeze())).squeeze() + return -0.5 * (num_datapoints * jnp.log(2.0 * jnp.pi) + log_det + quadratic) + + # Dense path: Sigma = Kxx + D. Standard Cholesky-based log-probability. + Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter) + Sigma = psd(Dense(Kxx_dense + jnp.diag(noise))) + marginal = GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) + return marginal.log_prob(diff).squeeze() def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat: diff --git a/pyproject.toml b/pyproject.toml index acd01cae4..ee91b5cca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ docs = [ "markdown-katex>=202406.1035", "scikit-learn>=1.5.1", "ucimlrepo>=0.0.7", + "rdata>=1.0.0", ] [tool.uv] diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index c78d66120..1771b631d 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -13,7 +13,7 @@ RationalQuadratic, StationaryKernel, ) -from gpjax.linalg.operators import Dense +from gpjax.linalg.operators import LowRank import jax from jax import config import jax.numpy as jnp @@ -52,7 +52,7 @@ def test_gram( linop = approximate.gram(x) # Check the return type - assert isinstance(linop, Dense) + assert isinstance(linop, LowRank) Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter diff --git a/tests/test_kernels/test_computations.py b/tests/test_kernels/test_computations.py index d1883ab2b..45cfdcde2 100644 --- a/tests/test_kernels/test_computations.py +++ b/tests/test_kernels/test_computations.py @@ -15,6 +15,7 @@ PSD, Dense, Diagonal, + LowRank, ) import jax.numpy as jnp import jax.random as jr @@ -67,7 +68,7 @@ def test_gram_shape(self, rff_kernel): """Gram matrix should be (N, N).""" x = jr.normal(jr.key(1), (10, 2)) gram = rff_kernel.gram(x) - assert isinstance(gram, Dense) + assert isinstance(gram, LowRank) assert PSD in gram.annotations assert gram.shape == (10, 10) diff --git a/tests/test_kernels/test_hsgp.py b/tests/test_kernels/test_hsgp.py index 23028946a..82bbf28e8 100644 --- a/tests/test_kernels/test_hsgp.py +++ b/tests/test_kernels/test_hsgp.py @@ -2,7 +2,8 @@ from gpjax.kernels.approximations.hsgp import HSGP from gpjax.kernels.stationary import RBF, Matern12, Matern32, Matern52 -from gpjax.linalg.operators import Dense +from gpjax.linalg.operators import LowRank +import jax from jax import config import jax.numpy as jnp import jax.random as jr @@ -11,307 +12,298 @@ config.update("jax_enable_x64", True) +STATIONARY_KERNELS = [RBF, Matern12, Matern32, Matern52] -class TestEigenvalues: - def test_eigenvalue_count(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), num_basis_fns=10, domain_half_width=3.0 - ) - evals = kernel.eigenvalues() - assert evals.shape == (10,) - - def test_eigenvalue_formula(self): - """sqrt(lambda_j) = j * pi / (2 * L).""" - L = 5.0 - m = 4 - kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=m, domain_half_width=L) - evals = kernel.eigenvalues() - expected = jnp.arange(1, m + 1) * jnp.pi / (2.0 * L) - npt.assert_allclose(evals, expected) - - def test_eigenvalues_increase(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), num_basis_fns=20, domain_half_width=3.0 - ) - evals = kernel.eigenvalues() - assert jnp.all(jnp.diff(evals) > 0) - - -class TestEigenfunctions: - def test_shape(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), - num_basis_fns=10, - domain_half_width=3.0, - center=0.0, - ) - x = jnp.linspace(-2, 2, 50)[:, None] - phi = kernel.eigenfunctions(x) - assert phi.shape == (50, 10) - - def test_orthonormality(self): - """Eigenfunctions should be approximately orthonormal over [-L, L].""" - L = 3.0 - m = 5 - kernel = HSGP( - base_kernel=RBF(n_dims=1), num_basis_fns=m, domain_half_width=L, center=0.0 - ) - # Dense grid for numerical integration - n_quad = 10000 - x = jnp.linspace(-L, L, n_quad)[:, None] - phi = kernel.eigenfunctions(x) - dx = 2.0 * L / n_quad - gram = phi.T @ phi * dx # Approximate integral - npt.assert_allclose(gram, jnp.eye(m), atol=1e-2) - - def test_zero_at_boundaries(self): - """Eigenfunctions should be zero at x = -L and x = L.""" - L = 3.0 - kernel = HSGP( - base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=L, center=0.0 - ) - boundaries = jnp.array([[-L], [L]]) - phi = kernel.eigenfunctions(boundaries) - npt.assert_allclose(phi, 0.0, atol=1e-12) - - -class TestCentering: - def test_explicit_center(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), - num_basis_fns=5, - domain_half_width=3.0, - center=1.0, - ) - assert kernel._center == 1.0 - - def test_auto_center(self): - kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=3.0) - x = jnp.linspace(2.0, 8.0, 50)[:, None] - _ = kernel.eigenfunctions(x) - assert kernel._center == pytest.approx(5.0) - - def test_auto_center_persists(self): - """Once set, auto-center should not change on subsequent calls.""" - kernel = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=5, domain_half_width=3.0) - x1 = jnp.linspace(0, 10, 50)[:, None] - x2 = jnp.linspace(-5, 5, 50)[:, None] - _ = kernel.eigenfunctions(x1) - center_after_first = kernel._center - _ = kernel.eigenfunctions(x2) - assert kernel._center == center_after_first - - -class TestComputeBasis: - def test_returns_tuple(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), - num_basis_fns=10, - domain_half_width=3.0, - center=0.0, - ) - x = jnp.linspace(-2, 2, 50)[:, None] - phi, sqrt_psd = kernel.compute_basis(x) - assert phi.shape == (50, 10) - assert sqrt_psd.shape == (10,) - - def test_sqrt_psd_positive(self): - kernel = HSGP( - base_kernel=RBF(n_dims=1), - num_basis_fns=10, - domain_half_width=3.0, - center=0.0, - ) - x = jnp.linspace(-2, 2, 50)[:, None] - _, sqrt_psd = kernel.compute_basis(x) - assert jnp.all(sqrt_psd > 0) - - -class TestGram: - @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) - def test_gram_shape_and_psd(self, KernelClass): - base = KernelClass(n_dims=1) - hsgp = HSGP( - base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - x = jnp.linspace(-3, 3, 30)[:, None] - linop = hsgp.gram(x) - assert isinstance(linop, Dense) - K = linop.to_dense() - assert K.shape == (30, 30) - # PSD check - evals, _ = jnp.linalg.eigh(K + 1e-6 * jnp.eye(30)) - assert jnp.all(evals > 0) - - @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) - def test_cross_covariance_shape(self, KernelClass): - base = KernelClass(n_dims=1) - hsgp = HSGP( - base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - x1 = jnp.linspace(-3, 3, 30)[:, None] - x2 = jnp.linspace(-2, 2, 20)[:, None] - Kxy = hsgp.cross_covariance(x1, x2) - assert Kxy.shape == (30, 20) - - def test_gram_symmetric(self): - base = RBF(n_dims=1) - hsgp = HSGP( - base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - x = jnp.linspace(-3, 3, 30)[:, None] - K = hsgp.gram(x).to_dense() - npt.assert_allclose(K, K.T, atol=1e-12) - - @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) - def test_diagonal(self, KernelClass): - base = KernelClass(n_dims=1) - hsgp = HSGP( - base_kernel=base, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - x = jnp.linspace(-3, 3, 30)[:, None] - diag = hsgp.diagonal(x) - K = hsgp.gram(x).to_dense() - npt.assert_allclose(jnp.diag(diag.to_dense()), jnp.diag(K), atol=1e-10) - - -class TestConvergence: - @pytest.mark.parametrize("KernelClass", [RBF, Matern32, Matern52]) - def test_gram_converges_to_exact(self, KernelClass): - """With large m and appropriate L, HSGP Gram should converge to exact.""" - base = KernelClass(n_dims=1) - x = jnp.linspace(-1, 1, 30)[:, None] - exact = base.gram(x).to_dense() - - hsgp_coarse = HSGP( - base_kernel=base, num_basis_fns=10, domain_half_width=5.0, center=0.0 - ) - hsgp_fine = HSGP( - base_kernel=base, num_basis_fns=80, domain_half_width=5.0, center=0.0 - ) - - err_coarse = jnp.linalg.norm(exact - hsgp_coarse.gram(x).to_dense(), ord="fro") - err_fine = jnp.linalg.norm(exact - hsgp_fine.gram(x).to_dense(), ord="fro") - - # Finer approximation should have smaller error - assert err_fine < err_coarse - - def test_rbf_close_to_exact(self): - """RBF with large m should be very close to exact.""" - base = RBF(n_dims=1) - x = jnp.linspace(-1, 1, 20)[:, None] - exact = base.gram(x).to_dense() - - hsgp = HSGP( - base_kernel=base, num_basis_fns=100, domain_half_width=5.0, center=0.0 - ) - approx = hsgp.gram(x).to_dense() - max_err = jnp.max(jnp.abs(exact - approx)) - assert max_err < 0.01 - - -class TestValidation: - def test_nonstationary_kernel_rejected(self): - from gpjax.kernels.nonstationary import Linear - - with pytest.raises(TypeError): - HSGP(base_kernel=Linear(1), num_basis_fns=10, domain_half_width=3.0) - - def test_pointwise_call_raises(self): - hsgp = HSGP(base_kernel=RBF(n_dims=1), num_basis_fns=10, domain_half_width=3.0) - with pytest.raises(RuntimeError): - hsgp(jnp.array([1.0]), jnp.array([2.0])) - - -class TestIntegration: - def test_prior_posterior_pipeline(self): - """HSGP should work as a drop-in kernel in the Prior/Posterior pipeline.""" - from gpjax.dataset import Dataset - from gpjax.gps import Prior - from gpjax.likelihoods import Gaussian - from gpjax.mean_functions import Zero - from gpjax.objectives import conjugate_mll - - key = jr.key(42) - n = 50 - x = jnp.linspace(-3, 3, n)[:, None] - y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) - D = Dataset(X=x, y=y) - - base_kernel = RBF(n_dims=1) - hsgp = HSGP( - base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - prior = Prior(kernel=hsgp, mean_function=Zero()) - likelihood = Gaussian(num_datapoints=n) - posterior = prior * likelihood - - # MLL should return a finite scalar - mll = conjugate_mll(posterior, D) - assert jnp.isfinite(mll) - - def test_predict(self): - """HSGP posterior should produce finite mean and variance.""" - from gpjax.dataset import Dataset - from gpjax.gps import Prior - from gpjax.likelihoods import Gaussian - from gpjax.mean_functions import Zero - - key = jr.key(42) - n = 50 - x = jnp.linspace(-3, 3, n)[:, None] - y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) - D = Dataset(X=x, y=y) - - base_kernel = RBF(n_dims=1) - hsgp = HSGP( - base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - prior = Prior(kernel=hsgp, mean_function=Zero()) - likelihood = Gaussian(num_datapoints=n) - posterior = prior * likelihood - - x_test = jnp.linspace(-2.5, 2.5, 30)[:, None] - pred = posterior.predict(x_test, D) - - assert jnp.all(jnp.isfinite(pred.mean)) - assert jnp.all(jnp.isfinite(pred.covariance())) - - def test_mll_differentiable(self): - """conjugate_mll with HSGP must be differentiable w.r.t. kernel params.""" - from flax import nnx - from gpjax.dataset import Dataset - from gpjax.gps import Prior - from gpjax.likelihoods import Gaussian - from gpjax.mean_functions import Zero - from gpjax.objectives import conjugate_mll - import jax - - key = jr.key(42) - n = 50 - x = jnp.linspace(-3, 3, n)[:, None] - y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) - D = Dataset(X=x, y=y) - - base_kernel = RBF(n_dims=1) - hsgp = HSGP( - base_kernel=base_kernel, num_basis_fns=20, domain_half_width=5.0, center=0.0 - ) - prior = Prior(kernel=hsgp, mean_function=Zero()) - likelihood = Gaussian(num_datapoints=n) - posterior = prior * likelihood - - # Split into graphdef and state - graphdef, state = nnx.split(posterior) - - def loss(state): - model = nnx.merge(graphdef, state) - return -conjugate_mll(model, D) - - grad_fn = jax.grad(loss) - grads = grad_fn(state) - - # Gradients should be finite - flat_grads = jax.tree.leaves(grads) - for g in flat_grads: - assert jnp.all(jnp.isfinite(g)), f"Non-finite gradient: {g}" + +def _make_hsgp( + kernel_class=RBF, + num_basis_fns: int = 20, + domain_half_width: float = 5.0, + center: float = 0.0, +) -> HSGP: + """Create an HSGP with sensible defaults for testing.""" + base_kernel = kernel_class(n_dims=1) + return HSGP( + base_kernel=base_kernel, + num_basis_fns=num_basis_fns, + domain_half_width=domain_half_width, + center=center, + ) + + +# ────────────────────────────────────────────────────────────────────── +# Eigenvalues +# ────────────────────────────────────────────────────────────────────── + + +def test_eigenvalue_count(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + eigenvalues = hsgp.eigenvalues() + assert eigenvalues.shape == (10,) + + +def test_eigenvalue_formula(): + """sqrt(lambda_j) = j * pi / (2 * L).""" + half_width = 5.0 + num_basis = 4 + hsgp = _make_hsgp(num_basis_fns=num_basis, domain_half_width=half_width) + eigenvalues = hsgp.eigenvalues() + indices = jnp.arange(1, num_basis + 1) + expected = indices * jnp.pi / (2.0 * half_width) + npt.assert_allclose(eigenvalues, expected) + + +def test_eigenvalues_are_strictly_increasing(): + hsgp = _make_hsgp(num_basis_fns=20, domain_half_width=3.0) + eigenvalues = hsgp.eigenvalues() + assert jnp.all(jnp.diff(eigenvalues) > 0) + + +# ────────────────────────────────────────────────────────────────────── +# Eigenfunctions +# ────────────────────────────────────────────────────────────────────── + + +def test_eigenfunction_shape(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + basis_matrix = hsgp.eigenfunctions(inputs) + assert basis_matrix.shape == (50, 10) + + +def test_eigenfunctions_are_approximately_orthonormal(): + """Eigenfunctions should be approximately orthonormal over [-L, L].""" + half_width = 3.0 + num_basis = 5 + hsgp = _make_hsgp(num_basis_fns=num_basis, domain_half_width=half_width, center=0.0) + num_quadrature_points = 10_000 + inputs = jnp.linspace(-half_width, half_width, num_quadrature_points)[:, None] + basis_matrix = hsgp.eigenfunctions(inputs) + spacing = 2.0 * half_width / num_quadrature_points + gram_matrix = basis_matrix.T @ basis_matrix * spacing + npt.assert_allclose(gram_matrix, jnp.eye(num_basis), atol=1e-2) + + +def test_eigenfunctions_vanish_at_boundaries(): + """Eigenfunctions should be zero at x = -L and x = L.""" + half_width = 3.0 + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=half_width, center=0.0) + boundary_points = jnp.array([[-half_width], [half_width]]) + basis_at_boundary = hsgp.eigenfunctions(boundary_points) + npt.assert_allclose(basis_at_boundary, 0.0, atol=1e-12) + + +# ────────────────────────────────────────────────────────────────────── +# Centering +# ────────────────────────────────────────────────────────────────────── + + +def test_explicit_center_is_stored(): + hsgp = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=5, + domain_half_width=3.0, + center=1.0, + ) + assert hsgp._center == 1.0 + + +def test_auto_center_uses_midpoint_of_input_range(): + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=3.0, center=None) + inputs = jnp.linspace(2.0, 8.0, 50)[:, None] + _ = hsgp.eigenfunctions(inputs) + assert hsgp._center == pytest.approx(5.0) + + +def test_auto_center_persists_across_calls(): + """Once set, auto-center should not change on subsequent calls.""" + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=3.0, center=None) + first_inputs = jnp.linspace(0, 10, 50)[:, None] + second_inputs = jnp.linspace(-5, 5, 50)[:, None] + _ = hsgp.eigenfunctions(first_inputs) + center_after_first_call = hsgp._center + _ = hsgp.eigenfunctions(second_inputs) + assert hsgp._center == center_after_first_call + + +# ────────────────────────────────────────────────────────────────────── +# compute_basis +# ────────────────────────────────────────────────────────────────────── + + +def test_compute_basis_returns_correct_shapes(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + basis_matrix, sqrt_spectral_weights = hsgp.compute_basis(inputs) + assert basis_matrix.shape == (50, 10) + assert sqrt_spectral_weights.shape == (10,) + + +def test_compute_basis_sqrt_psd_is_positive(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + _, sqrt_spectral_weights = hsgp.compute_basis(inputs) + assert jnp.all(sqrt_spectral_weights > 0) + + +# ────────────────────────────────────────────────────────────────────── +# Gram, cross-covariance, and diagonal +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_gram_returns_psd_low_rank_matrix(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs = jnp.linspace(-3, 3, 30)[:, None] + gram_operator = hsgp.gram(inputs) + + assert isinstance(gram_operator, LowRank) + gram_dense = gram_operator.to_dense() + assert gram_dense.shape == (30, 30) + + eigenvalues_of_gram, _ = jnp.linalg.eigh(gram_dense + 1e-6 * jnp.eye(30)) + assert jnp.all(eigenvalues_of_gram > 0) + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_cross_covariance_shape(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs_a = jnp.linspace(-3, 3, 30)[:, None] + inputs_b = jnp.linspace(-2, 2, 20)[:, None] + cross_covariance = hsgp.cross_covariance(inputs_a, inputs_b) + assert cross_covariance.shape == (30, 20) + + +def test_gram_is_symmetric(): + hsgp = _make_hsgp(kernel_class=RBF) + inputs = jnp.linspace(-3, 3, 30)[:, None] + gram_dense = hsgp.gram(inputs).to_dense() + npt.assert_allclose(gram_dense, gram_dense.T, atol=1e-12) + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_diagonal_matches_gram_diagonal(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs = jnp.linspace(-3, 3, 30)[:, None] + diagonal_dense = hsgp.diagonal(inputs).to_dense() + gram_dense = hsgp.gram(inputs).to_dense() + npt.assert_allclose(jnp.diag(diagonal_dense), jnp.diag(gram_dense), atol=1e-10) + + +# ────────────────────────────────────────────────────────────────────── +# Convergence to exact kernel +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("KernelClass", [RBF, Matern32, Matern52]) +def test_gram_error_decreases_with_more_basis_functions(KernelClass): + """With more basis functions, the HSGP Gram matrix should converge to exact.""" + base_kernel = KernelClass(n_dims=1) + inputs = jnp.linspace(-1, 1, 30)[:, None] + exact_gram = base_kernel.gram(inputs).to_dense() + + hsgp_coarse = _make_hsgp(kernel_class=KernelClass, num_basis_fns=10) + hsgp_fine = _make_hsgp(kernel_class=KernelClass, num_basis_fns=80) + + error_coarse = jnp.linalg.norm(exact_gram - hsgp_coarse.gram(inputs).to_dense()) + error_fine = jnp.linalg.norm(exact_gram - hsgp_fine.gram(inputs).to_dense()) + + assert error_fine < error_coarse + + +def test_rbf_gram_closely_matches_exact(): + """RBF with many basis functions should be very close to exact.""" + base_kernel = RBF(n_dims=1) + inputs = jnp.linspace(-1, 1, 20)[:, None] + exact_gram = base_kernel.gram(inputs).to_dense() + + hsgp = _make_hsgp(kernel_class=RBF, num_basis_fns=100) + approximate_gram = hsgp.gram(inputs).to_dense() + + max_absolute_error = jnp.max(jnp.abs(exact_gram - approximate_gram)) + assert max_absolute_error < 0.01 + + +# ────────────────────────────────────────────────────────────────────── +# Input validation +# ────────────────────────────────────────────────────────────────────── + + +def test_nonstationary_kernel_is_rejected(): + from gpjax.kernels.nonstationary import Linear + + with pytest.raises(TypeError): + HSGP(base_kernel=Linear(1), num_basis_fns=10, domain_half_width=3.0) + + +def test_pointwise_call_raises(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + with pytest.raises(RuntimeError): + hsgp(jnp.array([1.0]), jnp.array([2.0])) + + +# ────────────────────────────────────────────────────────────────────── +# Integration with Prior/Posterior pipeline +# ────────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def training_data(): + """Synthetic sinusoidal regression dataset.""" + key = jr.key(42) + num_points = 50 + inputs = jnp.linspace(-3, 3, num_points)[:, None] + targets = jnp.sin(inputs) + 0.1 * jr.normal(key, (num_points, 1)) + + from gpjax.dataset import Dataset + + return Dataset(X=inputs, y=targets), num_points + + +@pytest.fixture +def hsgp_posterior(training_data): + """Conjugate posterior with an HSGP-RBF kernel.""" + from gpjax.gps import Prior + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + _, num_points = training_data + hsgp = _make_hsgp(kernel_class=RBF, num_basis_fns=20) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=num_points) + return prior * likelihood + + +def test_conjugate_mll_returns_finite_scalar(training_data, hsgp_posterior): + from gpjax.objectives import conjugate_mll + + dataset, _ = training_data + mll = conjugate_mll(hsgp_posterior, dataset) + assert jnp.isfinite(mll) + + +def test_posterior_predict_returns_finite_moments(training_data, hsgp_posterior): + dataset, _ = training_data + test_inputs = jnp.linspace(-2.5, 2.5, 30)[:, None] + prediction = hsgp_posterior.predict(test_inputs, dataset) + + assert jnp.all(jnp.isfinite(prediction.mean)) + assert jnp.all(jnp.isfinite(prediction.covariance())) + + +def test_conjugate_mll_is_differentiable(training_data, hsgp_posterior): + """conjugate_mll with HSGP must be differentiable w.r.t. kernel parameters.""" + from flax import nnx + from gpjax.objectives import conjugate_mll + + dataset, _ = training_data + graphdef, state = nnx.split(hsgp_posterior) + + def negative_mll(state): + model = nnx.merge(graphdef, state) + return -conjugate_mll(model, dataset) + + gradients = jax.grad(negative_mll)(state) + flat_gradients = jax.tree.leaves(gradients) + for grad_leaf in flat_gradients: + assert jnp.all(jnp.isfinite(grad_leaf)), f"Non-finite gradient: {grad_leaf}" diff --git a/tests/test_kernels/test_spectral_density.py b/tests/test_kernels/test_spectral_density.py index 0e8c883b0..fd29581f2 100644 --- a/tests/test_kernels/test_spectral_density.py +++ b/tests/test_kernels/test_spectral_density.py @@ -10,112 +10,122 @@ config.update("jax_enable_x64", True) +UNIT_VARIANCE = jnp.array(1.0) +UNIT_LENGTHSCALE = jnp.array(1.0) +TEST_FREQUENCIES = jnp.array([0.0, 1.0, 5.0]) -def test_spectral_density_has_sample(): + +# ────────────────────────────────────────────────────────────────────── +# SpectralDensity interface +# ────────────────────────────────────────────────────────────────────── + + +def test_spectral_density_has_sample_method(): """SpectralDensity must expose sample() for RFF compatibility.""" - kernel = RBF(n_dims=1) - sd = kernel.spectral_density - assert isinstance(sd, SpectralDensity) - samples = sd.sample(key=jr.key(0), sample_shape=(10, 1)) + spectral_density = RBF(n_dims=1).spectral_density + assert isinstance(spectral_density, SpectralDensity) + samples = spectral_density.sample(key=jr.key(0), sample_shape=(10, 1)) assert samples.shape == (10, 1) -def test_spectral_density_callable(): +def test_spectral_density_is_callable(): """SpectralDensity must be callable with (omega, variance, lengthscale).""" - kernel = RBF(n_dims=1) - sd = kernel.spectral_density - omega = jnp.array([0.5, 1.0, 2.0]) - result = sd(omega, jnp.array(1.0), jnp.array(1.0)) - assert result.shape == (3,) - assert jnp.all(result > 0) + spectral_density = RBF(n_dims=1).spectral_density + frequencies = jnp.array([0.5, 1.0, 2.0]) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) + assert values.shape == (3,) + assert jnp.all(values > 0) -def test_rbf_spectral_density_formula(): - """Verify the RBF spectral density against the known closed form. +# ────────────────────────────────────────────────────────────────────── +# Closed-form spectral density formulae +# ────────────────────────────────────────────────────────────────────── - S(w) = variance * sqrt(2*pi) * lengthscale * exp(-0.5 * lengthscale^2 * w^2) - """ - kernel = RBF(n_dims=1, variance=2.0, lengthscale=0.5) - sd = kernel.spectral_density - omega = jnp.array([0.0, 1.0, 3.0]) +def test_rbf_spectral_density_matches_closed_form(): + """S(w) = variance * sqrt(2*pi) * lengthscale * exp(-0.5 * lengthscale^2 * w^2).""" variance = jnp.array(2.0) lengthscale = jnp.array(0.5) + spectral_density = RBF(n_dims=1, variance=2.0, lengthscale=0.5).spectral_density - result = sd(omega, variance, lengthscale) + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) expected = ( variance * jnp.sqrt(2 * jnp.pi) * lengthscale - * jnp.exp(-0.5 * lengthscale**2 * omega**2) + * jnp.exp(-0.5 * lengthscale**2 * TEST_FREQUENCIES**2) ) npt.assert_allclose(result, expected, atol=1e-12) -def test_rbf_spectral_density_peak_at_zero(): +def test_rbf_spectral_density_peaks_at_zero_and_decays(): """RBF spectral density peaks at omega=0 and decays monotonically.""" - kernel = RBF(n_dims=1) - sd = kernel.spectral_density - omega = jnp.linspace(0, 10, 100) - values = sd(omega, jnp.array(1.0), jnp.array(1.0)) - # Peak at omega=0 + spectral_density = RBF(n_dims=1).spectral_density + frequencies = jnp.linspace(0, 10, 100) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) assert values[0] == jnp.max(values) - # Monotonically decreasing assert jnp.all(jnp.diff(values) <= 0) -def test_matern12_spectral_density_formula(): - """Verify Matern12 (nu=1/2) spectral density. - - S(w) = variance * 2/ell * 1/(1/ell^2 + w^2) - """ - kernel = Matern12(n_dims=1, variance=1.5, lengthscale=0.8) - sd = kernel.spectral_density - omega = jnp.array([0.0, 1.0, 5.0]) - v, ell = jnp.array(1.5), jnp.array(0.8) +def test_matern12_spectral_density_matches_closed_form(): + """S(w) = variance * 2/ell * 1/(1/ell^2 + w^2).""" + variance = jnp.array(1.5) + lengthscale = jnp.array(0.8) + spectral_density = Matern12( + n_dims=1, variance=1.5, lengthscale=0.8 + ).spectral_density - result = sd(omega, v, ell) - expected = v * (2.0 / ell) / (1.0 / ell**2 + omega**2) + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + expected = ( + variance * (2.0 / lengthscale) / (1.0 / lengthscale**2 + TEST_FREQUENCIES**2) + ) npt.assert_allclose(result, expected, atol=1e-12) -def test_matern32_spectral_density_formula(): - """Verify Matern32 (nu=3/2) spectral density. - - S(w) = variance * 4*(sqrt(3)/ell)^3 / (3/ell^2 + w^2)^2 - """ - kernel = Matern32(n_dims=1, variance=2.0, lengthscale=1.5) - sd = kernel.spectral_density - omega = jnp.array([0.0, 1.0, 5.0]) - v, ell = jnp.array(2.0), jnp.array(1.5) +def test_matern32_spectral_density_matches_closed_form(): + """S(w) = variance * 4*(sqrt(3)/ell)^3 / (3/ell^2 + w^2)^2.""" + variance = jnp.array(2.0) + lengthscale = jnp.array(1.5) + spectral_density = Matern32( + n_dims=1, variance=2.0, lengthscale=1.5 + ).spectral_density - result = sd(omega, v, ell) - alpha = jnp.sqrt(3.0) / ell - expected = v * 4.0 * alpha**3 / (3.0 / ell**2 + omega**2) ** 2 + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + alpha = jnp.sqrt(3.0) / lengthscale + expected = ( + variance * 4.0 * alpha**3 / (3.0 / lengthscale**2 + TEST_FREQUENCIES**2) ** 2 + ) npt.assert_allclose(result, expected, atol=1e-12) -def test_matern52_spectral_density_formula(): - """Verify Matern52 (nu=5/2) spectral density. +def test_matern52_spectral_density_matches_closed_form(): + """S(w) = variance * (16/3)*(sqrt(5)/ell)^5 / (5/ell^2 + w^2)^3.""" + variance = jnp.array(3.0) + lengthscale = jnp.array(2.0) + spectral_density = Matern52( + n_dims=1, variance=3.0, lengthscale=2.0 + ).spectral_density - S(w) = variance * (16/3)*(sqrt(5)/ell)^5 / (5/ell^2 + w^2)^3 - """ - kernel = Matern52(n_dims=1, variance=3.0, lengthscale=2.0) - sd = kernel.spectral_density - omega = jnp.array([0.0, 1.0, 5.0]) - v, ell = jnp.array(3.0), jnp.array(2.0) - - result = sd(omega, v, ell) - alpha = jnp.sqrt(5.0) / ell - expected = v * (16.0 / 3.0) * alpha**5 / (5.0 / ell**2 + omega**2) ** 3 + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + alpha = jnp.sqrt(5.0) / lengthscale + expected = ( + variance + * (16.0 / 3.0) + * alpha**5 + / (5.0 / lengthscale**2 + TEST_FREQUENCIES**2) ** 3 + ) npt.assert_allclose(result, expected, atol=1e-12) +# ────────────────────────────────────────────────────────────────────── +# Positivity across all stationary kernels +# ────────────────────────────────────────────────────────────────────── + + @pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) -def test_spectral_density_positive(KernelClass): +def test_spectral_density_is_positive_everywhere(KernelClass): """All spectral densities must be positive for all omega.""" - kernel = KernelClass(n_dims=1) - sd = kernel.spectral_density - omega = jnp.linspace(0, 20, 200) - values = sd(omega, jnp.array(1.0), jnp.array(1.0)) + spectral_density = KernelClass(n_dims=1).spectral_density + frequencies = jnp.linspace(0, 20, 200) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) assert jnp.all(values > 0) diff --git a/tests/test_linalg/__init__.py b/tests/test_linalg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_linalg/test_lowrank.py b/tests/test_linalg/test_lowrank.py new file mode 100644 index 000000000..0a3657166 --- /dev/null +++ b/tests/test_linalg/test_lowrank.py @@ -0,0 +1,81 @@ +"""Tests for the LowRank linear operator.""" + +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt + +config.update("jax_enable_x64", True) + +from gpjax.linalg.operations import diag +from gpjax.linalg.operators import LowRank + + +class TestLowRankProperties: + def test_shape(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.shape == (10, 10) + + def test_rank(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.rank == 3 + + def test_dtype(self): + W = jnp.ones((10, 3), dtype=jnp.float64) + op = LowRank(W) + assert op.dtype == jnp.float64 + + def test_to_dense(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + npt.assert_allclose(op.to_dense(), W @ W.T, atol=1e-12) + + def test_transpose_is_self(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.T is op + + def test_diag_dispatch(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + expected = jnp.sum(W**2, axis=1) + npt.assert_allclose(diag(op), expected, atol=1e-12) + + +class TestLowRankPyTree: + def test_flatten_unflatten_roundtrip(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + leaves, treedef = jax.tree.flatten(op) + restored = treedef.unflatten(leaves) + npt.assert_allclose(restored.factor, op.factor) + + def test_jit_compatible(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + + @jax.jit + def fn(op): + return op.to_dense() + + result = fn(op) + npt.assert_allclose(result, W @ W.T, atol=1e-12) + + def test_grad_through_to_dense(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + + @jax.grad + def fn(W): + op = LowRank(W) + return jnp.sum(op.to_dense()) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) diff --git a/tests/test_linalg.py b/tests/test_linalg/test_operators.py similarity index 100% rename from tests/test_linalg.py rename to tests/test_linalg/test_operators.py diff --git a/tests/test_linalg/test_woodbury.py b/tests/test_linalg/test_woodbury.py new file mode 100644 index 000000000..0c32ff9f1 --- /dev/null +++ b/tests/test_linalg/test_woodbury.py @@ -0,0 +1,379 @@ +"""Tests for Woodbury identity helper functions.""" + +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + +from gpjax.linalg.woodbury import woodbury_logdet, woodbury_quad, woodbury_solve + + +def _dense_reference(W, noise): + """Build the dense matrix W W^T + diag(noise) for reference.""" + return W @ W.T + jnp.diag(noise) + + +class TestWoodburySolve: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5), (100, 10)]) + def test_matches_dense_vector(self, N, m): + key = jr.key(0) + k1, k2, _k3 = jr.split(key, 3) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5)]) + def test_matches_dense_matrix(self, N, m): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + B = jr.normal(k2, (N, 7)) + + result = woodbury_solve(W, noise, B) + expected = jnp.linalg.solve(_dense_reference(W, noise), B) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_heterogeneous_noise(self): + key = jr.key(0) + k1, k2, k3 = jr.split(key, 3) + N, m = 30, 4 + W = jr.normal(k1, (N, m)) + noise = jnp.abs(jr.normal(k2, (N,))) + 0.01 + b = jr.normal(k3, (N,)) + + result = woodbury_solve(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + b = jr.normal(k2, (20,)) + + result = jax.jit(woodbury_solve)(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + b = jr.normal(k2, (20,)) + + @jax.grad + def fn(W): + return jnp.sum(woodbury_solve(W, noise, b)) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestWoodburyLogdet: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5), (100, 10)]) + def test_matches_dense(self, N, m): + key = jr.key(0) + W = jr.normal(key, (N, m)) + noise = jnp.ones(N) * 0.1 + + result = woodbury_logdet(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_heterogeneous_noise(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 30, 4 + W = jr.normal(k1, (N, m)) + noise = jnp.abs(jr.normal(k2, (N,))) + 0.01 + + result = woodbury_logdet(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + W = jr.normal(key, (20, 3)) + noise = jnp.ones(20) * 0.1 + + result = jax.jit(woodbury_logdet)(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + W = jr.normal(key, (20, 3)) + noise = jnp.ones(20) * 0.1 + + @jax.grad + def fn(W): + return woodbury_logdet(W, noise) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestWoodburyQuad: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5)]) + def test_matches_dense(self, N, m): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + diff = jr.normal(k2, (N,)) + + result = woodbury_quad(W, noise, diff) + Sigma = _dense_reference(W, noise) + expected = diff @ jnp.linalg.solve(Sigma, diff) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + diff = jr.normal(k2, (20,)) + + result = jax.jit(woodbury_quad)(W, noise, diff) + Sigma = _dense_reference(W, noise) + expected = diff @ jnp.linalg.solve(Sigma, diff) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + diff = jr.normal(k2, (20,)) + + @jax.grad + def fn(W): + return woodbury_quad(W, noise, diff) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestNumericalStability: + def test_large_N_small_m(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 1000, 10 + W = jr.normal(k1, (N, m)) * 0.5 + noise = jnp.ones(N) * 0.1 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + assert jnp.all(jnp.isfinite(result)) + + ld = woodbury_logdet(W, noise) + assert jnp.isfinite(ld) + + def test_small_noise(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 50, 5 + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 1e-8 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + assert jnp.all(jnp.isfinite(result)) + + +class TestConjugateMllIntegration: + """Verify Woodbury path in conjugate_mll matches dense path.""" + + @pytest.mark.parametrize("KernelClass", ["RBF", "Matern52"]) + def test_mll_matches_dense(self, KernelClass): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF, Matern52 + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + + kernel_cls = {"RBF": RBF, "Matern52": Matern52}[KernelClass] + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = kernel_cls(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=40, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + # Woodbury path (automatic via LowRank gram) + mll_woodbury = conjugate_mll(posterior, D) + + # Dense reference: manually compute via dense gram + dense_gram = hsgp.gram(x).to_dense() + from gpjax.distributions import GaussianDistribution + from gpjax.linalg import Dense, psd + from gpjax.linalg.utils import add_jitter + + noise = likelihood.noise_vector(n) + mx = prior.mean_function(x) + y_flat, mx_flat = likelihood.prepare_targets(y, mx) + Kxx_dense = add_jitter(dense_gram, prior.jitter) + Sigma_dense = Kxx_dense + jnp.diag(noise) + Sigma = psd(Dense(Sigma_dense)) + mll_dense = ( + GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) + .log_prob(jnp.atleast_1d(y_flat.squeeze())) + .squeeze() + ) + + npt.assert_allclose(mll_woodbury, mll_dense, atol=1e-5) + + def test_mll_differentiable_woodbury(self): + from flax import nnx + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=20, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + graphdef, state = nnx.split(posterior) + + def loss(state): + model = nnx.merge(graphdef, state) + return -conjugate_mll(model, D) + + grad_fn = jax.grad(loss) + grads = grad_fn(state) + + flat_grads = jax.tree.leaves(grads) + for g in flat_grads: + assert jnp.all(jnp.isfinite(g)) + + +class TestPredictIntegration: + """Verify Woodbury predict path matches dense predict path.""" + + def test_predict_mean_matches_dense(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D) + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.covariance())) + + def test_predict_diagonal(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D, return_covariance_type="diagonal") + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.variance)) + + def test_predict_rff(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations import RFF + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + rff = RFF(base_kernel=RBF(n_dims=1), num_basis_fns=50, key=jr.key(0)) + prior = Prior(kernel=rff, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D) + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.covariance())) diff --git a/uv.lock b/uv.lock index daa54f696..47ed38cd2 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-02-06T09:56:46.847909Z" +exclude-newer = "2026-02-11T12:41:49.179683Z" exclude-newer-span = "P7D" [[package]] @@ -985,6 +985,7 @@ docs = [ { name = "networkx" }, { name = "pandas" }, { name = "pymdown-extensions" }, + { name = "rdata" }, { name = "scikit-learn" }, { name = "seaborn" }, { name = "ucimlrepo" }, @@ -1041,6 +1042,7 @@ requires-dist = [ { name = "optax", specifier = ">0.2.1" }, { name = "pandas", marker = "extra == 'docs'", specifier = ">=1.5.3" }, { name = "pymdown-extensions", marker = "extra == 'docs'", specifier = ">=10.7.1" }, + { name = "rdata", marker = "extra == 'docs'", specifier = ">=1.0.0" }, { name = "scikit-learn", marker = "extra == 'docs'", specifier = ">=1.5.1" }, { name = "seaborn", marker = "extra == 'docs'", specifier = ">=0.12.2" }, { name = "tensorstore", marker = "sys_platform == 'darwin'", specifier = "!=0.1.76" }, @@ -3111,6 +3113,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/1b/5dbe84eefc86f48473947e2f41711aded97eecef1231f4558f1f02713c12/pyzmq-27.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c9f7f6e13dff2e44a6afeaf2cf54cee5929ad64afaf4d40b50f93c58fc687355", size = 544862, upload-time = "2025-09-08T23:09:56.509Z" }, ] +[[package]] +name = "rdata" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "typing-extensions" }, + { name = "xarray" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/e5/c5626257359fde4662f9c2f658bcb37873ffad1490c8991640a361b8cf8e/rdata-1.0.0.tar.gz", hash = "sha256:d924d7657c9eeaee4b86a0ae87999f5beb0070ed47bd491c85bc79ee9644596d", size = 56618, upload-time = "2025-08-15T17:17:10.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/b6/aec7624eb1db90e58b7fa34ab6491ee06385020bbb9598d4b8a28051da81/rdata-1.0.0-py3-none-any.whl", hash = "sha256:b671e31676e158dc215297595d5085aee1674b91be3c26e38638ca056167d402", size = 72291, upload-time = "2025-08-15T17:17:09.404Z" }, +] + [[package]] name = "readme-renderer" version = "44.0" @@ -4121,6 +4138,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" }, ] +[[package]] +name = "xarray" +version = "2026.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/85/113ff1e2cde9e8a5b13c2f0ef4e9f5cd6ca3a036b6452f4dd523419289b5/xarray-2026.1.0.tar.gz", hash = "sha256:0c9814761f9d9a9545df37292d3fda89f83201f3e02ae0f09f03313d9cfdd5e2", size = 3107024, upload-time = "2026-01-28T17:49:03.822Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8e/952a351c10df395d9bab850f611f4368834ae9104d6449049f5a49e00925/xarray-2026.1.0-py3-none-any.whl", hash = "sha256:5fcc03d3ed8dfb662aa254efe6cd65efc70014182bbc2126e4b90d291d970d41", size = 1403009, upload-time = "2026-01-28T17:49:01.538Z" }, +] + [[package]] name = "xdoctest" version = "1.3.0" From 832f0bc4e6bd22df57b83f9bbfeacd760715aac0 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Feb 2026 19:39:24 +0100 Subject: [PATCH 3/5] Add doc reference --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index 527f18c0c..3c7d4bf3a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - Multi-Output GPs: _examples/multioutput.md - Accelerated Multi-Output GPs: _examples/oilmm.md - Orthogonal Additive GPs: _examples/oak.md + - Hilbert Space GPs: _examples/hsgp.md - 🧪 Experimental: - Numpyro Integration: _examples/numpyro_integration.md - Spatial Linear Gaussian Process: _examples/spatial_linear_gp.md From 0e65622ba842aaeebdb97e11df5c8472e4fdb892 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Feb 2026 20:05:36 +0100 Subject: [PATCH 4/5] Fix flaky test --- gpjax/kernels/approximations/hsgp.py | 2 ++ tests/test_heteroscedastic.py | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gpjax/kernels/approximations/hsgp.py b/gpjax/kernels/approximations/hsgp.py index 2c5d6c487..3c3b977bf 100644 --- a/gpjax/kernels/approximations/hsgp.py +++ b/gpjax/kernels/approximations/hsgp.py @@ -54,6 +54,8 @@ class HSGP(AbstractKernel): Example: >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> X = jnp.linspace(-1, 1, 50)[:, None] >>> base = gpx.kernels.Matern52(n_dims=1) >>> hsgp = gpx.kernels.HSGP(base, num_basis_fns=20, domain_half_width=5.0) >>> K = hsgp.gram(X) # approximate Gram matrix diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py index 94c99cf5e..f46cb73cd 100644 --- a/tests/test_heteroscedastic.py +++ b/tests/test_heteroscedastic.py @@ -167,11 +167,13 @@ def test_softplus_transform_numerical_accuracy(mean: float, variance: float, see # E[1/sigma^2] mc_inv_variance = jnp.mean(1.0 / transformed_samples) - # Allow for some MC error and quadrature approximation error + # Allow for some MC error and quadrature approximation error. + # atol covers near-zero values where rtol alone is unreliable. rtol = 0.15 - assert jnp.allclose(moments.variance, mc_variance, rtol=rtol) - assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol) - assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol) + atol = 0.01 + assert jnp.allclose(moments.variance, mc_variance, rtol=rtol, atol=atol) + assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol, atol=atol) + assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol, atol=atol) def test_heteroscedastic_variational_predict(prior, noise_prior, dataset): From 032fc68f329687776da1254226098d7d3051265c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 20 Feb 2026 20:05:23 +0100 Subject: [PATCH 5/5] Fix notebook exposition --- examples/hsgp.py | 73 +++++++++++++++++++++--------------------------- pyproject.toml | 2 +- uv.lock | 3 +- 3 files changed, 35 insertions(+), 43 deletions(-) diff --git a/examples/hsgp.py b/examples/hsgp.py index d14a0c2c3..6f308a3f2 100644 --- a/examples/hsgp.py +++ b/examples/hsgp.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # --- # jupyter: # jupytext: @@ -135,18 +134,21 @@ # ## Approximation quality # # How faithfully does the HSGP Gram matrix reproduce the exact kernel? We -# compare exact and approximate Gram matrices at $m = 10$, $25$, and $50$ for -# both the RBF and Matern-5/2 kernels. +# compare the *approximation error* $|K - \tilde{K}|$ at $m = 10$, $25$, and +# $50$ for both the RBF and Matern-5/2 kernels. Plotting the error on a +# shared log scale makes the convergence difference clearly visible. # %% +from matplotlib.colors import LogNorm + x_small = jnp.linspace(-3.0, 3.0, 80)[:, None] m_values = [10, 25, 50] -def gram_comparison(base_kernel, x, m_values): - """Return exact Gram and HSGP Gram matrices for several values of m.""" +def gram_errors(base_kernel, x, m_values): + """Return absolute error matrices |K_exact - K_hsgp| for several values of m.""" gram_exact = base_kernel.gram(x).to_dense() - gram_approximations = [] + errors = [] for num_basis in m_values: hsgp = gpx.kernels.HSGP( base_kernel=base_kernel, @@ -154,49 +156,38 @@ def gram_comparison(base_kernel, x, m_values): domain_half_width=4.0, center=0.0, ) - gram_approximations.append(hsgp.gram(x).to_dense()) - return gram_exact, gram_approximations - - -gram_exact_rbf, gram_hsgps_rbf = gram_comparison(base_rbf, x_small, m_values) - -fig, axes = plt.subplots(1, 4, figsize=(10, 2.5)) -vmin, vmax = float(gram_exact_rbf.min()), float(gram_exact_rbf.max()) -titles = ["Exact"] + [f"HSGP ($m = {m}$)" for m in m_values] -matrices = [gram_exact_rbf, *gram_hsgps_rbf] + errors.append(jnp.abs(gram_exact - hsgp.gram(x).to_dense())) + return gram_exact, errors -for ax, matrix, title in zip(axes, matrices, titles, strict=True): - ax.imshow(matrix, vmin=vmin, vmax=vmax, cmap="inferno") - ax.set_title(title, fontsize=9) - ax.set_xticks([]) - ax.set_yticks([]) -fig.suptitle("RBF kernel", fontsize=10, y=1.02) +gram_exact_rbf, errors_rbf = gram_errors(base_rbf, x_small, m_values) +gram_exact_m52, errors_m52 = gram_errors(base_m52, x_small, m_values) -# %% [markdown] -# For the RBF kernel, the approximation is near-indistinguishable from exact -# by $m = 25$, reflecting the rapid spectral weight decay. - -# %% -gram_exact_m52, gram_hsgps_m52 = gram_comparison(base_m52, x_small, m_values) +log_norm = LogNorm(vmin=1e-6, vmax=1.0) +fig, axes = plt.subplots(2, 3, figsize=(8, 5), sharex=True, sharey=True) -fig, axes = plt.subplots(1, 4, figsize=(10, 2.5)) -vmin, vmax = float(gram_exact_m52.min()), float(gram_exact_m52.max()) -titles = ["Exact"] + [f"HSGP ($m = {m}$)" for m in m_values] -matrices = [gram_exact_m52, *gram_hsgps_m52] +for j, m in enumerate(m_values): + im = axes[0, j].imshow(errors_rbf[j], norm=log_norm, cmap="inferno") + axes[0, j].set_title(f"$m = {m}$", fontsize=9) + axes[0, j].set_xticks([]) + axes[0, j].set_yticks([]) -for ax, matrix, title in zip(axes, matrices, titles, strict=True): - ax.imshow(matrix, vmin=vmin, vmax=vmax, cmap="inferno") - ax.set_title(title, fontsize=9) - ax.set_xticks([]) - ax.set_yticks([]) + axes[1, j].imshow(errors_m52[j], norm=log_norm, cmap="inferno") + axes[1, j].set_xticks([]) + axes[1, j].set_yticks([]) -fig.suptitle("Matern-5/2 kernel", fontsize=10, y=1.02) +axes[0, 0].set_ylabel("RBF", fontsize=10) +axes[1, 0].set_ylabel("Matern-5/2", fontsize=10) +fig.colorbar(im, ax=axes, label="$|K - \\tilde{K}|$", shrink=0.8) +fig.suptitle("HSGP approximation error", fontsize=11, y=1.0) # %% [markdown] -# The Matern-5/2 converges more slowly, with visible discrepancies at -# $m = 10$. Rougher kernels retain more high-frequency content. In practice, -# $m$ between 20 and 50 suffices for most one-dimensional problems. +# The RBF error shrinks rapidly toward machine precision in the interior, +# reflecting the Gaussian spectral density's fast decay. The Matern-5/2 +# retains appreciable error at $m = 10$ across a wider region, consistent +# with its heavier spectral tail. By $m = 50$ both kernels are well +# approximated in the interior; the residual error along the edges is the +# unavoidable boundary artefact of the Dirichlet eigenbasis. # %% [markdown] # ## Regression with real data diff --git a/pyproject.toml b/pyproject.toml index ee91b5cca..dfb9f7655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "gpjax" -dynamic = ["version"] +version = "0.13.6" description = 'Gaussian processes in JAX.' readme = "README.md" requires-python = ">=3.11" diff --git a/uv.lock b/uv.lock index 47ed38cd2..4042addeb 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-02-11T12:41:49.179683Z" +exclude-newer = "2026-02-13T18:59:31.785479Z" exclude-newer-span = "P7D" [[package]] @@ -951,6 +951,7 @@ wheels = [ [[package]] name = "gpjax" +version = "0.13.6" source = { editable = "." } dependencies = [ { name = "beartype" },