diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index a3edbed5..d652e8a9 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -19,7 +19,6 @@ from jax_galsim import fits from jax_galsim.bounds import BoundsI from jax_galsim.core.utils import ( - compute_major_minor_from_jacobian, ensure_hashable, implements, ) @@ -175,8 +174,6 @@ def __init__( depixelize=depixelize, offset=offset, gsparams=GSParams.check(gsparams), - _force_stepk=_force_stepk, - _force_maxk=_force_maxk, hdu=hdu, _recenter_image=_recenter_image, ) @@ -222,14 +219,19 @@ def _maxk(self): if self._jax_aux_data["_force_maxk"] > 0: return self._jax_aux_data["_force_maxk"] else: - return super()._maxk + # galsim uses a different way to handle the WCS effects on maxk + # for interpolated images. IDK why. - MRB + return self._original.maxk / self._original._wcs._maxScale() @property def _stepk(self): if self._jax_aux_data["_force_stepk"] > 0: return self._jax_aux_data["_force_stepk"] else: - return super()._stepk + # galsim uses a different way to handle the WCS effects on stepk + # for interpolated images. IDK why. - MRB + # super()._stepk + return self._original.stepk / self._original._wcs._minScale() @property @implements(_galsim.interpolatedimage.InterpolatedImage.x_interpolant) @@ -367,6 +369,16 @@ def _zeropad_image(arr, npad): @register_pytree_node_class class _InterpolatedImageImpl(GSObject): + """Internal class for handling interpolated images. + + An interpolated image carries an intrinsic WCS with it that can be anything + from a pixel-scale to a full Jacobian. + + We use this internal class to separate the underlying image bits from the + WCS handling bits. For those, we inherit from the Transform class so that + we can reuse its methods. + """ + _cache_noise_pad = {} _has_hard_edges = False @@ -395,15 +407,12 @@ def __init__( depixelize=False, offset=None, gsparams=None, - _force_stepk=0.0, - _force_maxk=0.0, hdu=None, _recenter_image=True, ): # this class does a ton of munging of the inputs that I don't want to reconstruct when # flattening and unflattening the class. # thus I am going to make some refs here so we have it when we need it - self._workspace = {} self._jax_children = ( image, dict( @@ -428,8 +437,6 @@ def __init__( use_true_center=use_true_center, depixelize=depixelize, gsparams=gsparams, - _force_stepk=_force_stepk, - _force_maxk=_force_maxk, _recenter_image=_recenter_image, hdu=hdu, ) @@ -530,13 +537,10 @@ def tree_unflatten(cls, aux_data, children): return ret def __getstate__(self): - d = self.__dict__.copy() - d.pop("_workspace") - return d + return self.__dict__.copy() def __setstate__(self, d): self.__dict__ = d - self._workspace = {} @property def x_interpolant(self): @@ -683,31 +687,15 @@ def _kim(self): @property def _maxk(self): - if self._jax_aux_data["_force_maxk"]: - _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) - return self._jax_aux_data["_force_maxk"] * minor - else: - return self._getMaxK(self._jax_aux_data["calculate_maxk"]) + return self._getMaxK(self._jax_aux_data["calculate_maxk"]) @property def _stepk(self): - if self._jax_aux_data["_force_stepk"]: - _, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2))) - return self._jax_aux_data["_force_stepk"] * minor - else: - return self._getStepK(self._jax_aux_data["calculate_stepk"]) + return self._getStepK(self._jax_aux_data["calculate_stepk"]) def _getStepK(self, calculate_stepk): # GalSim cannot automatically know what stepK and maxK are appropriate for the # input image. So it is usually worth it to do a manual calculation (below). - # - # However, there is also a hidden option to force it to use specific values of stepK and - # maxK (caveat user!). The values of _force_stepk and _force_maxk should be provided in - # terms of physical scale, e.g., for images that have a scale length of 0.1 arcsec, the - # stepK and maxK should be provided in units of 1/arcsec. Then we convert to the 1/pixel - # units required by the C++ layer below. Also note that profile recentering for even-sized - # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly - # below what is provided here, while maxK is preserved. if calculate_stepk: if calculate_stepk is True: im = self.image @@ -1173,9 +1161,9 @@ def _flux_frac(a, x, y, cenx, ceny): dx = jnp.reshape(dx, (a.shape[0], a.shape[1], 1)) dy = y - ceny dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1)) - d = jnp.arange(a.shape[0]) + d = jnp.arange(min(a.shape[0], a.shape[1])) d = jnp.reshape(d, (1, 1, -1)) - msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) + msk = (jnp.abs(dx) <= d) & (jnp.abs(dy) <= d) res = jnp.sum( jnp.where( msk, @@ -1184,7 +1172,6 @@ def _flux_frac(a, x, y, cenx, ceny): ), axis=(0, 1), ) - res = jnp.where(res > 0, res, -jnp.inf) return res @@ -1193,28 +1180,20 @@ def _calculate_size_containing_flux(image, thresh): cenx, ceny = image.center.x, image.center.y x, y = image.get_pixel_centers() fluxes = _flux_frac(image.array, x, y, cenx, ceny) - msk = fluxes >= -jnp.inf - fluxes = jnp.where(msk, fluxes, jnp.max(fluxes)) - d = jnp.arange(image.array.shape[0]) + 1.0 - # below we use a linear interpolation table to find the maximum size - # in pixels that contains a given flux (called thresh here) - # expfac controls how much we oversample the interpolation table - # in order to return a more accurate result - # we have it hard coded at 4 to compromise between speed and accuracy - expfac = 4.0 - dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0 - fluxes = jnp.interp(dint, d, fluxes) - msk = fluxes <= thresh + # we add 1 since the flux fraction computation above starts at + # one pixel and jnp.arange starts at zero + d = jnp.arange(min(image.array.shape[0], image.array.shape[1])) + 1.0 + p = jnp.sign(thresh) + msk = (p * fluxes) >= (p * thresh) return ( - jnp.argmax( + jnp.argmin( jnp.where( msk, - dint, - -jnp.inf, + d, + jnp.inf, ) ) - / expfac - + 1.0 + + 0.5 ) @@ -1247,6 +1226,11 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - _inner_comp_find_maxk(kim.array, thresh, kx, ky), + # jax-galsim tends to be less conservative for maxk + # since compared to galsim, it does NOT require 5 rows + # of pixels in a row below the threshold. + # thus we add pixels here to ensure the galsim tests pass. + # it turns out one worked ok so that is what we did. - MRB + _inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale, max_maxk, ) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index d349cd88..97f07663 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -16,9 +16,10 @@ from jax_galsim.interpolatedimage import ( _draw_with_interpolant_kval, _draw_with_interpolant_xval, - _flux_frac, ) +FRAC_TEST_TO_KEEP = 0.5 + @pytest.mark.parametrize( "interp", @@ -138,9 +139,8 @@ def test_interpolatedimage_utils_stepk_maxk(): gii = _galsim.InterpolatedImage(gimage_in, scale=scale) jgii = jax_galsim.InterpolatedImage(jgimage_in, scale=scale) - rtol = 1e-1 - np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=rtol, atol=0) - np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=rtol, atol=0) + np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) @pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) @@ -208,7 +208,7 @@ def test_interpolatedimage_utils_comp_to_galsim( ) rng = np.random.RandomState(seed=seed) - if rng.uniform() < 0.75: + if rng.uniform() < FRAC_TEST_TO_KEEP: pytest.skip( "Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time." ) @@ -233,8 +233,10 @@ def test_interpolatedimage_utils_comp_to_galsim( x_interpolant=x_interp, ) - np.testing.assert_allclose(gii.stepk, jgii.stepk, rtol=0.5, atol=0) - np.testing.assert_allclose(gii.maxk, jgii.maxk, rtol=0.5, atol=0) + np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) + # FIXME: match maxk + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) + assert jgii.maxk >= gii.maxk kxvals = [ (0, 0), (-5, -5), @@ -253,8 +255,8 @@ def test_interpolatedimage_utils_comp_to_galsim( if method == "kValue": dk = jgii._original._kim.scale * rng.uniform(low=0.5, high=1.5) np.testing.assert_allclose( - gii.kValue(x * dk, y * dk), jgii.kValue(x * dk, y * dk), + gii.kValue(x * dk, y * dk), err_msg=f"kValue mismatch: wcs={wcs}, x={x}, y={y}", ) else: @@ -262,8 +264,8 @@ def test_interpolatedimage_utils_comp_to_galsim( low=0.5, high=1.5 ) np.testing.assert_allclose( - gii.xValue(x * dx, y * dx), jgii.xValue(x * dx, y * dx), + gii.xValue(x * dx, y * dx), err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", ) @@ -311,13 +313,13 @@ def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): rng = np.random.RandomState(42) arr = rng.normal(size=(n, n)) - gim = jax_galsim.Image(arr, scale=1) + gim = _galsim.Image(arr, scale=1) gkim = gim.calculate_fft() gxkim = gkim.calculate_inverse_fft() - np.testing.assert_allclose(gim.array, gxkim[gim.bounds].array) - np.testing.assert_allclose(gim.array, im.array) - np.testing.assert_allclose(gkim.array, kim.array) - np.testing.assert_allclose(gxkim.array, xkim.array) + np.testing.assert_allclose(gxkim[gim.bounds].array, gim.array, rtol=0, atol=1e-12) + np.testing.assert_allclose(im.array, gim.array, rtol=0, atol=1e-12) + np.testing.assert_allclose(kim.array, gkim.array, rtol=0, atol=1e-12) + np.testing.assert_allclose(xkim.array, gxkim.array, rtol=0, atol=1e-12) @pytest.mark.parametrize( @@ -359,74 +361,271 @@ def test_interpolatedimage_interpolant_sample(interp): np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") -def test_interpolatedimage_flux_frac(): - obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) - img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel") - true_val = [ - 0.02186161, - 0.06551123, - 0.10894079, - 0.15200604, - 0.19456641, - 0.23648629, - 0.27763629, - 0.31789470, - 0.35714823, - 0.39529300, - 0.43223542, - 0.46789303, - 0.50219434, - 0.53507960, - 0.56650090, - 0.59642231, - 0.62481892, - 0.65167749, - 0.67699528, - 0.70077991, - 0.72304893, - 0.74382806, - 0.76315117, - 0.78105938, - 0.79759991, - 0.81282544, - 0.82679272, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - 0.83956224, - ] +@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) +@pytest.mark.parametrize("normalization", ["sb", "flux"]) +@pytest.mark.parametrize("use_true_center", [True, False]) +@pytest.mark.parametrize( + "wcs", + [ + _galsim.PixelScale(0.2), + _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), + _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), + ], +) +@pytest.mark.parametrize( + "offset_x", + [ + -4.35, + -0.45, + 0.0, + 0.67, + 3.78, + ], +) +@pytest.mark.parametrize( + "offset_y", + [ + -2.12, + -0.33, + 0.0, + 0.12, + 1.45, + ], +) +@pytest.mark.parametrize( + "ref_array", + [ + _galsim.Gaussian(fwhm=0.9) + .shear(g1=0.3, g2=-0.2) + .drawImage(nx=33, ny=33, scale=0.2) + .array, + _galsim.Gaussian(fwhm=0.9) + .shear(g1=-0.03, g2=0.1) + .drawImage(nx=32, ny=32, scale=0.2) + .array, + ], +) +def test_interpolatedimage_utils_comp_stepk_maxk_to_galsim( + ref_array, + offset_x, + offset_y, + wcs, + use_true_center, + normalization, + x_interp, +): + seed = max( + abs( + int( + hashlib.sha1( + f"{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( + "utf-8" + ) + ).hexdigest(), + 16, + ) + ) + % (10**7), + 1, + ) - x, y = img.get_pixel_centers() - cenx = img.center.x - ceny = img.center.y - val = _flux_frac(img.array, x, y, cenx, ceny) - np.testing.assert_allclose( - val, - true_val, - rtol=0, - atol=1e-6, + rng = np.random.RandomState(seed=seed) + if rng.uniform() < FRAC_TEST_TO_KEEP: + pytest.skip( + "Skipping `test_interpolatedimage_utils_comp_stepk_maxk_to_galsim` case at random to save time." + ) + + nse = rng.uniform(size=ref_array.shape) * ref_array.max() * 0.05 + + gimage_in = _galsim.Image(ref_array + nse, scale=0.2) + jgimage_in = jax_galsim.Image(ref_array + nse, scale=0.2) + + np.testing.assert_allclose(gimage_in.center.x, jgimage_in.center.x) + np.testing.assert_allclose(gimage_in.center.y, jgimage_in.center.y) + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + flux=20, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + flux=20, ) + + gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux + gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh) + + from jax_galsim.interpolatedimage import _calculate_size_containing_flux + + jgthresh = ( + 1.0 - jgii._original.gsparams.folding_threshold + ) * jgii._original._image_flux + jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh) + + lgR = _galsim_stepk_loop(gii._image, gthresh) + ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh) + + np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x) + np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y) + np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0)) + np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum()) + np.testing.assert_allclose(jgthresh, gthresh, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(ljgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) + + np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) + # FIXME: make maxk match + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) + + +# this is a copy of the galsim C++ algorithm in a pure python +# loop to help with debugging and testing +def _galsim_stepk_loop(im, target_flux): + if target_flux > 0: + p = 1.0 + else: + p = -1.0 + + b = im.bounds + dmax = int(min((b.getXMax() - b.getXMin()) / 2, (b.getYMax() - b.getYMin()) / 2)) + + flux = im(0, 0) + d = 1 + while d <= dmax: + # Add the left, right, top and bottom sides of box + for x in range(-d, d): + # Note: All 4 corners are added exactly once by including x=-d but omitting + # x=d from the loop. + flux += im(x, -d) # bottom + flux += im(d, x) # right + flux += im(-x, d) # top + flux += im(-d, -x) # left + + if p * flux >= p * target_flux: + break + + d += 1 + + return d + 0.5 + + +@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) +@pytest.mark.parametrize("normalization", ["sb", "flux"]) +@pytest.mark.parametrize("use_true_center", [True, False]) +@pytest.mark.parametrize( + "wcs", + [ + _galsim.PixelScale(0.2), + _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), + _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), + ], +) +@pytest.mark.parametrize( + "offset_x", + [ + -4.35, + -0.45, + 0.0, + 0.67, + 3.78, + ], +) +@pytest.mark.parametrize( + "offset_y", + [ + -2.12, + -0.33, + 0.0, + 0.12, + 1.45, + ], +) +@pytest.mark.parametrize( + "ref_array", + [ + _galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array, + _galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array, + ], +) +@pytest.mark.parametrize("method", ["kValue", "xValue"]) +def test_interpolatedimage_utils_force_stepk_maxk( + method, + ref_array, + offset_x, + offset_y, + wcs, + use_true_center, + normalization, + x_interp, +): + seed = max( + abs( + int( + hashlib.sha1( + f"{method}{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( + "utf-8" + ) + ).hexdigest(), + 16, + ) + ) + % (10**7), + 1, + ) + + rng = np.random.RandomState(seed=seed) + if rng.uniform() < FRAC_TEST_TO_KEEP: + pytest.skip( + "Skipping `test_interpolatedimage_utils_force_stepk_maxk` case at random to save time." + ) + + gimage_in = _galsim.Image(ref_array, scale=0.2) + jgimage_in = jax_galsim.Image(ref_array, scale=0.2) + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + ) + maxk = gii.maxk * 1.04 + stepk = gii.stepk / 1.04 + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + + np.testing.assert_allclose(gii.maxk, maxk) + np.testing.assert_allclose(gii.stepk, stepk) + np.testing.assert_allclose(jgii.maxk, maxk) + np.testing.assert_allclose(jgii.stepk, stepk)