Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
68c6938
fix: redo bounds again for dynamic usages
beckermr May 20, 2026
9d9dc80
style: please the dog
beckermr May 20, 2026
956f778
fix: only apply ~ to array bool
beckermr May 20, 2026
03c7869
fix: try dynamic children
beckermr May 20, 2026
231fdb2
fix: ensure bounds iuncludes tests are done properly
beckermr May 20, 2026
97b8823
fix: cast in a different way
beckermr May 20, 2026
997dadf
fix: finish dynamic bounds impl
beckermr May 20, 2026
0774b6a
fix: be sure to test everything
beckermr May 20, 2026
80f7327
fix: get bools right
beckermr May 20, 2026
9012a3e
fix: wrong tracing branch; remove test
beckermr May 20, 2026
1abee57
refactor: only need one of these
beckermr May 20, 2026
6718b24
Apply suggestion from @beckermr
beckermr May 20, 2026
f8d2c00
Apply suggestion from @beckermr
beckermr May 20, 2026
2b7634f
Apply suggestion from @beckermr
beckermr May 20, 2026
e2a4fac
test: update to latest test submodule
beckermr May 20, 2026
93eb30f
fix: ensure we handle branches on bounds eq properly
beckermr May 20, 2026
e44d8df
fix: this needs to be a float
beckermr May 20, 2026
ef586a3
fix: ensure we can FFT OK
beckermr May 20, 2026
618e137
Update jax_galsim/bounds.py
beckermr May 20, 2026
01f893c
Update jax_galsim/bounds.py
beckermr May 20, 2026
9d6f2fa
fix: use latest submodule
beckermr May 20, 2026
13e0266
test: update to latest submodule
beckermr May 20, 2026
2a9c0e7
fix: use to_galsim for fpacking
beckermr May 20, 2026
ac4e000
fix: do not convert all bounds props to arrays
beckermr May 21, 2026
004ada8
fix: put back variable pytree def
beckermr May 21, 2026
7f97f11
fix: make sure to send fits headers to galsim
beckermr May 21, 2026
a5fc11e
Apply suggestion from @beckermr
beckermr May 21, 2026
a0d0ba6
Merge branch 'typing-inits-cleanup' into bounds-fix-again-all-static
beckermr May 21, 2026
996a5fb
Merge branch 'typing-inits-cleanup' into bounds-fix-again-all-static
beckermr May 21, 2026
2143616
Merge branch 'main' into bounds-fix-again-all-static
beckermr May 21, 2026
fea3aec
fix: add back isStatic method
beckermr May 21, 2026
1a8f579
test: ensure api tests are correct
beckermr May 21, 2026
053de86
style: remove extra blank space changes
beckermr May 21, 2026
51efa74
Apply suggestion from @beckermr
beckermr May 21, 2026
cb9d084
Apply suggestion from @beckermr
beckermr May 21, 2026
63fe4b3
Apply suggestion from @beckermr
beckermr May 21, 2026
d0aab7e
fix: allow python bool for static bounds eq
beckermr May 21, 2026
9f99b92
fix: start on eq using jax bool values in most cases
beckermr May 21, 2026
0485a14
Apply suggestion from @beckermr
beckermr May 21, 2026
f642c26
fix: remove extra prints
beckermr May 21, 2026
f3d8e25
Merge branch 'bounds-fix-again-all-static' of https://github.com/GalS…
beckermr May 21, 2026
7d176a7
fix: return JAX bools for rest of things
beckermr May 21, 2026
1bf4a5d
fix: bool conversion in pos comp
beckermr May 21, 2026
4703099
fix: bool conversion in pos comp
beckermr May 21, 2026
de2bb72
test: update to latest submodule
beckermr May 21, 2026
29df9c3
doc: add docs
beckermr May 21, 2026
2d53ab0
doc: be a bit more specific
beckermr May 21, 2026
1c94cf3
fix: dead code
beckermr May 21, 2026
95f0ae9
test: add test of bool eq api
beckermr May 21, 2026
0d07661
test: add test of bool eq api
beckermr May 21, 2026
dea6915
test: add tests of bounds and vmap
beckermr May 21, 2026
1671b37
test: add tests of bounds and vmap
beckermr May 21, 2026
946958c
Apply suggestion from @beckermr
beckermr May 22, 2026
860a419
test: more tests for bounds and vmap
beckermr May 22, 2026
f312d87
Apply suggestion from @beckermr
beckermr May 22, 2026
8e7280f
test: ensure we have proper logic here
beckermr May 22, 2026
a9ef07f
Apply suggestions from code review
beckermr May 22, 2026
401e540
doc: clarify
beckermr May 22, 2026
0a1ddd1
test: more tests for bnds int and vmap
beckermr May 22, 2026
aa8d587
test: add tests for type conversion
beckermr May 22, 2026
4922f21
test: more tests for bounds set
beckermr May 22, 2026
2d48748
test: fix bounds api tests
beckermr May 22, 2026
64af33a
fix: sharpen tests
beckermr May 22, 2026
a7efe2d
doc: add comment
beckermr May 22, 2026
eb2748a
fix: rename for clarity
beckermr May 22, 2026
215b412
test: add tests of includes and simpler bounds init
beckermr May 23, 2026
863a99e
test: run faster?
beckermr May 23, 2026
137edd1
test: use more splits
beckermr May 23, 2026
6a5a37d
test: fewer splits, more workers
beckermr May 23, 2026
c7348ef
test: less logging
beckermr May 23, 2026
7b35ae8
test: try four xdist only
beckermr May 23, 2026
b71db8b
Apply suggestion from @beckermr
beckermr May 23, 2026
642b4ae
Apply suggestion from @beckermr
beckermr May 23, 2026
caa1e42
fix: implement _BoundsD/I
beckermr May 23, 2026
2a1a36d
Merge branch 'bounds-fix-again-all-static' of https://github.com/GalS…
beckermr May 23, 2026
d9af535
test: just a bit faster
beckermr May 23, 2026
afe868f
fix: do not cat this giant file
beckermr May 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
615 changes: 362 additions & 253 deletions jax_galsim/bounds.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import jax.numpy as jnp
import numpy as np

STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating)


def check_is_int_then_cast(val, msg):
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
val = cast_to_float(val)

if isinstance(val, (int, float, np.integer, np.floating)):
if isinstance(val, STATIC_SCALAR_TYPES):
# for simple inputs, we can check direct in python
if val != int(val):
raise TypeError(msg)
Expand Down Expand Up @@ -43,9 +45,7 @@ def cast_numpy_array_to_native_byte_order(arr):


def _cast_to_type(x, typ, accept_strings=False):
if isinstance(x, (int, float, np.integer, np.floating)) or (
accept_strings and isinstance(x, str)
):
if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)):
return typ(x)
else:
return jnp.astype(x, typ)
Expand Down
183 changes: 92 additions & 91 deletions jax_galsim/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True):
b = kwargs.pop("bounds")
if not isinstance(b, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

Comment thread
beckermr marked this conversation as resolved.
Outdated
if check_bounds and b.isDefined():
if b.deltax != array.shape[1]:
raise _galsim.GalSimIncompatibleValuesError(
Expand Down Expand Up @@ -571,6 +572,7 @@ def resize(self, bounds, wcs=None):
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

Comment thread
beckermr marked this conversation as resolved.
Outdated
self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype)
self._bounds = bounds
if wcs is not None:
Expand All @@ -580,20 +582,24 @@ def resize(self, bounds, wcs=None):
def subImage(self, bounds):
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

Comment thread
beckermr marked this conversation as resolved.
Outdated
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access subImage of undefined image"
)
inc_val = jnp.array(self.bounds.includes(bounds))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access subImage not (fully) in image",
)

if (
self.bounds.isStatic()
and bounds.isStatic()
and not self.bounds.includes(bounds)
isinstance(self.bounds.xmin, int)
and isinstance(self.bounds.ymin, int)
and isinstance(bounds.xmin, int)
and isinstance(bounds.ymin, int)
):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)

if self.bounds.isStatic() and bounds.isStatic():
i1 = bounds.ymin - self.ymin
i2 = bounds.ymax - self.ymin + 1
j1 = bounds.xmin - self.xmin
Expand All @@ -619,18 +625,19 @@ def setSubImage(self, bounds, rhs):
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

Comment thread
beckermr marked this conversation as resolved.
Outdated
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
)
if (
self.bounds.isStatic()
and bounds.isStatic()
and not self.bounds.includes(bounds)
):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)

inc_val = jnp.array(self.bounds.includes(bounds))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access subImage not (fully) in image",
)

