diff --git a/examples/hsgp.py b/examples/hsgp.py new file mode 100644 index 000000000..6f308a3f2 --- /dev/null +++ b/examples/hsgp.py @@ -0,0 +1,348 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Hilbert Space Gaussian Processes +# +# Standard GP inference requires inverting an $n \times n$ Gram matrix at +# $\mathcal{O}(n^3)$ cost, which limits practical use to a few thousand points. +# The Hilbert Space Gaussian Process (HSGP) approximation +# ([Solin and Sarkka, 2020](https://link.springer.com/article/10.1007/s11222-019-09886-w)) +# projects the GP onto a truncated basis of Laplacian eigenfunctions, reducing +# the cost to $\mathcal{O}(nm^2)$ for setup and $\mathcal{O}(m^3)$ per +# optimisation step, where $m \ll n$ is the number of basis functions. See +# also the PyMC tutorial by +# [Orduz and Capretto](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html) +# and the standalone introduction by +# [Orduz (2022)](https://juanitorduz.github.io/hsgp_intro/). + +# %% +from jax import config + +config.update("jax_enable_x64", True) + +from examples.utils import use_mpl_style +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import install_import_hook +import matplotlib as mpl +import matplotlib.pyplot as plt + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + +key = jr.key(42) +use_mpl_style() +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] +# ## The HSGP approximation +# +# For a stationary kernel $k$ with spectral density $S$, the HSGP approximates +# the covariance on a bounded domain $[-L, L]$: +# +# $$k(x, x') \approx \sum_{j=1}^{m} S\!\left(\sqrt{\lambda_j}\right) \phi_j(x)\,\phi_j(x').$$ +# +# The eigenfunctions $\phi_j(x) = L^{-1/2}\sin\!\bigl(j\pi(x + L) / 2L\bigr)$ +# are sinusoids of increasing frequency, and the eigenvalues +# $\lambda_j = (j\pi / 2L)^2$ grow quadratically. The spectral density $S$ +# weights each eigenfunction's contribution. In matrix form the approximate +# Gram matrix is $\tilde{K} = \Phi\,\Lambda\,\Phi^\top$ where +# $\Lambda = \mathrm{diag}\!\bigl(S(\sqrt{\lambda_1}), \ldots, S(\sqrt{\lambda_m})\bigr)$. +# +# Two parameters govern approximation quality: +# +# - **$m$** (number of basis functions): more terms capture shorter +# lengthscales, at higher computational cost. +# - **$L$** (domain half-width): must enclose all inputs after centering, +# but setting it much larger than necessary wastes basis functions. + +# %% [markdown] +# ## Eigenfunctions and spectral weights +# +# The HSGP has two ingredients: Laplacian eigenfunctions (which depend only on +# the domain, not the kernel) and spectral weights (which encode the kernel's +# properties). We visualise both below. + +# %% +x_grid = jnp.linspace(-3.0, 3.0, 300)[:, None] + +base_rbf = gpx.kernels.RBF(n_dims=1) +hsgp_rbf = gpx.kernels.HSGP( + base_kernel=base_rbf, num_basis_fns=20, domain_half_width=4.0, center=0.0 +) + +phi = hsgp_rbf.eigenfunctions(x_grid) + +fig, ax = plt.subplots(figsize=(7.5, 2.5)) +for j in range(6): + ax.plot(x_grid, phi[:, j], label=f"$\\phi_{{{j + 1}}}$", color=cols[j % len(cols)]) +ax.set_xlabel("$x$") +ax.set_ylabel("$\\phi_j(x)$") +ax.legend(loc="center left", bbox_to_anchor=(1.0, 0.5), fontsize=8) +ax.set_title("Laplacian eigenfunctions") + +# %% [markdown] +# The eigenfunctions are sinusoids of increasing frequency, forming an +# orthonormal basis on $[-L, L]$. Because they are independent of the kernel +# hyperparameters, they need only be evaluated once for a given domain. +# +# The spectral weights $S(\sqrt{\lambda_j})$ determine each eigenfunction's +# contribution. We compare weights for the RBF kernel (Gaussian spectral +# density, rapid decay) and the Matern-5/2 (heavier-tailed, slower decay). + +# %% +base_m52 = gpx.kernels.Matern52(n_dims=1) +hsgp_m52 = gpx.kernels.HSGP( + base_kernel=base_m52, num_basis_fns=20, domain_half_width=4.0, center=0.0 +) + +weights_rbf = hsgp_rbf.spectral_weights() +weights_m52 = hsgp_m52.spectral_weights() + +fig, axes = plt.subplots(1, 2, figsize=(7.5, 2.5), sharey=True) +basis_indices = jnp.arange(1, 21) + +axes[0].stem(basis_indices, weights_rbf, linefmt=cols[0], markerfmt="o", basefmt=" ") +axes[0].set_title("RBF") +axes[0].set_xlabel("Basis index $j$") +axes[0].set_ylabel("$S(\\sqrt{\\lambda_j})$") + +axes[1].stem(basis_indices, weights_m52, linefmt=cols[1], markerfmt="o", basefmt=" ") +axes[1].set_title("Matern-5/2") +axes[1].set_xlabel("Basis index $j$") + +# %% [markdown] +# The RBF weights decay rapidly, so few basis functions suffice. The +# Matern-5/2 retains appreciable weight at higher frequencies and needs more +# terms for the same approximation quality. + +# %% [markdown] +# ## Approximation quality +# +# How faithfully does the HSGP Gram matrix reproduce the exact kernel? We +# compare the *approximation error* $|K - \tilde{K}|$ at $m = 10$, $25$, and +# $50$ for both the RBF and Matern-5/2 kernels. Plotting the error on a +# shared log scale makes the convergence difference clearly visible. + +# %% +from matplotlib.colors import LogNorm + +x_small = jnp.linspace(-3.0, 3.0, 80)[:, None] +m_values = [10, 25, 50] + + +def gram_errors(base_kernel, x, m_values): + """Return absolute error matrices |K_exact - K_hsgp| for several values of m.""" + gram_exact = base_kernel.gram(x).to_dense() + errors = [] + for num_basis in m_values: + hsgp = gpx.kernels.HSGP( + base_kernel=base_kernel, + num_basis_fns=num_basis, + domain_half_width=4.0, + center=0.0, + ) + errors.append(jnp.abs(gram_exact - hsgp.gram(x).to_dense())) + return gram_exact, errors + + +gram_exact_rbf, errors_rbf = gram_errors(base_rbf, x_small, m_values) +gram_exact_m52, errors_m52 = gram_errors(base_m52, x_small, m_values) + +log_norm = LogNorm(vmin=1e-6, vmax=1.0) +fig, axes = plt.subplots(2, 3, figsize=(8, 5), sharex=True, sharey=True) + +for j, m in enumerate(m_values): + im = axes[0, j].imshow(errors_rbf[j], norm=log_norm, cmap="inferno") + axes[0, j].set_title(f"$m = {m}$", fontsize=9) + axes[0, j].set_xticks([]) + axes[0, j].set_yticks([]) + + axes[1, j].imshow(errors_m52[j], norm=log_norm, cmap="inferno") + axes[1, j].set_xticks([]) + axes[1, j].set_yticks([]) + +axes[0, 0].set_ylabel("RBF", fontsize=10) +axes[1, 0].set_ylabel("Matern-5/2", fontsize=10) +fig.colorbar(im, ax=axes, label="$|K - \\tilde{K}|$", shrink=0.8) +fig.suptitle("HSGP approximation error", fontsize=11, y=1.0) + +# %% [markdown] +# The RBF error shrinks rapidly toward machine precision in the interior, +# reflecting the Gaussian spectral density's fast decay. The Matern-5/2 +# retains appreciable error at $m = 10$ across a wider region, consistent +# with its heavier spectral tail. By $m = 50$ both kernels are well +# approximated in the interior; the residual error along the edges is the +# unavoidable boundary artefact of the Dirichlet eigenbasis. + +# %% [markdown] +# ## Regression with real data +# +# We apply the HSGP to daily nitrogen dioxide (NO$_2$) measurements from +# Marylebone Road in central London, one of the UK's most heavily trafficked +# roadside sites. The data come from the +# [Automatic Urban and Rural Network (AURN)](https://uk-air.defra.gov.uk/networks/site-info?site_id=MY1). +# We aggregate to daily means over approximately ten years (2016--2025), +# giving several thousand data points where exact GP inference is impractical. + +# %% +from pathlib import Path +import tempfile +from urllib.request import urlretrieve +import warnings + +import pandas as pd +import rdata + +SITE = "MY1" +YEARS = range(2016, 2026) + +frames = [] +for year in YEARS: + url = f"https://uk-air.defra.gov.uk/openair/R_data/{SITE}_{year}.RData" + tmp = Path(tempfile.gettempdir()) / f"{SITE}_{year}.RData" + urlretrieve(url, tmp) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + parsed = rdata.read_rda(str(tmp)) + frames.append(parsed[next(iter(parsed.keys()))]) + +hourly = pd.concat(frames, ignore_index=True) +hourly["date"] = pd.to_datetime(hourly["date"], unit="s") +daily = hourly[["date", "NO2"]].set_index("date").resample("D").mean().dropna() + +# %% +fig, ax = plt.subplots(figsize=(7.5, 2.5)) +ax.plot(daily.index, daily["NO2"], linewidth=0.5, color=cols[0]) +ax.set_xlabel("Date") +ax.set_ylabel("NO$_2$ ($\\mu$g m$^{-3}$)") +ax.set_title("Daily mean NO$_2$ at Marylebone Road, London") + +# %% [markdown] +# Clear features include an annual cycle from heating-season emissions, an +# abrupt dip during the 2020 COVID-19 lockdowns, and a gradual downward trend +# from tightening vehicle emission standards. + +# %% +# Convert dates to a numeric index and standardise for numerical stability. +t = (daily.index - daily.index[0]).days.values.astype(float) +y_obs = daily["NO2"].values + +t_mean, t_std = t.mean(), t.std() +y_mean, y_std = y_obs.mean(), y_obs.std() + +t_norm = ((t - t_mean) / t_std)[:, None] +y_norm = ((y_obs - y_mean) / y_std)[:, None] + +D = gpx.Dataset(X=t_norm, y=y_norm) + +# %% [markdown] +# We build a GP prior with an HSGP kernel using a Matern-3/2 base kernel. We +# use $m = 30$ basis functions and set $L$ to 1.2 times the half-range of +# the normalised inputs, providing a modest buffer beyond the data boundary. + +# %% +data_half_range = float((t_norm.max() - t_norm.min()) / 2.0) +L = 1.2 * data_half_range + +base_kernel = gpx.kernels.Matern32(n_dims=1) +hsgp_kernel = gpx.kernels.HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=L, + center=float(t_norm.mean()), +) + +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=hsgp_kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +posterior = prior * likelihood + +# %% [markdown] +# We optimise the negative conjugate marginal log-likelihood using L-BFGS-B, +# learning the kernel lengthscale, variance, and observation noise. + +# %% +opt_posterior, history = gpx.fit_scipy( + model=posterior, + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), + train_data=D, + trainable=gpx.parameters.Parameter, +) + +print( + f"Learned lengthscale: {opt_posterior.prior.kernel.base_kernel.lengthscale[...]:.4f}" +) +print( + f"Learned variance: {opt_posterior.prior.kernel.base_kernel.variance[...]:.4f}" +) +print(f"Learned obs. noise: {opt_posterior.likelihood.obs_stddev[...] ** 2:.4f}") + +# %% [markdown] +# With optimised hyperparameters, we compute the posterior predictive on a +# fine grid spanning the training domain. + +# %% +t_test = jnp.linspace(t_norm.min() - 0.05, t_norm.max() + 0.05, 500)[:, None] + +latent_dist = opt_posterior.predict( + t_test, train_data=D, return_covariance_type="diagonal" +) +predictive_dist = opt_posterior.likelihood(latent_dist) + +pred_mean = predictive_dist.mean * y_std + y_mean +pred_std = jnp.sqrt(predictive_dist.variance) * y_std + +# Convert test grid back to dates for plotting. +t_test_days = (t_test.squeeze() * t_std + t_mean).astype(int) +dates_test = daily.index[0] + pd.to_timedelta(t_test_days, unit="D") + +# %% +fig, ax = plt.subplots(figsize=(7.5, 3.0)) +ax.scatter(daily.index, y_obs, s=1, alpha=0.3, color=cols[0], label="Observations") +ax.plot(dates_test, pred_mean, color=cols[1], linewidth=1.5, label="Predictive mean") +ax.fill_between( + dates_test, + pred_mean - 2 * pred_std, + pred_mean + 2 * pred_std, + alpha=0.2, + color=cols[1], + label="Two sigma", +) +ax.set_xlabel("Date") +ax.set_ylabel("NO$_2$ ($\\mu$g m$^{-3}$)") +ax.legend(loc="upper right", fontsize=8) +ax.set_title("HSGP posterior for daily NO$_2$ at Marylebone Road") + +# %% [markdown] +# The HSGP posterior captures seasonal structure and the long-term downward +# trend, with credible intervals that widen where data are sparser. The +# computation completed in seconds on a single CPU; an exact GP over the same +# dataset would require forming and decomposing a dense Gram matrix with over +# ten million entries. + +# %% [markdown] +# ## References +# +# - Solin, A. and Sarkka, S. (2020). +# [Hilbert Space Methods for Reduced-Rank Gaussian Process Regression](https://link.springer.com/article/10.1007/s11222-019-09886-w). +# *Statistics and Computing*. +# - Orduz, J. and Martin, O. A. +# [PyMC HSGP tutorial](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html). +# - Orduz, J. (2022). +# [Introduction to the HSGP](https://juanitorduz.github.io/hsgp_intro/). +# - UK [Automatic Urban and Rural Network](https://uk-air.defra.gov.uk/networks/site-info?site_id=MY1) +# air quality data via the openair data service. diff --git a/gpjax/gps.py b/gpjax/gps.py index ae25c8777..bd813aee5 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -46,7 +46,9 @@ from gpjax.linalg.operations import ( lower_cholesky, ) +from gpjax.linalg.operators import LowRank from gpjax.linalg.utils import add_jitter +from gpjax.linalg.woodbury import woodbury_solve from gpjax.mean_functions import AbstractMeanFunction from gpjax.parameters import ( Parameter, @@ -380,7 +382,9 @@ def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: ####################### # GP Posteriors -#######################from gpjax.linalg.operators import LinearOperator +####################### + + class AbstractPosterior(nnx.Module, tp.Generic[P, L]): r"""Abstract Gaussian process posterior. @@ -587,60 +591,151 @@ def predict( kernel = self.prior.kernel x, y = train_data.X, train_data.y - P = self.likelihood.num_outputs + num_outputs = self.likelihood.num_outputs - # Prepare targets via likelihood protocol (identity for single-output, - # output-major reshape for multi-output) mx = self.prior.mean_function(x) y_flat, mx_flat = self.likelihood.prepare_targets(y, mx) noise = self.likelihood.noise_vector(train_data.n) Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter) - Sigma_dense = Kxx_dense + jnp.diag(noise) - Sigma = psd(Dense(Sigma_dense)) - L_sigma = lower_cholesky(Sigma) - Kxt = kernel.cross_covariance(x, test_inputs) - L_inv_Kxt = solve(L_sigma, Kxt) - L_inv_y_diff = solve(L_sigma, y_flat - mx_flat) - mean_t_raw = self.prior.mean_function(test_inputs) - mean_t = jnp.tile(mean_t_raw, (P, 1)) if P > 1 else mean_t_raw - mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff) + # Posterior test-point mean: m(x*) + Kxt^T Sigma^{-1} (y - m(x)) + mean_test_raw = self.prior.mean_function(test_inputs) + mean_test = ( + jnp.tile(mean_test_raw, (num_outputs, 1)) + if num_outputs > 1 + else mean_test_raw + ) - # Diagonal covariance not yet supported for multi-output - if return_covariance_type == "diagonal" and P > 1: + if return_covariance_type == "diagonal" and num_outputs > 1: warnings.warn( - "Diagonal covariance is not yet supported for multi-output GPs. " - "Returning full covariance.", + "Diagonal covariance is not yet supported for " + "multi-output GPs. Returning full covariance.", stacklevel=2, ) return_covariance_type = "dense" - def _return_full_covariance(L_inv_Kxt, t): - Ktt = kernel.gram(t) - covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) - covariance = add_jitter(covariance, self.prior.jitter) - covariance = psd(Dense(covariance)) - return covariance + if isinstance(Kxx, LowRank): + mean, cov = self._predict_lowrank( + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ) + else: + mean, cov = self._predict_dense( + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ) - def _return_diagonal_covariance(L_inv_Kxt, t): - Ktt = kernel.diagonal(t).diagonal - covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) + return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov) + + def _predict_lowrank( + self, + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ): + r"""Posterior prediction using the Woodbury identity. + + When K_xx = W W^T (low-rank), the marginal covariance is + Sigma = W W^T + D with D = diag(noise + jitter). The Woodbury + identity gives Sigma^{-1} in O(N m^2) instead of O(N^3). + """ + W = Kxx.factor + noise_with_jitter = noise + self.jitter + + # alpha = Sigma^{-1} (y - m(x)) + alpha = woodbury_solve(W, noise_with_jitter, (y_flat - mx_flat).squeeze()) + mean = mean_test + jnp.matmul(Kxt.T, alpha[:, None]) + + # Sigma^{-1} Kxt for the covariance update + Sigma_inv_Kxt = woodbury_solve(W, noise_with_jitter, Kxt) + + def _full_covariance(Sigma_inv_Kxt, test_inputs): + Ktt = kernel.gram(test_inputs).to_dense() + covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) + return psd(Dense(add_jitter(covariance, self.prior.jitter))) + + def _diagonal_covariance(Sigma_inv_Kxt, test_inputs): + Ktt_diag = kernel.diagonal(test_inputs).diagonal + covariance = Ktt_diag - jnp.einsum("ij,ij->j", Kxt, Sigma_inv_Kxt) covariance += self.prior.jitter - covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) - return covariance + return psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) cov = jax.lax.cond( return_covariance_type == "dense", - _return_full_covariance, - _return_diagonal_covariance, + _full_covariance, + _diagonal_covariance, + Sigma_inv_Kxt, + test_inputs, + ) + + return mean, cov + + def _predict_dense( + self, + kernel, + Kxx, + Kxt, + y_flat, + mx_flat, + noise, + mean_test, + test_inputs, + return_covariance_type, + ): + r"""Posterior prediction via the standard Cholesky decomposition. + + Sigma = K_xx + D, L L^T = Sigma, then solves use L^{-1}. + """ + Sigma_dense = add_jitter(Kxx.to_dense(), self.jitter) + jnp.diag(noise) + L_sigma = lower_cholesky(psd(Dense(Sigma_dense))) + + L_inv_Kxt = solve(L_sigma, Kxt) + L_inv_diff = solve(L_sigma, y_flat - mx_flat) + + mean = mean_test + jnp.matmul(L_inv_Kxt.T, L_inv_diff) + + def _full_covariance(L_inv_Kxt, test_inputs): + Ktt = kernel.gram(test_inputs).to_dense() + covariance = Ktt - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) + return psd(Dense(add_jitter(covariance, self.prior.jitter))) + + def _diagonal_covariance(L_inv_Kxt, test_inputs): + Ktt_diag = kernel.diagonal(test_inputs).diagonal + covariance = Ktt_diag - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt) + covariance += self.prior.jitter + return psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze())))) + + cov = jax.lax.cond( + return_covariance_type == "dense", + _full_covariance, + _diagonal_covariance, L_inv_Kxt, test_inputs, ) - return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov) + return mean, cov def sample_approx( self, diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 5ad342b86..cf8b06042 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -19,7 +19,7 @@ from gpjax.kernels.additive import ( OrthogonalAdditiveKernel, ) -from gpjax.kernels.approximations import RFF +from gpjax.kernels.approximations import HSGP, RFF from gpjax.kernels.base import ( AbstractKernel, Constant, @@ -32,6 +32,7 @@ DenseKernelComputation, DiagonalKernelComputation, EigenKernelComputation, + HSGPComputation, ) from gpjax.kernels.multioutput import ( ICMKernel, @@ -57,6 +58,7 @@ ) __all__ = [ + "HSGP", "RBF", "RFF", "AbstractKernel", @@ -68,6 +70,7 @@ "DiagonalKernelComputation", "EigenKernelComputation", "GraphKernel", + "HSGPComputation", "ICMKernel", "LCMKernel", "Linear", diff --git a/gpjax/kernels/approximations/__init__.py b/gpjax/kernels/approximations/__init__.py index 9b08299e5..1f6a804e9 100644 --- a/gpjax/kernels/approximations/__init__.py +++ b/gpjax/kernels/approximations/__init__.py @@ -1,3 +1,4 @@ +from gpjax.kernels.approximations.hsgp import HSGP from gpjax.kernels.approximations.rff import RFF -__all__ = ["RFF"] +__all__ = ["HSGP", "RFF"] diff --git a/gpjax/kernels/approximations/hsgp.py b/gpjax/kernels/approximations/hsgp.py new file mode 100644 index 000000000..3c3b977bf --- /dev/null +++ b/gpjax/kernels/approximations/hsgp.py @@ -0,0 +1,175 @@ +"""Hilbert Space Gaussian Process (HSGP) kernel approximation. + +Reference: + Solin & Sarkka (2019). "Hilbert Space Methods for Reduced-Rank + Gaussian Process Regression." Statistics and Computing. +""" + +import beartype.typing as tp +import jax.numpy as jnp +from jaxtyping import Float + +from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.computations.hsgp import HSGPComputation +from gpjax.kernels.stationary.base import StationaryKernel +from gpjax.typing import Array + + +class HSGP(AbstractKernel): + r"""Hilbert Space Gaussian Process approximation (1D). + + Approximates a stationary kernel by projecting onto the eigenbasis of the + Laplacian operator with Dirichlet boundary conditions on :math:`[-L, L]`. + The approximate covariance function is: + + .. math:: + \tilde{k}(x, x') = \sum_{j=1}^{m} S(\sqrt{\lambda_j})\, + \phi_j(x)\,\phi_j(x') + + where :math:`\lambda_j = (j\pi / 2L)^2` are the eigenvalues, + :math:`\phi_j(x) = L^{-1/2}\sin(j\pi(x + L) / 2L)` are the + eigenfunctions, and :math:`S` is the spectral density of the base + kernel. + + The linearised form decomposes the GP function as: + + .. math:: + f(x) \approx \Phi(x)\,\mathrm{diag}(\sqrt{S})\,\beta, + \quad \beta \sim \mathcal{N}(0, I_m) + + which is a Bayesian linear model with :math:`m` basis functions. + + Args: + base_kernel: Stationary kernel to approximate. Must provide a + ``spectral_density`` property returning a + :class:`~gpjax.kernels.stationary.utils.SpectralDensity`. + num_basis_fns: Number of basis functions :math:`m`. + domain_half_width: Half-width :math:`L` of the approximation domain. + Inputs should lie well inside :math:`[-L, L]` (after centering). + center: Center of the data domain. If ``None``, it is set + automatically on the first call to :meth:`eigenfunctions` as the + midpoint of the observed input range. + compute_engine: Computation engine (default + :class:`~gpjax.kernels.computations.hsgp.HSGPComputation`). + + Example: + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> X = jnp.linspace(-1, 1, 50)[:, None] + >>> base = gpx.kernels.Matern52(n_dims=1) + >>> hsgp = gpx.kernels.HSGP(base, num_basis_fns=20, domain_half_width=5.0) + >>> K = hsgp.gram(X) # approximate Gram matrix + """ + + compute_engine: HSGPComputation + + def __init__( + self, + base_kernel: StationaryKernel, + num_basis_fns: int, + domain_half_width: float, + center: tp.Union[float, None] = None, + compute_engine: HSGPComputation = HSGPComputation(), + ): + if not isinstance(base_kernel, StationaryKernel): + raise TypeError( + "HSGP can only approximate stationary kernels. " + f"Got {type(base_kernel).__name__}." + ) + _ = base_kernel.spectral_density + + super().__init__( + active_dims=base_kernel.active_dims, + n_dims=base_kernel.n_dims, + compute_engine=compute_engine, + ) + self.base_kernel = base_kernel + self.num_basis_fns = num_basis_fns + self.domain_half_width = domain_half_width + self._center = center + self.name = f"{self.base_kernel.name} (HSGP)" + + def eigenvalues(self) -> Float[Array, " m"]: + r"""Square roots of the Laplacian eigenvalues. + + .. math:: + \sqrt{\lambda_j} = \frac{j\pi}{2L}, \quad j = 1, \ldots, m + + Returns: + Array of shape ``(m,)``. + """ + indices = jnp.arange(1, self.num_basis_fns + 1) + return indices * jnp.pi / (2.0 * self.domain_half_width) + + def eigenfunctions(self, x: Float[Array, "N 1"]) -> Float[Array, "N m"]: + r"""Laplacian eigenfunctions evaluated at *x*. + + .. math:: + \phi_j(x) = \frac{1}{\sqrt{L}} + \sin\!\Bigl(\frac{j\pi\,(x + L)}{2L}\Bigr) + + where *x* is first shifted by the stored center. + + Args: + x: Input locations of shape ``(N, 1)``. + + Returns: + Basis matrix :math:`\Phi` of shape ``(N, m)``. + """ + if self._center is None: + self._center = float((x.max() + x.min()) / 2.0) + + half_width = self.domain_half_width + x_centered = x - self._center + sqrt_eigenvalues = self.eigenvalues() + return jnp.sin((x_centered + half_width) * sqrt_eigenvalues) / jnp.sqrt( + half_width + ) + + def spectral_weights(self) -> Float[Array, " m"]: + r"""Spectral density evaluated at the eigenvalue square roots. + + .. math:: + S(\sqrt{\lambda_j}) + + Returns: + Array of shape ``(m,)`` with the spectral weights. + """ + omega = self.eigenvalues() + return self.base_kernel.spectral_density( + omega, + self.base_kernel.variance[...], + self.base_kernel.lengthscale[...], + ) + + def compute_basis( + self, x: Float[Array, "N 1"] + ) -> tuple[Float[Array, "N m"], Float[Array, " m"]]: + r"""Linearised HSGP decomposition. + + Returns :math:`(\Phi, \sqrt{S})` such that + + .. math:: + f(x) \approx \Phi\,\mathrm{diag}(\sqrt{S})\,\beta, + \quad \beta \sim \mathcal{N}(0, I_m). + + Args: + x: Input locations of shape ``(N, 1)``. + + Returns: + Tuple ``(phi, sqrt_spectral_weights)`` where ``phi`` has shape + ``(N, m)`` and ``sqrt_spectral_weights`` has shape ``(m,)``. + """ + phi = self.eigenfunctions(x) + sqrt_spectral_weights = jnp.sqrt(self.spectral_weights()) + return phi, sqrt_spectral_weights + + def __call__( + self, + x: Float[Array, " D"], + y: Float[Array, " D"], + ) -> None: + raise RuntimeError( + "HSGP does not support pointwise kernel evaluation. " + "Use gram(), cross_covariance(), or compute_basis() instead." + ) diff --git a/gpjax/kernels/computations/__init__.py b/gpjax/kernels/computations/__init__.py index 60eb5a1b3..1de8b5f5f 100644 --- a/gpjax/kernels/computations/__init__.py +++ b/gpjax/kernels/computations/__init__.py @@ -21,6 +21,7 @@ from gpjax.kernels.computations.dense import DenseKernelComputation from gpjax.kernels.computations.diagonal import DiagonalKernelComputation from gpjax.kernels.computations.eigen import EigenKernelComputation +from gpjax.kernels.computations.hsgp import HSGPComputation __all__ = [ "AbstractKernelComputation", @@ -29,4 +30,5 @@ "DenseKernelComputation", "DiagonalKernelComputation", "EigenKernelComputation", + "HSGPComputation", ] diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index d4a06c2a2..07c8818d8 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -7,67 +7,69 @@ from gpjax.kernels.computations.base import AbstractKernelComputation from gpjax.linalg import ( Diagonal, + LowRank, ) +from gpjax.linalg.utils import psd from gpjax.typing import Array K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") -# TODO: Use low rank linear operator! - class BasisFunctionComputation(AbstractKernelComputation): - r"""Compute engine class for finite basis function approximations to a kernel.""" + r"""Compute engine for finite basis function approximations (RFF).""" + + def gram(self, kernel: K, x: Float[Array, "N D"]) -> LowRank: + features = self.compute_features(kernel, x) + weighted_features = features * jnp.sqrt(self.scaling(kernel)) + return psd(LowRank(weighted_features)) def _cross_covariance( self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - z1 = self.compute_features(kernel, x) - z2 = self.compute_features(kernel, y) - return self.scaling(kernel) * jnp.matmul(z1, z2.T) + features_x = self.compute_features(kernel, x) + features_y = self.compute_features(kernel, y) + return self.scaling(kernel) * jnp.matmul(features_x, features_y.T) def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]: - z1 = self.compute_features(kernel, inputs) - return self.scaling(kernel) * jnp.matmul(z1, z1.T) + features = self.compute_features(kernel, inputs) + return self.scaling(kernel) * jnp.matmul(features, features.T) def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal: - r"""For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. + r"""Diagonal of the approximate Gram matrix. Args: - kernel (AbstractKernel): the kernel function. - inputs (Float[Array, "N D"]): The input matrix. + kernel: The RFF kernel. + inputs: Input matrix of shape ``(N, D)``. - Returns - ------- - Diagonal: The computed diagonal variance entries. + Returns: + Diagonal variance entries. """ return super().diagonal(kernel.base_kernel, inputs) def compute_features( self, kernel: K, x: Float[Array, "N D"] ) -> Float[Array, "N L"]: - r"""Compute the features for the inputs. + r"""Compute random Fourier features :math:`[\cos(z), \sin(z)]`. Args: - kernel: the kernel function. - x: the inputs to the kernel function of shape `(N, D)`. + kernel: The RFF kernel. + x: Inputs of shape ``(N, D)``. Returns: - A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$. + Feature matrix of shape ``(N, 2M)`` where ``M = num_basis_fns``. """ frequencies = kernel.frequencies - scaling_factor = kernel.base_kernel.lengthscale[...] - z = jnp.matmul(x, (frequencies / scaling_factor).T) - z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) - return z + lengthscale = kernel.base_kernel.lengthscale[...] + projected = jnp.matmul(x, (frequencies / lengthscale).T) + return jnp.concatenate([jnp.cos(projected), jnp.sin(projected)], axis=-1) def scaling(self, kernel: K) -> Float[Array, ""]: - r"""Compute the scaling factor for the covariance matrix. + r"""Variance scaling factor: :math:`\sigma^2 / M`. Args: - kernel: the kernel function. + kernel: The RFF kernel. Returns: - A scalar array representing the scaling factor. + Scalar scaling factor. """ return kernel.base_kernel.variance[...] / kernel.num_basis_fns diff --git a/gpjax/kernels/computations/hsgp.py b/gpjax/kernels/computations/hsgp.py new file mode 100644 index 000000000..8baadb90f --- /dev/null +++ b/gpjax/kernels/computations/hsgp.py @@ -0,0 +1,52 @@ +"""Compute engine for the Hilbert Space GP approximation.""" + +import typing as tp + +import jax.numpy as jnp +from jaxtyping import Float + +import gpjax +from gpjax.kernels.computations.base import AbstractKernelComputation +from gpjax.linalg import Diagonal, LowRank +from gpjax.linalg.utils import psd +from gpjax.typing import Array + +K = tp.TypeVar("K", bound="gpjax.kernels.approximations.HSGP") + + +class HSGPComputation(AbstractKernelComputation): + r"""Compute engine for the HSGP kernel approximation. + + Builds the Gram matrix from the low-rank decomposition + :math:`\tilde{K} = \Phi \Lambda \Phi^\top` where + :math:`\Phi` contains Laplacian eigenfunctions and + :math:`\Lambda = \mathrm{diag}(S(\sqrt{\lambda_1}), \ldots, + S(\sqrt{\lambda_m}))`. + """ + + def _weighted_basis( + self, kernel: K, x: Float[Array, "N D"] + ) -> tuple[Float[Array, "N m"], Float[Array, "N m"]]: + r"""Return ``(phi, weighted_phi)`` where ``weighted_phi = phi * sqrt(S)``.""" + phi, sqrt_spectral_weights = kernel.compute_basis(x) + weighted_phi = phi * sqrt_spectral_weights[None, :] + return phi, weighted_phi + + def gram(self, kernel: K, x: Float[Array, "N D"]) -> LowRank: + _, weighted_phi = self._weighted_basis(kernel, x) + return psd(LowRank(weighted_phi)) + + def _gram(self, kernel: K, x: Float[Array, "N D"]) -> Float[Array, "N N"]: + _, weighted_phi = self._weighted_basis(kernel, x) + return weighted_phi @ weighted_phi.T + + def _cross_covariance( + self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + _, weighted_phi_x = self._weighted_basis(kernel, x) + _, weighted_phi_y = self._weighted_basis(kernel, y) + return weighted_phi_x @ weighted_phi_y.T + + def diagonal(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal: + _, weighted_phi = self._weighted_basis(kernel, x) + return psd(Diagonal(jnp.sum(weighted_phi**2, axis=1))) diff --git a/gpjax/kernels/stationary/base.py b/gpjax/kernels/stationary/base.py index 5de7572e2..9108abc56 100644 --- a/gpjax/kernels/stationary/base.py +++ b/gpjax/kernels/stationary/base.py @@ -18,13 +18,13 @@ from flax import nnx import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.base import AbstractKernel from gpjax.kernels.computations import ( AbstractKernelComputation, DenseKernelComputation, ) +from gpjax.kernels.stationary.utils import SpectralDensity from gpjax.parameters import ( NonNegativeReal, PositiveReal, @@ -95,11 +95,13 @@ def __init__( self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance) @property - def spectral_density(self) -> npd.Normal | npd.StudentT: - r"""The spectral density of the kernel. + def spectral_density(self) -> SpectralDensity: + r"""Spectral density :math:`S(\omega)` of this kernel. - Returns: - Callable[[Float[Array, "D"]], Float[Array, "D"]]: The spectral density function. + Subclasses override this to return a + :class:`~gpjax.kernels.stationary.utils.SpectralDensity` that + supports ``sample()`` (for RFF) and + ``__call__(omega, variance, lengthscale)`` (for HSGP). """ raise NotImplementedError( f"Kernel {self.name} does not have a spectral density." diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index 66f992ae1..e452f8144 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -48,5 +48,14 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=1) + def spectral_density(self) -> SpectralDensity: + r"""Matern-1/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \,\frac{2/\ell}{1/\ell^2 + \omega^2} + """ + + def _evaluate(omega, variance, lengthscale): + return variance * (2.0 / lengthscale) / (1.0 / lengthscale**2 + omega**2) + + return SpectralDensity(build_student_t_distribution(nu=1), _evaluate) diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index e47e031d9..a288f3ac4 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -54,5 +54,17 @@ def __call__( return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=3) + def spectral_density(self) -> SpectralDensity: + r"""Matern-3/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \, + \frac{4\,\alpha^3}{(3/\ell^2 + \omega^2)^2}, + \quad \alpha = \sqrt{3}/\ell + """ + + def _evaluate(omega, variance, lengthscale): + alpha = jnp.sqrt(3.0) / lengthscale + return variance * 4.0 * alpha**3 / (3.0 / lengthscale**2 + omega**2) ** 2 + + return SpectralDensity(build_student_t_distribution(nu=3), _evaluate) diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 84ca61069..1c1e7c943 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jaxtyping import Float -import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel from gpjax.kernels.stationary.utils import ( + SpectralDensity, build_student_t_distribution, euclidean_distance, ) @@ -53,5 +53,22 @@ def __call__( return K.squeeze() @property - def spectral_density(self) -> npd.StudentT: - return build_student_t_distribution(nu=5) + def spectral_density(self) -> SpectralDensity: + r"""Matern-5/2 spectral density. + + .. math:: + S(\omega) = \sigma^2 \, + \frac{(16/3)\,\alpha^5}{(5/\ell^2 + \omega^2)^3}, + \quad \alpha = \sqrt{5}/\ell + """ + + def _evaluate(omega, variance, lengthscale): + alpha = jnp.sqrt(5.0) / lengthscale + return ( + variance + * (16.0 / 3.0) + * alpha**5 + / (5.0 / lengthscale**2 + omega**2) ** 3 + ) + + return SpectralDensity(build_student_t_distribution(nu=5), _evaluate) diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 44ea74d0e..e1da294dc 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -18,7 +18,7 @@ import numpyro.distributions as npd from gpjax.kernels.stationary.base import StationaryKernel -from gpjax.kernels.stationary.utils import squared_distance +from gpjax.kernels.stationary.utils import SpectralDensity, squared_distance from gpjax.typing import ( Array, ScalarFloat, @@ -44,5 +44,20 @@ def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: return K.squeeze() @property - def spectral_density(self) -> npd.Normal: - return npd.Normal(0.0, 1.0) + def spectral_density(self) -> SpectralDensity: + r"""RBF spectral density. + + .. math:: + S(\omega) = \sigma^2 \sqrt{2\pi}\,\ell\, + \exp\!\bigl(-\tfrac{1}{2}\ell^2 \omega^2\bigr) + """ + + def _evaluate(omega, variance, lengthscale): + return ( + variance + * jnp.sqrt(2.0 * jnp.pi) + * lengthscale + * jnp.exp(-0.5 * lengthscale**2 * omega**2) + ) + + return SpectralDensity(npd.Normal(0.0, 1.0), _evaluate) diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index bbe7b0a7d..e0e3261b1 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import beartype.typing as tp import jax.numpy as jnp from jaxtyping import Float import numpyro.distributions as npd @@ -23,45 +24,97 @@ def build_student_t_distribution(nu: int) -> npd.StudentT: - r"""Build a Student's t distribution with a fixed smoothness parameter. - - For a fixed half-integer smoothness parameter, compute the spectral density of a - MatΓ©rn kernel; a Student's t distribution. + r"""Student's t distribution for Matern spectral densities. Args: - nu (int): The smoothness parameter of the MatΓ©rn kernel. + nu: Degrees of freedom (equals the Matern smoothness parameter + :math:`\nu` mapped to the nearest integer: 1, 3, or 5). + + Returns: + A standard Student's t distribution with ``df=nu``. + """ + return npd.StudentT(df=nu, loc=0.0, scale=1.0) + - Returns - ------- - tfp.Distribution: A Student's t distribution with the same smoothness parameter. +class SpectralDensity: + r"""Spectral density :math:`S(\omega)` of a stationary kernel. + + This class serves two roles: + + 1. **Sampling** (for RFF): delegates to a wrapped NumPyro distribution + via :meth:`sample`, drawing frequency samples :math:`\omega`. + 2. **Evaluation** (for HSGP): computes :math:`S(\omega)` at arbitrary + frequencies via :meth:`__call__`, incorporating kernel variance and + lengthscale. + + Args: + distribution: NumPyro distribution to sample from (used by RFF). + evaluate_fn: Callable ``(omega, variance, lengthscale) -> S(omega)`` + that evaluates the spectral density at given frequencies. """ - dist = npd.StudentT(df=nu, loc=0.0, scale=1.0) - return dist + + def __init__( + self, + distribution: npd.Distribution, + evaluate_fn: tp.Callable[ + [Float[Array, " M"], ScalarFloat, ScalarFloat], Float[Array, " M"] + ], + ): + self._distribution = distribution + self._evaluate_fn = evaluate_fn + + def sample(self, key: Array, sample_shape: tuple[int, ...]) -> Float[Array, "..."]: + """Draw frequency samples from the spectral distribution (used by RFF). + + Args: + key: JAX PRNG key. + sample_shape: Shape of the sample array. + + Returns: + Sampled frequencies. + """ + return self._distribution.sample(key=key, sample_shape=sample_shape) + + def __call__( + self, + omega: Float[Array, " M"], + variance: ScalarFloat, + lengthscale: ScalarFloat, + ) -> Float[Array, " M"]: + r"""Evaluate :math:`S(\omega)` at the given frequencies (used by HSGP). + + Args: + omega: Frequencies at which to evaluate the spectral density. + variance: Kernel variance :math:`\sigma^2`. + lengthscale: Kernel lengthscale :math:`\ell`. + + Returns: + Spectral density values :math:`S(\omega)`. + """ + return self._evaluate_fn(omega, variance, lengthscale) def squared_distance(x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: - r"""Compute the squared distance between a pair of inputs. + r"""Squared Euclidean distance :math:`\lVert x - y \rVert^2`. Args: - x (Float[Array, " D"]): First input. - y (Float[Array, " D"]): Second input. + x: First input vector. + y: Second input vector. - Returns - ------- - ScalarFloat: The squared distance between the inputs. + Returns: + The squared distance between the inputs. """ return jnp.sum((x - y) ** 2) def euclidean_distance(x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: - r"""Compute the euclidean distance between a pair of inputs. + r"""Euclidean distance :math:`\lVert x - y \rVert`, clamped for stability. Args: - x (Float[Array, " D"]): First input. - y (Float[Array, " D"]): Second input. + x: First input vector. + y: Second input vector. - Returns - ------- - ScalarFloat: The euclidean distance between the inputs. + Returns: + The Euclidean distance between the inputs. """ return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) diff --git a/gpjax/linalg/__init__.py b/gpjax/linalg/__init__.py index 489ac2cbd..366caebac 100644 --- a/gpjax/linalg/__init__.py +++ b/gpjax/linalg/__init__.py @@ -13,6 +13,7 @@ Identity, Kronecker, LinearOperator, + LowRank, Triangular, ) from gpjax.linalg.utils import ( @@ -28,6 +29,7 @@ "Identity", "Kronecker", "LinearOperator", + "LowRank", "Triangular", "diag", "logdet", diff --git a/gpjax/linalg/operations.py b/gpjax/linalg/operations.py index d287d49a9..1ccb69fd2 100644 --- a/gpjax/linalg/operations.py +++ b/gpjax/linalg/operations.py @@ -12,6 +12,7 @@ Identity, Kronecker, LinearOperator, + LowRank, Triangular, ) from gpjax.typing import ScalarFloat @@ -219,6 +220,10 @@ def _handle_blockdiag(A): def _handle_dense(A): return jnp.diag(A.array) + def _handle_lowrank(A): + # diag(W W^T)_i = sum_j W_{ij}^2 + return jnp.sum(A.factor**2, axis=1) + def _handle_default(A): return jnp.diag(A.to_dense()) @@ -229,6 +234,7 @@ def _handle_default(A): Kronecker: _handle_kronecker, BlockDiag: _handle_blockdiag, Dense: _handle_dense, + LowRank: _handle_lowrank, } handler = dispatch_table.get(type(A), _handle_default) diff --git a/gpjax/linalg/operators.py b/gpjax/linalg/operators.py index db49a9e43..98f14dabe 100644 --- a/gpjax/linalg/operators.py +++ b/gpjax/linalg/operators.py @@ -407,6 +407,52 @@ def _kronecker_tree_unflatten(aux_data, children): jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten) + +class LowRank(LinearOperator): + r"""Low-rank matrix K = W W^T where W has shape (N, m) with m << N. + + Basis-function kernel approximations (HSGP, RFF) produce a kernel + matrix that factors as K = W W^T. Storing only the (N, m) factor W + rather than the full (N, N) matrix enables O(N m^2) inference via the + Woodbury identity instead of O(N^3) Cholesky-based inference. + """ + + def __init__(self, factor: Float[Array, "N m"]): + super().__init__() + self.factor = factor + + @property + def shape(self) -> tuple[int, int]: + num_data = self.factor.shape[0] + return (num_data, num_data) + + @property + def rank(self) -> int: + return self.factor.shape[1] + + @property + def dtype(self) -> jnp.dtype: + return self.factor.dtype + + def to_dense(self) -> Float[Array, "N N"]: + return self.factor @ self.factor.T + + @property + def T(self) -> "LowRank": + # W W^T is symmetric, so the transpose is itself. + return self + + +def _lowrank_tree_flatten(lowrank): + return (lowrank.factor,), None + + +def _lowrank_tree_unflatten(aux_data, children): + return LowRank(children[0]) + + +jtu.register_pytree_node(LowRank, _lowrank_tree_flatten, _lowrank_tree_unflatten) + __all__ = [ "BlockDiag", "Dense", @@ -414,5 +460,6 @@ def _kronecker_tree_unflatten(aux_data, children): "Identity", "Kronecker", "LinearOperator", + "LowRank", "Triangular", ] diff --git a/gpjax/linalg/woodbury.py b/gpjax/linalg/woodbury.py new file mode 100644 index 000000000..c96d74d36 --- /dev/null +++ b/gpjax/linalg/woodbury.py @@ -0,0 +1,121 @@ +r"""Woodbury identity helpers for efficient low-rank + diagonal solves. + +When a kernel is approximated by a low-rank factorisation K \approx W W^T +(e.g. via HSGP or RFF), the marginal covariance becomes + + \Sigma = W W^T + D, where D = diag(noise), W is (N, m), m << N. + +Naively inverting \Sigma costs O(N^3). The *Woodbury matrix identity* reduces +this to O(N m^2 + m^3) by working with the small (m x m) *capacitance matrix* + + A = I_m + W^T D^{-1} W. + +The three operations exposed here --- solve, log-determinant, and quadratic +form --- are the building blocks for the GP marginal likelihood and posterior +prediction when the kernel has low-rank structure. +""" + +import jax.numpy as jnp +import jax.scipy as jsp +from jaxtyping import Float + +from gpjax.typing import Array, ScalarFloat + + +def _capacitance_cholesky( + W: Float[Array, "N m"], + noise_inv: Float[Array, " N"], +) -> Float[Array, "m m"]: + r"""Cholesky factor of the capacitance matrix A = I_m + W^T D^{-1} W. + + This is the shared computation underlying all Woodbury operations. + + Args: + W: Factor matrix, shape (N, m). + noise_inv: Element-wise reciprocal of the diagonal noise, shape (N,). + + Returns: + Lower-triangular Cholesky factor L_A such that A = L_A L_A^T. + """ + m = W.shape[1] + Dinv_W = noise_inv[:, None] * W + A = jnp.eye(m) + W.T @ Dinv_W + return jnp.linalg.cholesky(A) + + +def woodbury_solve( + W: Float[Array, "N m"], + noise: Float[Array, " N"], + b: Float[Array, "N ..."], +) -> Float[Array, "N ..."]: + r"""Solve (W W^T + D) x = b via the Woodbury identity. + + .. math:: + + x = D^{-1} b - D^{-1} W \, A^{-1} \, W^T D^{-1} b + + where D = diag(noise) and A = I_m + W^T D^{-1} W. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + b: Right-hand side, shape (N,) or (N, K). + + Returns: + Solution x with same shape as b. + """ + noise_inv = 1.0 / noise + Dinv_W = noise_inv[:, None] * W + Dinv_b = noise_inv * b if b.ndim == 1 else noise_inv[:, None] * b + + L_A = _capacitance_cholesky(W, noise_inv) + + WtDinv_b = Dinv_W.T @ b + forward = jsp.linalg.solve_triangular(L_A, WtDinv_b, lower=True) + Ainv_WtDinv_b = jsp.linalg.solve_triangular(L_A.T, forward, lower=False) + + return Dinv_b - Dinv_W @ Ainv_WtDinv_b + + +def woodbury_logdet( + W: Float[Array, "N m"], + noise: Float[Array, " N"], +) -> ScalarFloat: + r"""Log-determinant of W W^T + D via the matrix determinant lemma. + + .. math:: + + \log|\Sigma| = \log|A| + \sum_i \log(\text{noise}_i) + + where A = I_m + W^T D^{-1} W. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + + Returns: + Log-determinant as a scalar. + """ + noise_inv = 1.0 / noise + L_A = _capacitance_cholesky(W, noise_inv) + logdet_A = 2.0 * jnp.sum(jnp.log(jnp.diag(L_A))) + return logdet_A + jnp.sum(jnp.log(noise)) + + +def woodbury_quad( + W: Float[Array, "N m"], + noise: Float[Array, " N"], + diff: Float[Array, " N"], +) -> ScalarFloat: + r"""Quadratic form diff^T \Sigma^{-1} diff where \Sigma = W W^T + D. + + Args: + W: Factor matrix, shape (N, m). + noise: Diagonal noise vector, shape (N,). + diff: Vector, shape (N,). + + Returns: + Quadratic form as a scalar. + """ + solved = woodbury_solve(W, noise, diff) + return diff @ solved diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 236383b93..7c4d36004 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -23,7 +23,9 @@ psd, solve, ) +from gpjax.linalg.operators import LowRank from gpjax.linalg.utils import add_jitter +from gpjax.linalg.woodbury import woodbury_logdet, woodbury_quad from gpjax.typing import ( Array, ScalarFloat, @@ -120,12 +122,25 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat: noise = posterior.likelihood.noise_vector(data.n) Kxx = kernel.gram(x) - Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter) - Sigma_dense = Kxx_dense + jnp.diag(noise) - Sigma = psd(Dense(Sigma_dense)) + diff = jnp.atleast_1d((y_flat - mx_flat).squeeze()) + + if isinstance(Kxx, LowRank): + # Low-rank path: Sigma = W W^T + D. + # Use the Woodbury identity for O(N m^2) instead of O(N^3) Cholesky. + W = Kxx.factor + noise_with_jitter = noise + posterior.prior.jitter + num_datapoints = x.shape[0] * posterior.likelihood.num_outputs + + log_det = woodbury_logdet(W, noise_with_jitter) + quadratic = woodbury_quad(W, noise_with_jitter, diff) - mll = GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) - return mll.log_prob(jnp.atleast_1d(y_flat.squeeze())).squeeze() + return -0.5 * (num_datapoints * jnp.log(2.0 * jnp.pi) + log_det + quadratic) + + # Dense path: Sigma = Kxx + D. Standard Cholesky-based log-probability. + Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter) + Sigma = psd(Dense(Kxx_dense + jnp.diag(noise))) + marginal = GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) + return marginal.log_prob(diff).squeeze() def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat: diff --git a/mkdocs.yml b/mkdocs.yml index 527f18c0c..3c7d4bf3a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - Multi-Output GPs: _examples/multioutput.md - Accelerated Multi-Output GPs: _examples/oilmm.md - Orthogonal Additive GPs: _examples/oak.md + - Hilbert Space GPs: _examples/hsgp.md - πŸ§ͺ Experimental: - Numpyro Integration: _examples/numpyro_integration.md - Spatial Linear Gaussian Process: _examples/spatial_linear_gp.md diff --git a/pyproject.toml b/pyproject.toml index acd01cae4..dfb9f7655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "gpjax" -dynamic = ["version"] +version = "0.13.6" description = 'Gaussian processes in JAX.' readme = "README.md" requires-python = ">=3.11" @@ -57,6 +57,7 @@ docs = [ "markdown-katex>=202406.1035", "scikit-learn>=1.5.1", "ucimlrepo>=0.0.7", + "rdata>=1.0.0", ] [tool.uv] diff --git a/tests/test_heteroscedastic.py b/tests/test_heteroscedastic.py index 94c99cf5e..f46cb73cd 100644 --- a/tests/test_heteroscedastic.py +++ b/tests/test_heteroscedastic.py @@ -167,11 +167,13 @@ def test_softplus_transform_numerical_accuracy(mean: float, variance: float, see # E[1/sigma^2] mc_inv_variance = jnp.mean(1.0 / transformed_samples) - # Allow for some MC error and quadrature approximation error + # Allow for some MC error and quadrature approximation error. + # atol covers near-zero values where rtol alone is unreliable. rtol = 0.15 - assert jnp.allclose(moments.variance, mc_variance, rtol=rtol) - assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol) - assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol) + atol = 0.01 + assert jnp.allclose(moments.variance, mc_variance, rtol=rtol, atol=atol) + assert jnp.allclose(moments.log_variance, mc_log_variance, rtol=rtol, atol=atol) + assert jnp.allclose(moments.inv_variance, mc_inv_variance, rtol=rtol, atol=atol) def test_heteroscedastic_variational_predict(prior, noise_prior, dataset): diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index c78d66120..1771b631d 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -13,7 +13,7 @@ RationalQuadratic, StationaryKernel, ) -from gpjax.linalg.operators import Dense +from gpjax.linalg.operators import LowRank import jax from jax import config import jax.numpy as jnp @@ -52,7 +52,7 @@ def test_gram( linop = approximate.gram(x) # Check the return type - assert isinstance(linop, Dense) + assert isinstance(linop, LowRank) Kxx = linop.to_dense() + jnp.eye(n_data) * _jitter diff --git a/tests/test_kernels/test_computations.py b/tests/test_kernels/test_computations.py index d1883ab2b..45cfdcde2 100644 --- a/tests/test_kernels/test_computations.py +++ b/tests/test_kernels/test_computations.py @@ -15,6 +15,7 @@ PSD, Dense, Diagonal, + LowRank, ) import jax.numpy as jnp import jax.random as jr @@ -67,7 +68,7 @@ def test_gram_shape(self, rff_kernel): """Gram matrix should be (N, N).""" x = jr.normal(jr.key(1), (10, 2)) gram = rff_kernel.gram(x) - assert isinstance(gram, Dense) + assert isinstance(gram, LowRank) assert PSD in gram.annotations assert gram.shape == (10, 10) diff --git a/tests/test_kernels/test_hsgp.py b/tests/test_kernels/test_hsgp.py new file mode 100644 index 000000000..82bbf28e8 --- /dev/null +++ b/tests/test_kernels/test_hsgp.py @@ -0,0 +1,309 @@ +"""Tests for the Hilbert Space Gaussian Process (HSGP) kernel approximation.""" + +from gpjax.kernels.approximations.hsgp import HSGP +from gpjax.kernels.stationary import RBF, Matern12, Matern32, Matern52 +from gpjax.linalg.operators import LowRank +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + +STATIONARY_KERNELS = [RBF, Matern12, Matern32, Matern52] + + +def _make_hsgp( + kernel_class=RBF, + num_basis_fns: int = 20, + domain_half_width: float = 5.0, + center: float = 0.0, +) -> HSGP: + """Create an HSGP with sensible defaults for testing.""" + base_kernel = kernel_class(n_dims=1) + return HSGP( + base_kernel=base_kernel, + num_basis_fns=num_basis_fns, + domain_half_width=domain_half_width, + center=center, + ) + + +# ────────────────────────────────────────────────────────────────────── +# Eigenvalues +# ────────────────────────────────────────────────────────────────────── + + +def test_eigenvalue_count(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + eigenvalues = hsgp.eigenvalues() + assert eigenvalues.shape == (10,) + + +def test_eigenvalue_formula(): + """sqrt(lambda_j) = j * pi / (2 * L).""" + half_width = 5.0 + num_basis = 4 + hsgp = _make_hsgp(num_basis_fns=num_basis, domain_half_width=half_width) + eigenvalues = hsgp.eigenvalues() + indices = jnp.arange(1, num_basis + 1) + expected = indices * jnp.pi / (2.0 * half_width) + npt.assert_allclose(eigenvalues, expected) + + +def test_eigenvalues_are_strictly_increasing(): + hsgp = _make_hsgp(num_basis_fns=20, domain_half_width=3.0) + eigenvalues = hsgp.eigenvalues() + assert jnp.all(jnp.diff(eigenvalues) > 0) + + +# ────────────────────────────────────────────────────────────────────── +# Eigenfunctions +# ────────────────────────────────────────────────────────────────────── + + +def test_eigenfunction_shape(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + basis_matrix = hsgp.eigenfunctions(inputs) + assert basis_matrix.shape == (50, 10) + + +def test_eigenfunctions_are_approximately_orthonormal(): + """Eigenfunctions should be approximately orthonormal over [-L, L].""" + half_width = 3.0 + num_basis = 5 + hsgp = _make_hsgp(num_basis_fns=num_basis, domain_half_width=half_width, center=0.0) + num_quadrature_points = 10_000 + inputs = jnp.linspace(-half_width, half_width, num_quadrature_points)[:, None] + basis_matrix = hsgp.eigenfunctions(inputs) + spacing = 2.0 * half_width / num_quadrature_points + gram_matrix = basis_matrix.T @ basis_matrix * spacing + npt.assert_allclose(gram_matrix, jnp.eye(num_basis), atol=1e-2) + + +def test_eigenfunctions_vanish_at_boundaries(): + """Eigenfunctions should be zero at x = -L and x = L.""" + half_width = 3.0 + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=half_width, center=0.0) + boundary_points = jnp.array([[-half_width], [half_width]]) + basis_at_boundary = hsgp.eigenfunctions(boundary_points) + npt.assert_allclose(basis_at_boundary, 0.0, atol=1e-12) + + +# ────────────────────────────────────────────────────────────────────── +# Centering +# ────────────────────────────────────────────────────────────────────── + + +def test_explicit_center_is_stored(): + hsgp = HSGP( + base_kernel=RBF(n_dims=1), + num_basis_fns=5, + domain_half_width=3.0, + center=1.0, + ) + assert hsgp._center == 1.0 + + +def test_auto_center_uses_midpoint_of_input_range(): + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=3.0, center=None) + inputs = jnp.linspace(2.0, 8.0, 50)[:, None] + _ = hsgp.eigenfunctions(inputs) + assert hsgp._center == pytest.approx(5.0) + + +def test_auto_center_persists_across_calls(): + """Once set, auto-center should not change on subsequent calls.""" + hsgp = _make_hsgp(num_basis_fns=5, domain_half_width=3.0, center=None) + first_inputs = jnp.linspace(0, 10, 50)[:, None] + second_inputs = jnp.linspace(-5, 5, 50)[:, None] + _ = hsgp.eigenfunctions(first_inputs) + center_after_first_call = hsgp._center + _ = hsgp.eigenfunctions(second_inputs) + assert hsgp._center == center_after_first_call + + +# ────────────────────────────────────────────────────────────────────── +# compute_basis +# ────────────────────────────────────────────────────────────────────── + + +def test_compute_basis_returns_correct_shapes(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + basis_matrix, sqrt_spectral_weights = hsgp.compute_basis(inputs) + assert basis_matrix.shape == (50, 10) + assert sqrt_spectral_weights.shape == (10,) + + +def test_compute_basis_sqrt_psd_is_positive(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + inputs = jnp.linspace(-2, 2, 50)[:, None] + _, sqrt_spectral_weights = hsgp.compute_basis(inputs) + assert jnp.all(sqrt_spectral_weights > 0) + + +# ────────────────────────────────────────────────────────────────────── +# Gram, cross-covariance, and diagonal +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_gram_returns_psd_low_rank_matrix(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs = jnp.linspace(-3, 3, 30)[:, None] + gram_operator = hsgp.gram(inputs) + + assert isinstance(gram_operator, LowRank) + gram_dense = gram_operator.to_dense() + assert gram_dense.shape == (30, 30) + + eigenvalues_of_gram, _ = jnp.linalg.eigh(gram_dense + 1e-6 * jnp.eye(30)) + assert jnp.all(eigenvalues_of_gram > 0) + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_cross_covariance_shape(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs_a = jnp.linspace(-3, 3, 30)[:, None] + inputs_b = jnp.linspace(-2, 2, 20)[:, None] + cross_covariance = hsgp.cross_covariance(inputs_a, inputs_b) + assert cross_covariance.shape == (30, 20) + + +def test_gram_is_symmetric(): + hsgp = _make_hsgp(kernel_class=RBF) + inputs = jnp.linspace(-3, 3, 30)[:, None] + gram_dense = hsgp.gram(inputs).to_dense() + npt.assert_allclose(gram_dense, gram_dense.T, atol=1e-12) + + +@pytest.mark.parametrize("KernelClass", STATIONARY_KERNELS) +def test_diagonal_matches_gram_diagonal(KernelClass): + hsgp = _make_hsgp(kernel_class=KernelClass) + inputs = jnp.linspace(-3, 3, 30)[:, None] + diagonal_dense = hsgp.diagonal(inputs).to_dense() + gram_dense = hsgp.gram(inputs).to_dense() + npt.assert_allclose(jnp.diag(diagonal_dense), jnp.diag(gram_dense), atol=1e-10) + + +# ────────────────────────────────────────────────────────────────────── +# Convergence to exact kernel +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("KernelClass", [RBF, Matern32, Matern52]) +def test_gram_error_decreases_with_more_basis_functions(KernelClass): + """With more basis functions, the HSGP Gram matrix should converge to exact.""" + base_kernel = KernelClass(n_dims=1) + inputs = jnp.linspace(-1, 1, 30)[:, None] + exact_gram = base_kernel.gram(inputs).to_dense() + + hsgp_coarse = _make_hsgp(kernel_class=KernelClass, num_basis_fns=10) + hsgp_fine = _make_hsgp(kernel_class=KernelClass, num_basis_fns=80) + + error_coarse = jnp.linalg.norm(exact_gram - hsgp_coarse.gram(inputs).to_dense()) + error_fine = jnp.linalg.norm(exact_gram - hsgp_fine.gram(inputs).to_dense()) + + assert error_fine < error_coarse + + +def test_rbf_gram_closely_matches_exact(): + """RBF with many basis functions should be very close to exact.""" + base_kernel = RBF(n_dims=1) + inputs = jnp.linspace(-1, 1, 20)[:, None] + exact_gram = base_kernel.gram(inputs).to_dense() + + hsgp = _make_hsgp(kernel_class=RBF, num_basis_fns=100) + approximate_gram = hsgp.gram(inputs).to_dense() + + max_absolute_error = jnp.max(jnp.abs(exact_gram - approximate_gram)) + assert max_absolute_error < 0.01 + + +# ────────────────────────────────────────────────────────────────────── +# Input validation +# ────────────────────────────────────────────────────────────────────── + + +def test_nonstationary_kernel_is_rejected(): + from gpjax.kernels.nonstationary import Linear + + with pytest.raises(TypeError): + HSGP(base_kernel=Linear(1), num_basis_fns=10, domain_half_width=3.0) + + +def test_pointwise_call_raises(): + hsgp = _make_hsgp(num_basis_fns=10, domain_half_width=3.0) + with pytest.raises(RuntimeError): + hsgp(jnp.array([1.0]), jnp.array([2.0])) + + +# ────────────────────────────────────────────────────────────────────── +# Integration with Prior/Posterior pipeline +# ────────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def training_data(): + """Synthetic sinusoidal regression dataset.""" + key = jr.key(42) + num_points = 50 + inputs = jnp.linspace(-3, 3, num_points)[:, None] + targets = jnp.sin(inputs) + 0.1 * jr.normal(key, (num_points, 1)) + + from gpjax.dataset import Dataset + + return Dataset(X=inputs, y=targets), num_points + + +@pytest.fixture +def hsgp_posterior(training_data): + """Conjugate posterior with an HSGP-RBF kernel.""" + from gpjax.gps import Prior + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + _, num_points = training_data + hsgp = _make_hsgp(kernel_class=RBF, num_basis_fns=20) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=num_points) + return prior * likelihood + + +def test_conjugate_mll_returns_finite_scalar(training_data, hsgp_posterior): + from gpjax.objectives import conjugate_mll + + dataset, _ = training_data + mll = conjugate_mll(hsgp_posterior, dataset) + assert jnp.isfinite(mll) + + +def test_posterior_predict_returns_finite_moments(training_data, hsgp_posterior): + dataset, _ = training_data + test_inputs = jnp.linspace(-2.5, 2.5, 30)[:, None] + prediction = hsgp_posterior.predict(test_inputs, dataset) + + assert jnp.all(jnp.isfinite(prediction.mean)) + assert jnp.all(jnp.isfinite(prediction.covariance())) + + +def test_conjugate_mll_is_differentiable(training_data, hsgp_posterior): + """conjugate_mll with HSGP must be differentiable w.r.t. kernel parameters.""" + from flax import nnx + from gpjax.objectives import conjugate_mll + + dataset, _ = training_data + graphdef, state = nnx.split(hsgp_posterior) + + def negative_mll(state): + model = nnx.merge(graphdef, state) + return -conjugate_mll(model, dataset) + + gradients = jax.grad(negative_mll)(state) + flat_gradients = jax.tree.leaves(gradients) + for grad_leaf in flat_gradients: + assert jnp.all(jnp.isfinite(grad_leaf)), f"Non-finite gradient: {grad_leaf}" diff --git a/tests/test_kernels/test_spectral_density.py b/tests/test_kernels/test_spectral_density.py new file mode 100644 index 000000000..fd29581f2 --- /dev/null +++ b/tests/test_kernels/test_spectral_density.py @@ -0,0 +1,131 @@ +"""Tests for the SpectralDensity class and kernel spectral density evaluation.""" + +from gpjax.kernels.stationary import RBF, Matern12, Matern32, Matern52 +from gpjax.kernels.stationary.utils import SpectralDensity +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + +UNIT_VARIANCE = jnp.array(1.0) +UNIT_LENGTHSCALE = jnp.array(1.0) +TEST_FREQUENCIES = jnp.array([0.0, 1.0, 5.0]) + + +# ────────────────────────────────────────────────────────────────────── +# SpectralDensity interface +# ────────────────────────────────────────────────────────────────────── + + +def test_spectral_density_has_sample_method(): + """SpectralDensity must expose sample() for RFF compatibility.""" + spectral_density = RBF(n_dims=1).spectral_density + assert isinstance(spectral_density, SpectralDensity) + samples = spectral_density.sample(key=jr.key(0), sample_shape=(10, 1)) + assert samples.shape == (10, 1) + + +def test_spectral_density_is_callable(): + """SpectralDensity must be callable with (omega, variance, lengthscale).""" + spectral_density = RBF(n_dims=1).spectral_density + frequencies = jnp.array([0.5, 1.0, 2.0]) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) + assert values.shape == (3,) + assert jnp.all(values > 0) + + +# ────────────────────────────────────────────────────────────────────── +# Closed-form spectral density formulae +# ────────────────────────────────────────────────────────────────────── + + +def test_rbf_spectral_density_matches_closed_form(): + """S(w) = variance * sqrt(2*pi) * lengthscale * exp(-0.5 * lengthscale^2 * w^2).""" + variance = jnp.array(2.0) + lengthscale = jnp.array(0.5) + spectral_density = RBF(n_dims=1, variance=2.0, lengthscale=0.5).spectral_density + + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + expected = ( + variance + * jnp.sqrt(2 * jnp.pi) + * lengthscale + * jnp.exp(-0.5 * lengthscale**2 * TEST_FREQUENCIES**2) + ) + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_rbf_spectral_density_peaks_at_zero_and_decays(): + """RBF spectral density peaks at omega=0 and decays monotonically.""" + spectral_density = RBF(n_dims=1).spectral_density + frequencies = jnp.linspace(0, 10, 100) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) + assert values[0] == jnp.max(values) + assert jnp.all(jnp.diff(values) <= 0) + + +def test_matern12_spectral_density_matches_closed_form(): + """S(w) = variance * 2/ell * 1/(1/ell^2 + w^2).""" + variance = jnp.array(1.5) + lengthscale = jnp.array(0.8) + spectral_density = Matern12( + n_dims=1, variance=1.5, lengthscale=0.8 + ).spectral_density + + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + expected = ( + variance * (2.0 / lengthscale) / (1.0 / lengthscale**2 + TEST_FREQUENCIES**2) + ) + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_matern32_spectral_density_matches_closed_form(): + """S(w) = variance * 4*(sqrt(3)/ell)^3 / (3/ell^2 + w^2)^2.""" + variance = jnp.array(2.0) + lengthscale = jnp.array(1.5) + spectral_density = Matern32( + n_dims=1, variance=2.0, lengthscale=1.5 + ).spectral_density + + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + alpha = jnp.sqrt(3.0) / lengthscale + expected = ( + variance * 4.0 * alpha**3 / (3.0 / lengthscale**2 + TEST_FREQUENCIES**2) ** 2 + ) + npt.assert_allclose(result, expected, atol=1e-12) + + +def test_matern52_spectral_density_matches_closed_form(): + """S(w) = variance * (16/3)*(sqrt(5)/ell)^5 / (5/ell^2 + w^2)^3.""" + variance = jnp.array(3.0) + lengthscale = jnp.array(2.0) + spectral_density = Matern52( + n_dims=1, variance=3.0, lengthscale=2.0 + ).spectral_density + + result = spectral_density(TEST_FREQUENCIES, variance, lengthscale) + alpha = jnp.sqrt(5.0) / lengthscale + expected = ( + variance + * (16.0 / 3.0) + * alpha**5 + / (5.0 / lengthscale**2 + TEST_FREQUENCIES**2) ** 3 + ) + npt.assert_allclose(result, expected, atol=1e-12) + + +# ────────────────────────────────────────────────────────────────────── +# Positivity across all stationary kernels +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("KernelClass", [RBF, Matern12, Matern32, Matern52]) +def test_spectral_density_is_positive_everywhere(KernelClass): + """All spectral densities must be positive for all omega.""" + spectral_density = KernelClass(n_dims=1).spectral_density + frequencies = jnp.linspace(0, 20, 200) + values = spectral_density(frequencies, UNIT_VARIANCE, UNIT_LENGTHSCALE) + assert jnp.all(values > 0) diff --git a/tests/test_linalg/__init__.py b/tests/test_linalg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_linalg/test_lowrank.py b/tests/test_linalg/test_lowrank.py new file mode 100644 index 000000000..0a3657166 --- /dev/null +++ b/tests/test_linalg/test_lowrank.py @@ -0,0 +1,81 @@ +"""Tests for the LowRank linear operator.""" + +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt + +config.update("jax_enable_x64", True) + +from gpjax.linalg.operations import diag +from gpjax.linalg.operators import LowRank + + +class TestLowRankProperties: + def test_shape(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.shape == (10, 10) + + def test_rank(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.rank == 3 + + def test_dtype(self): + W = jnp.ones((10, 3), dtype=jnp.float64) + op = LowRank(W) + assert op.dtype == jnp.float64 + + def test_to_dense(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + npt.assert_allclose(op.to_dense(), W @ W.T, atol=1e-12) + + def test_transpose_is_self(self): + W = jnp.ones((10, 3)) + op = LowRank(W) + assert op.T is op + + def test_diag_dispatch(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + expected = jnp.sum(W**2, axis=1) + npt.assert_allclose(diag(op), expected, atol=1e-12) + + +class TestLowRankPyTree: + def test_flatten_unflatten_roundtrip(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + leaves, treedef = jax.tree.flatten(op) + restored = treedef.unflatten(leaves) + npt.assert_allclose(restored.factor, op.factor) + + def test_jit_compatible(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + op = LowRank(W) + + @jax.jit + def fn(op): + return op.to_dense() + + result = fn(op) + npt.assert_allclose(result, W @ W.T, atol=1e-12) + + def test_grad_through_to_dense(self): + key = jr.key(0) + W = jr.normal(key, (10, 3)) + + @jax.grad + def fn(W): + op = LowRank(W) + return jnp.sum(op.to_dense()) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) diff --git a/tests/test_linalg.py b/tests/test_linalg/test_operators.py similarity index 100% rename from tests/test_linalg.py rename to tests/test_linalg/test_operators.py diff --git a/tests/test_linalg/test_woodbury.py b/tests/test_linalg/test_woodbury.py new file mode 100644 index 000000000..0c32ff9f1 --- /dev/null +++ b/tests/test_linalg/test_woodbury.py @@ -0,0 +1,379 @@ +"""Tests for Woodbury identity helper functions.""" + +import jax +from jax import config +import jax.numpy as jnp +import jax.random as jr +import numpy.testing as npt +import pytest + +config.update("jax_enable_x64", True) + +from gpjax.linalg.woodbury import woodbury_logdet, woodbury_quad, woodbury_solve + + +def _dense_reference(W, noise): + """Build the dense matrix W W^T + diag(noise) for reference.""" + return W @ W.T + jnp.diag(noise) + + +class TestWoodburySolve: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5), (100, 10)]) + def test_matches_dense_vector(self, N, m): + key = jr.key(0) + k1, k2, _k3 = jr.split(key, 3) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5)]) + def test_matches_dense_matrix(self, N, m): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + B = jr.normal(k2, (N, 7)) + + result = woodbury_solve(W, noise, B) + expected = jnp.linalg.solve(_dense_reference(W, noise), B) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_heterogeneous_noise(self): + key = jr.key(0) + k1, k2, k3 = jr.split(key, 3) + N, m = 30, 4 + W = jr.normal(k1, (N, m)) + noise = jnp.abs(jr.normal(k2, (N,))) + 0.01 + b = jr.normal(k3, (N,)) + + result = woodbury_solve(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + b = jr.normal(k2, (20,)) + + result = jax.jit(woodbury_solve)(W, noise, b) + expected = jnp.linalg.solve(_dense_reference(W, noise), b) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + b = jr.normal(k2, (20,)) + + @jax.grad + def fn(W): + return jnp.sum(woodbury_solve(W, noise, b)) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestWoodburyLogdet: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5), (100, 10)]) + def test_matches_dense(self, N, m): + key = jr.key(0) + W = jr.normal(key, (N, m)) + noise = jnp.ones(N) * 0.1 + + result = woodbury_logdet(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_heterogeneous_noise(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 30, 4 + W = jr.normal(k1, (N, m)) + noise = jnp.abs(jr.normal(k2, (N,))) + 0.01 + + result = woodbury_logdet(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + W = jr.normal(key, (20, 3)) + noise = jnp.ones(20) * 0.1 + + result = jax.jit(woodbury_logdet)(W, noise) + expected = jnp.linalg.slogdet(_dense_reference(W, noise))[1] + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + W = jr.normal(key, (20, 3)) + noise = jnp.ones(20) * 0.1 + + @jax.grad + def fn(W): + return woodbury_logdet(W, noise) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestWoodburyQuad: + @pytest.mark.parametrize("N,m", [(20, 3), (50, 5)]) + def test_matches_dense(self, N, m): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 0.1 + diff = jr.normal(k2, (N,)) + + result = woodbury_quad(W, noise, diff) + Sigma = _dense_reference(W, noise) + expected = diff @ jnp.linalg.solve(Sigma, diff) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_jit(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + diff = jr.normal(k2, (20,)) + + result = jax.jit(woodbury_quad)(W, noise, diff) + Sigma = _dense_reference(W, noise) + expected = diff @ jnp.linalg.solve(Sigma, diff) + npt.assert_allclose(result, expected, atol=1e-6) + + def test_grad(self): + key = jr.key(0) + k1, k2 = jr.split(key) + W = jr.normal(k1, (20, 3)) + noise = jnp.ones(20) * 0.1 + diff = jr.normal(k2, (20,)) + + @jax.grad + def fn(W): + return woodbury_quad(W, noise, diff) + + grads = fn(W) + assert jnp.all(jnp.isfinite(grads)) + + +class TestNumericalStability: + def test_large_N_small_m(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 1000, 10 + W = jr.normal(k1, (N, m)) * 0.5 + noise = jnp.ones(N) * 0.1 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + assert jnp.all(jnp.isfinite(result)) + + ld = woodbury_logdet(W, noise) + assert jnp.isfinite(ld) + + def test_small_noise(self): + key = jr.key(0) + k1, k2 = jr.split(key) + N, m = 50, 5 + W = jr.normal(k1, (N, m)) + noise = jnp.ones(N) * 1e-8 + b = jr.normal(k2, (N,)) + + result = woodbury_solve(W, noise, b) + assert jnp.all(jnp.isfinite(result)) + + +class TestConjugateMllIntegration: + """Verify Woodbury path in conjugate_mll matches dense path.""" + + @pytest.mark.parametrize("KernelClass", ["RBF", "Matern52"]) + def test_mll_matches_dense(self, KernelClass): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF, Matern52 + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + + kernel_cls = {"RBF": RBF, "Matern52": Matern52}[KernelClass] + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = kernel_cls(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=40, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + # Woodbury path (automatic via LowRank gram) + mll_woodbury = conjugate_mll(posterior, D) + + # Dense reference: manually compute via dense gram + dense_gram = hsgp.gram(x).to_dense() + from gpjax.distributions import GaussianDistribution + from gpjax.linalg import Dense, psd + from gpjax.linalg.utils import add_jitter + + noise = likelihood.noise_vector(n) + mx = prior.mean_function(x) + y_flat, mx_flat = likelihood.prepare_targets(y, mx) + Kxx_dense = add_jitter(dense_gram, prior.jitter) + Sigma_dense = Kxx_dense + jnp.diag(noise) + Sigma = psd(Dense(Sigma_dense)) + mll_dense = ( + GaussianDistribution(jnp.atleast_1d(mx_flat.squeeze()), Sigma) + .log_prob(jnp.atleast_1d(y_flat.squeeze())) + .squeeze() + ) + + npt.assert_allclose(mll_woodbury, mll_dense, atol=1e-5) + + def test_mll_differentiable_woodbury(self): + from flax import nnx + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + from gpjax.objectives import conjugate_mll + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=20, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + graphdef, state = nnx.split(posterior) + + def loss(state): + model = nnx.merge(graphdef, state) + return -conjugate_mll(model, D) + + grad_fn = jax.grad(loss) + grads = grad_fn(state) + + flat_grads = jax.tree.leaves(grads) + for g in flat_grads: + assert jnp.all(jnp.isfinite(g)) + + +class TestPredictIntegration: + """Verify Woodbury predict path matches dense predict path.""" + + def test_predict_mean_matches_dense(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D) + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.covariance())) + + def test_predict_diagonal(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations.hsgp import HSGP + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + base_kernel = RBF(n_dims=1) + hsgp = HSGP( + base_kernel=base_kernel, + num_basis_fns=30, + domain_half_width=5.0, + center=0.0, + ) + prior = Prior(kernel=hsgp, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D, return_covariance_type="diagonal") + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.variance)) + + def test_predict_rff(self): + from gpjax.dataset import Dataset + from gpjax.gps import Prior + from gpjax.kernels.approximations import RFF + from gpjax.kernels.stationary import RBF + from gpjax.likelihoods import Gaussian + from gpjax.mean_functions import Zero + + key = jr.key(42) + n = 50 + x = jnp.linspace(-3, 3, n)[:, None] + y = jnp.sin(x) + 0.1 * jr.normal(key, (n, 1)) + D = Dataset(X=x, y=y) + + rff = RFF(base_kernel=RBF(n_dims=1), num_basis_fns=50, key=jr.key(0)) + prior = Prior(kernel=rff, mean_function=Zero()) + likelihood = Gaussian(num_datapoints=n) + posterior = prior * likelihood + + x_test = jnp.linspace(-2.5, 2.5, 20)[:, None] + pred = posterior.predict(x_test, D) + + assert jnp.all(jnp.isfinite(pred.mean)) + assert jnp.all(jnp.isfinite(pred.covariance())) diff --git a/uv.lock b/uv.lock index daa54f696..4042addeb 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-02-06T09:56:46.847909Z" +exclude-newer = "2026-02-13T18:59:31.785479Z" exclude-newer-span = "P7D" [[package]] @@ -951,6 +951,7 @@ wheels = [ [[package]] name = "gpjax" +version = "0.13.6" source = { editable = "." } dependencies = [ { name = "beartype" }, @@ -985,6 +986,7 @@ docs = [ { name = "networkx" }, { name = "pandas" }, { name = "pymdown-extensions" }, + { name = "rdata" }, { name = "scikit-learn" }, { name = "seaborn" }, { name = "ucimlrepo" }, @@ -1041,6 +1043,7 @@ requires-dist = [ { name = "optax", specifier = ">0.2.1" }, { name = "pandas", marker = "extra == 'docs'", specifier = ">=1.5.3" }, { name = "pymdown-extensions", marker = "extra == 'docs'", specifier = ">=10.7.1" }, + { name = "rdata", marker = "extra == 'docs'", specifier = ">=1.0.0" }, { name = "scikit-learn", marker = "extra == 'docs'", specifier = ">=1.5.1" }, { name = "seaborn", marker = "extra == 'docs'", specifier = ">=0.12.2" }, { name = "tensorstore", marker = "sys_platform == 'darwin'", specifier = "!=0.1.76" }, @@ -3111,6 +3114,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/1b/5dbe84eefc86f48473947e2f41711aded97eecef1231f4558f1f02713c12/pyzmq-27.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c9f7f6e13dff2e44a6afeaf2cf54cee5929ad64afaf4d40b50f93c58fc687355", size = 544862, upload-time = "2025-09-08T23:09:56.509Z" }, ] +[[package]] +name = "rdata" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "typing-extensions" }, + { name = "xarray" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/e5/c5626257359fde4662f9c2f658bcb37873ffad1490c8991640a361b8cf8e/rdata-1.0.0.tar.gz", hash = "sha256:d924d7657c9eeaee4b86a0ae87999f5beb0070ed47bd491c85bc79ee9644596d", size = 56618, upload-time = "2025-08-15T17:17:10.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/b6/aec7624eb1db90e58b7fa34ab6491ee06385020bbb9598d4b8a28051da81/rdata-1.0.0-py3-none-any.whl", hash = "sha256:b671e31676e158dc215297595d5085aee1674b91be3c26e38638ca056167d402", size = 72291, upload-time = "2025-08-15T17:17:09.404Z" }, +] + [[package]] name = "readme-renderer" version = "44.0" @@ -4121,6 +4139,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" }, ] +[[package]] +name = "xarray" +version = "2026.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/85/113ff1e2cde9e8a5b13c2f0ef4e9f5cd6ca3a036b6452f4dd523419289b5/xarray-2026.1.0.tar.gz", hash = "sha256:0c9814761f9d9a9545df37292d3fda89f83201f3e02ae0f09f03313d9cfdd5e2", size = 3107024, upload-time = "2026-01-28T17:49:03.822Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/8e/952a351c10df395d9bab850f611f4368834ae9104d6449049f5a49e00925/xarray-2026.1.0-py3-none-any.whl", hash = "sha256:5fcc03d3ed8dfb662aa254efe6cd65efc70014182bbc2126e4b90d291d970d41", size = 1403009, upload-time = "2026-01-28T17:49:01.538Z" }, +] + [[package]] name = "xdoctest" version = "1.3.0"