Skip to content

Commit 3c2760b

Browse files
authored
feat: add stop gradient calls for certain cases for beta/nu in Moffat and Spergel plus docs (#199)
* feat: add stop gradient calls for certain cases * Apply suggestions from code review
1 parent 8842c01 commit 3c2760b

3 files changed

Lines changed: 48 additions & 9 deletions

File tree

jax_galsim/moffat.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def _Knu(nu, x):
2929
@implements(
3030
_galsim.Moffat,
3131
lax_description="""\
32-
The LAX version of the Moffat profile
32+
The JAX-GalSim version of the Moffat profile
3333
3434
- does not support truncation or beta < 1.1
3535
- does not support gsparams.maxk_thresholds > 0.1
36+
- does not support autodiff with respect to the `beta` parameter
37+
for Fourier-space evaluations
3638
""",
3739
)
3840
@register_pytree_node_class
@@ -333,8 +335,10 @@ def _kValue_untrunc_interp_coeffs(self):
333335
k_min = 0
334336
k_max = self._maxk
335337
k = jnp.linspace(k_min, k_max, n_pts)
338+
339+
beta = jax.lax.stop_gradient(self.beta)
336340
vals = self._kValue_untrunc_func(
337-
self.beta,
341+
beta,
338342
k,
339343
self._knorm_bis,
340344
self._knorm,
@@ -343,9 +347,7 @@ def _kValue_untrunc_interp_coeffs(self):
343347

344348
# slope to match the interpolant onto an asymptotic expansion of kv
345349
# that is kv(x) ~ sqrt(pi/2/x) * exp(-x) * (1 + slp/x)
346-
aval = self._kValue_untrunc_asymp_func(
347-
self.beta, k[-1], self._knorm_bis, self._r0
348-
)
350+
aval = self._kValue_untrunc_asymp_func(beta, k[-1], self._knorm_bis, self._r0)
349351
slp = (vals[-1] / aval - 1) * k[-1] * self._r0
350352

351353
return k, vals, akima_interp_coeffs(k, vals), slp
@@ -355,6 +357,9 @@ def _kValue(self, kpos):
355357
"""computation of the Moffat response in k-space with interpolant + expansions
356358
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
357359
"""
360+
# we cannot compute gradients with respect to beta
361+
beta = jax.lax.stop_gradient(self.beta)
362+
358363
k = safe_sqrt(kpos.x**2 + kpos.y**2)
359364
out_shape = jnp.shape(k)
360365
k = jnp.atleast_1d(k)
@@ -364,7 +369,12 @@ def _kValue(self, kpos):
364369
k_msk = jnp.where(k > 0, k, k_[1])
365370
res = jnp.where(
366371
k > k_[-1],
367-
self._kValue_untrunc_asymp_func(self.beta, k_msk, self._knorm_bis, self._r0)
372+
self._kValue_untrunc_asymp_func(
373+
beta,
374+
k_msk,
375+
self._knorm_bis,
376+
self._r0,
377+
)
368378
* (1.0 + slp / k_msk / self._r0),
369379
res,
370380
)

jax_galsim/spergel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,10 @@ def _spergel_hlr_pade(x):
254254
255255
\hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu}
256256
257-
where :math:`r_0` is the ``scale_radius``, and :math: `\nu` mandatory to be in [-0.85,4.0]
257+
where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0]
258+
259+
The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for
260+
real-space evaluations.
258261
""",
259262
)
260263
@register_pytree_node_class
@@ -411,7 +414,7 @@ def _max_sb(self):
411414
@jax.jit
412415
def _xValue(self, pos):
413416
r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0
414-
res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu))
417+
res = jnp.where(r == 0, self._xnorm0, fz_nu(r, jax.lax.stop_gradient(self.nu)))
415418
return self._xnorm * res
416419

417420
@jax.jit

tests/jax/test_deriv_gsobject.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _run(val_):
3131
**kwargs_,
3232
gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8),
3333
)
34-
.drawImage(nx=5, ny=5, scale=0.2, method="fft")
34+
.drawImage(nx=5, ny=5, scale=0.2)
3535
.array[2, 2]
3636
** 2
3737
)
@@ -44,3 +44,29 @@ def _run(val_):
4444
atol = 1e-5
4545

4646
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)
47+
48+
49+
def test_deriv_gsobject_spergel_nu():
50+
val = 2.0
51+
eps = 1e-5
52+
53+
def _run(val_):
54+
return jnp.max(
55+
jgs.Spergel(
56+
val,
57+
scale_radius=2.0,
58+
gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8),
59+
)
60+
.drawImage(nx=5, ny=5, scale=0.2)
61+
.array[2, 2]
62+
** 2
63+
)
64+
65+
gfunc = jax.jit(jax.grad(_run))
66+
gval = gfunc(val)
67+
68+
gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps
69+
70+
atol = 1e-5
71+
72+
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)

0 commit comments

Comments
 (0)