if not isinstance(rhs, Image):
raise TypeError("Trying to copyFrom a non-image")
if bounds.numpyShape() != rhs.bounds.numpyShape():
Expand All @@ -640,7 +647,12 @@ def setSubImage(self, bounds, rhs):
rhs=rhs,
)

if self.bounds.isStatic() and bounds.isStatic():
if (
isinstance(self.bounds.xmin, int)
and isinstance(self.bounds.ymin, int)
and isinstance(bounds.xmin, int)
and isinstance(bounds.ymin, int)
):
i1 = bounds.ymin - self.ymin
i2 = bounds.ymax - self.ymin + 1
j1 = bounds.xmin - self.xmin
Expand Down Expand Up @@ -724,7 +736,7 @@ def wrap(self, bounds, hermitian=False):

def _raise_if_nonzero(bnds, x_or_y, msg):
if x_or_y == "x":
if bnds.isStatic():
if isinstance(bnds.xmin, int):
if bnds.xmin != 0:
raise _galsim.GalSimIncompatibleValuesError(
msg,
Expand All @@ -738,7 +750,7 @@ def _raise_if_nonzero(bnds, x_or_y, msg):
msg,
)
else:
if bnds.isStatic():
if isinstance(bnds.ymin, int):
if bnds.ymin != 0:
raise _galsim.GalSimIncompatibleValuesError(
msg,
Expand Down Expand Up @@ -860,14 +872,12 @@ def calculate_fft(self):
),
)

# galsim branches here if the image has the correct bounds, but JAX can't branch
# on calls that generate different size arrays
# so we always make a new image
full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2)
if self.bounds == full_bounds:
# Then the image is already in the shape we need.
ximage = self
else:
# Then we pad out with zeros
ximage = Image(full_bounds, dtype=self.dtype, init_value=0)
ximage[self.bounds] = self[self.bounds]
ximage = Image(full_bounds, dtype=self.dtype, init_value=0)
ximage[self.bounds] = self[self.bounds]

dx = self.scale
# dk = 2pi / (N dk)
Expand Down Expand Up @@ -902,34 +912,36 @@ def calculate_inverse_fft(self):
raise _galsim.GalSimError(
"calculate_inverse_fft requires that the image has a PixelScale wcs."
)
if not self.bounds.includes(0, 0):
raise _galsim.GalSimBoundsError(
"calculate_inverse_fft requires that the image includes (0,0)",
PositionI(0, 0),
self.bounds,
)

inc_val = jnp.array(self.bounds.includes(0, 0))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"calculate_inverse_fft requires that the image includes (0,0)",
)

No2 = max(
max(self.bounds.xmax, -self.bounds.ymin),
self.bounds.ymax,
)

target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2)
if self.bounds == target_bounds:
# Then the image is already in the shape we need.
kimage = self
else:
# Then we can pad out with zeros and wrap to get this in the form we need.
full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1)
kimage = Image(full_bounds, dtype=self.dtype, init_value=0)
posx_bounds = BoundsI(
xmin=0,
xmax=self.bounds.xmax,
ymin=self.bounds.ymin,
ymax=self.bounds.ymax,
)
kimage[posx_bounds] = self[posx_bounds]
kimage = kimage._wrap(target_bounds, True, False, 2 * No2)

# galsim branches here if the image has the correct bounds, but JAX can't branch
# on calls that generate different size arrays
# so we always make a new image

# Then we can pad out with zeros and wrap to get this in the form we need.
full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1)
kimage = Image(full_bounds, dtype=self.dtype, init_value=0)
posx_bounds = BoundsI(
xmin=0,
xmax=self.bounds.xmax,
ymin=self.bounds.ymin,
ymax=self.bounds.ymax,
)
kimage[posx_bounds] = self[posx_bounds]
kimage = kimage._wrap(target_bounds, True, False, 2 * No2)

dk = self.scale
# dx = 2pi / (N dk)
Expand Down Expand Up @@ -1067,12 +1079,13 @@ def getValue(self, x, y):
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
)
if not self.bounds.includes(x, y):
raise _galsim.GalSimBoundsError(
"Attempt to access position not in bounds of image.",
PositionI(x, y),
self.bounds,
)
inc_val = jnp.array(self.bounds.includes(x, y))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access position not in bounds of image.",
)

return self._getValue(x, y)

