diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 2b16e5f1..b476886e 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,9 +21,9 @@ jobs: fail-fast: false matrix: python-version: ["3.12"] - group: [1, 2, 3, 4] + group: [1, 2, 3, 4, 5, 6] env: - NUM_SPLITS: 4 + NUM_SPLITS: 6 steps: - uses: actions/checkout@v6 @@ -57,24 +57,24 @@ jobs: cp .test_durations .test_durations.${{ matrix.group }} ls -lah .test_durations* echo " " - cat .test_durations* fi - name: Test with pytest in float32 run: | pytest \ - -vv \ + -v \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ --splitting-algorithm least_duration \ --retries 1 \ - --test-in-float32 + --test-in-float32 \ + -n 4 - name: Test with pytest run: | pytest \ - -vv \ + -v \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ @@ -83,7 +83,7 @@ jobs: --splitting-algorithm least_duration \ --clean-durations \ --retries 1 \ - -n 2 + -n 4 - name: Upload test durations uses: actions/upload-artifact@v7 diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index ff28546e..0b1297f0 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -61,8 +61,21 @@ does not affect the original. # JAX-GalSim — real_part is a copy real_part = complex_image.real # independent array -Scalar Types, Array Types, and Casting --------------------------------------- +Fixed Array Shapes in JAX Function Transformations +-------------------------------------------------- + +JAX function transformations (e.g., ``jax.jit``, ``jax.vmap``, etc.) require statically known +array shapes in order to support tracing. To support this, the JAX-GalSim ``BoundsI`` class must +have a statically known shape. Further this class can be instantiated via the syntax +``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` where ``deltax/y`` are the statically defined +shape. ``BoundsI`` classes may have dynamically set ``x/ymin`` values. However, in this case the ``&`` +and ``+`` operations, which can change the shape of the ``BoundsI`` instance are not allowed in +JAX-traced code. ``BoundsI`` instances have a special method ``isStatic()`` which returns ``True`` +if the object was instantiated with statically know ``x/ymin`` values. A static ``BoundsI`` class +cannot be converted to a dynamic one via assignment and an attempt to do so will raise an exception. + +Scalar Types, Array Types, and Type Casting +------------------------------------------- With the use of JAX, there are now many possible types for numeric data. These include @@ -89,6 +102,24 @@ These rules allow JAX-GalSim to transparently handle JAX's tracing operations, b the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in some cases. +Object Comparison with the ``==`` Operator +------------------------------------------ + +In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean +array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result +of the ``==`` operator. Otherwise the return value is a Python boolean. Important cases of this +rule are static ``BoundsI`` objects, ``Interpolant`` objects (and their subclasses), and ``GSParams`` +objects, all of which return Python boolean values (i.e. ``True`` and ``False``). These difference +can be a source of subtle bugs since the negation of JAX array boolean values is typically done +with ``~``, while for Python boolean values it is done with ``not``. Mixing these two forms can +cause unexpected and incorrect results since + +.. code-block:: python + + >>> ~True is False + :1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="? + False + Random Number Generation ------------------------ diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index a90a5b6c..8bf81f45 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -33,7 +33,7 @@ ) # Basic building blocks -from .bounds import Bounds, BoundsD, BoundsI +from .bounds import Bounds, BoundsD, BoundsI, _BoundsD, _BoundsI from .gsparams import GSParams from .position import Position, PositionD, PositionI from .angle import Angle, AngleUnit, _Angle, radians, hours, degrees, arcmin, arcsec diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index fd3af1c3..c248c1a2 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -112,10 +112,13 @@ def __repr__(self): return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),) def __eq__(self, other): - return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value) + if not isinstance(other, AngleUnit): + return jnp.array(False) + else: + return jnp.array_equal(self.value, other.value) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __hash__(self): return hash(("galsim.AngleUnit", ensure_hashable(self.value))) @@ -253,10 +256,13 @@ def __repr__(self): return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),) def __eq__(self, other): - return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad) + if not isinstance(other, Angle): + return jnp.array(False) + else: + return jnp.array_equal(self.rad, other.rad) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __le__(self, other): if not isinstance(other, Angle): diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 69ab8746..9da0e9bc 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -14,26 +15,33 @@ from jax_galsim.position import Position, PositionD, PositionI BOUNDS_LAX_DESCR = """\ -The JAX implementation - -- will not always test whether the bounds are valid - -Further, the JAX implementation adds a new method, ``isStatic`` to the -``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance -has been instantiated with static, known values, ``isStatic()`` will -return ``True``. - -``BoundsI`` objects in JAX-Galsim support an additional initialization -call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, -the values for ``deltax/y`` indicate the width of the bounds and must be -static constants. - -When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin`` -are vectorized over. This restriction allows for code that renders -objects in fixed sized stamps with variable locations, a common -operation. ``BoundsI`` objects which are static (i.e., ``isStatic()`` -returns ``True``) are treated as constants with respect to ``vmap``, -``jit``, and other JAX transforms. +The JAX-GalSim implementation of the ``BoundsI/D`` classes have some key differences +from GalSim. + +- ``BoundsI`` instances must have statically known shapes, but may have non-static + start locations (i.e., ``xmin`` and ``ymin`` may be JAX arrays, traced in JIT operations, etc.). + This restriction mirrors the JAX restriction that arrays have fixed shapes when traced + for function transformations like ``jax.vmap``, ``jax.jit``, etc. +- Upon initialization, if a ``BoundsI`` object has a non-static shape, JAX-GalSim will attempt to convert + it to a static shape by extracting the dimensions from the array via ``.item()``. This operation will + cause JAX to raise an error if the code is being traced. JAX-Galsim performs the same conversion operation + when the ``deltax`` or ``deltay`` properties are set to non-static values via assignment. +- If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, and then one attempts to + convert them to non-static values via assignment, JAX-GalSim will attempt to convert the assigned values + back to static values. This operation will raise an error if the code is being traced. +- ``Bounds`` classes in JAX-GalSim have an etxra method, ``isStatic`` that returns ``True`` if the object + was instantiated with static ``xmin`` and ``ymin`` values. This method always returns ``False`` for + ``BoundsD`` objects. +- JAX-GalSim does not support the use of the `&` and `+` operators (i.e., the dunder methods ``__and__`` + and ``__add__`` ) with ``BoundsI`` objects when tracing code. +- JAX-Galsim supports an additional initialization signature ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` + to help users specify the widths ``deltax`` and ``deltay`` statically at initialization. +- When calling ``jax.vmap``, ``jax.jit`` etc. with ``BoundsI`` objects, ``xmin`` and ``ymin`` are + traced by JAX. The combination of this feature with statically known shapes allows for code that renders + objects in fixed sized stamps with variable locations, a common operation. +- For ``BoundsD``, all ``x(y)min(max)`` values are traced as arrays. +- ``Bounds`` objects always return a JAX boolean values for various method calls, except for + ``BoundsI.isDefined()`` which is always a Python boolean value. """ @@ -46,11 +54,13 @@ def __init__(self): ) def _parse_args(self, *args, **kwargs): + do_isdefined = True + if len(kwargs) == 0: if len(args) == 4: - self._isdefined = True self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: + do_isdefined = False self._isdefined = False self.xmin = 0 self.ymin = 0 @@ -70,7 +80,6 @@ def _parse_args(self, *args, **kwargs): self.ymin = args[0].ymin self.deltay = args[0].deltay + offset elif isinstance(args[0], Position): - self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y else: @@ -78,14 +87,12 @@ def _parse_args(self, *args, **kwargs): "Single argument to %s must be either a Bounds or a Position" % (self.__class__.__name__) ) - self._isdefined = True elif len(args) == 2: if isinstance(args[0], Position) and isinstance(args[1], Position): - self._isdefined = True - self.xmin = min(args[0].x, args[1].x) - self.xmax = max(args[0].x, args[1].x) - self.ymin = min(args[0].y, args[1].y) - self.ymax = max(args[0].y, args[1].y) + self.xmin = jnp.minimum(args[0].x, args[1].x) + self.xmax = jnp.maximum(args[0].x, args[1].x) + self.ymin = jnp.minimum(args[0].y, args[1].y) + self.ymax = jnp.maximum(args[0].y, args[1].y) else: raise TypeError( "Two arguments to %s must be Positions" @@ -103,7 +110,6 @@ def _parse_args(self, *args, **kwargs): ) else: try: - self._isdefined = True self.xmin = kwargs.pop("xmin") self.ymin = kwargs.pop("ymin") except KeyError: @@ -128,17 +134,7 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - # for simple inputs, we can check if the bounds are valid - if isinstance(self, BoundsD): - max_delta = 0 - else: - max_delta = 1 - if ( - isinstance(self.deltax, (int, float, np.integer, np.floating)) - and isinstance(self.deltay, (int, float, np.integer, np.floating)) - and (self.deltax < max_delta or self.deltay < max_delta) - ): - self._isdefined = False + return do_isdefined @implements(_galsim.Bounds.area) def area(self): @@ -166,18 +162,32 @@ def origin(self): @property @implements(_galsim.Bounds.center) def center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "center is invalid for an undefined Bounds" + if isinstance(self._isdefined, bool): + if not self._isdefined: + raise _galsim.GalSimUndefinedBoundsError( + "center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "center is invalid for an undefined Bounds", ) return self._center @property @implements(_galsim.Bounds.true_center) def true_center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "true_center is invalid for an undefined Bounds" + if isinstance(self._isdefined, bool): + if not self._isdefined: + raise _galsim.GalSimUndefinedBoundsError( + "true_center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "true_center is invalid for an undefined Bounds", ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) @@ -187,32 +197,34 @@ def includes(self, *args): if isinstance(args[0], Bounds): b = args[0] return ( - self.isDefined() - & b.isDefined() - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) + jnp.array(self.isDefined()) + & jnp.array(b.isDefined()) + & jnp.array(self.xmin <= b.xmin) + & jnp.array(self.xmax >= b.xmax) + & jnp.array(self.ymin <= b.ymin) + & jnp.array(self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( - self.isDefined() - & (self.xmin <= p.x) - & (self.ymin <= p.y) - & (p.x <= self.xmax) - & (p.y <= self.ymax) + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args + x = cast_to_float(x) + y = cast_to_float(y) return ( - self.isDefined() - & (self.xmin <= float(x)) - & (self.ymin <= float(y)) - & (float(x) <= self.xmax) - & (float(y) <= self.ymax) + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -264,59 +276,29 @@ def shift(self, delta): def __and__(self, other): if not isinstance(other, self.__class__): raise TypeError("other must be a %s instance" % self.__class__.__name__) - if not self.isDefined() or not other.isDefined(): - return self.__class__() - else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) + + return _bounds_and_op_dynamic(self, other) def __add__(self, other): if isinstance(other, self.__class__): - if not other.isDefined(): - return self - elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return other + return _bounds_bounds_add_op_dynamic(self, other) elif isinstance(other, self._pos_class): - if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return self.__class__(other) + return _bounds_pos_add_op_dynamic(self, other) else: raise TypeError( "other must be either a %s or a %s" % (self.__class__.__name__, self._pos_class.__name__) ) - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.xmax, self.ymin, self.ymax) - else: - return () - def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self._getinitargs() == other._getinitargs() + raise NotImplementedError( + "The `__eq__` magic method must be implemented by subclasses of `Bounds`." ) def __ne__(self, other): - return not self.__eq__(other) + raise NotImplementedError( + "The `__ne__` magic method must be implemented by subclasses of `Bounds`." + ) def __hash__(self): return hash( @@ -333,26 +315,23 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - children = (self.xmin, self.deltax, self.ymin, self.deltay) - else: - children = tuple() + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) # Define auxiliary static data that doesn’t need to be traced - aux_data = None + aux_data = {"isstatic": self._isstatic} return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - if children: - return cls( - xmin=children[0], - deltax=children[1], - ymin=children[2], - deltay=children[3], - ) - else: - return cls() + ret = cls.__new__(cls) + ret.xmin = children[0] + ret.deltax = children[1] + ret.ymin = children[2] + ret.deltay = children[3] + ret._isdefined = children[4] + ret._isstatic = aux_data["isstatic"] + + return ret @classmethod def from_galsim(cls, galsim_bounds): @@ -402,25 +381,205 @@ def isStatic(self): return self._isstatic +def _bounds_and_op_dynamic(self, other): + xmin = jnp.maximum(self.xmin, other.xmin) + xmax = jnp.minimum(self.xmax, other.xmax) + ymin = jnp.maximum(self.ymin, other.ymin) + ymax = jnp.minimum(self.ymax, other.ymax) + + is_defined = ( + jnp.array(self.isDefined()) + & jnp.array(other.isDefined()) + & jnp.array(ymin <= ymax) + & jnp.array(xmin <= xmax) + ) + xmin = jnp.where( + is_defined, + xmin, + 0.0, + ) + xmax = jnp.where( + is_defined, + xmax, + 0.0, + ) + ymin = jnp.where( + is_defined, + ymin, + 0.0, + ) + ymax = jnp.where( + is_defined, + ymax, + 0.0, + ) + + cls = self.__class__ + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + # we have to do a conversion to static bools here too + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = is_defined + ret._isstatic = False + + return ret + + +def _bounds_bounds_add_op_dynamic(self, other): + def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): + return jnp.where( + ~jnp.array(other_isdef), + self_attr, + jnp.where(jnp.array(self_isdef), op(self_attr, other_attr), other_attr), + ) + + xmin = _ret_correct_attr( + self._isdefined, self.xmin, other._isdefined, other.xmin, jnp.minimum + ) + xmax = _ret_correct_attr( + self._isdefined, self.xmax, other._isdefined, other.xmax, jnp.maximum + ) + ymin = _ret_correct_attr( + self._isdefined, self.ymin, other._isdefined, other.ymin, jnp.minimum + ) + ymax = _ret_correct_attr( + self._isdefined, self.ymax, other._isdefined, other.ymax, jnp.maximum + ) + + cls = self.__class__ + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + is_defined = jnp.where( + ~jnp.array(other._isdefined), + jnp.array(self._isdefined), + jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1), + jnp.array(other._isdefined), + ), + ) + # we have to do a conversion to static bools here too + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = jnp.where( + ~jnp.array(other._isdefined), + jnp.array(self._isdefined), + jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(other._isdefined), + ), + ) + ret._isstatic = False + + return ret + + +def _bounds_pos_add_op_dynamic(self, other): + xmin = jnp.where( + self._isdefined, + jnp.minimum(self.xmin, other.x), + other.x, + ) + xmax = jnp.where( + self._isdefined, + jnp.maximum(self.xmax, other.x), + other.x, + ) + ymin = jnp.where( + self._isdefined, + jnp.minimum(self.ymin, other.y), + other.y, + ) + ymax = jnp.where( + self._isdefined, + jnp.maximum(self.ymax, other.y), + other.y, + ) + + cls = self.__class__ + if isinstance(self, BoundsI): + # we use the class constructor here to ensure we properly convert + # bounds shape to static ints + ret = cls( + xmin=xmin, + deltax=xmax - xmin + 1, + ymin=ymin, + deltay=ymax - ymin + 1, + ) + is_defined = jnp.where( + jnp.array(self._isdefined), + jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1), + jnp.array(True), + ) + # we have to do a conversion to static bools here too + with jax.ensure_compile_time_eval(): + ret._isdefined = bool(is_defined.item()) + else: + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = jnp.where( + self._isdefined, + jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0), + jnp.array(True), + ) + ret._isstatic = False + + return ret + + @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsD(Bounds): _pos_class = PositionD def __init__(self, *args, **kwargs): + do_isdefined = self._parse_args(*args, **kwargs) + self.xmin = cast_to_float(jnp.array(self.xmin)) + self.deltax = cast_to_float(jnp.array(self.deltax)) + self.ymin = cast_to_float(jnp.array(self.ymin)) + self.deltay = cast_to_float(jnp.array(self.deltay)) + if do_isdefined: + self._isdefined = (self.deltax >= 0) & (self.deltay >= 0) + self._isdefined = jnp.array(self._isdefined) self._isstatic = False - self._parse_args(*args, **kwargs) - self.xmin = cast_to_float(self.xmin) - self.deltax = cast_to_float(self.deltax) - self.ymin = cast_to_float(self.ymin) - self.deltay = cast_to_float(self.deltay) def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)) and x.shape == () - and x.dtype.name in ["float32", "float64", "float"] + and jnp.issubdtype(jnp.array(x).dtype, jnp.floating) ): return elif x == float(x): @@ -453,7 +612,17 @@ def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) def __repr__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%r, %r, %r, %r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -465,7 +634,17 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -476,6 +655,26 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.xmax == other.xmax) + & jnp.array(self.ymax == other.ymax) + ) | ((~self_isdef) & (~other_isdef)) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) + def __hash__(self): return hash( ( @@ -487,6 +686,10 @@ def __hash__(self): ) ) + def _getinitargs(self): + # defined only for galsim test suite + return (self.xmin, self.xmax, self.ymin, self.ymax) + @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -494,47 +697,28 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): - # initial setting to let stuff pass through freely + # we set these variables to disable array to python int + # or python int to array conversions for xmin/ymin while we + # initialize the object. + # the setter methods below validate that the inputs are ints, + # so we skip that in the init. + # the class always converts deltax/deltay to python ints and + # an error will be raised if that cannot be done. self._isstatic = True - + self._dotypeconversion = False self._parse_args(*args, **kwargs) - self.deltax = cast_to_float(self.deltax) - self.deltay = cast_to_float(self.deltay) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") - self.deltax = cast_to_int(self.deltax) - self.deltay = cast_to_int(self.deltay) - - if not ( - isinstance( - self._xmin, - (int, float, np.floating, np.integer), - ) - and isinstance( - self._ymin, - (int, float, np.floating, np.integer), - ) - ): - self._isstatic = False - - # validate inputs are ints - self._xmin = check_is_int_then_cast( - self._xmin, "BoundsI must be initialized with integer values" - ) - self._ymin = check_is_int_then_cast( - self._ymin, "BoundsI must be initialized with integer values" - ) - - if self.deltax < 1 and self.deltay < 1: - self._isdefined = False + # now we compute these properties correctly and turn on type conversion + self._isdefined = self.deltax >= 1 and self.deltay >= 1 + self._isstatic = isinstance(self._xmin, int) and isinstance(self._ymin, int) + self._dotypeconversion = True def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)) and x.shape == () - and x.dtype.name in ["int32", "int64", "int"] + and jnp.issubdtype(jnp.array(x).dtype, jnp.integer) ): return elif x == int(x): @@ -545,58 +729,118 @@ def _check_scalar(self, x, name): def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if self.isDefined(): + if self._isdefined: return self.deltay, self.deltax else: return 0, 0 @property def xmin(self): - if self._isstatic: - return self._xmin - else: - return jnp.astype(self._xmin, jnp.int_) + # for non-static bounds we store xmin internally as a float even + # though it is an int so that autodiff works properly (needs floats in general). + # thus we cast here. + return cast_to_int(self._xmin) @xmin.setter def xmin(self, value): + value = check_is_int_then_cast(value, "BoundsI xmin values must be integers") if self._isstatic: + if self._dotypeconversion: + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) self._xmin = value else: - self._xmin = jnp.astype(value, jnp.float_) + self._xmin = jnp.astype(value, float) + + @property + def deltax(self): + return self._deltax + + @deltax.setter + def deltax(self, value): + value = check_is_int_then_cast( + value, "BoundsI deltax must be set to an integer value" + ) + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) + self._deltax = value @property def xmax(self): - return self.xmin + self.deltax - 1 + return cast_to_int(self.xmin + self.deltax - 1) @xmax.setter def xmax(self, value): + value = check_is_int_then_cast( + value, "BoundsI xmax must be set to an integer value" + ) self.deltax = value - self.xmin + 1 @property def ymin(self): - if self._isstatic: - return self._ymin - else: - return jnp.astype(self._ymin, jnp.int_) + # for non-static bounds we store ymin internally as a float even + # though it is an int so that autodiff works properly (needs floats in general). + # thus we cast here. + return cast_to_int(self._ymin) @ymin.setter def ymin(self, value): + value = check_is_int_then_cast(value, "BoundsI ymin values must be integers") if self._isstatic: + if self._dotypeconversion: + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) self._ymin = value else: - self._ymin = jnp.astype(value, jnp.float_) + self._ymin = jnp.astype(value, float) + + @property + def deltay(self): + return self._deltay + + @deltay.setter + def deltay(self, value): + value = check_is_int_then_cast( + value, "BoundsI deltay must be set to an integer value" + ) + # attempt to convert widths to static values + # this will raise if values are being traced + # we let that error propagate instead of reraising + # our own. + with jax.ensure_compile_time_eval(): + if not isinstance(value, int): + value = int(value.item()) + self._deltay = value @property def ymax(self): - return self.ymin + self.deltay - 1 + return cast_to_int(self.ymin + self.deltay - 1) @ymax.setter def ymax(self, value): + value = check_is_int_then_cast( + value, "BoundsI ymax must be set to an integer value" + ) self.deltay = value - self.ymin + 1 def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if not self.isDefined(): + if not self._isdefined: return 0 else: return self.deltax * self.deltay @@ -613,58 +857,8 @@ def _center(self): self.ymin + self.deltay // 2, ) - def tree_flatten(self): - """This function flattens the Bounds into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - if self._isstatic: - # Define the children nodes of the PyTree that need tracing - children = tuple() - - # Define auxiliary static data that doesn’t need to be traced - aux_data = { - "xmin": self._xmin, - "ymin": self._ymin, - "deltax": self.deltax, - "deltay": self.deltay, - } - else: - children = (self._xmin, self._ymin) - # Define auxiliary static data that doesn’t need to be traced - aux_data = {"deltax": self.deltax, "deltay": self.deltay} - else: - children = tuple() - aux_data = None - - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - if aux_data is not None: - ret = cls.__new__(cls) - if "xmin" in aux_data and "ymin" in aux_data: - ret._isstatic = True - ret._xmin = aux_data["xmin"] - ret._ymin = aux_data["ymin"] - else: - ret._isstatic = False - ret._xmin = children[0] - ret._ymin = children[1] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - if ret.deltax < 1 and ret.deltay < 1: - ret._isdefined = False - else: - ret._isdefined = True - else: - ret = cls() - - return ret - def __repr__(self): - if self.isDefined(): + if self._isdefined: return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -676,7 +870,7 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + if self._isdefined: return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -687,16 +881,45 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.deltax, self.ymin, self.deltay) + def __eq__(self, other): + if self is other: + if self._isstatic: + return True + else: + return jnp.array(True) + elif isinstance(other, self.__class__): + if self._isstatic and other._isstatic: + return ( + self._isdefined + and other._isdefined + and self.xmin == other.xmin + and self.ymin == other.ymin + and self.deltax == other.deltax + and self.deltay == other.deltay + ) or ((not self._isdefined) and (not other._isdefined)) + else: + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + return ( + self_isdef + & other_isdef + & jnp.array(self.xmin == other.xmin) + & jnp.array(self.ymin == other.ymin) + & jnp.array(self.deltax == other.deltax) + & jnp.array(self.deltay == other.deltay) + ) | ((~self_isdef) & (~other_isdef)) else: - return () + if self._isstatic: + return False + else: + return jnp.array(False) - def __eq__(self, other): - return self is other or ( - isinstance(other, BoundsI) and self._getinitargs() == other._getinitargs() - ) + def __ne__(self, other): + eqval = self.__eq__(other) + if isinstance(eqval, bool): + return not eqval + else: + return ~eqval def __hash__(self): return hash( @@ -708,3 +931,55 @@ def __hash__(self): ensure_hashable(self.deltay), ) ) + + def tree_flatten(self): + """This function flattens the Bounds into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + aux_data = {"isstatic": self._isstatic} + + # Define the children nodes of the PyTree that need tracing + if self._isstatic: + children = tuple() + aux_data["xmin"] = self._xmin + aux_data["ymin"] = self._ymin + else: + children = (self._xmin, self._ymin) + + # untraced aux data + aux_data["deltax"] = self.deltax + aux_data["deltay"] = self.deltay + aux_data["isdefined"] = self._isdefined + + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + ret = cls.__new__(cls) + if aux_data["isstatic"]: + ret._xmin = aux_data["xmin"] + ret._ymin = aux_data["ymin"] + else: + ret._xmin = children[0] + ret._ymin = children[1] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + ret._isstatic = aux_data["isstatic"] + return ret + + +@implements( + _galsim._BoundsD, + lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsD``.", +) +def _BoundsD(xmin, xmax, ymin, ymax): + return BoundsD(xmin, xmax, ymin, ymax) + + +@implements( + _galsim._BoundsI, + lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsI``.", +) +def _BoundsI(xmin, xmax, ymin, ymax): + return BoundsI(xmin, xmax, ymin, ymax) diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 5645eda3..f5c3c08b 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -839,14 +839,15 @@ def __hash__(self): return hash(repr(self)) def __eq__(self, other): - return ( - isinstance(other, CelestialCoord) - and jnp.array_equal(self._ra.rad, other._ra.rad) - and jnp.array_equal(self._dec.rad, other._dec.rad) - ) + if not isinstance(other, CelestialCoord): + return jnp.array(False) + else: + return jnp.array_equal(self._ra.rad, other._ra.rad) & jnp.array_equal( + self._dec.rad, other._dec.rad + ) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def tree_flatten(self): """This function flattens the CelestialCoord into a list of children diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index e6df0769..c8c4c1ae 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -162,13 +162,20 @@ def withGSParams(self, gsparams=None, **kwargs): return ret def __eq__(self, other): - return self is other or ( - isinstance(other, Convolution) - and self.obj_list == other.obj_list - and self.real_space == other.real_space - and self.gsparams == other.gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Convolution): + return ( + jnp.array(self.obj_list == other.obj_list) + & jnp.array(self.real_space == other.real_space) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( @@ -403,12 +410,19 @@ def withGSParams(self, gsparams=None, **kwargs): return ret def __eq__(self, other): - return self is other or ( - isinstance(other, Deconvolution) - and self.orig_obj == other.orig_obj - and self.gsparams == other.gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Deconvolution): + return ( + jnp.array(self.orig_obj == other.orig_obj) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 839d0658..d7394b6b 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -8,12 +8,14 @@ import jax.numpy as jnp import numpy as np +STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating) + def check_is_int_then_cast(val, msg): """Check if `val` is an integer, raise if not, otherwise cast to int.""" val = cast_to_float(val) - if isinstance(val, (int, float, np.integer, np.floating)): + if isinstance(val, STATIC_SCALAR_TYPES): # for simple inputs, we can check direct in python if val != int(val): raise TypeError(msg) @@ -23,7 +25,7 @@ def check_is_int_then_cast(val, msg): val = jnp.array(val) val = equinox.error_if( val, - np.any(val != jnp.trunc(val)), + jnp.any(val != jnp.trunc(val)), msg, ) val = val.astype(int) @@ -43,9 +45,7 @@ def cast_numpy_array_to_native_byte_order(arr): def _cast_to_type(x, typ, accept_strings=False): - if isinstance(x, (int, float, np.integer, np.floating)) or ( - accept_strings and isinstance(x, str) - ): + if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)): return typ(x) else: return jnp.astype(x, typ) @@ -77,52 +77,61 @@ def cast_to_int(x, accept_strings=False): return _cast_to_type(x, int, accept_strings=accept_strings) -def is_equal_with_arrays(x, y): +def is_equal_with_arrays(x, y, currval=None, no_jax=False): """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + if no_jax: + arr_func = np.array + arr_eq_func = np.array_equal + else: + arr_func = jnp.array + arr_eq_func = jnp.array_equal + + if currval is None: + currval = arr_func(True) + if isinstance(x, list): if isinstance(y, list) and len(x) == len(y): for vx, vy in zip(x, y): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, tuple): if isinstance(y, tuple) and len(x) == len(y): for vx, vy in zip(x, y): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, set): if isinstance(y, set) and len(x) == len(y): for vx, vy in zip(sorted(x), sorted(y)): - if not is_equal_with_arrays(vx, vy): - return False - return True + currval &= is_equal_with_arrays(vx, vy, currval=currval, no_jax=no_jax) else: - return False + currval &= arr_func(False) elif isinstance(x, dict): if isinstance(y, dict) and len(x) == len(y): for kx, vx in x.items(): - if kx not in y or (not is_equal_with_arrays(vx, y[kx])): - return False - return True + if kx not in y: + currval &= arr_func(False) + else: + currval &= is_equal_with_arrays( + vx, y[kx], currval=currval, no_jax=no_jax + ) else: - return False + currval &= jnp.array(False) elif isinstance(x, jax.Array) and jnp.ndim(x) > 0: if isinstance(y, jax.Array) and y.shape == x.shape: - return jnp.array_equal(x, y) + currval &= arr_eq_func(x, y) else: - return False + currval &= arr_func(False) elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or ( isinstance(y, jax.Array) and jnp.ndim(y) == 0 ): # this case covers comparing an array scalar to a python scalar or vice versa - return jnp.array_equal(x, y) + currval &= arr_eq_func(x, y) else: - return x == y + currval &= arr_func(x == y) + + return currval def _convert_to_numpy_nan(x): diff --git a/jax_galsim/fits.py b/jax_galsim/fits.py index 8c7837f9..5c66e1c6 100644 --- a/jax_galsim/fits.py +++ b/jax_galsim/fits.py @@ -56,18 +56,7 @@ def readCube(*args, **kwargs): @contextmanager def _image_as_numpy(image): if isinstance(image, Image): - try: - orig_array = image._array - # convert to numpy so astropy doesn't complain - image._array = np.array(image.array, dtype=orig_array.dtype) - # some of these check for Image instances, so we hackily set the class - # on the way in - old_class = image.__class__ - image.__class__ = _galsim.Image - yield image - finally: - image.__class__ = old_class - image._array = orig_array + yield image.to_galsim() else: try: yield np.array(image, dtype=image.dtype) diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 01654860..7c8fab94 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -832,25 +832,39 @@ def copy(self): return copy.copy(self) def __eq__(self, other): - return self is other or ( - isinstance(other, GSFitsWCS) - and self.wcs_type == other.wcs_type - and jnp.array_equal(self.crpix, other.crpix) - and jnp.array_equal(self.cd, other.cd) - and self.center == other.center - and ( - (self.pv is None and other.pv is None) - or jnp.array_equal(self.pv, other.pv) + if self is other: + return jnp.array(True) + elif isinstance(other, GSFitsWCS): + is_eq = ( + jnp.array(self.wcs_type == other.wcs_type) + & jnp.array_equal(self.crpix, other.crpix) + & jnp.array_equal(self.cd, other.cd) + & jnp.array(self.center == other.center) ) - and ( - (self.ab is None and other.ab is None) - or jnp.array_equal(self.ab, other.ab) - ) - and ( - (self.abp is None and other.abp is None) - or jnp.array_equal(self.abp, other.abp) - ) - ) + if self.pv is None and other.pv is None: + pass + elif self.pv is not None and other.pv is not None: + is_eq &= jnp.array_equal(self.pv, other.pv) + else: + is_eq &= jnp.array(False) + + if self.ab is None and other.ab is None: + pass + elif self.ab is not None and other.ab is not None: + is_eq &= jnp.array_equal(self.ab, other.ab) + else: + is_eq &= jnp.array(False) + + if self.abp is None and other.abp is None: + pass + elif self.abp is not None and other.abp is not None: + is_eq &= jnp.array_equal(self.abp, other.abp) + else: + is_eq &= jnp.array(False) + + return is_eq + else: + return jnp.array(False) def __repr__(self): pv_repr = repr(ensure_hashable(self.pv)) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index a5cf51b7..641bbdf3 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -197,10 +197,15 @@ def __neg__(self): return -1.0 * self def __eq__(self, other): - return (self is other) or ( - (type(other) is self.__class__) - and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) - ) + if self is other: + return jnp.array(True) + elif type(other) is self.__class__: + return is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) @implements(_galsim.GSObject.xValue) def xValue(self, *args, **kwargs): @@ -1052,12 +1057,12 @@ def drawKImage( ) # Can't both recenter a provided image and add to it. - if recenter and image.center != PositionI(0, 0) and add_to_image: - raise _galsim.GalSimIncompatibleValuesError( + if recenter and add_to_image: + zp = PositionI(0, 0) + zp.x = equinox.error_if( + jnp.array(zp.x, dtype=int), + image.center != zp, "Cannot use add_to_image=True unless image is centered at (0,0) or recenter=False", - recenter=recenter, - image=image, - add_to_image=add_to_image, ) # Set the center to 0,0 if appropriate diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 63e4da7a..50e1a7ed 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -584,14 +584,12 @@ def subImage(self, bounds): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) - if ( - self.bounds.isStatic() - and bounds.isStatic() - and not self.bounds.includes(bounds) - ): - raise _galsim.GalSimBoundsError( - "Attempt to access subImage not (fully) in image", bounds, self.bounds - ) + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin @@ -623,14 +621,14 @@ def setSubImage(self, bounds, rhs): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if ( - self.bounds.isStatic() - and bounds.isStatic() - and not self.bounds.includes(bounds) - ): - raise _galsim.GalSimBoundsError( - "Attempt to access subImage not (fully) in image", bounds, self.bounds - ) + + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) + if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") if bounds.numpyShape() != rhs.bounds.numpyShape(): @@ -848,17 +846,23 @@ def calculate_fft(self): ) # TODO: figure out how to do FFT at fixed size and then reconstruct - # the result - No2 = max( - max( - -self.bounds.xmin, - self.bounds.xmax + 1, - ), - max( - -self.bounds.ymin, - self.bounds.ymax + 1, - ), - ) + # the result. - MRB + # This has to be a static known constant since it is an array size + # so we ensure it is evaluated at compile-time and extract it + # from the array. + with jax.ensure_compile_time_eval(): + No2 = jnp.maximum( + jnp.maximum( + -self.bounds.xmin, + self.bounds.xmax + 1, + ), + jnp.maximum( + -self.bounds.ymin, + self.bounds.ymax + 1, + ), + ) + if not isinstance(No2, int): + No2 = int(No2.item()) full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) if self.bounds == full_bounds: @@ -902,18 +906,25 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if not self.bounds.includes(0, 0): - raise _galsim.GalSimBoundsError( - "calculate_inverse_fft requires that the image includes (0,0)", - PositionI(0, 0), - self.bounds, - ) - No2 = max( - max(self.bounds.xmax, -self.bounds.ymin), - self.bounds.ymax, + inc_val = jnp.array(self.bounds.includes(0, 0)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "calculate_inverse_fft requires that the image includes (0,0)", ) + # This has to be a static known constant since it is an array size + # so we ensure it is evaluated at compile-time and extract it + # from the array. + with jax.ensure_compile_time_eval(): + No2 = jnp.maximum( + jnp.maximum(self.bounds.xmax, -self.bounds.ymin), + self.bounds.ymax, + ) + if not isinstance(No2, int): + No2 = int(No2.item()) + target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) if self.bounds == target_bounds: # Then the image is already in the shape we need. @@ -1067,12 +1078,13 @@ def getValue(self, x, y): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(x, y): - raise _galsim.GalSimBoundsError( - "Attempt to access position not in bounds of image.", - PositionI(x, y), - self.bounds, - ) + inc_val = jnp.array(self.bounds.includes(x, y)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access position not in bounds of image.", + ) + return self._getValue(x, y) @implements(_galsim.Image._getValue) @@ -1090,10 +1102,13 @@ def setValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): - raise _galsim.GalSimBoundsError( - "Attempt to set position not in bounds of image", pos, self.bounds - ) + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._setValue(pos.x, pos.y, value) @implements(_galsim.Image._setValue) @@ -1111,10 +1126,13 @@ def addValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): - raise _galsim.GalSimBoundsError( - "Attempt to set position not in bounds of image", pos, self.bounds - ) + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._addValue(pos.x, pos.y, value) @implements(_galsim.Image._addValue) @@ -1184,18 +1202,23 @@ def __eq__(self, other): # >>> assert galsim.ImageD(int_array) == galsim.ImageF(int_array) # passes # >>> assert galsim.ImageD(double_array) == galsim.ImageF(double_array) # fails - return self is other or ( - isinstance(other, Image) - and self.bounds == other.bounds - and self.wcs == other.wcs - and ( - not self.bounds.isDefined() or jnp.array_equal(self.array, other.array) + if self is other: + return jnp.array(True) + elif isinstance(other, Image): + return ( + jnp.array(self.bounds == other.bounds) + & jnp.array(self.wcs == other.wcs) + & ( + (~jnp.array(self.bounds.isDefined())) + | jnp.array_equal(self.array, other.array) + ) + & jnp.array(self.isconst == other.isconst) ) - and self.isconst == other.isconst - ) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.Image.transpose) def transpose(self): @@ -1242,16 +1265,8 @@ def rot_180(self): def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing - if self.bounds.isStatic(): - children = (self.array, self.wcs) - aux_data = { - "dtype": self.dtype, - "bounds": self.bounds, - "isconst": self.isconst, - } - else: - children = (self.array, self.wcs, self.bounds) - aux_data = {"dtype": self.dtype, "isconst": self.isconst} + children = (self.array, self.wcs, self.bounds) + aux_data = {"dtype": self.dtype, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): @@ -1269,26 +1284,16 @@ def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] - if "bounds" in aux_data: - obj._bounds = aux_data["bounds"] - obj._dtype = aux_data["dtype"] - obj._is_const = aux_data["isconst"] - if len(children) > 2: - obj.added_flux = children[2] - if "header" in aux_data: - obj.header = aux_data["header"] - if len(children) > 3: - obj.photons = children[3] - else: - obj._bounds = children[2] - obj._dtype = aux_data["dtype"] - obj._is_const = aux_data["isconst"] - if len(children) > 3: - obj.added_flux = children[3] - if "header" in aux_data: - obj.header = aux_data["header"] - if len(children) > 4: - obj.photons = children[4] + obj._bounds = children[2] + obj._dtype = aux_data["dtype"] + obj._is_const = aux_data["isconst"] + if len(children) > 3: + obj.added_flux = children[3] + if "header" in aux_data: + obj.header = aux_data["header"] + if len(children) > 4: + obj.photons = children[4] + return obj @classmethod @@ -1313,9 +1318,12 @@ def from_galsim(cls, galsim_image): def to_galsim(self): """Create a galsim `Image` from a `jax_galsim.Image` object.""" wcs = self.wcs.to_galsim() if self.wcs is not None else None - return _galsim.Image( + ret = _galsim.Image( np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs ) + if hasattr(self, "header"): + ret.header = self.header + return ret @implements( _galsim.Image.FindAdaptiveMom, diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 1b489398..8cfe75bb 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -133,7 +133,11 @@ def _i(self): def __eq__(self, other): return (self is other) or ( type(other) is self.__class__ - and is_equal_with_arrays(self.tree_flatten()[1], other.tree_flatten()[1]) + and is_equal_with_arrays( + self.tree_flatten()[1], + other.tree_flatten()[1], + no_jax=True, + ) ) def __ne__(self, other): diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 9564bd36..589c27a9 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -328,17 +328,24 @@ def __str__(self): return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux) def __eq__(self, other): - return self is other or ( - isinstance(other, InterpolatedImage) - and self._xim == other._xim - and self.x_interpolant == other.x_interpolant - and self.k_interpolant == other.k_interpolant - and self.flux == other.flux - and self._original._offset == other._original._offset - and self.gsparams == other.gsparams - and self._stepk == other._stepk - and self._maxk == other._maxk - ) + if self is other: + return jnp.array(True) + elif isinstance(other, InterpolatedImage): + return ( + jnp.array(self._xim == other._xim) + & jnp.array(self.x_interpolant == other.x_interpolant) + & jnp.array(self.k_interpolant == other.k_interpolant) + & jnp.array_equal(self.flux, other.flux) + & jnp.array(self._original._offset == other._original._offset) + & jnp.array(self.gsparams == other.gsparams) + & jnp.array_equal(self._stepk, other._stepk) + & jnp.array_equal(self._maxk, other._maxk) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def tree_flatten(self): """This function flattens the InterpolatedImage into a list of children diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index 3956cf6a..3daf04c1 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -106,11 +106,15 @@ def _applyTo(self, image): raise NotImplementedError("Cannot call applyTo on a pure BaseNoise object") def __eq__(self, other): - # Quick and dirty. Just check reprs are equal. - return self is other or repr(self) == repr(other) + if self is other: + return jnp.array(True) + elif isinstance(other, BaseNoise): + return jnp.array(self._rng == other._rng) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) __hash__ = None @@ -173,6 +177,16 @@ def __repr__(self): def __str__(self): return "galsim.GaussianNoise(sigma=%s)" % (ensure_hashable(self.sigma),) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array_equal( + self._sigma, other._sigma + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the GaussianNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -265,6 +279,16 @@ def __repr__(self): def __str__(self): return "galsim.PoissonNoise(sky_level=%s)" % (self.sky_level) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array_equal( + self._sky_level, other._sky_level + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the PoissonNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -429,6 +453,19 @@ def __str__(self): self.read_noise, ) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return ( + jnp.array(self._rng == other._rng) + & jnp.array_equal(self._sky_level, other._sky_level) + & jnp.array_equal(self._gain, other._gain) + & jnp.array_equal(self._read_noise, other._read_noise) + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the CCDNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" @@ -570,6 +607,16 @@ def __repr__(self): def __str__(self): return "galsim.VariableGaussianNoise(var_image=%s)" % (self.var_image) + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array(self._rng == other._rng) & jnp.array( + self._var_image == other._var_image + ) + else: + return jnp.array(False) + def tree_flatten(self): """This function flattens the VariableGaussianNoise into a list of children nodes that will be traced by JAX and auxiliary static data.""" diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index e97ee2b7..e9b79411 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -847,22 +847,26 @@ def __str__(self): __hash__ = None def __eq__(self, other): - return self is other or ( - isinstance(other, PhotonArray) - and jnp.array_equal(self.x, other.x) - and jnp.array_equal(self.y, other.y) - and jnp.array_equal(self.flux, other.flux) - and jnp.array_equal(self._nokeep, other._nokeep) - and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) - and jnp.array_equal(self.dydz, other.dydz, equal_nan=True) - and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) - and jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True) - and jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True) - and jnp.array_equal(self.time, other.time, equal_nan=True) - ) + if self is other: + return jnp.array(True) + elif isinstance(other, PhotonArray): + return ( + jnp.array_equal(self.x, other.x) + & jnp.array_equal(self.y, other.y) + & jnp.array_equal(self.flux, other.flux) + & jnp.array_equal(self._nokeep, other._nokeep) + & jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) + & jnp.array_equal(self.dydz, other.dydz, equal_nan=True) + & jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) + & jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True) + & jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True) + & jnp.array_equal(self.time, other.time, equal_nan=True) + ) + else: + return jnp.array(False) def __ne__(self, other): - return not self == other + return ~self.__eq__(other) @implements( _galsim.PhotonArray.addTo, diff --git a/jax_galsim/position.py b/jax_galsim/position.py index cf36dba8..a277413f 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -114,14 +114,15 @@ def __str__(self): ) def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self.x == other.x - and self.y == other.y - ) + if self is other: + return jnp.array(True) + elif not isinstance(other, self.__class__): + return jnp.array(False) + else: + return jnp.array_equal(self.x, other.x) & jnp.array_equal(self.y, other.y) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 2cf2db27..e5442480 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import implements +from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements, is_equal_with_arrays try: from jax.extend.random import wrap_key_data @@ -95,7 +95,7 @@ def generates_in_pairs(self): def seed(self, seed=None): if seed is None: self._seed(seed=seed) - elif isinstance(seed, (int, float, np.integer, np.floating)): + elif isinstance(seed, STATIC_SCALAR_TYPES): if seed == int(seed): self._seed(seed=int(seed)) else: @@ -244,16 +244,17 @@ def __copy__(self): return self.duplicate() def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and jnp.array_equal( + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return jnp.array_equal( jrandom.key_data(self._key), jrandom.key_data(other._key) - ) - and self._params == other._params - ) + ) & is_equal_with_arrays(self._params, other._params) + else: + return jnp.array(False) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) __hash__ = None diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index b0663890..33a58668 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -280,10 +280,15 @@ def __sub__(self, other): return self + (-other) def __eq__(self, other): - return self is other or (isinstance(other, Shear) and self._g == other._g) + if self is other: + return jnp.array(True) + elif not isinstance(other, Shear): + return jnp.array(False) + else: + return jnp.array_equal(self._g, other._g) def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.Shear.getMatrix) def getMatrix(self): diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index e78a67e3..7e9bc8dd 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -158,15 +158,22 @@ def withGSParams(self, gsparams=None, **kwargs): return self.tree_unflatten(aux, chld) def __eq__(self, other): - return self is other or ( - isinstance(other, Transformation) - and self._original == other._original - and jnp.array_equal(self._jac, other._jac) - and self._offset == other._params["offset"] - and self._flux_ratio == other._flux_ratio - and self._gsparams == other._gsparams - and self._propagate_gsparams == other._propagate_gsparams - ) + if self is other: + return jnp.array(True) + elif isinstance(other, Transformation): + return ( + jnp.array(self._original == other._original) + & jnp.array_equal(self._jac, other._jac) + & jnp.array(self._offset == other._params["offset"]) + & jnp.array_equal(self._flux_ratio, other._flux_ratio) + & jnp.array(self._gsparams == other._gsparams) + & jnp.array(self._propagate_gsparams == other._propagate_gsparams) + ) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) def __hash__(self): return hash( diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index ea3a9dbc..43498c77 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -7,6 +7,7 @@ from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, ensure_hashable, implements, @@ -22,7 +23,7 @@ # this kind of casting is only done for writing FITS headers # and should never be done anywhere else in the code base def _cast_to_static_numeric_scalar(x, msg=None): - if isinstance(x, (int, float, np.integer, np.floating)): + if isinstance(x, STATIC_SCALAR_TYPES): return x if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): @@ -556,7 +557,7 @@ def _makeSkyImage(self, image, sky_level, color): # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) @implements(_galsim.wcs.UniformWCS) @@ -598,12 +599,16 @@ def _makeSkyImage(self, image, sky_level, color): # Just check if the locals match and if the origins match. def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self._local_wcs == other._local_wcs - and self.origin == other.origin - and self.world_origin == other.world_origin - ) + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return ( + jnp.array(self._local_wcs == other._local_wcs) + & jnp.array(self.origin == other.origin) + & jnp.array(self.world_origin == other.world_origin) + ) + else: + return jnp.array(False) @implements(_galsim.wcs.LocalWCS) @@ -827,7 +832,7 @@ def _radecToxy(self, ra, dec, units, color): # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): - return not self.__eq__(other) + return ~self.__eq__(other) ######################################################################################### @@ -952,9 +957,12 @@ def copy(self): return PixelScale(self._scale) def __eq__(self, other): - return self is other or ( - isinstance(other, PixelScale) and self.scale == other.scale - ) + if self is other: + return jnp.array(True) + elif isinstance(other, PixelScale): + return jnp.array_equal(self.scale, other.scale) + else: + return jnp.array(False) def __repr__(self): return "galsim.PixelScale(%r)" % (ensure_hashable(self.scale),) @@ -1075,11 +1083,14 @@ def copy(self): return ShearWCS(self._scale, self._shear) def __eq__(self, other): - return self is other or ( - isinstance(other, ShearWCS) - and self.scale == other.scale - and self.shear == other.shear - ) + if self is other: + return jnp.array(True) + elif isinstance(other, ShearWCS): + return jnp.array(self.scale == other.scale) & jnp.array( + self.shear == other.shear + ) + else: + return jnp.array(False) def __repr__(self): return "galsim.ShearWCS(%r, %r)" % (ensure_hashable(self.scale), self.shear) @@ -1281,13 +1292,17 @@ def copy(self): return JacobianWCS(self.dudx, self.dudy, self.dvdx, self.dvdy) def __eq__(self, other): - return self is other or ( - isinstance(other, JacobianWCS) - and self.dudx == other.dudx - and self.dudy == other.dudy - and self.dvdx == other.dvdx - and self.dvdy == other.dvdy - ) + if self is other: + return jnp.array(True) + elif isinstance(other, JacobianWCS): + return ( + jnp.array_equal(self.dudx, other.dudx) + & jnp.array_equal(self.dudy, other.dudy) + & jnp.array_equal(self.dvdx, other.dvdx) + & jnp.array_equal(self.dvdy, other.dvdy) + ) + else: + return jnp.array(False) def __repr__(self): return "galsim.JacobianWCS(%r, %r, %r, %r)" % ( diff --git a/tests/GalSim b/tests/GalSim index e102c876..ba294ca5 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit e102c876b36c5cb1f1b8e9ab3d17cf6d22727803 +Subproject commit ba294ca5fd19ad8c656beaf3f9b7d177134ae6c4 diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e76b081c..6557baf3 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -126,6 +126,14 @@ def _run_object_checks(obj, cls, kind): # check that we can hash the object hash(obj) + + # check that val jax array + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "to-from-galsim": gs_obj = obj.to_galsim() jgs_obj = obj.from_galsim(gs_obj) @@ -141,6 +149,14 @@ def _run_object_checks(obj, cls, kind): # check that we cannot hash the object assert obj.__hash__ is None + + # check that val jax array + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "pickle-eval-repr-wcs": import jax_galsim as galsim # noqa: F401 @@ -152,6 +168,14 @@ def _run_object_checks(obj, cls, kind): # check that we cannot hash the object hash(obj) + + # check that val jax array + if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance( + obj, jax_galsim.Sensor + ): + assert isinstance(eval(repr(obj)) == obj, bool) + else: + assert isinstance(eval(repr(obj)) == obj, jnp.ndarray) elif kind == "jax-compatible": # JAX tracing should be an identity assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj @@ -355,7 +379,18 @@ def _reg_fun(p): ): continue + # jax-galsim BoundsI classes do not store xmin, ymin + # or deltax/y directly + if issubclass(cls, jax_galsim.BoundsI) and method in [ + "xmin", + "ymin", + "deltax", + "deltay", + ]: + continue + # jax-galsim Bounds classes do not store xmax, ymax + # and have extra method if issubclass(cls, jax_galsim.Bounds) and method in [ "xmax", "ymax", @@ -363,12 +398,6 @@ def _reg_fun(p): ]: continue - if issubclass(cls, jax_galsim.BoundsI) and method in [ - "xmin", - "ymin", - ]: - continue - assert method in dir(gscls), ( cls.__name__ + "." + method + " not in galsim." + gscls.__name__ ) @@ -1122,3 +1151,9 @@ def test_api_gsparams(): assert getattr(jgsp, k) == v assert getattr(gsp, k) == v assert getattr(jjgsp, k) == v + + assert jgsp == jjgsp + assert isinstance(jgsp == jjgsp, bool) + + kwargs["minimum_fft_size"] = 126 + assert jgsp != jax_galsim.GSParams(**kwargs) diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py new file mode 100644 index 00000000..4a03ed93 --- /dev/null +++ b/tests/jax/test_bounds_jax.py @@ -0,0 +1,481 @@ +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import jax_galsim + + +@jax.vmap +@jax.jit +def _make_bounds_float(xmin, ymin, xmax, ymax): + bnds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_isdefined_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + + +@jax.vmap +@jax.jit +def _and_bounds_empty_float(bnds): + bnds = bnds & jax_galsim.BoundsD() + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _and_bounds_float(bnds): + bnds = bnds & jax_galsim.BoundsD(xmin=10, xmax=11, ymin=10, ymax=11) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _and_bounds_far_away_float(bnds): + bnds = bnds & jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_and_isdefined_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_empty_float(bnds) + assert bnds.isDefined().shape == (4,) + assert not jnp.any(bnds.isDefined()) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + np.testing.assert_array_equal(bnds.isDefined(), False) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal( + bnds.isDefined(), jnp.array([True, False, False, False]) + ) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 10 + assert bnds.xmax[0] == 11 + assert bnds.ymin[0] == 10 + assert bnds.ymax[0] == 10 + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _and_bounds_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + assert not jnp.any(bnds.isDefined()) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + np.testing.assert_array_equal(bnds.isDefined(), False) + + +@jax.vmap +@jax.jit +def _plus_bounds_far_away_float(bnds): + bnds = bnds + jax_galsim.BoundsD(xmin=100, xmax=110, ymin=100, ymax=110) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _plus_bounds_pos_far_away_float(bnds): + bnds = bnds + jax_galsim.PositionD(x=100, y=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_plus_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _plus_bounds_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal(bnds.isDefined(), True) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 9 + assert bnds.xmax[0] == 110 + assert bnds.ymin[0] == 9 + assert bnds.ymax[0] == 110 + + np.testing.assert_array_equal(bnds.xmin[1:], 100) + np.testing.assert_array_equal(bnds.xmax[1:], 110) + np.testing.assert_array_equal(bnds.ymin[1:], 100) + np.testing.assert_array_equal(bnds.ymax[1:], 110) + + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + bnds, isdef = _plus_bounds_pos_far_away_float(bnds) + assert bnds.isDefined().shape == (4,) + np.testing.assert_array_equal(bnds.isDefined(), True) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + assert bnds.xmin[0] == 9 + assert bnds.xmax[0] == 100 + assert bnds.ymin[0] == 9 + assert bnds.ymax[0] == 110 + + np.testing.assert_array_equal(bnds.xmin[1:], 100) + np.testing.assert_array_equal(bnds.xmax[1:], 100) + np.testing.assert_array_equal(bnds.ymin[1:], 110) + np.testing.assert_array_equal(bnds.ymax[1:], 110) + + +@jax.vmap +@jax.jit +def _make_bounds_int(xmin, ymin): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=10, deltay=11) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + np.testing.assert_array_equal(bnds.xmin, xmin, strict=True) + np.testing.assert_array_equal(bnds.ymin, ymin, strict=True) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 10 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 11 + + +@jax.vmap +@jax.jit +def _make_bounds_int_bad(xmin, ymin, delta): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=delta, deltay=delta) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_varying_shape_raises_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + delta = jnp.array([9, 11, 10, 12]) + with pytest.raises(Exception): + _make_bounds_int_bad(xmin, ymin, delta) + + +@jax.vmap +@jax.jit +def _and_bounds_empty_int(bnds): + bnds = bnds & jax_galsim.BoundsI() + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_and_raises_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + + with pytest.raises(Exception): + _and_bounds_empty_int(bnds) + + +@jax.vmap +@jax.jit +def _plus_bounds_far_away_int(bnds): + bnds = bnds + jax_galsim.BoundsI(xmin=100, deltax=110, ymin=100, deltay=110) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _plus_bounds_pos_far_away_int(bnds): + bnds = bnds + jax_galsim.PositionI(x=100, y=110) + return bnds, bnds.isDefined() + + +def test_bounds_jax_vmap_plus_raises_int(): + xmin = jnp.array([9, 10, 11, 12]) + ymin = jnp.array([9, 11, 10, 12]) + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef[0], strict=True) + np.testing.assert_array_equal(bnds.isDefined(), True) + assert jnp.all(isdef) + + with pytest.raises(Exception): + _plus_bounds_far_away_int(bnds) + + with pytest.raises(Exception): + _plus_bounds_pos_far_away_float(bnds) + + +def test_bounds_jax_int_set_static(): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + + bnds.xmin = 11.0 + assert isinstance(bnds.xmin, int) + assert bnds.xmin == 11 + bnds.xmin = jnp.array(12, dtype=float) + assert isinstance(bnds.xmin, int) + assert bnds.xmin == 12 + + bnds.ymin = 12.0 + assert isinstance(bnds.ymin, int) + assert bnds.ymin == 12 + bnds.ymin = jnp.array(13, dtype=float) + assert isinstance(bnds.ymin, int) + assert bnds.ymin == 13 + + bnds.deltax = 11.0 + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 11 + bnds.deltax = jnp.array(12, dtype=float) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 12 + + bnds.deltay = 12.0 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 12 + bnds.deltay = jnp.array(13, dtype=float) + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 13 + + +def test_bounds_jax_int_set_dynamic(): + bnds = jax_galsim.BoundsI( + xmin=jnp.array(1), ymin=jnp.array(2), deltax=10, deltay=11 + ) + + bnds.xmin = 11.0 + assert isinstance(bnds.xmin, jnp.ndarray) + assert bnds.xmin == 11 + bnds.xmin = jnp.array(12, dtype=float) + assert isinstance(bnds.xmin, jnp.ndarray) + assert bnds.xmin == 12 + + bnds.ymin = 12.0 + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 12 + bnds.ymin = jnp.array(13, dtype=float) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 13 + + bnds.deltax = 11.0 + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 11 + bnds.deltax = jnp.array(12, dtype=float) + assert isinstance(bnds.deltax, int) + assert bnds.deltax == 12 + + bnds.deltay = 12.0 + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 12 + bnds.deltay = jnp.array(13, dtype=float) + assert isinstance(bnds.deltay, int) + assert bnds.deltay == 13 + + +def test_bounds_jax_int_set_jit_raises(): + @jax.jit + def _make_bnds_bad_xmin(xmin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + @jax.jit + def _make_bnds_bad_ymin(ymin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.ymin = ymin + return bnds + + with pytest.raises(Exception): + _make_bnds_bad_xmin(2) + + with pytest.raises(Exception): + _make_bnds_bad_ymin(2) + + @jax.jit + def _make_bnds_bad_deltax(deltax): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.deltax = deltax + return bnds + + @jax.jit + def _make_bnds_bad_deltay(deltay): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.deltay = deltay + return bnds + + with pytest.raises(Exception): + _make_bnds_bad_deltay(2) + + with pytest.raises(Exception): + _make_bnds_bad_deltay(2) + + +def test_bounds_jax_int_set_jit(): + @jax.jit + def _make_bnds_set_xmin(xmin): + bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + @jax.jit + def _make_bnds_set_ymin(ymin): + bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11) + bnds.ymin = ymin + return bnds + + bnds = _make_bnds_set_xmin(2) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.xmin == 2 + assert isinstance(bnds.xmin, jnp.ndarray) + + bnds = _make_bnds_set_ymin(2) + assert isinstance(bnds.ymin, jnp.ndarray) + assert bnds.ymin == 2 + assert isinstance(bnds.ymin, jnp.ndarray) + + @partial(jax.jit, static_argnames=("xmin",)) + def _make_bnds_set_xmin_static(xmin): + bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11) + bnds.xmin = xmin + return bnds + + bnds = _make_bnds_set_xmin_static(2) + assert isinstance(bnds.xmin, int) + assert isinstance(bnds.ymin, int) + assert bnds.xmin == 2 + + +@jax.vmap +@jax.jit +def _make_bounds_int_nodef(xmin, ymin): + bnds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=0, deltay=0) + return bnds, bnds.isDefined() + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_float(bnds): + return bnds.includes(jax_galsim.BoundsD(9.5, 9.75, 9.5, 9.75)) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_nodef_float(bnds): + return bnds.includes(jax_galsim.BoundsD()) + + +@jax.vmap +@jax.jit +def _bounds_includes_pos_float(bnds): + return bnds.includes(jax_galsim.PositionD(9.5, 9.5)) + + +@jax.vmap +@jax.jit +def _bounds_includes_xy_float(bnds): + return bnds.includes(9.5, 9.5) + + +def test_bounds_jax_vmap_includes_float(): + xmin = jnp.array([8, 9, 10, 11, 12]) + xmax = jnp.array([9, 12, 11, 10, 9]) + ymin = jnp.array([7, 9, 11, 10, 12]) + ymax = jnp.array([8, 10, 10, 10, 10]) + bnds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bnds.isDefined(), isdef, strict=True) + + incs = _bounds_includes_bounds_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + incs = _bounds_includes_bounds_nodef_float(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + incs = _bounds_includes_xy_float(bnds) + np.testing.assert_array_equal(incs, [False, True, False, False, False], strict=True) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_int(bnds): + return bnds.includes(jax_galsim.BoundsI(9, 9, 10, 10)) + + +@jax.vmap +@jax.jit +def _bounds_includes_bounds_nodef_int(bnds): + return bnds.includes(jax_galsim.BoundsI()) + + +@jax.vmap +@jax.jit +def _bounds_includes_pos_int(bnds): + return bnds.includes(jax_galsim.PositionD(9, 10)) + + +@jax.vmap +@jax.jit +def _bounds_includes_xy_int(bnds): + return bnds.includes(9.5, 9.7) + + +def test_bounds_jax_vmap_includes_int(): + xmin = jnp.array([8, 9, 10, 11, 12]) + ymin = jnp.array([7, 9, 11, 10, 12]) + + bnds, isdef = _make_bounds_int(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef) + + incs = _bounds_includes_bounds_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + incs = _bounds_includes_bounds_nodef_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + incs = _bounds_includes_xy_int(bnds) + np.testing.assert_array_equal(incs, [True, True, False, False, False], strict=True) + + bnds, isdef = _make_bounds_int_nodef(xmin, ymin) + np.testing.assert_array_equal(bnds.isDefined(), isdef) + + incs = _bounds_includes_bounds_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_bounds_nodef_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_pos_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + ) + + incs = _bounds_includes_xy_int(bnds) + np.testing.assert_array_equal( + incs, [False, False, False, False, False], strict=True + )