Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4077a18
fix: clean up type handling
beckermr May 18, 2026
5f17c95
fix: allow strings optionally for floating casts
beckermr May 18, 2026
bb482c3
fix: more type cleanups
beckermr May 18, 2026
434f800
Apply suggestion from @beckermr
beckermr May 18, 2026
76c5097
fix: remove more weird casts and be more struct on types
beckermr May 18, 2026
3a6e4e0
fix: ensure bounds deltax is numeric scalar
beckermr May 18, 2026
cb1387f
fix: handle photon array sizes as well
beckermr May 18, 2026
2ea2e88
refactor: simpler code
beckermr May 18, 2026
19fa527
fix: accept any array with one element
beckermr May 18, 2026
edad337
fix: simpler
beckermr May 18, 2026
f7aa2a6
doc: add comment
beckermr May 18, 2026
caecf66
fix: start to remove has_tracers func
beckermr May 18, 2026
d5ee292
fix: accept array inputs for boundsI static
beckermr May 18, 2026
bb313a4
Apply suggestion from @beckermr
beckermr May 18, 2026
da4db03
fix: remove more has_tracers
beckermr May 18, 2026
afea1e6
Merge branch 'typing-inits-cleanup' of https://github.com/GalSim-deve…
beckermr May 18, 2026
7bda500
fix: more cleanup of has_tracers
beckermr May 18, 2026
d54e844
fix: remove use of has_tracers
beckermr May 18, 2026
c40a6aa
fix: remove unneeded api
beckermr May 18, 2026
9e45b84
fix: remove extra keyword we no longer need
beckermr May 18, 2026
15b4e22
Apply suggestion from @beckermr
beckermr May 18, 2026
402eb30
fix: bug in hermitian detecting checks
beckermr May 18, 2026
b185d6d
Merge branch 'typing-inits-cleanup' of https://github.com/GalSim-deve…
beckermr May 18, 2026
d6633c6
Merge branch 'equinox-err-2' into typing-inits-cleanup
beckermr May 18, 2026
eceafb9
doc: add some docs
beckermr May 18, 2026
9e668fa
doc: finish docs
beckermr May 18, 2026
1a1f18b
Merge branch 'main' into typing-inits-cleanup
beckermr May 18, 2026
c2c6097
Apply suggestion from @beckermr
beckermr May 18, 2026
d28f0b9
Apply suggestion from @beckermr
beckermr May 18, 2026
250560f
Apply suggestion from @beckermr
beckermr May 18, 2026
e686a8a
Apply suggestion from @beckermr
beckermr May 18, 2026
db83109
fix: no need to ensure it is hashable here
beckermr May 18, 2026
83df9cb
Merge branch 'main' into typing-inits-cleanup
beckermr May 21, 2026
0677f1a
Apply suggestion from @ismael-mendoza
beckermr May 21, 2026
b848ae0
Apply suggestion from @ismael-mendoza
beckermr May 21, 2026
be78240
Apply suggestion from @ismael-mendoza
beckermr May 21, 2026
f12c9f5
Apply suggestion from @ismael-mendoza
beckermr May 21, 2026
cacca74
Apply suggestion from @ismael-mendoza
beckermr May 21, 2026
f9f0613
Update descriptions of numeric data types in JAX
beckermr May 21, 2026
b1329d7
Fix condition to check if array shape is 1
beckermr May 21, 2026
3624566
Implement size check for photon array
beckermr May 21, 2026
1aeb448
fix: missed an import
beckermr May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions docs/sharp-bits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ does not affect the original.
# JAX-GalSim — real_part is a copy
real_part = complex_image.real # independent array

Scalar Types, Array Types, and Casting
--------------------------------------

With the use of JAX, there are now many possible types for numeric data. These include

- **Python scalars**: Objects with types that are ``float``, ``int``, or ``complex``.
- **NumPy scalars**: Objects with types that are subclasses of ``np.floating``, ``np.integer``, etc.
- **NumPy array scalars**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) == 0``.
- **NumPy arrays**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) > 0``.
- **JAX array scalars**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) == 0``.
- **JAX arrays**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) > 0``.

**JAX does not have pure scalar types like NumPy. JAX uses array scalars for those instead.**

JAX-GalSim uses the following rules when handling data types and casting.

