🚛The old repository of this project (ndimsplinejax@nmoteki) has moved here as my old github account nmoteki is no longer maintained!
Table of Contents
Interpolant is an efficiently-computable mathematical function that models a discrete dataset. Interpolant is an indispensable tool for mathmatically incorpolating observational data to physical simulations or statistical inferences without appreciable biases. This can be contrast to regression models (e.g., multilayer perceptron) that almost always suffers from under-or-over fitting issue to some extent.
There have been many interpolation code/software available; however, I didn't find any multidimensional interpolant compatible with both Just-In-Time compilation and Automatic Differentiation when I starded this project in mid 2022. In my research, I needed such interpolant for applying a recent Hamiltonian-MC code to my Bayesian inverse problem wherein the forward model is only accessible through a pre-computed discrete look-up table. In that case, the forward model, a light-scattering simulator for nonspherical particles, is computationally too complex to execute in place. So, I decided to develop this NdimSpline_JAX. I'd like to share the codes hoping they are useful for scientists and engineers.
compute_coefscomputes the natural-cubic spline coefficients from scalar y data on an N-dimensional Cartesian grid, using a separable tensor-product approach with O(N·MN+1) complexity.make_interpolantcreates a JIT & Autograd compatible interpolant that uses localized evaluation (only 4N coefficients per query point).- On each dimensional axis, x grid-points must be equidistant. The grid-points interval can be different among axes.
- Works for any number of dimensions N (no hardcoded dimension limit).
- The requirement of "equidistant grid-points on each axis" would not be a serious limitation in practice. A user can project/approximate a non-equidistant gridded data to equidistant gridded data by a mathematical transformation of each variable.
Benchmarked on 5D data with n = [10, 10, 10, 10, 10] (11 grid points per axis), CPU (WSL2), float64. All JIT timings are post-warmup averages over 1000 calls.
| Operation | v0.1.2 (old) | v1.0.1 (new) | Speedup |
|---|---|---|---|
| Coefficient computation | 96.6 s | 0.61 s | 158x |
| JIT eval | 3.1 ms | 106 us | 30x |
| JIT grad | 113.5 ms | 111 us | 1,022x |
| JIT value_and_grad | 113.4 ms | 108 us | 1,046x |
| Operation | v0.1.2 (old) | v1.0.1 (new) |
|---|---|---|
| Coefficient computation (peak memory) | 16.8 MB | 1.6 MB |
Key improvements:
- Coefficient computation: Separable tensor-product TDMA solver replaces dense
scipy.linalg.solve— O(N·MN+1) vs O(M3N). - Evaluation / gradient: Localized
dynamic_sliceextracts only 4N = 1,024 coefficients per query, instead of scanning all (n+3)5 = 371,293 entries.
This is an example of how you use the modules on your local computer. The author tested the codes using Python 3.12.8 on Windows 11 machine and Python 3.12.9 on WSL (Ubuntu).
- An execution enviroment of Python >= 3.12 on Linux, MacOS, or WSL2 on Windows
- Installation of
jaxmodule, and optionallyipykernelmodule if you execute JupyterNotebook files.
git clone https://github.com/NobuhiroMoteki/NdimSpline_JAX.gitFor a step-by-step environment setup on Windows (WSL2 + Ubuntu) and native Linux (Ubuntu) machines, see installation_guide.md.
Here is the workflow for an example of 5-dimensional x-space (N=5):
-
Define the grid information and prepare observation data.
import numpy as np a = [0, 0, 0, 0, 0] # lower bounds for each dimension b = [1, 2, 3, 4, 5] # upper bounds for each dimension n = [10, 10, 10, 10, 10] # number of grid intervals per dimension N = len(a) # Generate gridded data (replace with your own data in actual use) grids = [np.linspace(a[j], b[j], n[j] + 1) for j in range(N)] mesh = np.meshgrid(*grids, indexing="ij") y_data = np.ones_like(mesh[0]) for j in range(N): y_data *= np.sin(mesh[j])
-
Compute spline coefficients and create the interpolant.
import jax.numpy as jnp from ndim_spline_jax import compute_coefs, make_interpolant c = compute_coefs(N, jnp.array(y_data)) s = make_interpolant(a, b, n, c)
-
Evaluate, differentiate, and JIT-compile.
from jax import jit, grad, value_and_grad x = jnp.array([0.7, 1.0, 1.5, 2.0, 2.5]) # must satisfy a <= x <= b print(s(x)) # evaluate print(grad(s)(x)) # gradient print(value_and_grad(s)(x)) # value and gradient s_jit = jit(s) # JIT-compiled (much faster after warm-up) print(s_jit(x)) print(jit(grad(s))(x)) print(jit(value_and_grad(s))(x))
For executing this example, just run the caller.ipynb on JupyterNotebook or execute the caller.py script.
For a detailed description of the mathematical theory (B-spline formulation, tridiagonal system, Kronecker factorization, localized evaluation) and the mapping to the implementation, see docs/theory_note.pdf (source: docs/theory_note.tex).
The ./jupyter_notebooks subfolder contains .ipynb files scripting the individual dimensional cases. These files would help user's understandings or customizations.
- Maths of multidimensional natural-cubic spline interpolation: Habermann and Kindermann 2007, Multidimensional Spline Interpolation: Theory and Applications, DOI: 10.1007/s10614-007-9092-4.
- Google/JAX reference documentation: https://jax.readthedocs.io/en/latest/
- An introduction of Google/JAX for scientists (in Japanese): https://github.com/HajimeKawahara/playjax
Distributed under the MIT License. See LICENSE.txt for more information.
Nobuhiro Moteki - nobuhiro.moteki@gmail.com
This code-development project was conceived and proceeded in a part of the N.Moteki's research on atmospheric chemical composition in the NOAA Earth System Science Laboratory, supported by a fund JSPS KAKENIHI 19KK0289.