diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 23e5588a..22d4e16e 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -29,10 +29,12 @@ def _Knu(nu, x): @implements( _galsim.Moffat, lax_description="""\ -The LAX version of the Moffat profile +The JAX-GalSim version of the Moffat profile - does not support truncation or beta < 1.1 - does not support gsparams.maxk_thresholds > 0.1 +- does not support autodiff with respect to the `beta` parameter + for Fourier-space evaluations """, ) @register_pytree_node_class @@ -333,8 +335,10 @@ def _kValue_untrunc_interp_coeffs(self): k_min = 0 k_max = self._maxk k = jnp.linspace(k_min, k_max, n_pts) + + beta = jax.lax.stop_gradient(self.beta) vals = self._kValue_untrunc_func( - self.beta, + beta, k, self._knorm_bis, self._knorm, @@ -343,9 +347,7 @@ def _kValue_untrunc_interp_coeffs(self): # slope to match the interpolant onto an asymptotic expansion of kv # that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x) - aval = self._kValue_untrunc_asymp_func( - self.beta, k[-1], self._knorm_bis, self._r0 - ) + aval = self._kValue_untrunc_asymp_func(beta, k[-1], self._knorm_bis, self._r0) slp = (vals[-1] / aval - 1) * k[-1] * self._r0 return k, vals, akima_interp_coeffs(k, vals), slp @@ -355,6 +357,9 @@ def _kValue(self, kpos): """computation of the Moffat response in k-space with interpolant + expansions kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ + # we cannot compute gradients with respect to beta + beta = jax.lax.stop_gradient(self.beta) + k = safe_sqrt(kpos.x**2 + kpos.y**2) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) @@ -364,7 +369,12 @@ def _kValue(self, kpos): k_msk = jnp.where(k > 0, k, k_[1]) res = jnp.where( k > k_[-1], - self._kValue_untrunc_asymp_func(self.beta, k_msk, self._knorm_bis, self._r0) + self._kValue_untrunc_asymp_func( + beta, + k_msk, + self._knorm_bis, + self._r0, + ) * (1.0 + slp / k_msk / self._r0), res, ) diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index b0bf22a8..9ffd9e86 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -254,7 +254,10 @@ def _spergel_hlr_pade(x): \hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu} - where :math:`r_0` is the ``scale_radius``, and :math: `\nu` mandatory to be in [-0.85,4.0] + where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0] + + The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for + real-space evaluations. """, ) @register_pytree_node_class @@ -411,7 +414,7 @@ def _max_sb(self): @jax.jit def _xValue(self, pos): r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0 - res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu)) + res = jnp.where(r == 0, self._xnorm0, fz_nu(r, jax.lax.stop_gradient(self.nu))) return self._xnorm * res @jax.jit diff --git a/tests/jax/test_deriv_gsobject.py b/tests/jax/test_deriv_gsobject.py index 73dbd22e..1c24beb9 100644 --- a/tests/jax/test_deriv_gsobject.py +++ b/tests/jax/test_deriv_gsobject.py @@ -31,7 +31,7 @@ def _run(val_): **kwargs_, gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8), ) - .drawImage(nx=5, ny=5, scale=0.2, method="fft") + .drawImage(nx=5, ny=5, scale=0.2) .array[2, 2] ** 2 ) @@ -44,3 +44,29 @@ def _run(val_): atol = 1e-5 np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol) + + +def test_deriv_gsobject_spergel_nu(): + val = 2.0 + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Spergel( + val, + scale_radius=2.0, + gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8), + ) + .drawImage(nx=5, ny=5, scale=0.2) + .array[2, 2] + ** 2 + ) + + gfunc = jax.jit(jax.grad(_run)) + gval = gfunc(val) + + gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps + + atol = 1e-5 + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)