Skip to content

Commit 795f004

Browse files
authored
fix: nan in derivs wrt params for Moffat (#194)
* fix: nan in derivs wrt params for Moffat * Apply suggestion from @beckermr * Apply suggestion from @beckermr * Apply suggestion from @beckermr * fix: need safe_sqrt import * fix: put this back since we no longer support trunc for moffat * Apply suggestion from @beckermr * Apply suggestion from @beckermr
1 parent b47e400 commit 795f004

4 files changed

Lines changed: 67 additions & 5 deletions

File tree

jax_galsim/bessel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,9 @@ def _temme_series_kve(v, z):
251251
z_sq = z * z
252252
logzo2 = jnp.log(z / 2.0)
253253
mu = -v * logzo2
254-
sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v))
255-
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu)
254+
sinc_v = jnp.sinc(v)
255+
mu_msk = jnp.where(mu == 0.0, 1.0, mu)
256+
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_msk) / mu_msk)
256257

257258
initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v
258259
initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv
@@ -711,6 +712,7 @@ def t3(x): # x>8
711712
return factor * (rc * (cx + sx) - y * rs * (sx - cx))
712713

713714
x = jnp.abs(x)
715+
x_ = jnp.where(x != 0, x, 1.0)
714716
return jnp.select(
715-
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x), t2(x), t3(x)], default=x
717+
[x == 0, x <= 4, x <= 8, x > 8], [1, t1(x_), t2(x_), t3(x_)], default=x
716718
).reshape(orig_shape)

jax_galsim/core/math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
@jax.jit
6+
def safe_sqrt(x):
7+
"""Numerically safe sqrt operation with zero derivative at zero."""
8+
msk = x > 0
9+
x_msk = jnp.where(msk, x, 1.0)
10+
return jnp.where(msk, jnp.sqrt(x_msk), 0.0)

jax_galsim/moffat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from jax_galsim.bessel import kv
1010
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
11+
from jax_galsim.core.math import safe_sqrt
1112
from jax_galsim.core.utils import (
1213
ensure_hashable,
1314
has_tracers,
@@ -303,9 +304,12 @@ def _xValue(self, pos):
303304

304305
def _kValue_untrunc(self, k):
305306
"""Non truncated version of _kValue"""
307+
k_msk = jnp.where(k > 0, k, 1.0)
306308
return jnp.where(
307309
k > 0,
308-
self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k),
310+
self._knorm_bis
311+
* jnp.power(k_msk, self.beta - 1.0)
312+
* _Knu(self.beta - 1.0, k_msk),
309313
self._knorm,
310314
)
311315

@@ -314,7 +318,7 @@ def _kValue(self, kpos):
314318
"""computation of the Moffat response in k-space with switch of truncated/untracated case
315319
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
316320
"""
317-
k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq)
321+
k = safe_sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq)
318322
out_shape = jnp.shape(k)
319323
k = jnp.atleast_1d(k)
320324
res = self._kValue_untrunc(k)

tests/jax/test_deriv_gsobject.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import numpy as np
4+
import pytest
5+
6+
import jax_galsim as jgs
7+
8+
9+
@pytest.mark.parametrize(
10+
"gsobj,params,args, kwargs",
11+
[
12+
(jgs.Spergel, ["scale_radius", "half_light_radius"], [1.0], {}),
13+
(jgs.Exponential, ["scale_radius", "half_light_radius"], [], {}),
14+
(jgs.Gaussian, ["sigma", "fwhm", "half_light_radius"], [], {}),
15+
(jgs.Moffat, ["scale_radius", "half_light_radius", "fwhm"], [2.0], {}),
16+
],
17+
)
18+
def test_deriv_gsobject_radii(params, gsobj, args, kwargs):
19+
val = 2.0
20+
eps = 1e-5
21+
22+
for param in params:
23+
print("\nparam:", param, flush=True)
24+
25+
def _run(val_):
26+
kwargs_ = {param: val_}
27+
kwargs_.update(kwargs)
28+
return jnp.max(
29+
gsobj(
30+
*args,
31+
**kwargs_,
32+
gsparams=jgs.GSParams(minimum_fft_size=8, maximum_fft_size=8),
33+
)
34+
.drawImage(nx=5, ny=5, scale=0.2, method="fft")
35+
.array[2, 2]
36+
** 2
37+
)
38+
39+
gfunc = jax.jit(jax.grad(_run))
40+
gval = gfunc(val)
41+
42+
gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps
43+
44+
atol = 1e-5
45+
46+
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)

0 commit comments

Comments
 (0)