Skip to content

Commit 1da8021

Browse files
authored
fix: better accuracy lanczos fourier-space interp (#234)
* fix: better accuracy lanczos fourier-space interp * perf: 2.3x is ok * test: put back submodule
1 parent 1e20520 commit 1da8021

2 files changed

Lines changed: 11 additions & 7 deletions

File tree

jax_galsim/interpolant.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,12 +1498,16 @@ def _xval_noraise(self, x):
14981498
# it gets recompiled as needed for combinations of n, conserve_dc, du, and krange
14991499
@functools.partial(jax.jit, static_argnames=("n", "conserve_dc", "du", "krange"))
15001500
def _interp_kval(k, n, conserve_dc, du, krange):
1501-
_idata = _lanczos_kval_interp_table(
1502-
n,
1503-
du,
1504-
krange,
1505-
conserve_dc,
1506-
)
1501+
with jax.ensure_compile_time_eval():
1502+
_idata = _lanczos_kval_interp_table(
1503+
n,
1504+
# jax-galsim uses a slightly less accurate interpolation
1505+
# function (akima vs cubic spline) and so needs a smaller spacing
1506+
# 2.3x appears to be ok
1507+
du / 2.3,
1508+
krange,
1509+
conserve_dc,
1510+
)
15071511
return akima_interp(jnp.abs(k), *_idata, fixed_spacing=True)
15081512

15091513
def _kval_noraise(self, k):

tests/GalSim

0 commit comments

Comments
 (0)