Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ def _Knu(nu, x):
@implements(
_galsim.Moffat,
lax_description="""\
The LAX version of the Moffat profile
The JAX-GalSim version of the Moffat profile

- does not support truncation or beta < 1.1
- does not support gsparams.maxk_thresholds > 0.1
- does not support autodiff with respect to the `beta` parameter
for Fourier-space evaluations
""",
)
@register_pytree_node_class
Expand Down Expand Up @@ -333,8 +335,10 @@ def _kValue_untrunc_interp_coeffs(self):
k_min = 0
k_max = self._maxk
k = jnp.linspace(k_min, k_max, n_pts)

beta = jax.lax.stop_gradient(self.beta)
vals = self._kValue_untrunc_func(
self.beta,
beta,
k,
self._knorm_bis,
self._knorm,
Expand All @@ -343,9 +347,7 @@ def _kValue_untrunc_interp_coeffs(self):

# slope to match the interpolant onto an asymptotic expansion of kv
# that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x)
aval = self._kValue_untrunc_asymp_func(
self.beta, k[-1], self._knorm_bis, self._r0
)
aval = self._kValue_untrunc_asymp_func(beta, k[-1], self._knorm_bis, self._r0)
slp = (vals[-1] / aval - 1) * k[-1] * self._r0

return k, vals, akima_interp_coeffs(k, vals), slp
Expand All @@ -355,6 +357,9 @@ def _kValue(self, kpos):
"""computation of the Moffat response in k-space with interpolant + expansions
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
"""
# we cannot compute gradients with respect to beta
beta = jax.lax.stop_gradient(self.beta)

k = safe_sqrt(kpos.x**2 + kpos.y**2)
out_shape = jnp.shape(k)
k = jnp.atleast_1d(k)
Expand All @@ -364,7 +369,12 @@ def _kValue(self, kpos):
k_msk = jnp.where(k > 0, k, k_[1])
res = jnp.where(
k > k_[-1],
self._kValue_untrunc_asymp_func(self.beta, k_msk, self._knorm_bis, self._r0)
self._kValue_untrunc_asymp_func(
beta,
k_msk,
self._knorm_bis,
self._r0,
)
* (1.0 + slp / k_msk / self._r0),
res,
)
Expand Down
7 changes: 5 additions & 2 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def _spergel_hlr_pade(x):

\hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu}

where :math:`r_0` is the ``scale_radius``, and :math: `\nu` mandatory to be in [-0.85,4.0]
where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0]

The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for
real-space evaluations.
""",
)
@register_pytree_node_class
Expand Down Expand Up @@ -411,7 +414,7 @@ def _max_sb(self):
@jax.jit
def _xValue(self, pos):
r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0
res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu))
res = jnp.where(r == 0, self._xnorm0, fz_nu(r, jax.lax.stop_gradient(self.nu)))
return self._xnorm * res

@jax.jit
Expand Down
28 changes: 27 additions & 1 deletion tests/jax/test_deriv_gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _run(val_):
**kwargs_,
gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8),
)
.drawImage(nx=5, ny=5, scale=0.2, method="fft")
.drawImage(nx=5, ny=5, scale=0.2)
.array[2, 2]
** 2
)
Expand All @@ -44,3 +44,29 @@ def _run(val_):
atol = 1e-5

np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)


def test_deriv_gsobject_spergel_nu():
val = 2.0
eps = 1e-5

def _run(val_):
return jnp.max(
jgs.Spergel(
val,
scale_radius=2.0,
gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8),
)
.drawImage(nx=5, ny=5, scale=0.2)
.array[2, 2]
** 2
)

gfunc = jax.jit(jax.grad(_run))
gval = gfunc(val)

gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps

atol = 1e-5

np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)