- If the item is a Python numeric type (i.e., ``int`` or ``float``) or a
NumPy scalar type (i.e., ``isinstance(x, np.number)``, ``isinstance(x, np.integer)``, etc.),
convert it to a Python type of the appropriate kind.
- For all other array-like types, cast to the correct type via ``jax.numpy.astype(x, ...)``.
- For putting data into FITS headers only, JAX-GalSim converts of NumPy/JAX arrays to Python
numeric types as long as there is one element in the array (i.e., it is a NumPy scalar type,
an array scalar, or a 1D array with one element).

These rules allow JAX-GalSim to transparently handle JAX's tracing operations, but can result in
the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in
some cases.

Random Number Generation
------------------------

Expand Down Expand Up @@ -163,9 +191,6 @@ profile parameters passed into a ``jit``-compiled function):
def good(sigma):
return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma)

JAX-GalSim uses an internal ``has_tracers()`` utility to detect tracing and
avoid problematic control flow in its own implementations.

Fixed output shapes
^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -197,20 +222,9 @@ The ``__init__`` gotcha

During ``jit`` tracing, JAX calls constructors with **tracer objects** rather
than concrete Python numbers. Type checks like ``isinstance(sigma, float)`` will
fail on tracers. JAX-GalSim handles this internally, but if you subclass any
JAX-GalSim object, be aware that ``__init__`` may receive tracers:

.. code-block:: python

from jax_galsim.core.utils import has_tracers

