Skip to content

Commit 89b9d56

Browse files
authored
fix: use adaptive precision for moffat computation (#181)
* fix: use adaptive precision for moffat computation * feat: add pytest mark and option to run specific float32 tests * restart tests
1 parent 9fb8019 commit 89b9d56

4 files changed

Lines changed: 70 additions & 5 deletions

File tree

.github/workflows/python_package.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ jobs:
7474
--clean-durations \
7575
--retries 1
7676
77+
- name: Test with pytest in float32
78+
if: ${{ matrix.group == '1' }}
79+
run: |
80+
pytest \
81+
-vv \
82+
--test-in-float32
83+
7784
- name: Upload test durations
7885
uses: actions/upload-artifact@v6
7986
with:

jax_galsim/bessel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def si(x):
174174

175175
def _sqrt1px2(x):
176176
"""Numerically stable computation of sqrt(1 + x^2)."""
177-
eps = jnp.finfo(jnp.float64).eps
177+
eps = jnp.finfo(x.dtype).eps
178178
return jnp.where(
179179
jnp.abs(x) * jnp.sqrt(eps) <= 1.0,
180180
jnp.exp(0.5 * jnp.log1p(x * x)),
@@ -244,7 +244,7 @@ def _temme_series_kve(v, z):
244244
Assumes |v| < 0.5 and |z| <= 2 for fast convergence.
245245
Returns exponentially scaled values: Kv(v,z)*exp(z).
246246
"""
247-
tol = jnp.finfo(jnp.float64).eps
247+
tol = jnp.finfo(z.dtype).eps
248248

249249
coeff1, coeff2, gamma1pv_inv, gamma1mv_inv = _evaluate_temme_coeffs(v)
250250

@@ -308,7 +308,7 @@ def _continued_fraction_kve(v, z):
308308
Assumes |v| < 0.5 and |z| > 2.
309309
Returns exponentially scaled values: Kv(v,z)*exp(z).
310310
"""
311-
tol = jnp.finfo(jnp.float64).eps
311+
tol = jnp.finfo(z.dtype).eps
312312
max_iterations = 1000
313313

314314
initial_numerator = v * v - 0.25

tests/conftest.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Define the accuracy for running the tests
2+
import sys
3+
24
import jax
35

4-
jax.config.update("jax_enable_x64", True)
6+
if "--test-in-float32" not in sys.argv:
7+
jax.config.update("jax_enable_x64", True)
58

69
import inspect # noqa: E402
710
import os # noqa: E402
8-
import sys # noqa: E402
911
from functools import lru_cache, partial # noqa: E402
1012
from unittest.mock import patch # noqa: E402
1113

@@ -216,3 +218,29 @@ def pytest_runtest_logreport(report):
216218
report.outcome = "allowed failure"
217219

218220
yield report
221+
222+
223+
def pytest_addoption(parser):
224+
parser.addoption(
225+
"--test-in-float32",
226+
action="store_true",
227+
default=False,
228+
help="Run tests in float32 instead of float64.",
229+
)
230+
231+
232+
def pytest_configure(config):
233+
config.addinivalue_line(
234+
"markers", "test_in_float32: mark test to run in float32 tests."
235+
)
236+
237+
238+
def pytest_report_header(config):
239+
return [f"JAX float64 enabled: {jax.config.read('jax_enable_x64')}"]
240+
241+
242+
def pytest_runtest_setup(item):
243+
if "test_in_float32" not in item.keywords and item.config.getoption(
244+
"--test-in-float32"
245+
):
246+
pytest.skip("Skipped test not marked to be run in float32.")

tests/jax/test_moffat_comp_galsim.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import galsim as _galsim
2+
import jax.numpy as jnp
23
import numpy as np
4+
import pytest
35

46
import jax_galsim as galsim
57

@@ -54,3 +56,31 @@ def test_moffat_comp_galsim_maxk():
5456
psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5
5557
)
5658
np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0)
59+
60+
61+
@pytest.mark.test_in_float32
62+
def test_moffat_conv_nan_float32():
63+
# test case from https://github.com/GalSim-developers/JAX-GalSim/issues/179
64+
gal_flux = 1.0e5 # counts
65+
gal_r0 = 2.7 # arcsec
66+
g1 = 0.1 #
67+
g2 = 0.2 #
68+
psf_beta = 5 #
69+
psf_re = 1.0 # arcsec
70+
pixel_scale = 0.2 # arcsec / pixel
71+
72+
# Define the galaxy profile.
73+
gal = galsim.Exponential(flux=gal_flux, scale_radius=gal_r0)
74+
75+
# Shear the galaxy by some value.
76+
gal = gal.shear(g1=g1, g2=g2)
77+
78+
# Define the PSF profile.
79+
psf = galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re)
80+
81+
# Final profile is the convolution of these.
82+
final = galsim.Convolve([gal, psf])
83+
84+
img_arr = final.drawImage(scale=pixel_scale, dtype=jnp.float32).array
85+
86+
assert jnp.all(jnp.isfinite(img_arr))

0 commit comments

Comments
 (0)