From 68c6938cb68459b33f38ca61b26f8038767b2c4c Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 19 May 2026 22:29:34 -0500 Subject: [PATCH 01/72] fix: redo bounds again for dynamic usages --- jax_galsim/bounds.py | 653 ++++++++++++++++++++++++----------- tests/GalSim | 2 +- tests/jax/test_bounds_jax.py | 49 +++ 3 files changed, 493 insertions(+), 211 deletions(-) create mode 100644 tests/jax/test_bounds_jax.py diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 69ab8746..f9fda0e0 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,12 +1,11 @@ +import equinox import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( cast_to_float, - cast_to_int, check_is_int_then_cast, ensure_hashable, implements, @@ -14,19 +13,15 @@ from jax_galsim.position import Position, PositionD, PositionI BOUNDS_LAX_DESCR = """\ -The JAX implementation - -- will not always test whether the bounds are valid - -Further, the JAX implementation adds a new method, ``isStatic`` to the -``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance +The JAX implementation adds a new method, ``isStatic`` to the +``Bounds`` class. If JAX-GalSim detects that a ``BoundsI`` instance has been instantiated with static, known values, ``isStatic()`` will -return ``True``. +return ``True``, otherwise it is ``False``. For ``BoundsD``, ``isStatic()`` +always returns ``False``. -``BoundsI`` objects in JAX-Galsim support an additional initialization -call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, -the values for ``deltax/y`` indicate the width of the bounds and must be -static constants. +``BoundsI`` objects in JAX-Galsim must have a fixed width. To help support +this requirement, JAX-Galsim supports an additional initialization call +``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin`` are vectorized over. This restriction allows for code that renders @@ -46,11 +41,13 @@ def __init__(self): ) def _parse_args(self, *args, **kwargs): + do_isdefined = True + if len(kwargs) == 0: if len(args) == 4: - self._isdefined = True self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: + do_isdefined = False self._isdefined = False self.xmin = 0 self.ymin = 0 @@ -70,7 +67,6 @@ def _parse_args(self, *args, **kwargs): self.ymin = args[0].ymin self.deltay = args[0].deltay + offset elif isinstance(args[0], Position): - self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y else: @@ -78,10 +74,8 @@ def _parse_args(self, *args, **kwargs): "Single argument to %s must be either a Bounds or a Position" % (self.__class__.__name__) ) - self._isdefined = True elif len(args) == 2: if isinstance(args[0], Position) and isinstance(args[1], Position): - self._isdefined = True self.xmin = min(args[0].x, args[1].x) self.xmax = max(args[0].x, args[1].x) self.ymin = min(args[0].y, args[1].y) @@ -103,7 +97,6 @@ def _parse_args(self, *args, **kwargs): ) else: try: - self._isdefined = True self.xmin = kwargs.pop("xmin") self.ymin = kwargs.pop("ymin") except KeyError: @@ -128,17 +121,7 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - # for simple inputs, we can check if the bounds are valid - if isinstance(self, BoundsD): - max_delta = 0 - else: - max_delta = 1 - if ( - isinstance(self.deltax, (int, float, np.integer, np.floating)) - and isinstance(self.deltay, (int, float, np.integer, np.floating)) - and (self.deltax < max_delta or self.deltay < max_delta) - ): - self._isdefined = False + return do_isdefined @implements(_galsim.Bounds.area) def area(self): @@ -166,18 +149,32 @@ def origin(self): @property @implements(_galsim.Bounds.center) def center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "center is invalid for an undefined Bounds" + if not isinstance(self._isdefined, jnp.ndarray): + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "center is invalid for an undefined Bounds", ) return self._center @property @implements(_galsim.Bounds.true_center) def true_center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "true_center is invalid for an undefined Bounds" + if not isinstance(self._isdefined, jnp.ndarray): + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "true_center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "true_center is invalid for an undefined Bounds", ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) @@ -262,61 +259,24 @@ def shift(self, delta): ) def __and__(self, other): - if not isinstance(other, self.__class__): - raise TypeError("other must be a %s instance" % self.__class__.__name__) - if not self.isDefined() or not other.isDefined(): - return self.__class__() - else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__and__` method!" + ) def __add__(self, other): - if isinstance(other, self.__class__): - if not other.isDefined(): - return self - elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return other - elif isinstance(other, self._pos_class): - if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return self.__class__(other) - else: - raise TypeError( - "other must be either a %s or a %s" - % (self.__class__.__name__, self._pos_class.__name__) - ) - - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.xmax, self.ymin, self.ymax) - else: - return () + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__add__` method!" + ) def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self._getinitargs() == other._getinitargs() + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__eq__` method!" ) def __ne__(self, other): - return not self.__eq__(other) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__ne__` method!" + ) def __hash__(self): return hash( @@ -333,10 +293,7 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - children = (self.xmin, self.deltax, self.ymin, self.deltay) - else: - children = tuple() + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data) @@ -344,15 +301,16 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - if children: - return cls( - xmin=children[0], - deltax=children[1], - ymin=children[2], - deltay=children[3], - ) - else: - return cls() + ret = cls.__new__(cls) + ret.xmin = children[0] + ret.deltax = children[1] + ret.ymin = children[2] + ret.deltay = children[3] + ret._isdefined = children[4] + ret._isstatic = False + ret._isstaticshape = False + + return ret @classmethod def from_galsim(cls, galsim_bounds): @@ -401,6 +359,163 @@ def isStatic(self): ``False`` for ``BoundsD``.""" return self._isstatic + def isStaticShape(self): + """Returns ``True`` if the ``BoundsI`` instance + has static, known dimensions. Always returns + ``False`` for ``BoundsD``.""" + return self._isstaticshape + + +def _bounds_and_op_static(self, other): + if not self.isDefined() or not other.isDefined(): + return self.__class__() + else: + xmin = max(self.xmin, other.xmin) + xmax = min(self.xmax, other.xmax) + ymin = max(self.ymin, other.ymin) + ymax = min(self.ymax, other.ymax) + if xmin > xmax or ymin > ymax: + return self.__class__() + else: + return self.__class__(xmin, xmax, ymin, ymax) + + +def _bounds_and_op_dynamic(self, other): + xmin = jnp.maximum(self.xmin, other.xmin) + xmax = jnp.minimum(self.xmax, other.xmax) + ymin = jnp.maximum(self.ymin, other.ymin) + ymax = jnp.minimum(self.ymax, other.ymax) + + is_defined = self.isDefined() & other.isDefined() & (ymin <= ymax) & (xmin <= xmax) + xmin = jnp.where( + is_defined, + xmin, + 0.0, + ) + xmax = jnp.where( + is_defined, + xmax, + 0.0, + ) + ymin = jnp.where( + is_defined, + ymin, + 0.0, + ) + ymax = jnp.where( + is_defined, + ymax, + 0.0, + ) + + cls = self.__class__ + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = is_defined + ret._isstatic = False + ret._isstaticshape = False + + return ret + + +def _bounds_bounds_add_op_static(self, other): + if not other.isDefined(): + return self + elif self.isDefined(): + xmin = min(self.xmin, other.xmin) + xmax = max(self.xmax, other.xmax) + ymin = min(self.ymin, other.ymin) + ymax = max(self.ymax, other.ymax) + return self.__class__(xmin, xmax, ymin, ymax) + else: + return other + + +def _bounds_bounds_add_op_dynamic(self, other, min_delta): + def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): + return jnp.where( + ~other_isdef, + self_attr, + jnp.where(self_isdef, op(self_attr, other_attr), other_attr), + ) + + xmin = _ret_correct_attr( + self._isdefined, self.xmin, other._isdefined, other.xmin, jnp.minimum + ) + xmax = _ret_correct_attr( + self._isdefined, self.xmax, other._isdefined, other.xmax, jnp.maximum + ) + ymin = _ret_correct_attr( + self._isdefined, self.ymin, other._isdefined, other.ymin, jnp.minimum + ) + ymax = _ret_correct_attr( + self._isdefined, self.ymax, other._isdefined, other.ymax, jnp.maximum + ) + + cls = self.__class__ + ret = cls.__new__(cls) + + ret.xmin = xmin + ret.deltax = xmax - xmin + min_delta + ret.ymin = ymin + ret.deltay = ymax - ymin + min_delta + ret._isdefined = jnp.where( + ~other._isdefined, + self._isdefined, + jnp.where( + self._isdefined, + (ret.deltax >= min_delta) & (ret.deltay >= min_delta), + other._isdefined, + ), + ) + ret._isstatic = False + ret._isstaticshape = False + + return ret + + +def _bounds_pos_add_op_dynamic(self, other, min_delta): + xmin = jnp.where( + self._isdefined, + jnp.minimum(self.xmin, other.x), + other.x, + ) + xmax = jnp.where( + self._isdefined, + jnp.maximum(self.xmax, other.x), + other.x, + ) + ymin = jnp.where( + self._isdefined, + jnp.minimum(self.ymin, other.y), + other.y, + ) + ymax = jnp.where( + self._isdefined, + jnp.maximum(self.ymax, other.y), + other.y, + ) + + cls = self.__class__ + ret = cls.__new__(cls) + + ret.xmin = xmin + ret.deltax = xmax - xmin + min_delta + ret.ymin = ymin + ret.deltay = ymax - ymin + min_delta + ret._isdefined = jnp.where( + self._isdefined, + (ret.deltax >= min_delta) & (ret.deltay >= min_delta), + jnp.array(True), + ) + ret._isstatic = False + ret._isstaticshape = False + + return ret + @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -409,18 +524,22 @@ class BoundsD(Bounds): def __init__(self, *args, **kwargs): self._isstatic = False - self._parse_args(*args, **kwargs) + self._isstaticshape = False + do_isdefined = self._parse_args(*args, **kwargs) self.xmin = cast_to_float(self.xmin) self.deltax = cast_to_float(self.deltax) self.ymin = cast_to_float(self.ymin) self.deltay = cast_to_float(self.deltay) + if do_isdefined: + self._isdefined = (self.deltax >= 0) & (self.deltay >= 0) + self._isdefined = jnp.array(self._isdefined) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () - and x.dtype.name in ["float32", "float64", "float"] + and jnp.issubdtype(x.dtype, jnp.floating) ): return elif x == float(x): @@ -453,7 +572,17 @@ def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) def __repr__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%r, %r, %r, %r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -465,7 +594,17 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -487,6 +626,45 @@ def __hash__(self): ) ) + def _getinitargs(self): + # defined only for galsim test suite + return (self.xmin, self.xmax, self.ymin, self.ymax) + + def __eq__(self, other): + if self is other: + return True + elif isinstance(other, self.__class__): + return ( + self.isDefined() + & other.isDefined() + & (self.xmin == other.xmin) + & (self.ymin == other.ymin) + & (self.xmax == other.xmax) + & (self.ymax == other.ymax) + ) | ((~self.isDefined()) & (~other.isDefined())) + else: + return False + + def __ne__(self, other): + return ~self.__eq__(other) + + def __and__(self, other): + if not isinstance(other, self.__class__): + raise TypeError("other must be a %s instance" % self.__class__.__name__) + + return _bounds_and_op_dynamic(self, other) + + def __add__(self, other): + if isinstance(other, self.__class__): + return _bounds_bounds_add_op_dynamic(self, other, 0) + elif isinstance(other, self._pos_class): + return _bounds_pos_add_op_dynamic(self, other, 0) + else: + raise TypeError( + "other must be either a %s or a %s" + % (self.__class__.__name__, self._pos_class.__name__) + ) + @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -494,47 +672,48 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): - # initial setting to let stuff pass through freely - self._isstatic = True - self._parse_args(*args, **kwargs) - self.deltax = cast_to_float(self.deltax) - self.deltay = cast_to_float(self.deltay) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") - self.deltax = cast_to_int(self.deltax) - self.deltay = cast_to_int(self.deltay) - - if not ( - isinstance( - self._xmin, - (int, float, np.floating, np.integer), - ) - and isinstance( - self._ymin, - (int, float, np.floating, np.integer), - ) - ): - self._isstatic = False - # validate inputs are ints - self._xmin = check_is_int_then_cast( - self._xmin, "BoundsI must be initialized with integer values" + self.deltax = check_is_int_then_cast( + self.deltax, "BoundsI must be initialized with integer values" ) - self._ymin = check_is_int_then_cast( - self._ymin, "BoundsI must be initialized with integer values" + self.deltay = check_is_int_then_cast( + self.deltay, "BoundsI must be initialized with integer values" + ) + self.xmin = check_is_int_then_cast( + self.xmin, "BoundsI must be initialized with integer values" + ) + self.ymin = check_is_int_then_cast( + self.ymin, "BoundsI must be initialized with integer values" ) - if self.deltax < 1 and self.deltay < 1: - self._isdefined = False + if isinstance(self.deltax, int) and isinstance(self.deltay, int): + self._isstaticshape = True + else: + self._isstaticshape = False + + if ( + isinstance(self.xmin, int) + and isinstance(self.ymin, int) + and isinstance(self.deltax, int) + and isinstance(self.deltay, int) + ): + self._isstatic = True + else: + self._isstatic = False + + if self.isStaticShape(): + self._isdefined = self.deltax >= 1 and self.deltay >= 1 + else: + self._isdefined = (self.deltax >= 1) & (self.deltay >= 1) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () - and x.dtype.name in ["int32", "int64", "int"] + and jnp.issubdtype(x.dtype, jnp.integer) ): return elif x == int(x): @@ -545,24 +724,17 @@ def _check_scalar(self, x, name): def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if self.isDefined(): - return self.deltay, self.deltax - else: - return 0, 0 - - @property - def xmin(self): - if self._isstatic: - return self._xmin - else: - return jnp.astype(self._xmin, jnp.int_) - - @xmin.setter - def xmin(self, value): - if self._isstatic: - self._xmin = value + if self._isstaticshape: + if self._isdefined: + return self.deltay, self.deltax + else: + return 0, 0 else: - self._xmin = jnp.astype(value, jnp.float_) + return jax.lax.cond( + self._isdefined, + lambda: (self.deltay, self.deltax), + lambda: (jnp.zeros_like(self.deltay), jnp.zeros_like(self.deltax)), + ) @property def xmax(self): @@ -572,20 +744,6 @@ def xmax(self): def xmax(self, value): self.deltax = value - self.xmin + 1 - @property - def ymin(self): - if self._isstatic: - return self._ymin - else: - return jnp.astype(self._ymin, jnp.int_) - - @ymin.setter - def ymin(self, value): - if self._isstatic: - self._ymin = value - else: - self._ymin = jnp.astype(value, jnp.float_) - @property def ymax(self): return self.ymin + self.deltay - 1 @@ -596,10 +754,17 @@ def ymax(self, value): def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if not self.isDefined(): - return 0 + if self._isstaticshape: + if self._isdefined: + return self.deltax * self.deltay + else: + return 0 else: - return self.deltax * self.deltay + return jax.lax.cond( + self._isdefined, + lambda: self.deltax * self.deltay, + lambda: 0.0, + ) @property def _center(self): @@ -617,54 +782,62 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - if self._isstatic: - # Define the children nodes of the PyTree that need tracing - children = tuple() - - # Define auxiliary static data that doesn’t need to be traced - aux_data = { - "xmin": self._xmin, - "ymin": self._ymin, - "deltax": self.deltax, - "deltay": self.deltay, - } - else: - children = (self._xmin, self._ymin) - # Define auxiliary static data that doesn’t need to be traced - aux_data = {"deltax": self.deltax, "deltay": self.deltay} - else: - children = tuple() - aux_data = None + aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} + if self._isstaticshape: + aux_data["deltax"] = self.deltax + aux_data["deltay"] = self.deltay + aux_data["isdefined"] = self._isdefined + + if self._isstatic: + aux_data["xmin"] = self.xmin + aux_data["ymin"] = self.ymin + + children = tuple( + jnp.broadcast_arrays( + self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined + ) + ) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - if aux_data is not None: - ret = cls.__new__(cls) - if "xmin" in aux_data and "ymin" in aux_data: - ret._isstatic = True - ret._xmin = aux_data["xmin"] - ret._ymin = aux_data["ymin"] - else: - ret._isstatic = False - ret._xmin = children[0] - ret._ymin = children[1] + ret = cls.__new__(cls) + + if aux_data["isstaticshape"]: ret.deltax = aux_data["deltax"] ret.deltay = aux_data["deltay"] - if ret.deltax < 1 and ret.deltay < 1: - ret._isdefined = False - else: - ret._isdefined = True + ret._isdefined = aux_data["isdefined"] + else: + ret.deltax = children[1] + ret.deltay = children[3] + ret._isdefined = children[4] + + if aux_data["isstatic"]: + ret.xmin = aux_data["xmin"] + ret.ymin = aux_data["ymin"] else: - ret = cls() + ret.xmin = children[0] + ret.ymin = children[2] + + ret._isstatic = aux_data["isstatic"] + ret._isstaticshape = aux_data["isstaticshape"] return ret def __repr__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -676,7 +849,17 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -687,17 +870,6 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.deltax, self.ymin, self.deltay) - else: - return () - - def __eq__(self, other): - return self is other or ( - isinstance(other, BoundsI) and self._getinitargs() == other._getinitargs() - ) - def __hash__(self): return hash( ( @@ -708,3 +880,64 @@ def __hash__(self): ensure_hashable(self.deltay), ) ) + + def __eq__(self, other): + if self is other: + return True + elif isinstance(other, self.__class__): + if self.isStatic() and other.isStatic(): + min_eq = (self.xmin == other.xmin) and (self.ymin == other.ymin) + self_isdef = self.isDefined() + other_isdef = other.isDefined() + shape_eq = (self.deltax == other.deltax) and ( + self.deltay == other.deltay + ) + return (self_isdef and other_isdef and shape_eq and min_eq) or ( + (not self_isdef) and (not other_isdef) + ) + else: + min_eq = jnp.array(self.xmin == other.xmin) & jnp.array( + self.ymin == other.ymin + ) + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + shape_eq = jnp.array(self.deltax == other.deltax) & jnp.array( + self.deltay == other.deltay + ) + return (self_isdef & other_isdef & shape_eq & min_eq) | ( + (~self_isdef) & (~other_isdef) + ) + else: + return False + + def __ne__(self, other): + if not isinstance(other, self.__class__): + return True + + if self.isStatic() and other.isStatic(): + return not self.__eq__(other) + else: + return ~self.__eq__(other) + + def __and__(self, other): + if not isinstance(other, self.__class__): + raise TypeError("other must be a %s instance" % self.__class__.__name__) + + if self.isStatic() and other.isStatic(): + return _bounds_and_op_static(self, other) + else: + return _bounds_and_op_dynamic(self, other) + + def __add__(self, other): + if isinstance(other, self.__class__): + if self.isStatic() and other.isStatic(): + return _bounds_bounds_add_op_static(self, other) + else: + return _bounds_bounds_add_op_dynamic(self, other, 1) + elif isinstance(other, self._pos_class): + return _bounds_pos_add_op_dynamic(self, other, 1) + else: + raise TypeError( + "other must be either a %s or a %s" + % (self.__class__.__name__, self._pos_class.__name__) + ) diff --git a/tests/GalSim b/tests/GalSim index 549616e8..f3d81a1d 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 549616e8ca4bb84142fae6cdb0a006669f92454b +Subproject commit f3d81a1d18a30651d8769818731d4c4ac3541478 diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py new file mode 100644 index 00000000..c43068c5 --- /dev/null +++ b/tests/jax/test_bounds_jax.py @@ -0,0 +1,49 @@ +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim + + +@jax.vmap +@jax.jit +def _make_bounds_int(xmin, ymin, xmax, ymax): + bds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bds, bds.isDefined() + + +def test_bounds_jax_vmap_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 11, 10]) + bds, isdef = _make_bounds_int(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) + + # turn a bounds of arrays into a list of bounds + # see https://github.com/jax-ml/jax/discussions/35711 + list_of_bnds = jax.tree.transpose( + jax.tree.structure(bds), + None, + jax.tree.map(list, bds) + ) + assert list_of_bnds[0] != list_of_bnds[2] + assert list_of_bnds[1] == list_of_bnds[2] + assert list_of_bnds[2] == list_of_bnds[3] + assert all(not bnds.isStatic() for bnds in list_of_bnds) + + +@jax.vmap +@jax.jit +def _make_bounds_float(xmin, ymin, xmax, ymax): + bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bds, bds.isDefined() + + +def test_bounds_jax_vmap_isdefined_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) From 9d9dc80862b334626d50656acf827bd6dc513845 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 19 May 2026 22:31:25 -0500 Subject: [PATCH 02/72] style: please the dog --- tests/jax/test_bounds_jax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index c43068c5..f9481487 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -23,9 +23,7 @@ def test_bounds_jax_vmap_isdefined_int(): # turn a bounds of arrays into a list of bounds # see https://github.com/jax-ml/jax/discussions/35711 list_of_bnds = jax.tree.transpose( - jax.tree.structure(bds), - None, - jax.tree.map(list, bds) + jax.tree.structure(bds), None, jax.tree.map(list, bds) ) assert list_of_bnds[0] != list_of_bnds[2] assert list_of_bnds[1] == list_of_bnds[2] From 956f778bc6e3ae0e69b75875acf43b5a9c44739e Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 19 May 2026 22:34:33 -0500 Subject: [PATCH 03/72] fix: only apply ~ to array bool --- jax_galsim/bounds.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index f9fda0e0..b6273baa 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -632,7 +632,7 @@ def _getinitargs(self): def __eq__(self, other): if self is other: - return True + return jnp.array(True) elif isinstance(other, self.__class__): return ( self.isDefined() @@ -643,7 +643,7 @@ def __eq__(self, other): & (self.ymax == other.ymax) ) | ((~self.isDefined()) & (~other.isDefined())) else: - return False + return jnp.array(False) def __ne__(self, other): return ~self.__eq__(other) @@ -883,7 +883,10 @@ def __hash__(self): def __eq__(self, other): if self is other: - return True + if self.isStatic() and other.isStatic(): + return True + else: + return jnp.array(True) elif isinstance(other, self.__class__): if self.isStatic() and other.isStatic(): min_eq = (self.xmin == other.xmin) and (self.ymin == other.ymin) From 03c78694655d0d6f2425449b5e5e9551119f4130 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 06:03:10 -0500 Subject: [PATCH 04/72] fix: try dynamic children --- jax_galsim/bounds.py | 42 +++++++++++++++++++++++------------------- tests/jax/test_api.py | 1 + 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index b6273baa..7965aceb 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -783,20 +783,22 @@ def tree_flatten(self): nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} + + if self._isstatic: + aux_data["xmin"] = self.xmin + aux_data["ymin"] = self.ymin + if self._isstaticshape: aux_data["deltax"] = self.deltax aux_data["deltay"] = self.deltay aux_data["isdefined"] = self._isdefined if self._isstatic: - aux_data["xmin"] = self.xmin - aux_data["ymin"] = self.ymin - - children = tuple( - jnp.broadcast_arrays( - self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined - ) - ) + children = tuple() + elif self._isstaticshape: + children = (self.xmin, self.ymin) + else: + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) return (children, aux_data) @@ -804,26 +806,28 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) + ret._isstatic = aux_data["isstatic"] + ret._isstaticshape = aux_data["isstaticshape"] - if aux_data["isstaticshape"]: + if ret._isstatic: + ret.xmin = aux_data["xmin"] + ret.ymin = aux_data["ymin"] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + elif ret._isstaticshape: + ret.xmin = children[0] + ret.ymin = children[1] ret.deltax = aux_data["deltax"] ret.deltay = aux_data["deltay"] ret._isdefined = aux_data["isdefined"] else: + ret.xmin = children[0] ret.deltax = children[1] + ret.ymin = children[2] ret.deltay = children[3] ret._isdefined = children[4] - if aux_data["isstatic"]: - ret.xmin = aux_data["xmin"] - ret.ymin = aux_data["ymin"] - else: - ret.xmin = children[0] - ret.ymin = children[2] - - ret._isstatic = aux_data["isstatic"] - ret._isstaticshape = aux_data["isstaticshape"] - return ret def __repr__(self): diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e76b081c..c3afed10 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -360,6 +360,7 @@ def _reg_fun(p): "xmax", "ymax", "isStatic", + "isStaticShape", ]: continue From 231fdb266c1c706a40718062388668568e9ef08a Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 07:24:24 -0500 Subject: [PATCH 05/72] fix: ensure bounds iuncludes tests are done properly --- jax_galsim/bounds.py | 226 +++++++++++++++++++++++++++---------------- 1 file changed, 141 insertions(+), 85 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 7965aceb..89534107 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -180,41 +180,9 @@ def true_center(self): @implements(_galsim.Bounds.includes) def includes(self, *args): - if len(args) == 1: - if isinstance(args[0], Bounds): - b = args[0] - return ( - self.isDefined() - & b.isDefined() - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) - ) - elif isinstance(args[0], Position): - p = args[0] - return ( - self.isDefined() - & (self.xmin <= p.x) - & (self.ymin <= p.y) - & (p.x <= self.xmax) - & (p.y <= self.ymax) - ) - else: - raise TypeError("Invalid argument %s" % args[0]) - elif len(args) == 2: - x, y = args - return ( - self.isDefined() - & (self.xmin <= float(x)) - & (self.ymin <= float(y)) - & (float(x) <= self.xmax) - & (float(y) <= self.ymax) - ) - elif len(args) == 0: - raise TypeError("include takes at least 1 argument (0 given)") - else: - raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `includes` method!" + ) @implements(_galsim.Bounds.expand) def expand(self, factor_x, factor_y=None): @@ -571,6 +539,44 @@ def _area(self): def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) + @implements(_galsim.Bounds.includes) + def includes(self, *args): + if len(args) == 1: + if isinstance(args[0], Bounds): + b = args[0] + return ( + self.isDefined() + & b.isDefined() + & (self.xmin <= b.xmin) + & (self.xmax >= b.xmax) + & (self.ymin <= b.ymin) + & (self.ymax >= b.ymax) + ) + elif isinstance(args[0], Position): + p = args[0] + return ( + self.isDefined() + & (self.xmin <= p.x) + & (self.ymin <= p.y) + & (p.x <= self.xmax) + & (p.y <= self.ymax) + ) + else: + raise TypeError("Invalid argument %s" % args[0]) + elif len(args) == 2: + x, y = args + return ( + self.isDefined() + & (self.xmin <= cast_to_float(x)) + & (self.ymin <= cast_to_float(y)) + & (cast_to_float(x) <= self.xmax) + & (cast_to_float(y) <= self.ymax) + ) + elif len(args) == 0: + raise TypeError("include takes at least 1 argument (0 given)") + else: + raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) + def __repr__(self): # sometimes we will encounter a tracer here # and so we suppress any boolean conversion errors @@ -778,57 +784,55 @@ def _center(self): self.ymin + self.deltay // 2, ) - def tree_flatten(self): - """This function flattens the Bounds into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} - - if self._isstatic: - aux_data["xmin"] = self.xmin - aux_data["ymin"] = self.ymin - - if self._isstaticshape: - aux_data["deltax"] = self.deltax - aux_data["deltay"] = self.deltay - aux_data["isdefined"] = self._isdefined - - if self._isstatic: - children = tuple() - elif self._isstaticshape: - children = (self.xmin, self.ymin) - else: - children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) - - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - ret = cls.__new__(cls) - ret._isstatic = aux_data["isstatic"] - ret._isstaticshape = aux_data["isstaticshape"] - - if ret._isstatic: - ret.xmin = aux_data["xmin"] - ret.ymin = aux_data["ymin"] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - ret._isdefined = aux_data["isdefined"] - elif ret._isstaticshape: - ret.xmin = children[0] - ret.ymin = children[1] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - ret._isdefined = aux_data["isdefined"] + @implements(_galsim.Bounds.includes) + def includes(self, *args): + if len(args) == 1: + if isinstance(args[0], Bounds): + b = args[0] + if self.isStatic() and b.isStatic(): + return ( + self.isDefined() + and b.isDefined() + and (self.xmin <= b.xmin) + and (self.xmax >= b.xmax) + and (self.ymin <= b.ymin) + and (self.ymax >= b.ymax) + ) + else: + return ( + jnp.array(self.isDefined()) + & jnp.array(b.isDefined()) + & jnp.array(self.xmin <= b.xmin) + & jnp.array(self.xmax >= b.xmax) + & jnp.array(self.ymin <= b.ymin) + & jnp.array(self.ymax >= b.ymax) + ) + elif isinstance(args[0], Position): + p = args[0] + return ( + jnp.array(self.isDefined()) + & (self.xmin <= p.x) + & (self.ymin <= p.y) + & (p.x <= self.xmax) + & (p.y <= self.ymax) + ) + else: + raise TypeError("Invalid argument %s" % args[0]) + elif len(args) == 2: + x, y = args + x = cast_to_float(jnp.array(x)) + y = cast_to_float(jnp.array(y)) + return ( + jnp.array(self.isDefined()) + & (self.xmin <= x) + & (self.ymin <= y) + & (x <= self.xmax) + & (y <= self.ymax) + ) + elif len(args) == 0: + raise TypeError("include takes at least 1 argument (0 given)") else: - ret.xmin = children[0] - ret.deltax = children[1] - ret.ymin = children[2] - ret.deltay = children[3] - ret._isdefined = children[4] - - return ret + raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) def __repr__(self): # sometimes we will encounter a tracer here @@ -948,3 +952,55 @@ def __add__(self, other): "other must be either a %s or a %s" % (self.__class__.__name__, self._pos_class.__name__) ) + + def tree_flatten(self): + """This function flattens the Bounds into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} + + if self._isstatic: + aux_data["xmin"] = self.xmin + aux_data["ymin"] = self.ymin + + if self._isstaticshape: + aux_data["deltax"] = self.deltax + aux_data["deltay"] = self.deltay + aux_data["isdefined"] = self._isdefined + + if self._isstatic: + children = tuple() + elif self._isstaticshape: + children = (self.xmin, self.ymin) + else: + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) + + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + ret = cls.__new__(cls) + ret._isstatic = aux_data["isstatic"] + ret._isstaticshape = aux_data["isstaticshape"] + + if ret._isstatic: + ret.xmin = aux_data["xmin"] + ret.ymin = aux_data["ymin"] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + elif ret._isstaticshape: + ret.xmin = children[0] + ret.ymin = children[1] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + else: + ret.xmin = children[0] + ret.deltax = children[1] + ret.ymin = children[2] + ret.deltay = children[3] + ret._isdefined = children[4] + + return ret From 97b882387f72b1e3b4f996e443b248fd80c0d58b Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 07:25:50 -0500 Subject: [PATCH 06/72] fix: cast in a different way --- jax_galsim/bounds.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 89534107..3c618681 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -811,23 +811,23 @@ def includes(self, *args): p = args[0] return ( jnp.array(self.isDefined()) - & (self.xmin <= p.x) - & (self.ymin <= p.y) - & (p.x <= self.xmax) - & (p.y <= self.ymax) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args - x = cast_to_float(jnp.array(x)) - y = cast_to_float(jnp.array(y)) + x = cast_to_float(x) + y = cast_to_float(y) return ( jnp.array(self.isDefined()) - & (self.xmin <= x) - & (self.ymin <= y) - & (x <= self.xmax) - & (y <= self.ymax) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") From 997dadfd8d053c38f20096d6d889224ec193e2ed Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 11:49:05 -0500 Subject: [PATCH 07/72] fix: finish dynamic bounds impl --- jax_galsim/bounds.py | 52 ++++++++++++++++------ jax_galsim/image.py | 100 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 18 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 3c618681..533c5186 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -2,6 +2,7 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( @@ -809,26 +810,49 @@ def includes(self, *args): ) elif isinstance(args[0], Position): p = args[0] - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= p.x) - & jnp.array(self.ymin <= p.y) - & jnp.array(p.x <= self.xmax) - & jnp.array(p.y <= self.ymax) - ) + ok_types = (int, float, np.integer, np.floating) + if ( + self._isstatic + and isinstance(p.x, ok_types) + and isinstance(p.y, ok_types) + ): + return ( + self.isDefined() + and (self.xmin <= p.x) + and (self.ymin <= p.y) + and (p.x <= self.xmax) + and (p.y <= self.ymax) + ) + else: + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) + ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args x = cast_to_float(x) y = cast_to_float(y) - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= x) - & jnp.array(self.ymin <= y) - & jnp.array(x <= self.xmax) - & jnp.array(y <= self.ymax) - ) + if self._isstatic and isinstance(x, float) and isinstance(y, float): + return ( + self.isDefined() + and (self.xmin <= x) + and (self.ymin <= y) + and (x <= self.xmax) + and (y <= self.ymax) + ) + else: + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) + ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") else: diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 7c78a178..8032eb67 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -269,6 +269,13 @@ def __init__(self, *args, **kwargs): raise TypeError("wcs parameters must be a galsim.BaseWCS instance") self.wcs = wcs + # raise an error if bounds doesn't have a fixed width + if not self._bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + @staticmethod def _get_xmin_ymin(array, kwargs, check_bounds=True): """A helper function for parsing xmin, ymin, bounds options with a given array""" @@ -280,6 +287,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not b.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if check_bounds and b.isDefined(): if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -571,6 +586,14 @@ def resize(self, bounds, wcs=None): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: @@ -580,6 +603,14 @@ def resize(self, bounds, wcs=None): def subImage(self, bounds): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" @@ -592,6 +623,13 @@ def subImage(self, bounds): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin @@ -619,6 +657,14 @@ def setSubImage(self, bounds, rhs): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" @@ -631,6 +677,14 @@ def setSubImage(self, bounds, rhs): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) + if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") if bounds.numpyShape() != rhs.bounds.numpyShape(): @@ -722,6 +776,13 @@ def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + def _raise_if_nonzero(bnds, x_or_y, msg): if x_or_y == "x": if bnds.isStatic(): @@ -902,12 +963,19 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if not self.bounds.includes(0, 0): + if self.bounds.isStatic() and not self.bounds.includes(0, 0): raise _galsim.GalSimBoundsError( "calculate_inverse_fft requires that the image includes (0,0)", PositionI(0, 0), self.bounds, ) + else: + inc_val = jnp.array(self.bounds.includes(0, 0)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "calculate_inverse_fft requires that the image includes (0,0)", + ) No2 = max( max(self.bounds.xmax, -self.bounds.ymin), @@ -1067,12 +1135,20 @@ def getValue(self, x, y): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(x, y): + if self.bounds.isStatic() and not self.bounds.includes(x, y): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), self.bounds, ) + else: + inc_val = jnp.array(self.bounds.includes(x, y)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access position not in bounds of image.", + ) + return self._getValue(x, y) @implements(_galsim.Image._getValue) @@ -1090,10 +1166,18 @@ def setValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if self.bounds.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._setValue(pos.x, pos.y, value) @implements(_galsim.Image._setValue) @@ -1111,10 +1195,18 @@ def addValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if self.bounds.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._addValue(pos.x, pos.y, value) @implements(_galsim.Image._addValue) From 0774b6a49ce9dac4165dc24c20ab4644668c0acb Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 12:05:03 -0500 Subject: [PATCH 08/72] fix: be sure to test everything --- jax_galsim/bounds.py | 4 ++-- jax_galsim/core/utils.py | 8 ++++---- jax_galsim/image.py | 12 +++++++++--- jax_galsim/position.py | 8 ++++++++ jax_galsim/random.py | 4 ++-- jax_galsim/wcs.py | 3 ++- 6 files changed, 27 insertions(+), 12 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 533c5186..c8ba4954 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -2,10 +2,10 @@ import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, check_is_int_then_cast, ensure_hashable, @@ -810,7 +810,7 @@ def includes(self, *args): ) elif isinstance(args[0], Position): p = args[0] - ok_types = (int, float, np.integer, np.floating) + ok_types = STATIC_SCALAR_TYPES if ( self._isstatic and isinstance(p.x, ok_types) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 839d0658..9a06975c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -8,12 +8,14 @@ import jax.numpy as jnp import numpy as np +STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating) + def check_is_int_then_cast(val, msg): """Check if `val` is an integer, raise if not, otherwise cast to int.""" val = cast_to_float(val) - if isinstance(val, (int, float, np.integer, np.floating)): + if isinstance(val, STATIC_SCALAR_TYPES): # for simple inputs, we can check direct in python if val != int(val): raise TypeError(msg) @@ -43,9 +45,7 @@ def cast_numpy_array_to_native_byte_order(arr): def _cast_to_type(x, typ, accept_strings=False): - if isinstance(x, (int, float, np.integer, np.floating)) or ( - accept_strings and isinstance(x, str) - ): + if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)): return typ(x) else: return jnp.astype(x, typ) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8032eb67..c4a60ed3 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -7,6 +7,7 @@ from jax_galsim.bounds import Bounds, BoundsD, BoundsI from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_numpy_array_to_native_byte_order, ensure_hashable, implements, @@ -1135,7 +1136,12 @@ def getValue(self, x, y): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if self.bounds.isStatic() and not self.bounds.includes(x, y): + if ( + self.bounds.isStatic() + and isinstance(x, STATIC_SCALAR_TYPES) + and isinstance(y, STATIC_SCALAR_TYPES) + and not self.bounds.includes(x, y) + ): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), @@ -1166,7 +1172,7 @@ def setValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if self.bounds.isStatic() and not self.bounds.includes(pos): + if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) @@ -1195,7 +1201,7 @@ def addValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if self.bounds.isStatic() and not self.bounds.includes(pos): + if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index cf36dba8..6b5ffc0d 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -4,6 +4,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, check_is_int_then_cast, ensure_hashable, @@ -182,6 +183,13 @@ def to_galsim(self): cast(self.y), ) + def isStatic(self): + """Returns ``True`` if the ``Position`` instance + ``x`` and ``y`` values are not arrays""" + return isinstance(self.x, STATIC_SCALAR_TYPES) and isinstance( + self.y, STATIC_SCALAR_TYPES + ) + @implements(_galsim.PositionD) @register_pytree_node_class diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 2cf2db27..b5e730f3 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import implements +from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements try: from jax.extend.random import wrap_key_data @@ -95,7 +95,7 @@ def generates_in_pairs(self): def seed(self, seed=None): if seed is None: self._seed(seed=seed) - elif isinstance(seed, (int, float, np.integer, np.floating)): + elif isinstance(seed, STATIC_SCALAR_TYPES): if seed == int(seed): self._seed(seed=int(seed)) else: diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 6dcec13b..ec1320be 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -7,6 +7,7 @@ from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, ensure_hashable, implements, @@ -22,7 +23,7 @@ # this kind of casting is only done for writing FITS headers # and should never be done anywhere else in the code base def _cast_to_static_numeric_scalar(x, msg=None): - if isinstance(x, (int, float, np.integer, np.floating)): + if isinstance(x, STATIC_SCALAR_TYPES): return x if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): From 80f7327e6658f7ab181723758877bd25bd921020 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 13:45:08 -0500 Subject: [PATCH 09/72] fix: get bools right --- jax_galsim/bounds.py | 563 ++++++++++++++--------------------------- jax_galsim/image.py | 202 ++++----------- jax_galsim/position.py | 8 - 3 files changed, 246 insertions(+), 527 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index c8ba4954..818590a1 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -5,7 +5,6 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - STATIC_SCALAR_TYPES, cast_to_float, check_is_int_then_cast, ensure_hashable, @@ -14,22 +13,26 @@ from jax_galsim.position import Position, PositionD, PositionI BOUNDS_LAX_DESCR = """\ -The JAX implementation adds a new method, ``isStatic`` to the -``Bounds`` class. If JAX-GalSim detects that a ``BoundsI`` instance -has been instantiated with static, known values, ``isStatic()`` will -return ``True``, otherwise it is ``False``. For ``BoundsD``, ``isStatic()`` -always returns ``False``. - -``BoundsI`` objects in JAX-Galsim must have a fixed width. To help support -this requirement, JAX-Galsim supports an additional initialization call -``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. - -When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin`` -are vectorized over. This restriction allows for code that renders -objects in fixed sized stamps with variable locations, a common -operation. ``BoundsI`` objects which are static (i.e., ``isStatic()`` -returns ``True``) are treated as constants with respect to ``vmap``, -``jit``, and other JAX transforms. +The JAX-GalSim implementation of the ``BoundsI/D`` classes have some key differences +from GalSim. + +- ``BoundsI`` instances must have statically known shapes, but may have non-static + start locations (i.e., ``xmin`` and ``ymin`` may be JAX arrays, traced in JIT operations, etc.). + This restriction mirrors the JAX restriction that arrays have fixed shapes when traced + for function transformations like ``jax.vmap``, ``jax.jit``, etc. +- Upon initialization, if a ``BoundsI`` object has a non-static shape, JAX-GalSim will attempt to convert + it to a static shape by extracting the dimensions from the array via ``.item()``. This operation will + cause JAX to raise an error if the code is being traced. +- JAX-GalSim does not support the use of the `&/+` dunder methods (i.e., ``__and__`` and ``__add__``) + for ``BoundsI`` objects when tracing code. +- JAX-Galsim supports an additional initialization signature ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` + to help users specify the widths ``deltax`` and ``deltay`` statically at initialization. +- When calling ``jax.vmap``, ``jax.jit`` etc. with ``BoundsI`` objects, ``xmin`` and ``ymin`` are + traced by JAX. The combination of this feature with statically known shapes allows for code that renders + objects in fixed sized stamps with variable locations, a common operation. +- For ``BoundsD``, all ``x(y)min(max)`` values are traced as arrays. +- ``Bounds`` objects always return a JAX boolean values for various method calls, except for + ``BoundsI.isDefined()`` which is always a Python boolean value. """ @@ -77,10 +80,10 @@ def _parse_args(self, *args, **kwargs): ) elif len(args) == 2: if isinstance(args[0], Position) and isinstance(args[1], Position): - self.xmin = min(args[0].x, args[1].x) - self.xmax = max(args[0].x, args[1].x) - self.ymin = min(args[0].y, args[1].y) - self.ymax = max(args[0].y, args[1].y) + self.xmin = jnp.minimum(args[0].x, args[1].x) + self.xmax = jnp.maximum(args[0].x, args[1].x) + self.ymin = jnp.minimum(args[0].y, args[1].y) + self.ymax = jnp.maximum(args[0].y, args[1].y) else: raise TypeError( "Two arguments to %s must be Positions" @@ -228,24 +231,41 @@ def shift(self, delta): ) def __and__(self, other): - raise NotImplementedError( - "Subclasses of `Bounds` must implement the `__and__` method!" - ) + if not isinstance(other, self.__class__): + raise TypeError("other must be a %s instance" % self.__class__.__name__) + + return _bounds_and_op_dynamic(self, other) def __add__(self, other): - raise NotImplementedError( - "Subclasses of `Bounds` must implement the `__add__` method!" - ) + if isinstance(other, self.__class__): + return _bounds_bounds_add_op_dynamic(self, other) + elif isinstance(other, self._pos_class): + return _bounds_pos_add_op_dynamic(self, other) + else: + raise TypeError( + "other must be either a %s or a %s" + % (self.__class__.__name__, self._pos_class.__name__) + ) def __eq__(self, other): - raise NotImplementedError( - "Subclasses of `Bounds` must implement the `__eq__` method!" - ) + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.xmax == other.xmax) + & jnp.array(self.ymax == other.ymax) + ) | ((~self_isdef) & (~other_isdef)) + else: + return jnp.array(False) def __ne__(self, other): - raise NotImplementedError( - "Subclasses of `Bounds` must implement the `__ne__` method!" - ) + return ~self.__eq__(other) def __hash__(self): return hash( @@ -276,8 +296,6 @@ def tree_unflatten(cls, aux_data, children): ret.ymin = children[2] ret.deltay = children[3] ret._isdefined = children[4] - ret._isstatic = False - ret._isstaticshape = False return ret @@ -322,32 +340,6 @@ def to_galsim(self): else: return gs_class() - def isStatic(self): - """Returns ``True`` if the ``BoundsI`` instance - has static, known dimensions and location. Always returns - ``False`` for ``BoundsD``.""" - return self._isstatic - - def isStaticShape(self): - """Returns ``True`` if the ``BoundsI`` instance - has static, known dimensions. Always returns - ``False`` for ``BoundsD``.""" - return self._isstaticshape - - -def _bounds_and_op_static(self, other): - if not self.isDefined() or not other.isDefined(): - return self.__class__() - else: - xmin = max(self.xmin, other.xmin) - xmax = min(self.xmax, other.xmax) - ymin = max(self.ymin, other.ymin) - ymax = min(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) - def _bounds_and_op_dynamic(self, other): xmin = jnp.maximum(self.xmin, other.xmin) @@ -355,7 +347,12 @@ def _bounds_and_op_dynamic(self, other): ymin = jnp.maximum(self.ymin, other.ymin) ymax = jnp.minimum(self.ymax, other.ymax) - is_defined = self.isDefined() & other.isDefined() & (ymin <= ymax) & (xmin <= xmax) + is_defined = ( + jnp.array(self.isDefined()) + & jnp.array(other.isDefined()) + & jnp.array(ymin <= ymax) + & jnp.array(xmin <= xmax) + ) xmin = jnp.where( is_defined, xmin, @@ -378,37 +375,34 @@ def _bounds_and_op_dynamic(self, other): ) cls = self.__class__ - ret = cls.__new__(cls) - ret.xmin = xmin - ret.deltax = xmax - xmin - ret.ymin = ymin - ret.deltay = ymax - ymin - ret._isdefined = is_defined - ret._isstatic = False - ret._isstaticshape = False + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + # we have to do a conversion to static bools here too + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = is_defined return ret -def _bounds_bounds_add_op_static(self, other): - if not other.isDefined(): - return self - elif self.isDefined(): - xmin = min(self.xmin, other.xmin) - xmax = max(self.xmax, other.xmax) - ymin = min(self.ymin, other.ymin) - ymax = max(self.ymax, other.ymax) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return other - - -def _bounds_bounds_add_op_dynamic(self, other, min_delta): +def _bounds_bounds_add_op_dynamic(self, other): def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): return jnp.where( - ~other_isdef, + ~jnp.array(other_isdef), self_attr, - jnp.where(self_isdef, op(self_attr, other_attr), other_attr), + jnp.where(jnp.array(self_isdef), op(self_attr, other_attr), other_attr), ) xmin = _ret_correct_attr( @@ -425,28 +419,46 @@ def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): ) cls = self.__class__ - ret = cls.__new__(cls) - - ret.xmin = xmin - ret.deltax = xmax - xmin + min_delta - ret.ymin = ymin - ret.deltay = ymax - ymin + min_delta - ret._isdefined = jnp.where( - ~other._isdefined, - self._isdefined, - jnp.where( - self._isdefined, - (ret.deltax >= min_delta) & (ret.deltay >= min_delta), - other._isdefined, - ), - ) - ret._isstatic = False - ret._isstaticshape = False + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + is_defined = jnp.where( + ~jnp.array(other._isdefined), + jnp.array(self._isdefined), + jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1), + jnp.array(other._isdefined), + ), + ) + # we have to do a conversion to static bools here too + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = jnp.where( + ~jnp.array(other._isdefined), + jnp.array(self._isdefined), + jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(other._isdefined), + ), + ) return ret -def _bounds_pos_add_op_dynamic(self, other, min_delta): +def _bounds_pos_add_op_dynamic(self, other): xmin = jnp.where( self._isdefined, jnp.minimum(self.xmin, other.x), @@ -469,19 +481,33 @@ def _bounds_pos_add_op_dynamic(self, other, min_delta): ) cls = self.__class__ - ret = cls.__new__(cls) - - ret.xmin = xmin - ret.deltax = xmax - xmin + min_delta - ret.ymin = ymin - ret.deltay = ymax - ymin + min_delta - ret._isdefined = jnp.where( - self._isdefined, - (ret.deltax >= min_delta) & (ret.deltay >= min_delta), - jnp.array(True), - ) - ret._isstatic = False - ret._isstaticshape = False + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + is_defined = jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(True), + ) + # we have to do a conversion to static bools here too + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = jnp.where( + self._isdefined, + jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(True), + ) return ret @@ -492,8 +518,6 @@ class BoundsD(Bounds): _pos_class = PositionD def __init__(self, *args, **kwargs): - self._isstatic = False - self._isstaticshape = False do_isdefined = self._parse_args(*args, **kwargs) self.xmin = cast_to_float(self.xmin) self.deltax = cast_to_float(self.deltax) @@ -637,41 +661,6 @@ def _getinitargs(self): # defined only for galsim test suite return (self.xmin, self.xmax, self.ymin, self.ymax) - def __eq__(self, other): - if self is other: - return jnp.array(True) - elif isinstance(other, self.__class__): - return ( - self.isDefined() - & other.isDefined() - & (self.xmin == other.xmin) - & (self.ymin == other.ymin) - & (self.xmax == other.xmax) - & (self.ymax == other.ymax) - ) | ((~self.isDefined()) & (~other.isDefined())) - else: - return jnp.array(False) - - def __ne__(self, other): - return ~self.__eq__(other) - - def __and__(self, other): - if not isinstance(other, self.__class__): - raise TypeError("other must be a %s instance" % self.__class__.__name__) - - return _bounds_and_op_dynamic(self, other) - - def __add__(self, other): - if isinstance(other, self.__class__): - return _bounds_bounds_add_op_dynamic(self, other, 0) - elif isinstance(other, self._pos_class): - return _bounds_pos_add_op_dynamic(self, other, 0) - else: - raise TypeError( - "other must be either a %s or a %s" - % (self.__class__.__name__, self._pos_class.__name__) - ) - @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -695,25 +684,16 @@ def __init__(self, *args, **kwargs): self.ymin, "BoundsI must be initialized with integer values" ) - if isinstance(self.deltax, int) and isinstance(self.deltay, int): - self._isstaticshape = True - else: - self._isstaticshape = False - - if ( - isinstance(self.xmin, int) - and isinstance(self.ymin, int) - and isinstance(self.deltax, int) - and isinstance(self.deltay, int) - ): - self._isstatic = True - else: - self._isstatic = False + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + if not isinstance(self.deltax, int): + self.deltax = int(self.deltax.item()) + if not isinstance(self.deltay, int): + self.deltay = int(self.deltay.item()) - if self.isStaticShape(): - self._isdefined = self.deltax >= 1 and self.deltay >= 1 - else: - self._isdefined = (self.deltax >= 1) & (self.deltay >= 1) + self._isdefined = self.deltax >= 1 and self.deltay >= 1 def _check_scalar(self, x, name): try: @@ -731,17 +711,10 @@ def _check_scalar(self, x, name): def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if self._isstaticshape: - if self._isdefined: - return self.deltay, self.deltax - else: - return 0, 0 + if self._isdefined: + return self.deltay, self.deltax else: - return jax.lax.cond( - self._isdefined, - lambda: (self.deltay, self.deltax), - lambda: (jnp.zeros_like(self.deltay), jnp.zeros_like(self.deltax)), - ) + return 0, 0 @property def xmax(self): @@ -761,17 +734,10 @@ def ymax(self, value): def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if self._isstaticshape: - if self._isdefined: - return self.deltax * self.deltay - else: - return 0 + if self._isdefined: + return self.deltax * self.deltay else: - return jax.lax.cond( - self._isdefined, - lambda: self.deltax * self.deltay, - lambda: 0.0, - ) + return 0 @property def _center(self): @@ -790,86 +756,43 @@ def includes(self, *args): if len(args) == 1: if isinstance(args[0], Bounds): b = args[0] - if self.isStatic() and b.isStatic(): - return ( - self.isDefined() - and b.isDefined() - and (self.xmin <= b.xmin) - and (self.xmax >= b.xmax) - and (self.ymin <= b.ymin) - and (self.ymax >= b.ymax) - ) - else: - return ( - jnp.array(self.isDefined()) - & jnp.array(b.isDefined()) - & jnp.array(self.xmin <= b.xmin) - & jnp.array(self.xmax >= b.xmax) - & jnp.array(self.ymin <= b.ymin) - & jnp.array(self.ymax >= b.ymax) - ) + return ( + jnp.array(self.isDefined()) + & jnp.array(b.isDefined()) + & jnp.array(self.xmin <= b.xmin) + & jnp.array(self.xmax >= b.xmax) + & jnp.array(self.ymin <= b.ymin) + & jnp.array(self.ymax >= b.ymax) + ) elif isinstance(args[0], Position): p = args[0] - ok_types = STATIC_SCALAR_TYPES - if ( - self._isstatic - and isinstance(p.x, ok_types) - and isinstance(p.y, ok_types) - ): - return ( - self.isDefined() - and (self.xmin <= p.x) - and (self.ymin <= p.y) - and (p.x <= self.xmax) - and (p.y <= self.ymax) - ) - else: - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= p.x) - & jnp.array(self.ymin <= p.y) - & jnp.array(p.x <= self.xmax) - & jnp.array(p.y <= self.ymax) - ) + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) + ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args x = cast_to_float(x) y = cast_to_float(y) - if self._isstatic and isinstance(x, float) and isinstance(y, float): - return ( - self.isDefined() - and (self.xmin <= x) - and (self.ymin <= y) - and (x <= self.xmax) - and (y <= self.ymax) - ) - else: - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= x) - & jnp.array(self.ymin <= y) - & jnp.array(x <= self.xmax) - & jnp.array(y <= self.ymax) - ) + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) + ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") else: raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) def __repr__(self): - # sometimes we will encounter a tracer here - # and so we suppress any boolean conversion errors - try: - if jnp.any(self.isDefined()): - print_full = True - else: - print_full = False - except Exception: - print_full = True - - if print_full: + if self._isdefined: return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -881,17 +804,7 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - # sometimes we will encounter a tracer here - # and so we suppress any boolean conversion errors - try: - if jnp.any(self.isDefined()): - print_full = True - else: - print_full = False - except Exception: - print_full = True - - if print_full: + if self._isdefined: return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -913,91 +826,17 @@ def __hash__(self): ) ) - def __eq__(self, other): - if self is other: - if self.isStatic() and other.isStatic(): - return True - else: - return jnp.array(True) - elif isinstance(other, self.__class__): - if self.isStatic() and other.isStatic(): - min_eq = (self.xmin == other.xmin) and (self.ymin == other.ymin) - self_isdef = self.isDefined() - other_isdef = other.isDefined() - shape_eq = (self.deltax == other.deltax) and ( - self.deltay == other.deltay - ) - return (self_isdef and other_isdef and shape_eq and min_eq) or ( - (not self_isdef) and (not other_isdef) - ) - else: - min_eq = jnp.array(self.xmin == other.xmin) & jnp.array( - self.ymin == other.ymin - ) - self_isdef = jnp.array(self.isDefined()) - other_isdef = jnp.array(other.isDefined()) - shape_eq = jnp.array(self.deltax == other.deltax) & jnp.array( - self.deltay == other.deltay - ) - return (self_isdef & other_isdef & shape_eq & min_eq) | ( - (~self_isdef) & (~other_isdef) - ) - else: - return False - - def __ne__(self, other): - if not isinstance(other, self.__class__): - return True - - if self.isStatic() and other.isStatic(): - return not self.__eq__(other) - else: - return ~self.__eq__(other) - - def __and__(self, other): - if not isinstance(other, self.__class__): - raise TypeError("other must be a %s instance" % self.__class__.__name__) - - if self.isStatic() and other.isStatic(): - return _bounds_and_op_static(self, other) - else: - return _bounds_and_op_dynamic(self, other) - - def __add__(self, other): - if isinstance(other, self.__class__): - if self.isStatic() and other.isStatic(): - return _bounds_bounds_add_op_static(self, other) - else: - return _bounds_bounds_add_op_dynamic(self, other, 1) - elif isinstance(other, self._pos_class): - return _bounds_pos_add_op_dynamic(self, other, 1) - else: - raise TypeError( - "other must be either a %s or a %s" - % (self.__class__.__name__, self._pos_class.__name__) - ) - def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} + children = (self.xmin, self.ymin) - if self._isstatic: - aux_data["xmin"] = self.xmin - aux_data["ymin"] = self.ymin - - if self._isstaticshape: - aux_data["deltax"] = self.deltax - aux_data["deltay"] = self.deltay - aux_data["isdefined"] = self._isdefined - - if self._isstatic: - children = tuple() - elif self._isstaticshape: - children = (self.xmin, self.ymin) - else: - children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) + # untraced aux data + aux_data = {} + aux_data["deltax"] = self.deltax + aux_data["deltay"] = self.deltay + aux_data["isdefined"] = self._isdefined return (children, aux_data) @@ -1005,26 +844,10 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) - ret._isstatic = aux_data["isstatic"] - ret._isstaticshape = aux_data["isstaticshape"] - - if ret._isstatic: - ret.xmin = aux_data["xmin"] - ret.ymin = aux_data["ymin"] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - ret._isdefined = aux_data["isdefined"] - elif ret._isstaticshape: - ret.xmin = children[0] - ret.ymin = children[1] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - ret._isdefined = aux_data["isdefined"] - else: - ret.xmin = children[0] - ret.deltax = children[1] - ret.ymin = children[2] - ret.deltay = children[3] - ret._isdefined = children[4] + ret.xmin = children[0] + ret.ymin = children[1] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] return ret diff --git a/jax_galsim/image.py b/jax_galsim/image.py index c4a60ed3..3cc813b8 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -7,7 +7,6 @@ from jax_galsim.bounds import Bounds, BoundsD, BoundsI from jax_galsim.core.utils import ( - STATIC_SCALAR_TYPES, cast_numpy_array_to_native_byte_order, ensure_hashable, implements, @@ -270,13 +269,6 @@ def __init__(self, *args, **kwargs): raise TypeError("wcs parameters must be a galsim.BaseWCS instance") self.wcs = wcs - # raise an error if bounds doesn't have a fixed width - if not self._bounds.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - @staticmethod def _get_xmin_ymin(array, kwargs, check_bounds=True): """A helper function for parsing xmin, ymin, bounds options with a given array""" @@ -289,13 +281,6 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # raise an error if bounds doesn't have a fixed width - if not b.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - if check_bounds and b.isDefined(): if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -588,13 +573,6 @@ def resize(self, bounds, wcs=None): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # raise an error if bounds doesn't have a fixed width - if not bounds.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: @@ -605,34 +583,23 @@ def subImage(self, bounds): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # raise an error if bounds doesn't have a fixed width - if not bounds.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) + if ( - self.bounds.isStatic() - and bounds.isStatic() - and not self.bounds.includes(bounds) + isinstance(self.bounds.xmin, int) + and isinstance(self.bounds.ymin, int) + and isinstance(bounds.xmin, int) + and isinstance(bounds.ymin, int) ): - raise _galsim.GalSimBoundsError( - "Attempt to access subImage not (fully) in image", bounds, self.bounds - ) - else: - inc_val = jnp.array(self.bounds.includes(bounds)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "Attempt to access subImage not (fully) in image", - ) - - if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin @@ -659,32 +626,17 @@ def setSubImage(self, bounds, rhs): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # raise an error if bounds doesn't have a fixed width - if not bounds.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if ( - self.bounds.isStatic() - and bounds.isStatic() - and not self.bounds.includes(bounds) - ): - raise _galsim.GalSimBoundsError( - "Attempt to access subImage not (fully) in image", bounds, self.bounds - ) - else: - inc_val = jnp.array(self.bounds.includes(bounds)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "Attempt to access subImage not (fully) in image", - ) + + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") @@ -695,7 +647,12 @@ def setSubImage(self, bounds, rhs): rhs=rhs, ) - if self.bounds.isStatic() and bounds.isStatic(): + if ( + isinstance(self.bounds.xmin, int) + and isinstance(self.bounds.ymin, int) + and isinstance(bounds.xmin, int) + and isinstance(bounds.ymin, int) + ): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin @@ -777,16 +734,9 @@ def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - # raise an error if bounds doesn't have a fixed width - if not bounds.isStaticShape(): - raise RuntimeError( - "JAX-GalSim `Image` objects must have a `BoundsI` instance with " - "a static shape (i.e., `image.bounds.isStaticShape() is True`)." - ) - def _raise_if_nonzero(bnds, x_or_y, msg): if x_or_y == "x": - if bnds.isStatic(): + if isinstance(bnds.xmin, int): if bnds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( msg, @@ -800,7 +750,7 @@ def _raise_if_nonzero(bnds, x_or_y, msg): msg, ) else: - if bnds.isStatic(): + if isinstance(bnds.ymin, int): if bnds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( msg, @@ -964,19 +914,13 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if self.bounds.isStatic() and not self.bounds.includes(0, 0): - raise _galsim.GalSimBoundsError( - "calculate_inverse_fft requires that the image includes (0,0)", - PositionI(0, 0), - self.bounds, - ) - else: - inc_val = jnp.array(self.bounds.includes(0, 0)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "calculate_inverse_fft requires that the image includes (0,0)", - ) + + inc_val = jnp.array(self.bounds.includes(0, 0)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "calculate_inverse_fft requires that the image includes (0,0)", + ) No2 = max( max(self.bounds.xmax, -self.bounds.ymin), @@ -1136,24 +1080,12 @@ def getValue(self, x, y): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if ( - self.bounds.isStatic() - and isinstance(x, STATIC_SCALAR_TYPES) - and isinstance(y, STATIC_SCALAR_TYPES) - and not self.bounds.includes(x, y) - ): - raise _galsim.GalSimBoundsError( - "Attempt to access position not in bounds of image.", - PositionI(x, y), - self.bounds, - ) - else: - inc_val = jnp.array(self.bounds.includes(x, y)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "Attempt to access position not in bounds of image.", - ) + inc_val = jnp.array(self.bounds.includes(x, y)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access position not in bounds of image.", + ) return self._getValue(x, y) @@ -1172,17 +1104,12 @@ def setValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): - raise _galsim.GalSimBoundsError( - "Attempt to set position not in bounds of image", pos, self.bounds - ) - else: - inc_val = jnp.array(self.bounds.includes(pos)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "Attempt to set position not in bounds of image", - ) + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) self._setValue(pos.x, pos.y, value) @@ -1201,17 +1128,12 @@ def addValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): - raise _galsim.GalSimBoundsError( - "Attempt to set position not in bounds of image", pos, self.bounds - ) - else: - inc_val = jnp.array(self.bounds.includes(pos)) - inc_val = equinox.error_if( - inc_val, - jnp.any(~inc_val), - "Attempt to set position not in bounds of image", - ) + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) self._addValue(pos.x, pos.y, value) @@ -1340,16 +1262,8 @@ def rot_180(self): def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing - if self.bounds.isStatic(): - children = (self.array, self.wcs) - aux_data = { - "dtype": self.dtype, - "bounds": self.bounds, - "isconst": self.isconst, - } - else: - children = (self.array, self.wcs, self.bounds) - aux_data = {"dtype": self.dtype, "isconst": self.isconst} + children = (self.array, self.wcs, self.bounds) + aux_data = {"dtype": self.dtype, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): @@ -1377,16 +1291,6 @@ def tree_unflatten(cls, aux_data, children): obj.header = aux_data["header"] if len(children) > 3: obj.photons = children[3] - else: - obj._bounds = children[2] - obj._dtype = aux_data["dtype"] - obj._is_const = aux_data["isconst"] - if len(children) > 3: - obj.added_flux = children[3] - if "header" in aux_data: - obj.header = aux_data["header"] - if len(children) > 4: - obj.photons = children[4] return obj @classmethod diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 6b5ffc0d..cf36dba8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -4,7 +4,6 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - STATIC_SCALAR_TYPES, cast_to_float, check_is_int_then_cast, ensure_hashable, @@ -183,13 +182,6 @@ def to_galsim(self): cast(self.y), ) - def isStatic(self): - """Returns ``True`` if the ``Position`` instance - ``x`` and ``y`` values are not arrays""" - return isinstance(self.x, STATIC_SCALAR_TYPES) and isinstance( - self.y, STATIC_SCALAR_TYPES - ) - @implements(_galsim.PositionD) @register_pytree_node_class From 9012a3e674a56c417e46ab081ce6ace8a4de9232 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 13:50:31 -0500 Subject: [PATCH 10/72] fix: wrong tracing branch; remove test --- jax_galsim/image.py | 20 ++++++++++---------- tests/jax/test_bounds_jax.py | 26 -------------------------- 2 files changed, 10 insertions(+), 36 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 3cc813b8..e35ddb97 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1281,16 +1281,16 @@ def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] - if "bounds" in aux_data: - obj._bounds = aux_data["bounds"] - obj._dtype = aux_data["dtype"] - obj._is_const = aux_data["isconst"] - if len(children) > 2: - obj.added_flux = children[2] - if "header" in aux_data: - obj.header = aux_data["header"] - if len(children) > 3: - obj.photons = children[3] + obj._bounds = children[2] + obj._dtype = aux_data["dtype"] + obj._is_const = aux_data["isconst"] + if len(children) > 3: + obj.added_flux = children[3] + if "header" in aux_data: + obj.header = aux_data["header"] + if len(children) > 4: + obj.photons = children[4] + return obj @classmethod diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index f9481487..19e318d5 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -5,32 +5,6 @@ import jax_galsim -@jax.vmap -@jax.jit -def _make_bounds_int(xmin, ymin, xmax, ymax): - bds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) - return bds, bds.isDefined() - - -def test_bounds_jax_vmap_isdefined_int(): - xmin = jnp.array([9, 10, 11, 12]) - xmax = jnp.array([12, 11, 10, 9]) - ymin = jnp.array([9, 11, 10, 12]) - ymax = jnp.array([10, 10, 11, 10]) - bds, isdef = _make_bounds_int(xmin, ymin, xmax, ymax) - np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) - - # turn a bounds of arrays into a list of bounds - # see https://github.com/jax-ml/jax/discussions/35711 - list_of_bnds = jax.tree.transpose( - jax.tree.structure(bds), None, jax.tree.map(list, bds) - ) - assert list_of_bnds[0] != list_of_bnds[2] - assert list_of_bnds[1] == list_of_bnds[2] - assert list_of_bnds[2] == list_of_bnds[3] - assert all(not bnds.isStatic() for bnds in list_of_bnds) - - @jax.vmap @jax.jit def _make_bounds_float(xmin, ymin, xmax, ymax): From 1abee57c357371b2765a87fe9566bdf72eb0ab3d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 13:56:16 -0500 Subject: [PATCH 11/72] refactor: only need one of these --- jax_galsim/bounds.py | 126 ++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 85 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 818590a1..85ea71c8 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -184,9 +184,43 @@ def true_center(self): @implements(_galsim.Bounds.includes) def includes(self, *args): - raise NotImplementedError( - "Subclasses of `Bounds` must implement the `includes` method!" - ) + if len(args) == 1: + if isinstance(args[0], Bounds): + b = args[0] + return ( + jnp.array(self.isDefined()) + & jnp.array(b.isDefined()) + & jnp.array(self.xmin <= b.xmin) + & jnp.array(self.xmax >= b.xmax) + & jnp.array(self.ymin <= b.ymin) + & jnp.array(self.ymax >= b.ymax) + ) + elif isinstance(args[0], Position): + p = args[0] + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) + ) + else: + raise TypeError("Invalid argument %s" % args[0]) + elif len(args) == 2: + x, y = args + x = cast_to_float(x) + y = cast_to_float(y) + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) + ) + elif len(args) == 0: + raise TypeError("include takes at least 1 argument (0 given)") + else: + raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) @implements(_galsim.Bounds.expand) def expand(self, factor_x, factor_y=None): @@ -519,10 +553,10 @@ class BoundsD(Bounds): def __init__(self, *args, **kwargs): do_isdefined = self._parse_args(*args, **kwargs) - self.xmin = cast_to_float(self.xmin) - self.deltax = cast_to_float(self.deltax) - self.ymin = cast_to_float(self.ymin) - self.deltay = cast_to_float(self.deltay) + self.xmin = cast_to_float(jnp.array(self.xmin)) + self.deltax = cast_to_float(jnp.array(self.deltax)) + self.ymin = cast_to_float(jnp.array(self.ymin)) + self.deltay = cast_to_float(jnp.array(self.deltay)) if do_isdefined: self._isdefined = (self.deltax >= 0) & (self.deltay >= 0) self._isdefined = jnp.array(self._isdefined) @@ -564,44 +598,6 @@ def _area(self): def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) - @implements(_galsim.Bounds.includes) - def includes(self, *args): - if len(args) == 1: - if isinstance(args[0], Bounds): - b = args[0] - return ( - self.isDefined() - & b.isDefined() - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) - ) - elif isinstance(args[0], Position): - p = args[0] - return ( - self.isDefined() - & (self.xmin <= p.x) - & (self.ymin <= p.y) - & (p.x <= self.xmax) - & (p.y <= self.ymax) - ) - else: - raise TypeError("Invalid argument %s" % args[0]) - elif len(args) == 2: - x, y = args - return ( - self.isDefined() - & (self.xmin <= cast_to_float(x)) - & (self.ymin <= cast_to_float(y)) - & (cast_to_float(x) <= self.xmax) - & (cast_to_float(y) <= self.ymax) - ) - elif len(args) == 0: - raise TypeError("include takes at least 1 argument (0 given)") - else: - raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) - def __repr__(self): # sometimes we will encounter a tracer here # and so we suppress any boolean conversion errors @@ -751,46 +747,6 @@ def _center(self): self.ymin + self.deltay // 2, ) - @implements(_galsim.Bounds.includes) - def includes(self, *args): - if len(args) == 1: - if isinstance(args[0], Bounds): - b = args[0] - return ( - jnp.array(self.isDefined()) - & jnp.array(b.isDefined()) - & jnp.array(self.xmin <= b.xmin) - & jnp.array(self.xmax >= b.xmax) - & jnp.array(self.ymin <= b.ymin) - & jnp.array(self.ymax >= b.ymax) - ) - elif isinstance(args[0], Position): - p = args[0] - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= p.x) - & jnp.array(self.ymin <= p.y) - & jnp.array(p.x <= self.xmax) - & jnp.array(p.y <= self.ymax) - ) - else: - raise TypeError("Invalid argument %s" % args[0]) - elif len(args) == 2: - x, y = args - x = cast_to_float(x) - y = cast_to_float(y) - return ( - jnp.array(self.isDefined()) - & jnp.array(self.xmin <= x) - & jnp.array(self.ymin <= y) - & jnp.array(x <= self.xmax) - & jnp.array(y <= self.ymax) - ) - elif len(args) == 0: - raise TypeError("include takes at least 1 argument (0 given)") - else: - raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) - def __repr__(self): if self._isdefined: return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( From 6718b24a7c9f0b8de2984566f6974bbf0fa7a35f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 20 May 2026 13:58:08 -0500 Subject: [PATCH 12/72] Apply suggestion from @beckermr --- jax_galsim/bounds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 85ea71c8..b69c26df 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -526,7 +526,7 @@ def _bounds_pos_add_op_dynamic(self, other): ) is_defined = jnp.where( jnp.array(self._isdefined), - jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1), jnp.array(True), ) # we have to do a conversion to static bools here too From f8d2c003fcecc3e0976f6de1136549e2c4b46d45 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 20 May 2026 13:58:37 -0500 Subject: [PATCH 13/72] Apply suggestion from @beckermr --- tests/jax/test_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index c3afed10..b0435ecb 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -359,8 +359,6 @@ def _reg_fun(p): if issubclass(cls, jax_galsim.Bounds) and method in [ "xmax", "ymax", - "isStatic", - "isStaticShape", ]: continue From 2b7634f091e14a2c69e5ec1baa468bac38c4b70a Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 20 May 2026 14:01:33 -0500 Subject: [PATCH 14/72] Apply suggestion from @beckermr --- jax_galsim/bounds.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index b69c26df..59801a68 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -730,10 +730,10 @@ def ymax(self, value): def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if self._isdefined: - return self.deltax * self.deltay - else: + if not self._isdefined: return 0 + else: + return self.deltax * self.deltay @property def _center(self): From e2a4fac3c75d941980bc1f90cec3fc32f8f31068 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 14:59:02 -0500 Subject: [PATCH 15/72] test: update to latest test submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index f3d81a1d..007a87ae 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit f3d81a1d18a30651d8769818731d4c4ac3541478 +Subproject commit 007a87aeed4d6b77f03745fbf977df3e35918eb0 From 93eb30f8a76bbae23fe016d5dda83f6019310cd9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 15:25:06 -0500 Subject: [PATCH 16/72] fix: ensure we handle branches on bounds eq properly --- jax_galsim/bounds.py | 26 ++++++++++++++++++-------- jax_galsim/image.py | 43 +++++++++++++++++++++---------------------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 59801a68..00e390f3 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -287,14 +287,24 @@ def __eq__(self, other): elif isinstance(other, self.__class__): self_isdef = jnp.array(self.isDefined()) other_isdef = jnp.array(other.isDefined()) - return ( - self_isdef - & other_isdef - & jnp.array(self.xmin == other.xmin) - & jnp.array(self.ymin == other.ymin) - & jnp.array(self.xmax == other.xmax) - & jnp.array(self.ymax == other.ymax) - ) | ((~self_isdef) & (~other_isdef)) + if isinstance(self, BoundsD): + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.xmax == other.xmax) + & jnp.array(self.ymax == other.ymax) + ) | ((~self_isdef) & (~other_isdef)) + else: + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.deltax == other.deltax) + & jnp.array(self.deltay == other.deltay) + ) | ((~self_isdef) & (~other_isdef)) else: return jnp.array(False) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index e35ddb97..0bb6f13c 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -872,14 +872,12 @@ def calculate_fft(self): ), ) + # galsim branches here if the image has the correct bounds, but JAX can't branch + # on calls that generate different size arrays + # so we always make a new image full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) - if self.bounds == full_bounds: - # Then the image is already in the shape we need. - ximage = self - else: - # Then we pad out with zeros - ximage = Image(full_bounds, dtype=self.dtype, init_value=0) - ximage[self.bounds] = self[self.bounds] + ximage = Image(full_bounds, dtype=self.dtype, init_value=0) + ximage[self.bounds] = self[self.bounds] dx = self.scale # dk = 2pi / (N dk) @@ -928,21 +926,22 @@ def calculate_inverse_fft(self): ) target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) - if self.bounds == target_bounds: - # Then the image is already in the shape we need. - kimage = self - else: - # Then we can pad out with zeros and wrap to get this in the form we need. - full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) - kimage = Image(full_bounds, dtype=self.dtype, init_value=0) - posx_bounds = BoundsI( - xmin=0, - xmax=self.bounds.xmax, - ymin=self.bounds.ymin, - ymax=self.bounds.ymax, - ) - kimage[posx_bounds] = self[posx_bounds] - kimage = kimage._wrap(target_bounds, True, False, 2 * No2) + + # galsim branches here if the image has the correct bounds, but JAX can't branch + # on calls that generate different size arrays + # so we always make a new image + + # Then we can pad out with zeros and wrap to get this in the form we need. + full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) + kimage = Image(full_bounds, dtype=self.dtype, init_value=0) + posx_bounds = BoundsI( + xmin=0, + xmax=self.bounds.xmax, + ymin=self.bounds.ymin, + ymax=self.bounds.ymax, + ) + kimage[posx_bounds] = self[posx_bounds] + kimage = kimage._wrap(target_bounds, True, False, 2 * No2) dk = self.scale # dx = 2pi / (N dk) From e44d8df3af00ec7172ba87ac5a78ec78f8b4bf5e Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 15:46:41 -0500 Subject: [PATCH 17/72] fix: this needs to be a float --- jax_galsim/bounds.py | 49 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 00e390f3..a32e1d03 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -6,6 +6,7 @@ from jax_galsim.core.utils import ( cast_to_float, + cast_to_int, check_is_int_then_cast, ensure_hashable, implements, @@ -722,21 +723,59 @@ def numpyShape(self): else: return 0, 0 + # we store xmin internally as a float even though it is an int + # so that autodiff works properly (needs floats in general) + @property + def xmin(self): + return jnp.astype(self._xmin, int) + + @xmin.setter + def xmin(self, value): + self._xmin = jnp.astype(value, float) + @property def xmax(self): - return self.xmin + self.deltax - 1 + return cast_to_int(self.xmin + self.deltax - 1) @xmax.setter def xmax(self, value): self.deltax = value - self.xmin + 1 + self.deltax = check_is_int_then_cast( + self.deltax, "BoundsI xmax must be set to an integer value" + ) + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + if not isinstance(self.deltax, int): + self.deltax = int(self.deltax.item()) + + # we store ymin internally as a float even though it is an int + # so that autodiff works properly (needs floats in general) + @property + def ymin(self): + return jnp.astype(self._ymin, int) + + @ymin.setter + def ymin(self, value): + self._ymin = jnp.astype(value, float) @property def ymax(self): - return self.ymin + self.deltay - 1 + return cast_to_int(self.ymin + self.deltay - 1) @ymax.setter def ymax(self, value): self.deltay = value - self.ymin + 1 + self.deltay = check_is_int_then_cast( + self.deltay, "BoundsI ymax must be set to an integer value" + ) + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + if not isinstance(self.deltay, int): + self.deltay = int(self.deltay.item()) def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. @@ -796,7 +835,7 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.xmin, self.ymin) + children = (self._xmin, self._ymin) # untraced aux data aux_data = {} @@ -810,8 +849,8 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) - ret.xmin = children[0] - ret.ymin = children[1] + ret._xmin = children[0] + ret._ymin = children[1] ret.deltax = aux_data["deltax"] ret.deltay = aux_data["deltay"] ret._isdefined = aux_data["isdefined"] From ef586a3cc24ecadfa0bfafb36247b4cb216ca27b Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 16:12:59 -0500 Subject: [PATCH 18/72] fix: ensure we can FFT OK --- jax_galsim/image.py | 86 ++++++++++++++++++++++++++------------------- tests/GalSim | 2 +- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 0bb6f13c..ff36cb52 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -859,25 +859,30 @@ def calculate_fft(self): "JAX-GalSim does not support forward FFTs of complex dtypes." ) - # TODO: figure out how to do FFT at fixed size and then reconstruct - # the result - No2 = max( - max( - -self.bounds.xmin, - self.bounds.xmax + 1, - ), - max( - -self.bounds.ymin, - self.bounds.ymax + 1, - ), - ) + # This has to be a static known constant since it is an array size + # so we ensure it is evaluated at compile-time and extract it + # from the array. + with jax.ensure_compile_time_eval(): + No2 = jnp.maximum( + jnp.maximum( + -self.bounds.xmin, + self.bounds.xmax + 1, + ), + jnp.maximum( + -self.bounds.ymin, + self.bounds.ymax + 1, + ), + ) - # galsim branches here if the image has the correct bounds, but JAX can't branch - # on calls that generate different size arrays - # so we always make a new image full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) - ximage = Image(full_bounds, dtype=self.dtype, init_value=0) - ximage[self.bounds] = self[self.bounds] + if ( + self.bounds.deltax == full_bounds.deltax + and self.bounds.deltay == full_bounds.deltay + ): + ximage = self + else: + ximage = Image(full_bounds, dtype=self.dtype, init_value=0) + ximage[self.bounds] = self[self.bounds] dx = self.scale # dk = 2pi / (N dk) @@ -920,28 +925,35 @@ def calculate_inverse_fft(self): "calculate_inverse_fft requires that the image includes (0,0)", ) - No2 = max( - max(self.bounds.xmax, -self.bounds.ymin), - self.bounds.ymax, - ) + # This has to be a static known constant since it is an array size + # so we ensure it is evaluated at compile-time and extract it + # from the array. + with jax.ensure_compile_time_eval(): + No2 = jnp.maximum( + jnp.maximum(self.bounds.xmax, -self.bounds.ymin), + self.bounds.ymax, + ) + if not isinstance(No2, int): + No2 = int(No2.item()) target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) - - # galsim branches here if the image has the correct bounds, but JAX can't branch - # on calls that generate different size arrays - # so we always make a new image - - # Then we can pad out with zeros and wrap to get this in the form we need. - full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) - kimage = Image(full_bounds, dtype=self.dtype, init_value=0) - posx_bounds = BoundsI( - xmin=0, - xmax=self.bounds.xmax, - ymin=self.bounds.ymin, - ymax=self.bounds.ymax, - ) - kimage[posx_bounds] = self[posx_bounds] - kimage = kimage._wrap(target_bounds, True, False, 2 * No2) + if ( + self.bounds.deltax == target_bounds.deltax + and self.bounds.deltay == target_bounds.deltay + ): + kimage = self + else: + # Then we can pad out with zeros and wrap to get this in the form we need. + full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) + kimage = Image(full_bounds, dtype=self.dtype, init_value=0) + posx_bounds = BoundsI( + xmin=0, + xmax=self.bounds.xmax, + ymin=self.bounds.ymin, + ymax=self.bounds.ymax, + ) + kimage[posx_bounds] = self[posx_bounds] + kimage = kimage._wrap(target_bounds, True, False, 2 * No2) dk = self.scale # dx = 2pi / (N dk) diff --git a/tests/GalSim b/tests/GalSim index 007a87ae..fad1390f 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 007a87aeed4d6b77f03745fbf977df3e35918eb0 +Subproject commit fad1390f674e0ea85b00fcd6e1fc856f35e6fc49 From 618e137851287dbf403a91db9a8ab8fb1c700d8a Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 20 May 2026 16:13:34 -0500 Subject: [PATCH 19/72] Update jax_galsim/bounds.py --- jax_galsim/bounds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a32e1d03..83873e5f 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -171,7 +171,7 @@ def center(self): @implements(_galsim.Bounds.true_center) def true_center(self): if not isinstance(self._isdefined, jnp.ndarray): - if not self.isDefined(): + if not self._isdefined: raise _galsim.GalSimUndefinedBoundsError( "true_center is invalid for an undefined Bounds" ) From 01f893cc6196e7280aba54af5be67edef3e45602 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 20 May 2026 16:13:44 -0500 Subject: [PATCH 20/72] Update jax_galsim/bounds.py --- jax_galsim/bounds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 83873e5f..ac6d74dd 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -155,7 +155,7 @@ def origin(self): @implements(_galsim.Bounds.center) def center(self): if not isinstance(self._isdefined, jnp.ndarray): - if not self.isDefined(): + if not self._isdefined: raise _galsim.GalSimUndefinedBoundsError( "center is invalid for an undefined Bounds" ) From 9d6f2fae49be9040f293caafaa98d750fcb73425 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 16:42:17 -0500 Subject: [PATCH 21/72] fix: use latest submodule --- jax_galsim/image.py | 2 ++ tests/GalSim | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index ff36cb52..8bee3e24 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -873,6 +873,8 @@ def calculate_fft(self): self.bounds.ymax + 1, ), ) + if not isinstance(No2, int): + No2 = int(No2.item()) full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) if ( diff --git a/tests/GalSim b/tests/GalSim index fad1390f..f8bf84b7 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit fad1390f674e0ea85b00fcd6e1fc856f35e6fc49 +Subproject commit f8bf84b7baeac968f48206447abb8afdbd1cd451 From 13e0266bd88d8ca5af2577e2f6981bc60b0422c9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 16:43:42 -0500 Subject: [PATCH 22/72] test: update to latest submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index f8bf84b7..b36f5f93 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit f8bf84b7baeac968f48206447abb8afdbd1cd451 +Subproject commit b36f5f9353571f7a35b9bc788e94d51e1ce9295b From 2a9c0e735a4364d347c74969f8660eb3182b28bd Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 16:51:17 -0500 Subject: [PATCH 23/72] fix: use to_galsim for fpacking --- jax_galsim/fits.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/jax_galsim/fits.py b/jax_galsim/fits.py index 8c7837f9..5c66e1c6 100644 --- a/jax_galsim/fits.py +++ b/jax_galsim/fits.py @@ -56,18 +56,7 @@ def readCube(*args, **kwargs): @contextmanager def _image_as_numpy(image): if isinstance(image, Image): - try: - orig_array = image._array - # convert to numpy so astropy doesn't complain - image._array = np.array(image.array, dtype=orig_array.dtype) - # some of these check for Image instances, so we hackily set the class - # on the way in - old_class = image.__class__ - image.__class__ = _galsim.Image - yield image - finally: - image.__class__ = old_class - image._array = orig_array + yield image.to_galsim() else: try: yield np.array(image, dtype=image.dtype) From ac4e000edee4a769f648b2f8a660c4d27d931cec Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 23:25:45 -0500 Subject: [PATCH 24/72] fix: do not convert all bounds props to arrays --- jax_galsim/bounds.py | 43 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index ac6d74dd..a6461b6b 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -24,6 +24,8 @@ - Upon initialization, if a ``BoundsI`` object has a non-static shape, JAX-GalSim will attempt to convert it to a static shape by extracting the dimensions from the array via ``.item()``. This operation will cause JAX to raise an error if the code is being traced. +- If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, an error will be raised + if one attempts to convert those values to non-static values. - JAX-GalSim does not support the use of the `&/+` dunder methods (i.e., ``__and__`` and ``__add__``) for ``BoundsI`` objects when tracing code. - JAX-Galsim supports an additional initialization signature ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` @@ -675,6 +677,10 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): + # we set these variables to disable type checking and conversion + # for xmin/ymin while we initialize the object + self._isstatic = True + self._dotypechecking = False self._parse_args(*args, **kwargs) # validate inputs are ints @@ -702,6 +708,13 @@ def __init__(self, *args, **kwargs): self._isdefined = self.deltax >= 1 and self.deltay >= 1 + # now we compute these properties correctly and turn on type checking + if isinstance(self._xmin, int) and isinstance(self._ymin, int): + self._isstatic = True + else: + self._isstatic = False + self._dotypechecking = True + def _check_scalar(self, x, name): try: if ( @@ -727,11 +740,22 @@ def numpyShape(self): # so that autodiff works properly (needs floats in general) @property def xmin(self): - return jnp.astype(self._xmin, int) + if self._isstatic: + return self._xmin + else: + return jnp.astype(self._xmin, int) @xmin.setter def xmin(self, value): - self._xmin = jnp.astype(value, float) + value = check_is_int_then_cast(value, "BoundsI xmin values must be integers") + if self._isstatic: + if self._dotypechecking and isinstance(value, jnp.ndarray): + raise RuntimeError( + "Static `BoundsI` classes cannot be converted to dynamic ones." + ) + self._xmin = value + else: + self._xmin = jnp.astype(value, float) @property def xmax(self): @@ -754,11 +778,22 @@ def xmax(self, value): # so that autodiff works properly (needs floats in general) @property def ymin(self): - return jnp.astype(self._ymin, int) + if self._isstatic: + return self._ymin + else: + return jnp.astype(self._ymin, int) @ymin.setter def ymin(self, value): - self._ymin = jnp.astype(value, float) + value = check_is_int_then_cast(value, "BoundsI ymin values must be integers") + if self._isstatic: + if self._dotypechecking and isinstance(value, jnp.ndarray): + raise RuntimeError( + "Static `BoundsI` classes cannot be converted to dynamic ones." + ) + self._ymin = value + else: + self._ymin = jnp.astype(value, float) @property def ymax(self): From 004ada8480ae944171e4a6381fb539a3039ccb13 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 23:30:17 -0500 Subject: [PATCH 25/72] fix: put back variable pytree def --- jax_galsim/bounds.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a6461b6b..72a675c3 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -869,11 +869,17 @@ def __hash__(self): def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" + aux_data = {"isstatic": self._isstatic} + # Define the children nodes of the PyTree that need tracing - children = (self._xmin, self._ymin) + if self._isstatic: + children = tuple() + aux_data["xmin"] = self._xmin + aux_data["ymin"] = self._ymin + else: + children = (self._xmin, self._ymin) # untraced aux data - aux_data = {} aux_data["deltax"] = self.deltax aux_data["deltay"] = self.deltay aux_data["isdefined"] = self._isdefined @@ -884,10 +890,14 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) - ret._xmin = children[0] - ret._ymin = children[1] + if aux_data["isstatic"]: + ret._xmin = aux_data["xmin"] + ret._ymin = aux_data["ymin"] + else: + ret._xmin = children[0] + ret._ymin = children[1] ret.deltax = aux_data["deltax"] ret.deltay = aux_data["deltay"] ret._isdefined = aux_data["isdefined"] - + ret._isstatic = aux_data["isstatic"] return ret From 7f97f11deaddf3b17761c0e590dcf57b70f936c7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 20 May 2026 23:35:36 -0500 Subject: [PATCH 26/72] fix: make sure to send fits headers to galsim --- jax_galsim/image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8bee3e24..afedf963 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1328,9 +1328,12 @@ def from_galsim(cls, galsim_image): def to_galsim(self): """Create a galsim `Image` from a `jax_galsim.Image` object.""" wcs = self.wcs.to_galsim() if self.wcs is not None else None - return _galsim.Image( + ret = _galsim.Image( np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs ) + if hasattr(self, "header"): + ret.header = self.header + return ret @implements( _galsim.Image.FindAdaptiveMom, From a5fc11e51f0d3237d2a2bba3d033ca40a4e8893d Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:51:59 -0500 Subject: [PATCH 27/72] Apply suggestion from @beckermr --- jax_galsim/image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index afedf963..e6d924e2 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -859,6 +859,8 @@ def calculate_fft(self): "JAX-GalSim does not support forward FFTs of complex dtypes." ) + # TODO: figure out how to do FFT at fixed size and then reconstruct + # the result. - MRB # This has to be a static known constant since it is an array size # so we ensure it is evaluated at compile-time and extract it # from the array. From fea3aec42f829b4b32f5b707b64b454f87599eea Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 09:12:30 -0500 Subject: [PATCH 28/72] fix: add back isStatic method --- jax_galsim/bounds.py | 65 ++++++++++++++++++++++++++-------------- jax_galsim/core/utils.py | 2 +- jax_galsim/image.py | 18 +++-------- 3 files changed, 48 insertions(+), 37 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 72a675c3..e154cf97 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -2,6 +2,7 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( @@ -26,8 +27,11 @@ cause JAX to raise an error if the code is being traced. - If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, an error will be raised if one attempts to convert those values to non-static values. -- JAX-GalSim does not support the use of the `&/+` dunder methods (i.e., ``__and__`` and ``__add__``) - for ``BoundsI`` objects when tracing code. +- ``Bounds`` classes in JAX-GalSim have an etxra method, ``isStatic`` that returns ``True`` if the object + was instantiated with static ``xmin`` and ``ymin`` values. This method always returns ``False`` for + ``BoundsD`` objects. +- JAX-GalSim does not support the use of the `&` and `+` operators (i.e., the dunder methods ``__and__`` + and ``__add__`` ) with ``BoundsI`` objects when tracing code. - JAX-Galsim supports an additional initialization signature ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` to help users specify the widths ``deltax`` and ``deltay`` statically at initialization. - When calling ``jax.vmap``, ``jax.jit`` etc. with ``BoundsI`` objects, ``xmin`` and ``ymin`` are @@ -156,7 +160,7 @@ def origin(self): @property @implements(_galsim.Bounds.center) def center(self): - if not isinstance(self._isdefined, jnp.ndarray): + if isinstance(self._isdefined, bool): if not self._isdefined: raise _galsim.GalSimUndefinedBoundsError( "center is invalid for an undefined Bounds" @@ -172,7 +176,7 @@ def center(self): @property @implements(_galsim.Bounds.true_center) def true_center(self): - if not isinstance(self._isdefined, jnp.ndarray): + if isinstance(self._isdefined, bool): if not self._isdefined: raise _galsim.GalSimUndefinedBoundsError( "true_center is invalid for an undefined Bounds" @@ -331,7 +335,7 @@ def tree_flatten(self): # Define the children nodes of the PyTree that need tracing children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) # Define auxiliary static data that doesn’t need to be traced - aux_data = None + aux_data = {"isstatic": self._isstatic} return (children, aux_data) @classmethod @@ -343,6 +347,7 @@ def tree_unflatten(cls, aux_data, children): ret.ymin = children[2] ret.deltay = children[3] ret._isdefined = children[4] + ret._isstatic = aux_data["isstatic"] return ret @@ -387,6 +392,12 @@ def to_galsim(self): else: return gs_class() + def isStatic(self): + """Returns ``True`` if the ``BoundsI`` instance + has static, known dimensions and location. Always returns + ``False`` for ``BoundsD``.""" + return self._isstatic + def _bounds_and_op_dynamic(self, other): xmin = jnp.maximum(self.xmin, other.xmin) @@ -432,7 +443,8 @@ def _bounds_and_op_dynamic(self, other): deltay=ymax - ymin + 1, ) # we have to do a conversion to static bools here too - ret._isdefined = bool(is_defined.item()) + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) else: ret = cls.__new__(cls) ret.xmin = xmin @@ -440,6 +452,7 @@ def _bounds_and_op_dynamic(self, other): ret.ymin = ymin ret.deltay = ymax - ymin ret._isdefined = is_defined + ret._isstatic = False return ret @@ -485,7 +498,8 @@ def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): ), ) # we have to do a conversion to static bools here too - ret._isdefined = bool(is_defined.item()) + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) else: ret = cls.__new__(cls) ret.xmin = xmin @@ -501,6 +515,7 @@ def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): jnp.array(other._isdefined), ), ) + ret._isstatic = False return ret @@ -543,7 +558,8 @@ def _bounds_pos_add_op_dynamic(self, other): jnp.array(True), ) # we have to do a conversion to static bools here too - ret._isdefined = bool(is_defined.item()) + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) else: ret = cls.__new__(cls) ret.xmin = xmin @@ -555,6 +571,7 @@ def _bounds_pos_add_op_dynamic(self, other): jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), jnp.array(True), ) + ret._isstatic = False return ret @@ -573,13 +590,14 @@ def __init__(self, *args, **kwargs): if do_isdefined: self._isdefined = (self.deltax >= 0) & (self.deltay >= 0) self._isdefined = jnp.array(self._isdefined) + self._isstatic = False def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)) and x.shape == () - and jnp.issubdtype(x.dtype, jnp.floating) + and jnp.issubdtype(jnp.array(x).dtype, jnp.floating) ): return elif x == float(x): @@ -701,10 +719,11 @@ def __init__(self, *args, **kwargs): # this will raise if values are being traced # we let that error propagate instead of reraising # our own. - if not isinstance(self.deltax, int): - self.deltax = int(self.deltax.item()) - if not isinstance(self.deltay, int): - self.deltay = int(self.deltay.item()) + with jax.ensure_compile_time_eval(): + if not isinstance(self.deltax, int): + self.deltax = int(self.deltax.item()) + if not isinstance(self.deltay, int): + self.deltay = int(self.deltay.item()) self._isdefined = self.deltax >= 1 and self.deltay >= 1 @@ -718,9 +737,9 @@ def __init__(self, *args, **kwargs): def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)) and x.shape == () - and jnp.issubdtype(x.dtype, jnp.integer) + and jnp.issubdtype(jnp.array(x).dtype, jnp.integer) ): return elif x == int(x): @@ -749,7 +768,7 @@ def xmin(self): def xmin(self, value): value = check_is_int_then_cast(value, "BoundsI xmin values must be integers") if self._isstatic: - if self._dotypechecking and isinstance(value, jnp.ndarray): + if self._dotypechecking and not isinstance(value, int): raise RuntimeError( "Static `BoundsI` classes cannot be converted to dynamic ones." ) @@ -771,8 +790,9 @@ def xmax(self, value): # this will raise if values are being traced # we let that error propagate instead of reraising # our own. - if not isinstance(self.deltax, int): - self.deltax = int(self.deltax.item()) + with jax.ensure_compile_time_eval(): + if not isinstance(self.deltax, int): + self.deltax = int(self.deltax.item()) # we store ymin internally as a float even though it is an int # so that autodiff works properly (needs floats in general) @@ -787,7 +807,7 @@ def ymin(self): def ymin(self, value): value = check_is_int_then_cast(value, "BoundsI ymin values must be integers") if self._isstatic: - if self._dotypechecking and isinstance(value, jnp.ndarray): + if self._dotypechecking and not isinstance(value, int): raise RuntimeError( "Static `BoundsI` classes cannot be converted to dynamic ones." ) @@ -809,8 +829,9 @@ def ymax(self, value): # this will raise if values are being traced # we let that error propagate instead of reraising # our own. - if not isinstance(self.deltay, int): - self.deltay = int(self.deltay.item()) + with jax.ensure_compile_time_eval(): + if not isinstance(self.deltay, int): + self.deltay = int(self.deltay.item()) def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 9a06975c..52d57ba5 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -25,7 +25,7 @@ def check_is_int_then_cast(val, msg): val = jnp.array(val) val = equinox.error_if( val, - np.any(val != jnp.trunc(val)), + jnp.any(val != jnp.trunc(val)), msg, ) val = val.astype(int) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index a55cbc14..ba52838f 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -594,12 +594,7 @@ def subImage(self, bounds): "Attempt to access subImage not (fully) in image", ) - if ( - isinstance(self.bounds.xmin, int) - and isinstance(self.bounds.ymin, int) - and isinstance(bounds.xmin, int) - and isinstance(bounds.ymin, int) - ): + if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin @@ -647,12 +642,7 @@ def setSubImage(self, bounds, rhs): rhs=rhs, ) - if ( - isinstance(self.bounds.xmin, int) - and isinstance(self.bounds.ymin, int) - and isinstance(bounds.xmin, int) - and isinstance(bounds.ymin, int) - ): + if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin @@ -736,7 +726,7 @@ def wrap(self, bounds, hermitian=False): def _raise_if_nonzero(bnds, x_or_y, msg): if x_or_y == "x": - if isinstance(bnds.xmin, int): + if bnds.isStatic(): if bnds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( msg, @@ -750,7 +740,7 @@ def _raise_if_nonzero(bnds, x_or_y, msg): msg, ) else: - if isinstance(bnds.ymin, int): + if bnds.isStatic(): if bnds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( msg, From 1a8f579059d2f5cc4f2ac2bd9a4e41086f07e3a6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 09:14:21 -0500 Subject: [PATCH 29/72] test: ensure api tests are correct --- tests/jax/test_api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index b0435ecb..bf765f50 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -359,12 +359,9 @@ def _reg_fun(p): if issubclass(cls, jax_galsim.Bounds) and method in [ "xmax", "ymax", - ]: - continue - - if issubclass(cls, jax_galsim.BoundsI) and method in [ "xmin", "ymin", + "isStatic", ]: continue From 053de86d3d3cccca223aefa0b56b16f2cb53639e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 09:16:24 -0500 Subject: [PATCH 30/72] style: remove extra blank space changes Co-authored-by: Matthew R. Becker --- jax_galsim/image.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index ba52838f..8397d268 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -280,7 +280,6 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if check_bounds and b.isDefined(): if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -572,7 +571,6 @@ def resize(self, bounds, wcs=None): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: @@ -582,7 +580,6 @@ def resize(self, bounds, wcs=None): def subImage(self, bounds): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" @@ -620,7 +617,6 @@ def setSubImage(self, bounds, rhs): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" From 51efa747ac1720c7abfe50dd4b40f3bb1ae6db73 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 09:18:07 -0500 Subject: [PATCH 31/72] Apply suggestion from @beckermr --- jax_galsim/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8397d268..f5e583d2 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -871,6 +871,7 @@ def calculate_fft(self): ): ximage = self else: + # Then we pad out with zeros ximage = Image(full_bounds, dtype=self.dtype, init_value=0) ximage[self.bounds] = self[self.bounds] From cb9d084ce06b77acef5735c74aace4a76f74511c Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 09:18:37 -0500 Subject: [PATCH 32/72] Apply suggestion from @beckermr --- jax_galsim/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f5e583d2..9bcdc3ec 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -869,6 +869,7 @@ def calculate_fft(self): self.bounds.deltax == full_bounds.deltax and self.bounds.deltay == full_bounds.deltay ): + # Then the image is already in the shape we need. ximage = self else: # Then we pad out with zeros From 63fe4b336e54a9377697aa9cd9f188e927450a9f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 09:19:15 -0500 Subject: [PATCH 33/72] Apply suggestion from @beckermr --- jax_galsim/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 9bcdc3ec..bd6cc0b7 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -933,6 +933,7 @@ def calculate_inverse_fft(self): self.bounds.deltax == target_bounds.deltax and self.bounds.deltay == target_bounds.deltay ): + # Then the image is already in the shape we need. kimage = self else: # Then we can pad out with zeros and wrap to get this in the form we need. From d0aab7e043d965fdf02491cdc3eee88777c0445d Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 09:30:08 -0500 Subject: [PATCH 34/72] fix: allow python bool for static bounds eq --- jax_galsim/bounds.py | 92 +++++++++++++++++++++++++++++++------------- jax_galsim/image.py | 10 +---- 2 files changed, 68 insertions(+), 34 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index e154cf97..d6112b04 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -289,34 +289,14 @@ def __add__(self, other): ) def __eq__(self, other): - if self is other: - return jnp.array(True) - elif isinstance(other, self.__class__): - self_isdef = jnp.array(self.isDefined()) - other_isdef = jnp.array(other.isDefined()) - if isinstance(self, BoundsD): - return ( - self_isdef - & other_isdef - & jnp.array(self.xmin == other.xmin) - & jnp.array(self.ymin == other.ymin) - & jnp.array(self.xmax == other.xmax) - & jnp.array(self.ymax == other.ymax) - ) | ((~self_isdef) & (~other_isdef)) - else: - return ( - self_isdef - & other_isdef - & jnp.array(self.xmin == other.xmin) - & jnp.array(self.ymin == other.ymin) - & jnp.array(self.deltax == other.deltax) - & jnp.array(self.deltay == other.deltay) - ) | ((~self_isdef) & (~other_isdef)) - else: - return jnp.array(False) + raise NotImplementedError( + "The `__eq__` magic method must be implemented by subclasses of `Bounds`." + ) def __ne__(self, other): - return ~self.__eq__(other) + raise NotImplementedError( + "The `__ne__` magic method must be implemented by subclasses of `Bounds`." + ) def __hash__(self): return hash( @@ -673,6 +653,26 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.xmax == other.xmax) + & jnp.array(self.ymax == other.ymax) + ) | ((~self_isdef) & (~other_isdef)) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) + def __hash__(self): return hash( ( @@ -876,6 +876,46 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) + def __eq__(self, other): + if self is other: + if self._isstatic: + return True + else: + return jnp.array(True) + elif isinstance(other, self.__class__): + if self._isstatic and other._isstatic: + return ( + self._isdefined + and other._isdefined + and self.xmin == other.xmin + and self.ymin == other.ymin + and self.deltax == other.deltax + and self.deltay == other.deltay + ) or ((not self._isdefined) and (not other._isdefined)) + else: + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.deltax == other.deltax) + & jnp.array(self.deltay == other.deltay) + ) | ((~self_isdef) & (~other_isdef)) + else: + if self._isstatic: + return False + else: + return jnp.array(False) + + def __ne__(self, other): + eqval = self.__eq__(other) + if isinstance(eqval, bool): + return not eqval + else: + return ~eqval + def __hash__(self): return hash( ( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index bd6cc0b7..6ea1521d 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -865,10 +865,7 @@ def calculate_fft(self): No2 = int(No2.item()) full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) - if ( - self.bounds.deltax == full_bounds.deltax - and self.bounds.deltay == full_bounds.deltay - ): + if self.bounds == full_bounds: # Then the image is already in the shape we need. ximage = self else: @@ -929,10 +926,7 @@ def calculate_inverse_fft(self): No2 = int(No2.item()) target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) - if ( - self.bounds.deltax == target_bounds.deltax - and self.bounds.deltay == target_bounds.deltay - ): + if self.bounds == target_bounds: # Then the image is already in the shape we need. kimage = self else: From 9f99b9277ba852b3ff201ebac501049b4dd3e8c6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 13:23:38 -0500 Subject: [PATCH 35/72] fix: start on eq using jax bool values in most cases --- jax_galsim/angle.py | 14 +++++--- jax_galsim/celestial.py | 13 +++---- jax_galsim/convolve.py | 21 ++++++++---- jax_galsim/core/utils.py | 60 +++++++++++++++++++++------------ jax_galsim/gsobject.py | 13 ++++--- jax_galsim/image.py | 23 ++++++++----- jax_galsim/interpolant.py | 6 +++- jax_galsim/interpolatedimage.py | 29 ++++++++++------ jax_galsim/photon_array.py | 32 ++++++++++-------- jax_galsim/position.py | 13 +++---- jax_galsim/random.py | 17 +++++----- jax_galsim/shear.py | 9 +++-- jax_galsim/transform.py | 25 +++++++++----- 13 files changed, 173 insertions(+), 102 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index fd3af1c3..c248c1a2 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -112,10 +112,13 @@ def __repr__(self): return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),) def __eq__(self, other): - return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value) + if not isinstance(other, AngleUnit): + return jnp.array(False) + else: + return jnp.array_equal(self.value, other.value) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __hash__(self): return hash(("galsim.AngleUnit", ensure_hashable(self.value))) @@ -253,10 +256,13 @@ def __repr__(self): return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),) def __eq__(self, other): - return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad) + if not isinstance(other, Angle): + return jnp.array(False) + else: + return jnp.array_equal(self.rad, other.rad) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __le__(self, other): if not isinstance(other, Angle): diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 5645eda3..f5c3c08b 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -839,14 +839,15 @@ def __hash__(self): return hash(repr(self)) def __eq__(self, other): - return ( - isinstance(other, CelestialCoord) - and jnp.array_equal(self._ra.rad, other._ra.rad) - and jnp.array_equal(self._dec.rad, other._dec.rad) - ) + if not isinstance(other, CelestialCoord): + return jnp.array(False) + else: + return jnp.array_equal(self._ra.rad, other._ra.rad) & jnp.array_equal( + self._dec.rad, other._dec.rad + ) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def tree_flatten(self): """This function flattens the CelestialCoord into a list of children diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index e6df0769..4f1f9043 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -162,13 +162,20 @@ def withGSParams(self, gsparams=None, **kwargs): return ret def __eq__(self, other): - return self is other or ( - isinstance(other, Convolution) - and self.obj_list == other.obj_list - and self.real_space == other.real_space - and self.gsparams == other.gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Convolution): + return ( + jnp.array(self.obj_list == other.obj_list) + & jnp.array(self.real_space == other.real_space) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 52d57ba5..b38dd7af 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -77,52 +77,70 @@ def cast_to_int(x, accept_strings=False): return _cast_to_type(x, int, accept_strings=accept_strings) -def is_equal_with_arrays(x, y): +def is_equal_with_arrays(x, y, currval=None, no_jax=False): """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + print("x|y|currval:", repr(x), repr(y), repr(currval), flush=True) + + if no_jax: + arr_func = np.array + arr_eq_func = np.array_equal + else: + arr_func = jnp.array + arr_eq_func = jnp.array_equal + + if currval is None: + currval = arr_func(True) + if isinstance(x, list): + print("list branch", flush=True) if isinstance(y, list) and len(x) == len(y): for vx, vy in zip(x, y): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, tuple): + print("tuple branch", flush=True) if isinstance(y, tuple) and len(x) == len(y): for vx, vy in zip(x, y): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, set): + print("set branch", flush=True) if isinstance(y, set) and len(x) == len(y): for vx, vy in zip(sorted(x), sorted(y)): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, dict): + print("dict branch", flush=True) if isinstance(y, dict) and len(x) == len(y): for kx, vx in x.items(): - if kx not in y or (not is_equal_with_arrays(vx, y[kx])): - return False - return True + if kx not in y: + currval &= arr_func(False) + else: + currval &= is_equal_with_arrays( + vx, y[kx], currval=currval, no_jax=no_jax + ) else: - return False + currval &= jnp.array(False) elif isinstance(x, jax.Array) and jnp.ndim(x) > 0: + print("array branch", flush=True) if isinstance(y, jax.Array) and y.shape == x.shape: - return jnp.array_equal(x, y) + currval &= arr_eq_func(x, y) else: - return False + currval &= arr_func(False) elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( isinstance(y, jax.Array) and jnp.ndim(y) == 0 ): + print("array scalar branch", flush=True) # this case covers comparing an array scalar to a python scalar or vice versa - return jnp.array_equal(x, y) + currval &= arr_eq_func(x, y) else: - return x == y + print("default branch", flush=True) + currval &= arr_func(x == y) + + return currval def _convert_to_numpy_nan(x): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index a5cf51b7..9f3e8aa0 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -197,10 +197,15 @@ def __neg__(self): return -1.0 * self def __eq__(self, other): - return (self is other) or ( - (type(other) is self.__class__) - and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) - ) + if self is other: + return jnp.array(True) + elif type(other) is self.__class__: + return is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) @implements(_galsim.GSObject.xValue) def xValue(self, *args, **kwargs): diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 6ea1521d..50e1a7ed 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1202,18 +1202,23 @@ def __eq__(self, other): # >>> assert galsim.ImageD(int_array) == galsim.ImageF(int_array) # passes # >>> assert galsim.ImageD(double_array) == galsim.ImageF(double_array) # fails - return self is other or ( - isinstance(other, Image) - and self.bounds == other.bounds - and self.wcs == other.wcs - and ( - not self.bounds.isDefined() or jnp.array_equal(self.array, other.array) + if self is other: + return jnp.array(True) + elif isinstance(other, Image): + return ( + jnp.array(self.bounds == other.bounds) + & jnp.array(self.wcs == other.wcs) + & ( + (~jnp.array(self.bounds.isDefined())) + | jnp.array_equal(self.array, other.array) + ) + & jnp.array(self.isconst == other.isconst) ) - and self.isconst == other.isconst - ) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.Image.transpose) def transpose(self): diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 1b489398..8cfe75bb 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -133,7 +133,11 @@ def _i(self): def __eq__(self, other): return (self is other) or ( type(other) is self.__class__ - and is_equal_with_arrays(self.tree_flatten()[1], other.tree_flatten()[1]) + and is_equal_with_arrays( + self.tree_flatten()[1], + other.tree_flatten()[1], + no_jax=True, + ) ) def __ne__(self, other): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 9564bd36..7c2ed70c 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -328,17 +328,24 @@ def __str__(self): return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) def __eq__(self, other): - return self is other or ( - isinstance(other, InterpolatedImage) - and self._xim == other._xim - and self.x_interpolant == other.x_interpolant - and self.k_interpolant == other.k_interpolant - and self.flux == other.flux - and self._original._offset == other._original._offset - and self.gsparams == other.gsparams - and self._stepk == other._stepk - and self._maxk == other._maxk - ) + if self is other: + return jnp.array(True) + elif isinstance(other, InterpolatedImage): + return ( + (self._xim == other._xim) + & (self.x_interpolant == other.x_interpolant) + & (self.k_interpolant == other.k_interpolant) + & jnp.array_equal(self.flux, other.flux) + & (self._original._offset == other._original._offset) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array_equal(self._stepk, other._stepk) + & jnp.array_equal(self._maxk, other._maxk) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def tree_flatten(self): """This function flattens the InterpolatedImage into a list of children diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index e97ee2b7..e9b79411 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -847,22 +847,26 @@ def __str__(self): __hash__ = None def __eq__(self, other): - return self is other or ( - isinstance(other, PhotonArray) - and jnp.array_equal(self.x, other.x) - and jnp.array_equal(self.y, other.y) - and jnp.array_equal(self.flux, other.flux) - and jnp.array_equal(self._nokeep, other._nokeep) - and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) - and jnp.array_equal(self.dydz, other.dydz, equal_nan=True) - and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) - and jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True) - and jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True) - and jnp.array_equal(self.time, other.time, equal_nan=True) - ) + if self is other: + return jnp.array(True) + elif isinstance(other, PhotonArray): + return ( + jnp.array_equal(self.x, other.x) + & jnp.array_equal(self.y, other.y) + & jnp.array_equal(self.flux, other.flux) + & jnp.array_equal(self._nokeep, other._nokeep) + & jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) + & jnp.array_equal(self.dydz, other.dydz, equal_nan=True) + & jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) + & jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True) + & jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True) + & jnp.array_equal(self.time, other.time, equal_nan=True) + ) + else: + return jnp.array(False) def __ne__(self, other): - return not self == other + return ~self.__eq__(other) @implements( _galsim.PhotonArray.addTo, diff --git a/jax_galsim/position.py b/jax_galsim/position.py index cf36dba8..a277413f 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -114,14 +114,15 @@ def __str__(self): ) def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self.x == other.x - and self.y == other.y - ) + if self is other: + return jnp.array(True) + elif not isinstance(other, self.__class__): + return jnp.array(False) + else: + return jnp.array_equal(self.x, other.x) & jnp.array_equal(self.y, other.y) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/random.py b/jax_galsim/random.py index b5e730f3..e5442480 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements +from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements, is_equal_with_arrays try: from jax.extend.random import wrap_key_data @@ -244,16 +244,17 @@ def __copy__(self): return self.duplicate() def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and jnp.array_equal( + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array_equal( jrandom.key_data(self._key), jrandom.key_data(other._key) - ) - and self._params == other._params - ) + ) & is_equal_with_arrays(self._params, other._params) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) __hash__ = None diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index b0663890..33a58668 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -280,10 +280,15 @@ def __sub__(self, other): return self + (-other) def __eq__(self, other): - return self is other or (isinstance(other, Shear) and self._g == other._g) + if self is other: + return jnp.array(True) + elif not isinstance(other, Shear): + return jnp.array(False) + else: + return jnp.array_equal(self._g, other._g) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.Shear.getMatrix) def getMatrix(self): diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index e78a67e3..9606701a 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -158,15 +158,22 @@ def withGSParams(self, gsparams=None, **kwargs): return self.tree_unflatten(aux, chld) def __eq__(self, other): - return self is other or ( - isinstance(other, Transformation) - and self._original == other._original - and jnp.array_equal(self._jac, other._jac) - and self._offset == other._params["offset"] - and self._flux_ratio == other._flux_ratio - and self._gsparams == other._gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Transformation): + return ( + (self._original == other._original) + & jnp.array_equal(self._jac, other._jac) + & (self._offset == other._params["offset"]) + & jnp.array_equal(self._flux_ratio, other._flux_ratio) + & jnp.array(self._gsparams == other._gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( From 0485a14ecb2bee1f1c0887540f1c93897832547d Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 13:25:34 -0500 Subject: [PATCH 36/72] Apply suggestion from @beckermr --- jax_galsim/interpolatedimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 7c2ed70c..8d9815a5 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -333,8 +333,8 @@ def __eq__(self, other): elif isinstance(other, InterpolatedImage): return ( (self._xim == other._xim) - & (self.x_interpolant == other.x_interpolant) - & (self.k_interpolant == other.k_interpolant) + & jnp.array(self.x_interpolant == other.x_interpolant) + & jnp.array(self.k_interpolant == other.k_interpolant) & jnp.array_equal(self.flux, other.flux) & (self._original._offset == other._original._offset) & jnp.array(self.gsparams == other.gsparams) From f642c26bfa9287d656b7d0066ae25156b041e7be Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 13:26:28 -0500 Subject: [PATCH 37/72] fix: remove extra prints --- jax_galsim/core/utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index b38dd7af..d7394b6b 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -79,8 +79,6 @@ def cast_to_int(x, accept_strings=False): def is_equal_with_arrays(x, y, currval=None, no_jax=False): """Return True if the data is equal, False otherwise. Handles jax.Array types.""" - print("x|y|currval:", repr(x), repr(y), repr(currval), flush=True) - if no_jax: arr_func = np.array arr_eq_func = np.array_equal @@ -92,28 +90,24 @@ def is_equal_with_arrays(x, y, currval=None, no_jax=False): currval = arr_func(True) if isinstance(x, list): - print("list branch", flush=True) if isinstance(y, list) and len(x) == len(y): for vx, vy in zip(x, y): currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: currval &= arr_func(False) elif isinstance(x, tuple): - print("tuple branch", flush=True) if isinstance(y, tuple) and len(x) == len(y): for vx, vy in zip(x, y): currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: currval &= arr_func(False) elif isinstance(x, set): - print("set branch", flush=True) if isinstance(y, set) and len(x) == len(y): for vx, vy in zip(sorted(x), sorted(y)): currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: currval &= arr_func(False) elif isinstance(x, dict): - print("dict branch", flush=True) if isinstance(y, dict) and len(x) == len(y): for kx, vx in x.items(): if kx not in y: @@ -125,7 +119,6 @@ def is_equal_with_arrays(x, y, currval=None, no_jax=False): else: currval &= jnp.array(False) elif isinstance(x, jax.Array) and jnp.ndim(x) > 0: - print("array branch", flush=True) if isinstance(y, jax.Array) and y.shape == x.shape: currval &= arr_eq_func(x, y) else: @@ -133,11 +126,9 @@ def is_equal_with_arrays(x, y, currval=None, no_jax=False): elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( isinstance(y, jax.Array) and jnp.ndim(y) == 0 ): - print("array scalar branch", flush=True) # this case covers comparing an array scalar to a python scalar or vice versa currval &= arr_eq_func(x, y) else: - print("default branch", flush=True) currval &= arr_func(x == y) return currval From 7d176a75d8f84c5c1dc19b502c71053dbdb7d160 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 14:43:42 -0500 Subject: [PATCH 38/72] fix: return JAX bools for rest of things --- jax_galsim/convolve.py | 19 ++++++++---- jax_galsim/fitswcs.py | 50 ++++++++++++++++++++----------- jax_galsim/interpolatedimage.py | 4 +-- jax_galsim/transform.py | 4 +-- jax_galsim/wcs.py | 53 +++++++++++++++++++++++---------- 5 files changed, 86 insertions(+), 44 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 4f1f9043..c8c4c1ae 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -410,12 +410,19 @@ def withGSParams(self, gsparams=None, **kwargs): return ret def __eq__(self, other): - return self is other or ( - isinstance(other, Deconvolution) - and self.orig_obj == other.orig_obj - and self.gsparams == other.gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Deconvolution): + return ( + jnp.array(self.orig_obj == other.orig_obj) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 01654860..7c8fab94 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -832,25 +832,39 @@ def copy(self): return copy.copy(self) def __eq__(self, other): - return self is other or ( - isinstance(other, GSFitsWCS) - and self.wcs_type == other.wcs_type - and jnp.array_equal(self.crpix, other.crpix) - and jnp.array_equal(self.cd, other.cd) - and self.center == other.center - and ( - (self.pv is None and other.pv is None) - or jnp.array_equal(self.pv, other.pv) + if self is other: + return jnp.array(True) + elif isinstance(other, GSFitsWCS): + is_eq = ( + jnp.array(self.wcs_type == other.wcs_type) + & jnp.array_equal(self.crpix, other.crpix) + & jnp.array_equal(self.cd, other.cd) + & jnp.array(self.center == other.center) ) - and ( - (self.ab is None and other.ab is None) - or jnp.array_equal(self.ab, other.ab) - ) - and ( - (self.abp is None and other.abp is None) - or jnp.array_equal(self.abp, other.abp) - ) - ) + if self.pv is None and other.pv is None: + pass + elif self.pv is not None and other.pv is not None: + is_eq &= jnp.array_equal(self.pv, other.pv) + else: + is_eq &= jnp.array(False) + + if self.ab is None and other.ab is None: + pass + elif self.ab is not None and other.ab is not None: + is_eq &= jnp.array_equal(self.ab, other.ab) + else: + is_eq &= jnp.array(False) + + if self.abp is None and other.abp is None: + pass + elif self.abp is not None and other.abp is not None: + is_eq &= jnp.array_equal(self.abp, other.abp) + else: + is_eq &= jnp.array(False) + + return is_eq + else: + return jnp.array(False) def __repr__(self): pv_repr = repr(ensure_hashable(self.pv)) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 8d9815a5..589c27a9 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -332,11 +332,11 @@ def __eq__(self, other): return jnp.array(True) elif isinstance(other, InterpolatedImage): return ( - (self._xim == other._xim) + jnp.array(self._xim == other._xim) & jnp.array(self.x_interpolant == other.x_interpolant) & jnp.array(self.k_interpolant == other.k_interpolant) & jnp.array_equal(self.flux, other.flux) - & (self._original._offset == other._original._offset) + & jnp.array(self._original._offset == other._original._offset) & jnp.array(self.gsparams == other.gsparams) & jnp.array_equal(self._stepk, other._stepk) & jnp.array_equal(self._maxk, other._maxk) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 9606701a..7e9bc8dd 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -162,9 +162,9 @@ def __eq__(self, other): return jnp.array(True) elif isinstance(other, Transformation): return ( - (self._original == other._original) + jnp.array(self._original == other._original) & jnp.array_equal(self._jac, other._jac) - & (self._offset == other._params["offset"]) + & jnp.array(self._offset == other._params["offset"]) & jnp.array_equal(self._flux_ratio, other._flux_ratio) & jnp.array(self._gsparams == other._gsparams) & jnp.array(self._propagate_gsparams == other._propagate_gsparams) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index c939449c..35fb0e56 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -557,7 +557,7 @@ def _makeSkyImage(self, image, sky_level, color): # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.wcs.UniformWCS) @@ -599,12 +599,16 @@ def _makeSkyImage(self, image, sky_level, color): # Just check if the locals match and if the origins match. def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self._local_wcs == other._local_wcs - and self.origin == other.origin - and self.world_origin == other.world_origin - ) + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return ( + jnp.array(self._local_wcs == other._local_wcs) + & jnp.array(self.origin == other.origin) + & jnp.array(self.world_origin == other.world_origin) + ) + else: + return jnp.array(False) @implements(_galsim.wcs.LocalWCS) @@ -828,7 +832,7 @@ def _radecToxy(self, ra, dec, units, color): # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) ######################################################################################### @@ -953,9 +957,12 @@ def copy(self): return PixelScale(self._scale) def __eq__(self, other): - return self is other or ( - isinstance(other, PixelScale) and self.scale == other.scale - ) + if self is other: + return jnp.array(True) + elif isinstance(other, PixelScale): + return jnp.array_equal(self.scale, other.scale) + else: + return jnp.array(False) def __repr__(self): return "galsim.PixelScale(%r)" % (ensure_hashable(self.scale),) @@ -1076,11 +1083,14 @@ def copy(self): return ShearWCS(self._scale, self._shear) def __eq__(self, other): - return self is other or ( - isinstance(other, ShearWCS) - and self.scale == other.scale - and self.shear == other.shear - ) + if self is other: + return jnp.array(True) + elif isinstance(other, ShearWCS): + return jnp.array(self.scale == other.scale) & jnp.array( + self.shear == other.shear + ) + else: + return jnp.array(False) def __repr__(self): return "galsim.ShearWCS(%r, %r)" % (ensure_hashable(self.scale), self.shear) @@ -1282,6 +1292,17 @@ def copy(self): return JacobianWCS(self.dudx, self.dudy, self.dvdx, self.dvdy) def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, JacobianWCS): + return ( + jnp.array_equal(self.dudx, other.dudx) + & jnp.array_equal(self.dudy, other.dudy) + & jnp.array_equal(self.dvdx, other.dvdx) + & jnp.array_equal(self.dvdy, other.dvdy) + ) + else: + return jnp.array(False) return self is other or ( isinstance(other, JacobianWCS) and self.dudx == other.dudx From 1bf4a5dc46b86ccf203714645d12c9173a8d8769 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 14:49:55 -0500 Subject: [PATCH 39/72] fix: bool conversion in pos comp --- jax_galsim/gsobject.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 9f3e8aa0..2fcc94d2 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1057,12 +1057,12 @@ def drawKImage( ) # Can't both recenter a provided image and add to it. - if recenter and image.center != PositionI(0, 0) and add_to_image: - raise _galsim.GalSimIncompatibleValuesError( + if recenter and add_to_image: + zp = PositionI(0, 0) + equinox.error_if( + zp.x, + image.center != zp, "Cannot use add_to_image=True unless image is centered at (0,0) or recenter=False", - recenter=recenter, - image=image, - add_to_image=add_to_image, ) # Set the center to 0,0 if appropriate From 470309951200e52c770814c5d07c5ba9e265c994 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 14:51:55 -0500 Subject: [PATCH 40/72] fix: bool conversion in pos comp --- jax_galsim/gsobject.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 2fcc94d2..641bbdf3 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1059,8 +1059,8 @@ def drawKImage( # Can't both recenter a provided image and add to it. if recenter and add_to_image: zp = PositionI(0, 0) - equinox.error_if( - zp.x, + zp.x = equinox.error_if( + jnp.array(zp.x, dtype=int), image.center != zp, "Cannot use add_to_image=True unless image is centered at (0,0) or recenter=False", ) From de2bb7276f9194615cd987d6a3e344a5b59fe872 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 14:54:45 -0500 Subject: [PATCH 41/72] test: update to latest submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index e102c876..63d576d1 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit e102c876b36c5cb1f1b8e9ab3d17cf6d22727803 +Subproject commit 63d576d1ffe836e965a1c6b547e127e5f457cbb9 From 29df9c38369ca6a0a64d8cc59702a5d6238cdff9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 15:12:30 -0500 Subject: [PATCH 42/72] doc: add docs --- docs/sharp-bits.rst | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index ff28546e..49c7fe6a 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -61,8 +61,21 @@ does not affect the original. # JAX-GalSim — real_part is a copy real_part = complex_image.real # independent array -Scalar Types, Array Types, and Casting --------------------------------------- +Fixed Array Shapes in JAX Function Transformations +-------------------------------------------------- + +JAX function transformations (e.g., ``jax.jit``, ``jax.vmap``, etc.) require statically known +array shapes in order to support tracing. To support this, the JAX-GalSim ``BoundsI`` class must +have a statically known shape. Further this class can be instantiated via the syntax +``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` where ``deltax/y`` are the statically defined +shape. ``BoundsI`` classes may have dynamically set ``x/ymin`` values. However, in this case the ``&`` +and ``+`` operations, which can change the shape of the ``BoundsI`` instance are not allowed in +JAX-traced code. ``BoundsI`` instances have a special method ``isStatic()`` which returns ``True`` +if the object was instantiated with statically know ``x/ymin`` values. A static ``BoundsI`` class +cannot be converted to a dynamic one via assignment and an attempt to do so will raise an exception. + +Scalar Types, Array Types, and Type Casting +------------------------------------------- With the use of JAX, there are now many possible types for numeric data. These include @@ -89,6 +102,23 @@ These rules allow JAX-GalSim to transparently handle JAX's tracing operations, b the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in some cases. +Object Comparison with ``==`` +----------------------------- + +In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean +array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result +of the ``==`` operator. Important cases of this rule are static ``BoundsI`` objects and +``Interpolant`` objects (and their subclasses), which return Python boolean values (i.e. +``True`` and ``False``). These difference can be a source of subtle bugs since the negation +of JAX array boolean values is typically done with ``~``, while for Python boolean values it is +done with ``not``. Mixing these two forms can cause unexpected and incorrect results since + +.. code-block:: python + + >>> ~True is False + :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? + False + Random Number Generation ------------------------ From 2d53ab0058edfc564ed42bc064cefaaef2334f23 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 15:18:26 -0500 Subject: [PATCH 43/72] doc: be a bit more specific --- docs/sharp-bits.rst | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 49c7fe6a..71e0c50e 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -102,16 +102,17 @@ These rules allow JAX-GalSim to transparently handle JAX's tracing operations, b the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in some cases. -Object Comparison with ``==`` ------------------------------ +Object Comparison with the ``==`` Operator +------------------------------------------ In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result -of the ``==`` operator. Important cases of this rule are static ``BoundsI`` objects and -``Interpolant`` objects (and their subclasses), which return Python boolean values (i.e. -``True`` and ``False``). These difference can be a source of subtle bugs since the negation -of JAX array boolean values is typically done with ``~``, while for Python boolean values it is -done with ``not``. Mixing these two forms can cause unexpected and incorrect results since +of the ``==`` operator, otherwise they return Python boolean values. Important cases of this +rule are static ``BoundsI`` objects and ``Interpolant`` objects (and their subclasses), which +return Python boolean values (i.e. ``True`` and ``False``). These difference can be a source +of subtle bugs since the negation of JAX array boolean values is typically done with ``~``, +while for Python boolean values it is done with ``not``. Mixing these two forms can cause +unexpected and incorrect results since .. code-block:: python From 1c94cf3ee9540cc09626b31788d14bdbdfcafca4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 15:23:00 -0500 Subject: [PATCH 44/72] fix: dead code --- jax_galsim/wcs.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 35fb0e56..43498c77 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1303,13 +1303,6 @@ def __eq__(self, other): ) else: return jnp.array(False) - return self is other or ( - isinstance(other, JacobianWCS) - and self.dudx == other.dudx - and self.dudy == other.dudy - and self.dvdx == other.dvdx - and self.dvdy == other.dvdy - ) def __repr__(self): return "galsim.JacobianWCS(%r, %r, %r, %r)" % ( From 95f0ae9f8a0d6058c0f58d83526d0092d56695d6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 16:35:54 -0500 Subject: [PATCH 45/72] test: add test of bool eq api --- jax_galsim/noise.py | 53 ++++++++++++++++++++++++++++++++++++++++--- tests/jax/test_api.py | 28 +++++++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 3956cf6a..3daf04c1 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -106,11 +106,15 @@ def _applyTo(self, image): raise NotImplementedError("Cannot call applyTo on a pure BaseNoise object") def __eq__(self, other): - # Quick and dirty. Just check reprs are equal. - return self is other or repr(self) == repr(other) + if self is other: + return jnp.array(True) + elif isinstance(other, BaseNoise): + return jnp.array(self._rng == other._rng) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) __hash__ = None @@ -173,6 +177,16 @@ def __repr__(self): def __str__(self): return "galsim.GaussianNoise(sigma=%s)" % (ensure_hashable(self.sigma),) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array_equal( + self._sigma, other._sigma + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the GaussianNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -265,6 +279,16 @@ def __repr__(self): def __str__(self): return "galsim.PoissonNoise(sky_level=%s)" % (self.sky_level) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array_equal( + self._sky_level, other._sky_level + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the PoissonNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -429,6 +453,19 @@ def __str__(self): self.read_noise, ) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return ( + jnp.array(self._rng == other._rng) + & jnp.array_equal(self._sky_level, other._sky_level) + & jnp.array_equal(self._gain, other._gain) + & jnp.array_equal(self._read_noise, other._read_noise) + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the CCDNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -570,6 +607,16 @@ def __repr__(self): def __str__(self): return "galsim.VariableGaussianNoise(var_image=%s)" % (self.var_image) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array( + self._var_image == other._var_image + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the VariableGaussianNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index bf765f50..c710097f 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -126,6 +126,16 @@ def _run_object_checks(obj, cls, kind): # check that we can hash the object hash(obj) + + # check that val jax array + if ( + hasattr(obj, "isStatic") + and obj.isStatic() + or isinstance(obj, jax_galsim.Sensor) + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "to-from-galsim": gs_obj = obj.to_galsim() jgs_obj = obj.from_galsim(gs_obj) @@ -141,6 +151,14 @@ def _run_object_checks(obj, cls, kind): # check that we cannot hash the object assert obj.__hash__ is None + + # check that val jax array + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "pickle-eval-repr-wcs": import jax_galsim as galsim # noqa: F401 @@ -152,6 +170,16 @@ def _run_object_checks(obj, cls, kind): # check that we cannot hash the object hash(obj) + + # check that val jax array + if ( + hasattr(obj, "isStatic") + and obj.isStatic() + or isinstance(obj, jax_galsim.Sensor) + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "jax-compatible": # JAX tracing should be an identity assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj From 0d07661806e7178ddcc2f6e85b66a2747e7f9f99 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 16:37:18 -0500 Subject: [PATCH 46/72] test: add test of bool eq api --- tests/jax/test_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index c710097f..7071cdb4 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -1146,3 +1146,9 @@ def test_api_gsparams(): assert getattr(jgsp, k) == v assert getattr(gsp, k) == v assert getattr(jjgsp, k) == v + + assert jgsp == jjgsp + assert isinstance(jgsp == jjgsp, bool) + + kwargs["minimum_fft_size"] = 126 + assert jgsp != jax_galsim.GSParams(**kwargs) From dea69152a59b3d892fc71d7cdc9b4f34a74d0903 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 16:59:09 -0500 Subject: [PATCH 47/72] test: add tests of bounds and vmap --- tests/jax/test_bounds_jax.py | 95 ++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 19e318d5..44a0ac6d 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -8,8 +8,8 @@ @jax.vmap @jax.jit def _make_bounds_float(xmin, ymin, xmax, ymax): - bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) - return bds, bds.isDefined() + bnds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bnds, bnds.isDefined() def test_bounds_jax_vmap_isdefined_float(): @@ -17,5 +17,92 @@ def test_bounds_jax_vmap_isdefined_float(): xmax = jnp.array([12, 11, 10, 9]) ymin = jnp.array([9, 11, 10, 12]) ymax = jnp.array([10, 10, 10, 10]) - bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) - np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + + +@jax.vmap +@jax.jit +def _and_bounds_empty_float(bnds): + bnds = bnds & jax_galsim.BoundsD() + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _and_bounds_float(bnds): + bnds = bnds & jax_galsim.BoundsD(xmin=10, xmax=11, ymin=10, ymax=11) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _and_bounds_far_away_float(bnds): + bnds = bnds & jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_and_isdefined_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_empty_float(bnds) + assert bnds.isDefined().shape == (4,) + assert not jnp.any(bnds.isDefined()) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + np.testing.assert_array_equal(bnds.isDefined(), False) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal( + bnds.isDefined(), jnp.array([True, False, False, False]) + ) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 10 + assert bnds.xmax[0] == 11 + assert bnds.ymin[0] == 10 + assert bnds.ymax[0] == 10 + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + assert not jnp.any(bnds.isDefined()) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + np.testing.assert_array_equal(bnds.isDefined(), False) + + +@jax.vmap +@jax.jit +def _plus_bounds_far_away_float(bnds): + bnds = bnds + jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_plus_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _plus_bounds_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal(bnds.isDefined(), True) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 9 + assert bnds.xmax[0] == 110 + assert bnds.ymin[0] == 9 + assert bnds.ymax[0] == 110 + + np.testing.assert_array_equal(bnds.xmin[1:], 100) + np.testing.assert_array_equal(bnds.xmax[1:], 110) + np.testing.assert_array_equal(bnds.ymin[1:], 100) + np.testing.assert_array_equal(bnds.ymax[1:], 110) From 1671b371a35fe57bd57a21e546b583c4ffac24a0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 17:03:03 -0500 Subject: [PATCH 48/72] test: add tests of bounds and vmap --- tests/jax/test_bounds_jax.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 44a0ac6d..5b21486c 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -85,6 +85,13 @@ def _plus_bounds_far_away_float(bnds): return bnds, bnds.isDefined() +@jax.vmap +@jax.jit +def _plus_bounds_pos_far_away_float(bnds): + bnds = bnds + jax_galsim.PositionD(x=100, y=110) + return bnds, bnds.isDefined() + + def test_bounds_jax_vmap_plus_float(): xmin = jnp.array([9, 10, 11, 12]) xmax = jnp.array([12, 11, 10, 9]) @@ -106,3 +113,19 @@ def test_bounds_jax_vmap_plus_float(): np.testing.assert_array_equal(bnds.xmax[1:], 110) np.testing.assert_array_equal(bnds.ymin[1:], 100) np.testing.assert_array_equal(bnds.ymax[1:], 110) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _plus_bounds_pos_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal(bnds.isDefined(), True) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 9 + assert bnds.xmax[0] == 100 + assert bnds.ymin[0] == 9 + assert bnds.ymax[0] == 110 + + np.testing.assert_array_equal(bnds.xmin[1:], 100) + np.testing.assert_array_equal(bnds.xmax[1:], 100) + np.testing.assert_array_equal(bnds.ymin[1:], 110) + np.testing.assert_array_equal(bnds.ymax[1:], 110) From 946958ce058c9b4e25076aeaf872de9e9fd857b6 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 May 2026 05:01:56 -0500 Subject: [PATCH 49/72] Apply suggestion from @beckermr --- tests/jax/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 7071cdb4..ddbf4434 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -173,8 +173,8 @@ def _run_object_checks(obj, cls, kind): # check that val jax array if ( - hasattr(obj, "isStatic") - and obj.isStatic() + (hasattr(obj, "isStatic") + and obj.isStatic()) or isinstance(obj, jax_galsim.Sensor) ): assert isinstance(eval(repr(obj)) == obj, bool) From 860a419446d3253650fcaaf3e7cdabbcbc08b746 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:02:22 -0500 Subject: [PATCH 50/72] test: more tests for bounds and vmap --- tests/jax/test_api.py | 6 +-- tests/jax/test_bounds_jax.py | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index ddbf4434..a0744209 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -172,10 +172,8 @@ def _run_object_checks(obj, cls, kind): hash(obj) # check that val jax array - if ( - (hasattr(obj, "isStatic") - and obj.isStatic()) - or isinstance(obj, jax_galsim.Sensor) + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor ): assert isinstance(eval(repr(obj)) == obj, bool) else: diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 5b21486c..652eb91b 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp import numpy as np +import pytest import jax_galsim @@ -129,3 +130,82 @@ def test_bounds_jax_vmap_plus_float(): np.testing.assert_array_equal(bnds.xmax[1:], 100) np.testing.assert_array_equal(bnds.ymin[1:], 110) np.testing.assert_array_equal(bnds.ymax[1:], 110) + + +@jax.vmap +@jax.jit +def _make_bounds_int(xmin, ymin): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=10, deltay=10) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + + +@jax.vmap +@jax.jit +def _make_bounds_int_bad(xmin, ymin, delta): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=delta, deltay=delta) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_varying_shape_raises_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + delta = jnp.array([9, 11, 10, 12]) + with pytest.raises(Exception): + _make_bounds_int_bad(xmin, ymin, delta) + + +@jax.vmap +@jax.jit +def _and_bounds_empty_int(bnds): + bnds = bnds & jax_galsim.BoundsI() + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_and_raises_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + + with pytest.raises(Exception): + _and_bounds_empty_int(bnds) + + +@jax.vmap +@jax.jit +def _plus_bounds_far_away_int(bnds): + bnds = bnds + jax_galsim.BoundsI(xmin=100, deltax=110, ymin=100, deltay=110) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _plus_bounds_pos_far_away_int(bnds): + bnds = bnds + jax_galsim.PositionI(x=100, y=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_plus_raises_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + + with pytest.raises(Exception): + _plus_bounds_far_away_int(bnds) + + with pytest.raises(Exception): + _plus_bounds_pos_far_away_float(bnds) From f312d871215b888e01c192e9f6cc5e0aa5a9ae96 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 May 2026 05:04:45 -0500 Subject: [PATCH 51/72] Apply suggestion from @beckermr --- tests/jax/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index a0744209..122fb6e5 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -129,8 +129,8 @@ def _run_object_checks(obj, cls, kind): # check that val jax array if ( - hasattr(obj, "isStatic") - and obj.isStatic() + (hasattr(obj, "isStatic") + and obj.isStatic()) or isinstance(obj, jax_galsim.Sensor) ): assert isinstance(eval(repr(obj)) == obj, bool) From 8e7280fcde90feb9e5b9d2b6548c8d5610effd1e Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:05:15 -0500 Subject: [PATCH 52/72] test: ensure we have proper logic here --- tests/jax/test_api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 122fb6e5..1aea9f11 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -128,10 +128,8 @@ def _run_object_checks(obj, cls, kind): hash(obj) # check that val jax array - if ( - (hasattr(obj, "isStatic") - and obj.isStatic()) - or isinstance(obj, jax_galsim.Sensor) + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor ): assert isinstance(eval(repr(obj)) == obj, bool) else: From a9ef07fda0555c3f88257f8d4aa0867498b105a0 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 22 May 2026 05:06:54 -0500 Subject: [PATCH 53/72] Apply suggestions from code review Co-authored-by: Matthew R. Becker --- docs/sharp-bits.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 71e0c50e..d8bc1fa0 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -107,8 +107,8 @@ Object Comparison with the ``==`` Operator In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result -of the ``==`` operator, otherwise they return Python boolean values. Important cases of this -rule are static ``BoundsI`` objects and ``Interpolant`` objects (and their subclasses), which +of the ``==`` operator. Otherwise the return value is a Python boolean. Important cases of this +rule are static ``BoundsI`` objects, ``Interpolant`` objects (and their subclasses), and ``GSParams`` objects, all of which return Python boolean values (i.e. ``True`` and ``False``). These difference can be a source of subtle bugs since the negation of JAX array boolean values is typically done with ``~``, while for Python boolean values it is done with ``not``. Mixing these two forms can cause From 401e5403c29c9a7ab7fc289b04d6e3e62b63d9a2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:07:58 -0500 Subject: [PATCH 54/72] doc: clarify --- docs/sharp-bits.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index d8bc1fa0..0b1297f0 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -108,11 +108,11 @@ Object Comparison with the ``==`` Operator In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result of the ``==`` operator. Otherwise the return value is a Python boolean. Important cases of this -rule are static ``BoundsI`` objects, ``Interpolant`` objects (and their subclasses), and ``GSParams`` objects, all of which -return Python boolean values (i.e. ``True`` and ``False``). These difference can be a source -of subtle bugs since the negation of JAX array boolean values is typically done with ``~``, -while for Python boolean values it is done with ``not``. Mixing these two forms can cause -unexpected and incorrect results since +rule are static ``BoundsI`` objects, ``Interpolant`` objects (and their subclasses), and ``GSParams`` +objects, all of which return Python boolean values (i.e. ``True`` and ``False``). These difference +can be a source of subtle bugs since the negation of JAX array boolean values is typically done +with ``~``, while for Python boolean values it is done with ``not``. Mixing these two forms can +cause unexpected and incorrect results since .. code-block:: python From 0a1ddd15ab10123ee83c3545bd90be567da16846 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:13:30 -0500 Subject: [PATCH 55/72] test: more tests for bnds int and vmap --- tests/jax/test_bounds_jax.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 652eb91b..64efecd2 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -135,7 +135,7 @@ def test_bounds_jax_vmap_plus_float(): @jax.vmap @jax.jit def _make_bounds_int(xmin, ymin): - bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=10, deltay=10) + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=10, deltay=11) return bnds, bnds.isDefined() @@ -146,6 +146,12 @@ def test_bounds_jax_vmap_isdefined_int(): np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) np.testing.assert_array_equal(bnds.isDefined(), True) assert jnp.all(isdef) + np.testing.assert_array_equal(bnds.xmin, xmin, strict=True) + np.testing.assert_array_equal(bnds.ymin, ymin, strict=True) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 10 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 11 @jax.vmap From aa8d58798d4526ff048f63e302646fc415874f1c Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:35:01 -0500 Subject: [PATCH 56/72] test: add tests for type conversion --- jax_galsim/bounds.py | 90 ++++++++++++++++++++++++------------ tests/jax/test_bounds_jax.py | 32 +++++++++++++ 2 files changed, 93 insertions(+), 29 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index d6112b04..fc25ef5e 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -24,9 +24,11 @@ for function transformations like ``jax.vmap``, ``jax.jit``, etc. - Upon initialization, if a ``BoundsI`` object has a non-static shape, JAX-GalSim will attempt to convert it to a static shape by extracting the dimensions from the array via ``.item()``. This operation will - cause JAX to raise an error if the code is being traced. -- If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, an error will be raised - if one attempts to convert those values to non-static values. + cause JAX to raise an error if the code is being traced. JAX-Galsim performs the same conversion operation + when the ``deltax`` or ``deltay`` properties are set to non-static values via assignment. +- If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, and then one attempts to + convert them to non-static values via assignment, JAX-GalSim will attempt to convert the assigned values + back to static values. This operation will raise an error if the code is being traced. - ``Bounds`` classes in JAX-GalSim have an etxra method, ``isStatic`` that returns ``True`` if the object was instantiated with static ``xmin`` and ``ymin`` values. This method always returns ``False`` for ``BoundsD`` objects. @@ -768,31 +770,46 @@ def xmin(self): def xmin(self, value): value = check_is_int_then_cast(value, "BoundsI xmin values must be integers") if self._isstatic: - if self._dotypechecking and not isinstance(value, int): - raise RuntimeError( - "Static `BoundsI` classes cannot be converted to dynamic ones." - ) + if self._dotypechecking: + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) self._xmin = value else: self._xmin = jnp.astype(value, float) @property - def xmax(self): - return cast_to_int(self.xmin + self.deltax - 1) + def deltax(self): + return self._deltax - @xmax.setter - def xmax(self, value): - self.deltax = value - self.xmin + 1 - self.deltax = check_is_int_then_cast( - self.deltax, "BoundsI xmax must be set to an integer value" + @deltax.setter + def deltax(self, value): + value = check_is_int_then_cast( + value, "BoundsI deltax must be set to an integer value" ) # attempt to convert widths to static values # this will raise if values are being traced # we let that error propagate instead of reraising # our own. with jax.ensure_compile_time_eval(): - if not isinstance(self.deltax, int): - self.deltax = int(self.deltax.item()) + if not isinstance(value, int): + value = int(value.item()) + self._deltax = value + + @property + def xmax(self): + return cast_to_int(self.xmin + self.deltax - 1) + + @xmax.setter + def xmax(self, value): + value = check_is_int_then_cast( + value, "BoundsI xmax must be set to an integer value" + ) + self.deltax = value - self.xmin + 1 # we store ymin internally as a float even though it is an int # so that autodiff works properly (needs floats in general) @@ -807,31 +824,46 @@ def ymin(self): def ymin(self, value): value = check_is_int_then_cast(value, "BoundsI ymin values must be integers") if self._isstatic: - if self._dotypechecking and not isinstance(value, int): - raise RuntimeError( - "Static `BoundsI` classes cannot be converted to dynamic ones." - ) + if self._dotypechecking: + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) self._ymin = value else: self._ymin = jnp.astype(value, float) @property - def ymax(self): - return cast_to_int(self.ymin + self.deltay - 1) + def deltay(self): + return self._deltay - @ymax.setter - def ymax(self, value): - self.deltay = value - self.ymin + 1 - self.deltay = check_is_int_then_cast( - self.deltay, "BoundsI ymax must be set to an integer value" + @deltay.setter + def deltay(self, value): + value = check_is_int_then_cast( + value, "BoundsI deltay must be set to an integer value" ) # attempt to convert widths to static values # this will raise if values are being traced # we let that error propagate instead of reraising # our own. with jax.ensure_compile_time_eval(): - if not isinstance(self.deltay, int): - self.deltay = int(self.deltay.item()) + if not isinstance(value, int): + value = int(value.item()) + self._deltay = value + + @property + def ymax(self): + return cast_to_int(self.ymin + self.deltay - 1) + + @ymax.setter + def ymax(self, value): + value = check_is_int_then_cast( + value, "BoundsI ymax must be set to an integer value" + ) + self.deltay = value - self.ymin + 1 def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 64efecd2..8f1f4cac 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -215,3 +215,35 @@ def test_bounds_jax_vmap_plus_raises_int(): with pytest.raises(Exception): _plus_bounds_pos_far_away_float(bnds) + + +def test_bounds_jax_int_set(): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + + bnds.xmin = 11.0 + assert isinstance(bnds.xmin, int) + assert bnds.xmin == 11 + bnds.xmin = jnp.array(12, dtype=float) + assert isinstance(bnds.xmin, int) + assert bnds.xmin == 12 + + bnds.ymin = 12.0 + assert isinstance(bnds.ymin, int) + assert bnds.ymin == 12 + bnds.ymin = jnp.array(13, dtype=float) + assert isinstance(bnds.ymin, int) + assert bnds.ymin == 13 + + bnds.deltax = 11.0 + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 11 + bnds.deltax = jnp.array(12, dtype=float) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 12 + + bnds.deltay = 12.0 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 12 + bnds.deltay = jnp.array(13, dtype=float) + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 13 From 4922f21bd0bc977bcac879b06ea684d26ebe482e Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:46:20 -0500 Subject: [PATCH 57/72] test: more tests for bounds set --- tests/jax/test_bounds_jax.py | 111 ++++++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 8f1f4cac..5b5d05a6 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -1,3 +1,5 @@ +from functools import partial + import jax import jax.numpy as jnp import numpy as np @@ -217,7 +219,7 @@ def test_bounds_jax_vmap_plus_raises_int(): _plus_bounds_pos_far_away_float(bnds) -def test_bounds_jax_int_set(): +def test_bounds_jax_int_set_static(): bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) bnds.xmin = 11.0 @@ -247,3 +249,110 @@ def test_bounds_jax_int_set(): bnds.deltay = jnp.array(13, dtype=float) assert isinstance(bnds.deltay, int) assert bnds.deltay == 13 + + +def test_bounds_jax_int_set_dynamic(): + bnds = jax_galsim.BoundsI( + xmin=jnp.array(1), ymin=jnp.array(2), deltax=10, deltay=11 + ) + + bnds.xmin = 11.0 + assert isinstance(bnds.xmin, jnp.ndarray) + assert bnds.xmin == 11 + bnds.xmin = jnp.array(12, dtype=float) + assert isinstance(bnds.xmin, jnp.ndarray) + assert bnds.xmin == 12 + + bnds.ymin = 12.0 + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 12 + bnds.ymin = jnp.array(13, dtype=float) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 13 + + bnds.deltax = 11.0 + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 11 + bnds.deltax = jnp.array(12, dtype=float) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 12 + + bnds.deltay = 12.0 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 12 + bnds.deltay = jnp.array(13, dtype=float) + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 13 + + +def test_bounds_jax_int_set_jit_raises(): + @jax.jit + def _make_bnds_bad_xmin(xmin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + @jax.jit + def _make_bnds_bad_ymin(ymin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.ymin = ymin + return bnds + + with pytest.raises(Exception): + _make_bnds_bad_xmin(2) + + with pytest.raises(Exception): + _make_bnds_bad_ymin(2) + + @jax.jit + def _make_bnds_bad_deltax(deltax): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.deltax = deltax + return bnds + + @jax.jit + def _make_bnds_bad_deltay(deltay): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.deltay = deltay + return bnds + + with pytest.raises(Exception): + _make_bnds_bad_deltay(2) + + with pytest.raises(Exception): + _make_bnds_bad_deltay(2) + + +def test_bounds_jax_int_set_jit(): + @jax.jit + def _make_bnds_set_xmin(xmin): + bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + @jax.jit + def _make_bnds_set_ymin(ymin): + bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11) + bnds.ymin = ymin + return bnds + + bnds = _make_bnds_set_xmin(2) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.xmin == 2 + assert isinstance(bnds.xmin, jnp.ndarray) + + bnds = _make_bnds_set_ymin(2) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 2 + assert isinstance(bnds.ymin, jnp.ndarray) + + @partial(jax.jit, static_argnames=("xmin",)) + def _make_bnds_set_xmin_static(xmin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + bnds = _make_bnds_set_xmin_static(2) + assert isinstance(bnds.xmin, int) + assert isinstance(bnds.ymin, int) + assert bnds.xmin == 2 From 2d48748d2ca9b315c1c9e32831b1cf0f79832c1d Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:48:28 -0500 Subject: [PATCH 58/72] test: fix bounds api tests --- tests/jax/test_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 1aea9f11..13833cf0 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -386,6 +386,8 @@ def _reg_fun(p): "xmin", "ymin", "isStatic", + "deltax", + "deltay", ]: continue From 64af33a70a231b66aa41c8529dcff600a26dbea4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:50:10 -0500 Subject: [PATCH 59/72] fix: sharpen tests --- tests/jax/test_api.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 13833cf0..40e4490b 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -379,18 +379,22 @@ def _reg_fun(p): ): continue - # jax-galsim Bounds classes do not store xmax, ymax - if issubclass(cls, jax_galsim.Bounds) and method in [ - "xmax", - "ymax", + # jax-galsim BoundsI classes do not store xmin, ymin + if issubclass(cls, jax_galsim.BoundsI) and method in [ "xmin", "ymin", - "isStatic", "deltax", "deltay", ]: continue + if issubclass(cls, jax_galsim.Bounds) and method in [ + "xmax", + "ymax", + "isStatic", + ]: + continue + assert method in dir(gscls), ( cls.__name__ + "." + method + " not in galsim." + gscls.__name__ ) From a7efe2d43d6077fff06966be3771fc0e93f85f03 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 05:50:40 -0500 Subject: [PATCH 60/72] doc: add comment --- tests/jax/test_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 40e4490b..f314b784 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -388,6 +388,8 @@ def _reg_fun(p): ]: continue + # jax-galsim Bounds classes do not store xmin, ymin + # and have extra method if issubclass(cls, jax_galsim.Bounds) and method in [ "xmax", "ymax", From eb2748ae7f0c5a38f196814f6ee6477436dd8f2f Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 22 May 2026 06:49:54 -0500 Subject: [PATCH 61/72] fix: rename for clarity --- jax_galsim/bounds.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index fc25ef5e..d5c1ea85 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -700,7 +700,7 @@ def __init__(self, *args, **kwargs): # we set these variables to disable type checking and conversion # for xmin/ymin while we initialize the object self._isstatic = True - self._dotypechecking = False + self._dotypeconversion = False self._parse_args(*args, **kwargs) # validate inputs are ints @@ -734,7 +734,7 @@ def __init__(self, *args, **kwargs): self._isstatic = True else: self._isstatic = False - self._dotypechecking = True + self._dotypeconversion = True def _check_scalar(self, x, name): try: @@ -770,7 +770,7 @@ def xmin(self): def xmin(self, value): value = check_is_int_then_cast(value, "BoundsI xmin values must be integers") if self._isstatic: - if self._dotypechecking: + if self._dotypeconversion: # attempt to convert widths to static values # this will raise if values are being traced # we let that error propagate instead of reraising @@ -824,7 +824,7 @@ def ymin(self): def ymin(self, value): value = check_is_int_then_cast(value, "BoundsI ymin values must be integers") if self._isstatic: - if self._dotypechecking: + if self._dotypeconversion: # attempt to convert widths to static values # this will raise if values are being traced # we let that error propagate instead of reraising From 215b412b35138e68e6b9a77d4e63bf65ab0a8dec Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 07:42:23 -0400 Subject: [PATCH 62/72] test: add tests of includes and simpler bounds init --- jax_galsim/bounds.py | 61 +++++------------ tests/jax/test_bounds_jax.py | 123 +++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 44 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index d5c1ea85..7d74791b 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -697,43 +697,20 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): - # we set these variables to disable type checking and conversion - # for xmin/ymin while we initialize the object + # we set these variables to disable array to python int + # or python int to array conversions for xmin/ymin while we + # initialize the object. + # the setter methods below validate that the inputs are ints, + # so we skip that in the init. + # the class always converts deltax/deltay to python ints and + # an error will be raised if that cannot be done. self._isstatic = True self._dotypeconversion = False self._parse_args(*args, **kwargs) - # validate inputs are ints - self.deltax = check_is_int_then_cast( - self.deltax, "BoundsI must be initialized with integer values" - ) - self.deltay = check_is_int_then_cast( - self.deltay, "BoundsI must be initialized with integer values" - ) - self.xmin = check_is_int_then_cast( - self.xmin, "BoundsI must be initialized with integer values" - ) - self.ymin = check_is_int_then_cast( - self.ymin, "BoundsI must be initialized with integer values" - ) - - # attempt to convert widths to static values - # this will raise if values are being traced - # we let that error propagate instead of reraising - # our own. - with jax.ensure_compile_time_eval(): - if not isinstance(self.deltax, int): - self.deltax = int(self.deltax.item()) - if not isinstance(self.deltay, int): - self.deltay = int(self.deltay.item()) - + # now we compute these properties correctly and turn on type conversion self._isdefined = self.deltax >= 1 and self.deltay >= 1 - - # now we compute these properties correctly and turn on type checking - if isinstance(self._xmin, int) and isinstance(self._ymin, int): - self._isstatic = True - else: - self._isstatic = False + self._isstatic = isinstance(self._xmin, int) and isinstance(self._ymin, int) self._dotypeconversion = True def _check_scalar(self, x, name): @@ -757,14 +734,12 @@ def numpyShape(self): else: return 0, 0 - # we store xmin internally as a float even though it is an int - # so that autodiff works properly (needs floats in general) @property def xmin(self): - if self._isstatic: - return self._xmin - else: - return jnp.astype(self._xmin, int) + # for non-static bounds we store xmin internally as a float even + # though it is an int so that autodiff works properly (needs floats in general). + # thus we cast here. + return cast_to_int(self._xmin) @xmin.setter def xmin(self, value): @@ -811,14 +786,12 @@ def xmax(self, value): ) self.deltax = value - self.xmin + 1 - # we store ymin internally as a float even though it is an int - # so that autodiff works properly (needs floats in general) @property def ymin(self): - if self._isstatic: - return self._ymin - else: - return jnp.astype(self._ymin, int) + # for non-static bounds we store ymin internally as a float even + # though it is an int so that autodiff works properly (needs floats in general). + # thus we cast here. + return cast_to_int(self._ymin) @ymin.setter def ymin(self, value): diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py index 5b5d05a6..4a03ed93 100644 --- a/tests/jax/test_bounds_jax.py +++ b/tests/jax/test_bounds_jax.py @@ -356,3 +356,126 @@ def _make_bnds_set_xmin_static(xmin): assert isinstance(bnds.xmin, int) assert isinstance(bnds.ymin, int) assert bnds.xmin == 2 + + +@jax.vmap +@jax.jit +def _make_bounds_int_nodef(xmin, ymin): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=0, deltay=0) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_float(bnds): + return bnds.includes(jax_galsim.BoundsD(9.5, 9.75, 9.5, 9.75)) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_nodef_float(bnds): + return bnds.includes(jax_galsim.BoundsD()) + + +@jax.vmap +@jax.jit +def _bounds_includes_pos_float(bnds): + return bnds.includes(jax_galsim.PositionD(9.5, 9.5)) + + +@jax.vmap +@jax.jit +def _bounds_includes_xy_float(bnds): + return bnds.includes(9.5, 9.5) + + +def test_bounds_jax_vmap_includes_float(): + xmin = jnp.array([8, 9, 10, 11, 12]) + xmax = jnp.array([9, 12, 11, 10, 9]) + ymin = jnp.array([7, 9, 11, 10, 12]) + ymax = jnp.array([8, 10, 10, 10, 10]) + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + + incs = _bounds_includes_bounds_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + incs = _bounds_includes_bounds_nodef_float(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + incs = _bounds_includes_xy_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_int(bnds): + return bnds.includes(jax_galsim.BoundsI(9, 9, 10, 10)) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_nodef_int(bnds): + return bnds.includes(jax_galsim.BoundsI()) + + +@jax.vmap +@jax.jit +def _bounds_includes_pos_int(bnds): + return bnds.includes(jax_galsim.PositionD(9, 10)) + + +@jax.vmap +@jax.jit +def _bounds_includes_xy_int(bnds): + return bnds.includes(9.5, 9.7) + + +def test_bounds_jax_vmap_includes_int(): + xmin = jnp.array([8, 9, 10, 11, 12]) + ymin = jnp.array([7, 9, 11, 10, 12]) + + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef) + + incs = _bounds_includes_bounds_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + incs = _bounds_includes_bounds_nodef_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + incs = _bounds_includes_xy_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + bnds, isdef = _make_bounds_int_nodef(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef) + + incs = _bounds_includes_bounds_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_bounds_nodef_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_xy_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) From 863a99ed0dc9e0d745316f1b607b44b1ddc164c6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 07:47:18 -0400 Subject: [PATCH 63/72] test: run faster? --- .github/workflows/python_package.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 2b16e5f1..cabcf97e 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -83,7 +83,7 @@ jobs: --splitting-algorithm least_duration \ --clean-durations \ --retries 1 \ - -n 2 + -n 3 - name: Upload test durations uses: actions/upload-artifact@v7 From 137edd1126078288e24841a064d373279540322c Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 07:56:03 -0400 Subject: [PATCH 64/72] test: use more splits --- .github/workflows/python_package.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index cabcf97e..48b53470 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,9 +21,9 @@ jobs: fail-fast: false matrix: python-version: ["3.12"] - group: [1, 2, 3, 4] + group: [1, 2, 3, 4, 5, 6, 7, 8] env: - NUM_SPLITS: 4 + NUM_SPLITS: 8 steps: - uses: actions/checkout@v6 @@ -83,7 +83,7 @@ jobs: --splitting-algorithm least_duration \ --clean-durations \ --retries 1 \ - -n 3 + -n 2 - name: Upload test durations uses: actions/upload-artifact@v7 From 6a5a37dc9821a9d14e5de46c7f79457543c8130c Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:03:12 -0400 Subject: [PATCH 65/72] test: fewer splits, more workers --- .github/workflows/python_package.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 48b53470..ae9dd1ac 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: python-version: ["3.12"] - group: [1, 2, 3, 4, 5, 6, 7, 8] + group: [1, 2, 3, 4, 5, 6] env: NUM_SPLITS: 8 @@ -69,7 +69,8 @@ jobs: --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ --splitting-algorithm least_duration \ --retries 1 \ - --test-in-float32 + --test-in-float32 \ + -n 4 - name: Test with pytest run: | @@ -83,7 +84,7 @@ jobs: --splitting-algorithm least_duration \ --clean-durations \ --retries 1 \ - -n 2 + -n 4 - name: Upload test durations uses: actions/upload-artifact@v7 From c7348ef6feb1c7369c810800fcfa38b54f598cf2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:03:27 -0400 Subject: [PATCH 66/72] test: less logging --- .github/workflows/python_package.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index ae9dd1ac..fa76ae47 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -63,7 +63,7 @@ jobs: - name: Test with pytest in float32 run: | pytest \ - -vv \ + -v \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ @@ -75,7 +75,7 @@ jobs: - name: Test with pytest run: | pytest \ - -vv \ + -v \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ From 7b35ae88027097b4e6a09612f7f63dd670f11f90 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:10:05 -0400 Subject: [PATCH 67/72] test: try four xdist only --- .github/workflows/python_package.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index fa76ae47..ad342130 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,9 +21,9 @@ jobs: fail-fast: false matrix: python-version: ["3.12"] - group: [1, 2, 3, 4, 5, 6] + group: [1, 2, 3, 4] env: - NUM_SPLITS: 8 + NUM_SPLITS: 4 steps: - uses: actions/checkout@v6 From b71db8bd0109c65148879300baff8d43a64a0316 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sat, 23 May 2026 07:11:01 -0500 Subject: [PATCH 68/72] Apply suggestion from @beckermr --- tests/jax/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index f314b784..411d4a80 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -388,7 +388,7 @@ def _reg_fun(p): ]: continue - # jax-galsim Bounds classes do not store xmin, ymin + # jax-galsim Bounds classes do not store xmax, ymax # and have extra method if issubclass(cls, jax_galsim.Bounds) and method in [ "xmax", From 642b4ae052e103ae222406362fce13880435d038 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sat, 23 May 2026 07:11:34 -0500 Subject: [PATCH 69/72] Apply suggestion from @beckermr --- tests/jax/test_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 411d4a80..6557baf3 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -380,6 +380,7 @@ def _reg_fun(p): continue # jax-galsim BoundsI classes do not store xmin, ymin + # or deltax/y directly if issubclass(cls, jax_galsim.BoundsI) and method in [ "xmin", "ymin", From caa1e424241046eef824e7b180caed52ebf0b132 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:22:55 -0400 Subject: [PATCH 70/72] fix: implement _BoundsD/I --- jax_galsim/__init__.py | 2 +- jax_galsim/bounds.py | 16 ++++++++++++++++ tests/GalSim | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index a90a5b6c..8bf81f45 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -33,7 +33,7 @@ ) # Basic building blocks -from .bounds import Bounds, BoundsD, BoundsI +from .bounds import Bounds, BoundsD, BoundsI, _BoundsD, _BoundsI from .gsparams import GSParams from .position import Position, PositionD, PositionI from .angle import Angle, AngleUnit, _Angle, radians, hours, degrees, arcmin, arcsec diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 7d74791b..9da0e9bc 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -967,3 +967,19 @@ def tree_unflatten(cls, aux_data, children): ret._isdefined = aux_data["isdefined"] ret._isstatic = aux_data["isstatic"] return ret + + +@implements( + _galsim._BoundsD, + lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsD``.", +) +def _BoundsD(xmin, xmax, ymin, ymax): + return BoundsD(xmin, xmax, ymin, ymax) + + +@implements( + _galsim._BoundsI, + lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsI``.", +) +def _BoundsI(xmin, xmax, ymin, ymax): + return BoundsI(xmin, xmax, ymin, ymax) diff --git a/tests/GalSim b/tests/GalSim index 63d576d1..ba294ca5 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 63d576d1ffe836e965a1c6b547e127e5f457cbb9 +Subproject commit ba294ca5fd19ad8c656beaf3f9b7d177134ae6c4 From d9af53528d456258bc6fdab7106558fa3d9c54fa Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:32:19 -0400 Subject: [PATCH 71/72] test: just a bit faster --- .github/workflows/python_package.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index ad342130..d8ee661d 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,9 +21,9 @@ jobs: fail-fast: false matrix: python-version: ["3.12"] - group: [1, 2, 3, 4] + group: [1, 2, 3, 4, 5, 6] env: - NUM_SPLITS: 4 + NUM_SPLITS: 6 steps: - uses: actions/checkout@v6 From afe868f0b684dad246e51344d7ce197839a7e445 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 23 May 2026 08:34:08 -0400 Subject: [PATCH 72/72] fix: do not cat this giant file --- .github/workflows/python_package.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index d8ee661d..b476886e 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -57,7 +57,6 @@ jobs: cp .test_durations .test_durations.${{ matrix.group }} ls -lah .test_durations* echo " " - cat .test_durations* fi - name: Test with pytest in float32