@implements(_galsim.Image._getValue)
Expand All @@ -1090,10 +1103,13 @@ def setValue(self, *args, **kwargs):
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
inc_val = jnp.array(self.bounds.includes(pos))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to set position not in bounds of image",
)

self._setValue(pos.x, pos.y, value)

@implements(_galsim.Image._setValue)
Expand All @@ -1111,10 +1127,13 @@ def addValue(self, *args, **kwargs):
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
inc_val = jnp.array(self.bounds.includes(pos))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to set position not in bounds of image",
)

self._addValue(pos.x, pos.y, value)

@implements(_galsim.Image._addValue)
Expand Down Expand Up @@ -1242,16 +1261,8 @@ def rot_180(self):
def tree_flatten(self):
"""Flatten the image into a list of values."""
# Define the children nodes of the PyTree that need tracing
if self.bounds.isStatic():
children = (self.array, self.wcs)
aux_data = {
"dtype": self.dtype,
"bounds": self.bounds,
"isconst": self.isconst,
}
else:
children = (self.array, self.wcs, self.bounds)
aux_data = {"dtype": self.dtype, "isconst": self.isconst}
children = (self.array, self.wcs, self.bounds)
aux_data = {"dtype": self.dtype, "isconst": self.isconst}
# other routines may add these attributes to images on the fly
# we have to include them here so that JAX knows how to handle them in jitting etc.
if hasattr(self, "added_flux"):
Expand All @@ -1269,26 +1280,16 @@ def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj._array = children[0]
obj.wcs = children[1]
if "bounds" in aux_data:
obj._bounds = aux_data["bounds"]
obj._dtype = aux_data["dtype"]
obj._is_const = aux_data["isconst"]
if len(children) > 2:
obj.added_flux = children[2]
if "header" in aux_data:
obj.header = aux_data["header"]
if len(children) > 3:
obj.photons = children[3]
else:
obj._bounds = children[2]
obj._dtype = aux_data["dtype"]
obj._is_const = aux_data["isconst"]
if len(children) > 3:
obj.added_flux = children[3]
if "header" in aux_data:
obj.header = aux_data["header"]
if len(children) > 4:
obj.photons = children[4]
obj._bounds = children[2]
obj._dtype = aux_data["dtype"]
obj._is_const = aux_data["isconst"]
if len(children) > 3:
obj.added_flux = children[3]
if "header" in aux_data:
obj.header = aux_data["header"]
if len(children) > 4:
obj.photons = children[4]

return obj

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import implements
from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements

try:
from jax.extend.random import wrap_key_data
Expand Down Expand Up @@ -95,7 +95,7 @@ def generates_in_pairs(self):
def seed(self, seed=None):
if seed is None:
self._seed(seed=seed)
elif isinstance(seed, (int, float, np.integer, np.floating)):
elif isinstance(seed, STATIC_SCALAR_TYPES):
if seed == int(seed):
self._seed(seed=int(seed))
else:
Expand Down
3 changes: 2 additions & 1 deletion jax_galsim/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax_galsim.angle import AngleUnit, arcsec, radians
from jax_galsim.celestial import CelestialCoord
from jax_galsim.core.utils import (
STATIC_SCALAR_TYPES,
cast_to_float,
ensure_hashable,
implements,
Expand All @@ -22,7 +23,7 @@
# this kind of casting is only done for writing FITS headers
# and should never be done anywhere else in the code base
def _cast_to_static_numeric_scalar(x, msg=None):
if isinstance(x, (int, float, np.integer, np.floating)):
if isinstance(x, STATIC_SCALAR_TYPES):
return x

if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)):
Expand Down
2 changes: 1 addition & 1 deletion tests/GalSim
1 change: 0 additions & 1 deletion tests/jax/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def _reg_fun(p):
if issubclass(cls, jax_galsim.Bounds) and method in [
"xmax",
"ymax",
"isStatic",
]:
continue

Expand Down
21 changes: 21 additions & 0 deletions tests/jax/test_bounds_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp
import numpy as np

import jax_galsim


@jax.vmap
@jax.jit
def _make_bounds_float(xmin, ymin, xmax, ymax):
bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
return bds, bds.isDefined()


def test_bounds_jax_vmap_isdefined_float():
xmin = jnp.array([9, 10, 11, 12])
xmax = jnp.array([12, 11, 10, 9])
ymin = jnp.array([9, 11, 10, 12])
ymax = jnp.array([10, 10, 10, 10])
bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True)
Loading