class MyProfile(jax_galsim.GSObject):
def __init__(self, sigma, gsparams=None):
if not has_tracers(sigma):
# Only validate with concrete values
if sigma <= 0:
raise ValueError("sigma must be positive")
...
return ``False`` on tracers, and you cannot check correctness of values (e.g.,
``if sigma > 0: ...```). JAX-GalSim handles this internally, but if you subclass any
JAX-GalSim object, be aware that ``__init__`` may receive tracers.

Profile Restrictions
--------------------
Expand All @@ -221,12 +235,9 @@ Some GalSim features are not yet implemented in JAX-GalSim:
- **ChromaticObject**: All chromatic functionality (wavelength-dependent
profiles) is not available.
- **InterpolatedKImage**: Not implemented.
- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: See :doc:`api-coverage` for
- **Airy, Kolmogorov, OpticalPSF, RealGalaxy, etc.**: See :doc:`api-coverage` for
the full list.

The project currently implements **22.5 %** of the GalSim public API, focused
on the most commonly used profiles and operations.

Numerical Precision
-------------------

Expand All @@ -249,11 +260,11 @@ These differences are typically at the level of floating-point round-off
should not affect scientific conclusions.

⚠️ Additional Sharp Bits
--------------------------
------------------------

In the :doc:`api/index` you will find **🔪 JAX-GalSim - The Sharp Bits 🔪** blocks highlighting additional important caveats for specific classes and or methods. These could include things like:

- Many classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`).
- Some classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.InterpolatedImage`).
- Certain profiles might not be auto-differentiable with respect to some of their parameters (e.g., :class:`~jax_galsim.Spergel`, :class:`~jax_galsim.Moffat`)
- Limitations regarding what types of inputes are handled (e.g., :meth:`~jax_galsim.Image.calculate_fft` does not accept complex dtypes.)

7 changes: 3 additions & 4 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from jax_galsim.core.utils import (
cast_to_float,
ensure_hashable,
has_tracers,
implements,
)

Expand Down Expand Up @@ -55,7 +54,7 @@ class AngleUnit(object):
def __init__(self, value):
if isinstance(value, AngleUnit):
raise TypeError("Cannot construct AngleUnit from another AngleUnit")
self._value = cast_to_float(value)
self._value = cast_to_float(value, accept_strings=True)

@property
@implements(_galsim.AngleUnit.value)
Expand Down Expand Up @@ -199,7 +198,7 @@ def __sub__(self, other):
return _Angle(self._rad - other._rad)

def __mul__(self, other):
if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)):
if isinstance(other, (Angle, AngleUnit)):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is the check here? I guess more specifically, why not ensure that other is scalar instead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the goal in this PR is to let arrays and tracers pass through without trying to detect which is which. We can do that for the most part (except for a few cases like BoundsI). So the goal of the check here is guard against common errors of handling Angles and AngleUnits.

raise TypeError(
"Cannot multiply Angle by %s of type %s" % (other, type(other))
)
Expand All @@ -210,7 +209,7 @@ def __mul__(self, other):
def __div__(self, other):
if isinstance(other, AngleUnit):
return self._rad / other.value
elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES):
elif not isinstance(other, Angle):
return _Angle(self._rad / other)
else:
raise TypeError(
Expand Down
50 changes: 18 additions & 32 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
CONST_TYPES,
cast_to_float,
cast_to_int,
cast_to_python_float,
check_is_int_then_cast,
ensure_hashable,
has_tracers,
implements,
)
from jax_galsim.position import Position, PositionD, PositionI
Expand All @@ -23,10 +21,7 @@
Further, the JAX implementation adds a new method, ``isStatic`` to the
``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance
has been instantiated with static, known values, ``isStatic()`` will
return ``True``. You can indicate to JAX-GalSim that a ``BoundsI``
instance should be static via initializing it with the ``static``
keyword set to the ``True``. If the object detects that it is being
initialized with non-static data, an error will be raised.
return ``True``.

``BoundsI`` objects in JAX-Galsim support an additional initialization
call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case,
Expand Down Expand Up @@ -139,8 +134,8 @@ def _parse_args(self, *args, **kwargs):
else:
max_delta = 1
if (
isinstance(self.deltax, CONST_TYPES)
and isinstance(self.deltay, CONST_TYPES)
isinstance(self.deltax, (int, float, np.integer, np.floating))
and isinstance(self.deltay, (int, float, np.integer, np.floating))
and (self.deltax < max_delta or self.deltay < max_delta)
):
self._isdefined = False
Expand Down Expand Up @@ -364,10 +359,8 @@ def from_galsim(cls, galsim_bounds):
"""Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object."""
if isinstance(galsim_bounds, _galsim.BoundsD):
_cls = BoundsD
kwargs = {}
elif isinstance(galsim_bounds, _galsim.BoundsI):
_cls = BoundsI
kwargs = {"static": True}
else:
raise TypeError(
"galsim_bounds must be either a %s or a %s"
Expand All @@ -379,7 +372,6 @@ def from_galsim(cls, galsim_bounds):
galsim_bounds.xmax,
galsim_bounds.ymin,
galsim_bounds.ymax,
**kwargs,
)
else:
return _cls()
Expand Down Expand Up @@ -505,24 +497,25 @@ def __init__(self, *args, **kwargs):
# initial setting to let stuff pass through freely
self._isstatic = True

force_static = kwargs.pop("static", False)

self._parse_args(*args, **kwargs)

if has_tracers(self.deltax) or has_tracers(self.deltay):
raise RuntimeError(
"Jax-GalSim BoundsI instances must have a fixed width! "
f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}."
)

self.deltax = cast_to_python_float(self.deltax)
self.deltay = cast_to_python_float(self.deltay)
self.deltax = cast_to_float(self.deltax)
self.deltay = cast_to_float(self.deltay)
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
self.deltax = int(cast_to_int(self.deltax))
self.deltay = int(cast_to_int(self.deltay))
self.deltax = cast_to_int(self.deltax)
self.deltay = cast_to_int(self.deltay)
Comment on lines +502 to +507
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if I understand correctly, if someone tries to have non static deltax or deltay they will now get a ConcretizationTypeError, with no error raised from JAX Galsim explicitly. Is that the intention since we warn the user in the LAX docs and we want to avoid using has_tracers?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is the intention. I am further cleaning this up in #250.


if has_tracers(self._xmin) or has_tracers(self._ymin):
if not (
isinstance(
self._xmin,
(int, float, np.floating, np.integer),
)
and isinstance(
self._ymin,
(int, float, np.floating, np.integer),
)
):
self._isstatic = False

# validate inputs are ints
Expand All @@ -536,13 +529,6 @@ def __init__(self, *args, **kwargs):
if self.deltax < 1 and self.deltay < 1:
self._isdefined = False

if force_static and not self._isstatic:
raise RuntimeError(
"BoundsI initialized with non-static "
f"data (xmin,ymin = {self._xmin},{self._yminb}) "
"when static data was explicitly requested."
)

def _check_scalar(self, x, name):
try:
if (
Expand Down
Loading
Loading