Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d7a4b56
fix: make stepk match galsim
beckermr Apr 23, 2026
8e1c000
test: remove regression test
beckermr Apr 23, 2026
11586e7
fix: fudge maxk just a bit bigger for now
beckermr Apr 23, 2026
6fd5ac6
fix: remove commented code
beckermr Apr 23, 2026
723a6fe
fix: use smallest dimension
beckermr Apr 23, 2026
22855a3
fix: lower fudge for maxk
beckermr Apr 23, 2026
dff7164
doc: get comment right
beckermr Apr 23, 2026
1d8aee0
fix: pad out maxk by some number of pixels
beckermr Apr 24, 2026
1832461
test: some test changes sicne there are more bugs here
beckermr Apr 24, 2026
ae2a60d
fix: adjust tests
beckermr Apr 24, 2026
02740b5
test: ensure jax-galsim is more conservative
beckermr Apr 24, 2026
15eac76
Update tests/jax/test_interpolatedimage_utils.py
beckermr Apr 24, 2026
26d1209
fix: address code review
beckermr Apr 24, 2026
62cebd6
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
5a4050c
style: pre-commit
beckermr Apr 24, 2026
d556039
fix: stupid copy paste bug
beckermr Apr 24, 2026
513a250
fix: clean out unused code
beckermr Apr 24, 2026
930e189
fix: clean out unused code
beckermr Apr 24, 2026
5fa99dc
test: add explicit test for behavior of force stepk and maxk
beckermr Apr 24, 2026
323c12a
Update jax_galsim/interpolatedimage.py
beckermr Apr 24, 2026
b8a319b
doc: notes on purpose of separate ii class
beckermr Apr 24, 2026
75cf5a3
Merge branch 'stepk-fix' of https://github.com/GalSim-developers/JAX-…
beckermr Apr 24, 2026
2f1bdb4
style: pre-the-commit
beckermr Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 37 additions & 53 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from jax_galsim import fits
from jax_galsim.bounds import BoundsI
from jax_galsim.core.utils import (
compute_major_minor_from_jacobian,
ensure_hashable,
implements,
)
Expand Down Expand Up @@ -175,8 +174,6 @@ def __init__(
depixelize=depixelize,
offset=offset,
gsparams=GSParams.check(gsparams),
_force_stepk=_force_stepk,
_force_maxk=_force_maxk,
hdu=hdu,
_recenter_image=_recenter_image,
)
Expand Down Expand Up @@ -222,14 +219,19 @@ def _maxk(self):
if self._jax_aux_data["_force_maxk"] > 0:
return self._jax_aux_data["_force_maxk"]
else:
return super()._maxk
# galsim uses a different way to handle the WCS effects on maxk
# for interpolated images. IDK why. - MRB
return self._original.maxk / self._original._wcs._maxScale()

@property
def _stepk(self):
if self._jax_aux_data["_force_stepk"] > 0:
return self._jax_aux_data["_force_stepk"]
else:
return super()._stepk
# galsim uses a different way to handle the WCS effects on stepk
# for interpolated images. IDK why. - MRB
# super()._stepk
return self._original.stepk / self._original._wcs._minScale()

@property
@implements(_galsim.interpolatedimage.InterpolatedImage.x_interpolant)
Expand Down Expand Up @@ -367,6 +369,16 @@ def _zeropad_image(arr, npad):

@register_pytree_node_class
class _InterpolatedImageImpl(GSObject):
"""Internal class for handling interpolated images.

An interpolated image carries an intrinsic WCS with it that can be anything
from a pixel-scale to a full Jacobian.

We use this internal class to separate the underlying image bits from the
WCS handling bits. For those, we inherit from the Transform class so that
we can reuse its methods.
"""

_cache_noise_pad = {}

