diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 9dd49dc1..f7b14acb 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -251,8 +251,9 @@ def _temme_series_kve(v, z): z_sq = z * z logzo2 = jnp.log(z / 2.0) mu = -v * logzo2 - sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v)) - sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu) + sinc_v = jnp.sinc(v) + mu_msk = jnp.where(mu == 0.0, 1.0, mu) + sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_msk) / mu_msk) initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv @@ -711,6 +712,7 @@ def t3(x): # x>8 return factor * (rc * (cx + sx) - y * rs * (sx - cx)) x = jnp.abs(x) + x_ = jnp.where(x != 0, x, 1.0) return jnp.select( - [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x + [x == 0, x <= 4, x <= 8, x > 8], [1, t1(x_), t2(x_), t3(x_)], default=x ).reshape(orig_shape) diff --git a/jax_galsim/core/math.py b/jax_galsim/core/math.py new file mode 100644 index 00000000..e4cde10e --- /dev/null +++ b/jax_galsim/core/math.py @@ -0,0 +1,10 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def safe_sqrt(x): + """Numerically safe sqrt operation with zero derivative at zero.""" + msk = x > 0 + x_msk = jnp.where(msk, x, 1.0) + return jnp.where(msk, jnp.sqrt(x_msk), 0.0) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 6d5c9dc6..dc9880f9 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -8,6 +8,7 @@ from jax_galsim.bessel import kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue +from jax_galsim.core.math import safe_sqrt from jax_galsim.core.utils import ( ensure_hashable, has_tracers, @@ -303,9 +304,12 @@ def _xValue(self, pos): def _kValue_untrunc(self, k): """Non truncated version of _kValue""" + k_msk = jnp.where(k > 0, k, 1.0) return jnp.where( k > 0, - self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k), + self._knorm_bis + * jnp.power(k_msk, self.beta - 1.0) + * _Knu(self.beta - 1.0, k_msk), self._knorm, ) @@ -314,7 +318,7 @@ def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ - k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + k = safe_sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) res = self._kValue_untrunc(k) diff --git a/tests/jax/test_deriv_gsobject.py b/tests/jax/test_deriv_gsobject.py new file mode 100644 index 00000000..73dbd22e --- /dev/null +++ b/tests/jax/test_deriv_gsobject.py @@ -0,0 +1,46 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import jax_galsim as jgs + + +@pytest.mark.parametrize( + "gsobj,params,args, kwargs", + [ + (jgs.Spergel, ["scale_radius", "half_light_radius"], [1.0], {}), + (jgs.Exponential, ["scale_radius", "half_light_radius"], [], {}), + (jgs.Gaussian, ["sigma", "fwhm", "half_light_radius"], [], {}), + (jgs.Moffat, ["scale_radius", "half_light_radius", "fwhm"], [2.0], {}), + ], +) +def test_deriv_gsobject_radii(params, gsobj, args, kwargs): + val = 2.0 + eps = 1e-5 + + for param in params: + print("\nparam:", param, flush=True) + + def _run(val_): + kwargs_ = {param: val_} + kwargs_.update(kwargs) + return jnp.max( + gsobj( + *args, + **kwargs_, + gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8), + ) + .drawImage(nx=5, ny=5, scale=0.2, method="fft") + .array[2, 2] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + atol = 1e-5 + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)