_has_hard_edges = False
Expand Down Expand Up @@ -395,15 +407,12 @@ def __init__(
depixelize=False,
offset=None,
gsparams=None,
_force_stepk=0.0,
_force_maxk=0.0,
hdu=None,
_recenter_image=True,
):
# this class does a ton of munging of the inputs that I don't want to reconstruct when
# flattening and unflattening the class.
# thus I am going to make some refs here so we have it when we need it
self._workspace = {}
self._jax_children = (
image,
dict(
Expand All @@ -428,8 +437,6 @@ def __init__(
use_true_center=use_true_center,
depixelize=depixelize,
gsparams=gsparams,
_force_stepk=_force_stepk,
_force_maxk=_force_maxk,
_recenter_image=_recenter_image,
hdu=hdu,
)
Expand Down Expand Up @@ -530,13 +537,10 @@ def tree_unflatten(cls, aux_data, children):
return ret

def __getstate__(self):
d = self.__dict__.copy()
d.pop("_workspace")
return d
return self.__dict__.copy()

def __setstate__(self, d):
self.__dict__ = d
self._workspace = {}

@property
def x_interpolant(self):
Expand Down Expand Up @@ -683,31 +687,15 @@ def _kim(self):

@property
def _maxk(self):
if self._jax_aux_data["_force_maxk"]:
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
return self._jax_aux_data["_force_maxk"] * minor
else:
return self._getMaxK(self._jax_aux_data["calculate_maxk"])
return self._getMaxK(self._jax_aux_data["calculate_maxk"])

@property
def _stepk(self):
if self._jax_aux_data["_force_stepk"]:
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
return self._jax_aux_data["_force_stepk"] * minor
else:
return self._getStepK(self._jax_aux_data["calculate_stepk"])
return self._getStepK(self._jax_aux_data["calculate_stepk"])

def _getStepK(self, calculate_stepk):
# GalSim cannot automatically know what stepK and maxK are appropriate for the
# input image. So it is usually worth it to do a manual calculation (below).
#
# However, there is also a hidden option to force it to use specific values of stepK and
# maxK (caveat user!). The values of _force_stepk and _force_maxk should be provided in
# terms of physical scale, e.g., for images that have a scale length of 0.1 arcsec, the
# stepK and maxK should be provided in units of 1/arcsec. Then we convert to the 1/pixel
# units required by the C++ layer below. Also note that profile recentering for even-sized
# images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly
# below what is provided here, while maxK is preserved.
if calculate_stepk:
if calculate_stepk is True:
im = self.image
Expand Down Expand Up @@ -1173,9 +1161,9 @@ def _flux_frac(a, x, y, cenx, ceny):
dx = jnp.reshape(dx, (a.shape[0], a.shape[1], 1))
dy = y - ceny
dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1))
d = jnp.arange(a.shape[0])
d = jnp.arange(min(a.shape[0], a.shape[1]))
d = jnp.reshape(d, (1, 1, -1))
msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d)
msk = (jnp.abs(dx) <= d) & (jnp.abs(dy) <= d)
Comment thread
beckermr marked this conversation as resolved.
res = jnp.sum(
jnp.where(
msk,
Expand All @@ -1184,7 +1172,6 @@ def _flux_frac(a, x, y, cenx, ceny):
),
axis=(0, 1),
)
res = jnp.where(res > 0, res, -jnp.inf)
return res


Expand All @@ -1193,28 +1180,20 @@ def _calculate_size_containing_flux(image, thresh):
cenx, ceny = image.center.x, image.center.y
x, y = image.get_pixel_centers()
fluxes = _flux_frac(image.array, x, y, cenx, ceny)
msk = fluxes >= -jnp.inf
fluxes = jnp.where(msk, fluxes, jnp.max(fluxes))
d = jnp.arange(image.array.shape[0]) + 1.0
# below we use a linear interpolation table to find the maximum size
# in pixels that contains a given flux (called thresh here)
# expfac controls how much we oversample the interpolation table
# in order to return a more accurate result
# we have it hard coded at 4 to compromise between speed and accuracy
expfac = 4.0
dint = jnp.arange(image.array.shape[0] * expfac) / expfac + 1.0
fluxes = jnp.interp(dint, d, fluxes)
msk = fluxes <= thresh
# we add 1 since the flux fraction computation above starts at
# one pixel and jnp.arange starts at zero
d = jnp.arange(min(image.array.shape[0], image.array.shape[1])) + 1.0
Comment thread
beckermr marked this conversation as resolved.
p = jnp.sign(thresh)
msk = (p * fluxes) >= (p * thresh)
return (
jnp.argmax(
jnp.argmin(
jnp.where(
msk,
dint,
-jnp.inf,
d,
jnp.inf,
)
)
/ expfac
+ 1.0
+ 0.5
)


Expand Down Expand Up @@ -1247,6 +1226,11 @@ def _find_maxk(kim, max_maxk, thresh):
# maxk from the image (computed by _inner_comp_find_maxk)
# by max_maxk from above
return jnp.minimum(
_inner_comp_find_maxk(kim.array, thresh, kx, ky),
# jax-galsim tends to be less conservative for maxk
# since compared to galsim, it does NOT require 5 rows
# of pixels in a row below the threshold.
# thus we add pixels here to ensure the galsim tests pass.
# it turns out one worked ok so that is what we did. - MRB
_inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale,
max_maxk,
)
Loading