Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
3d11630
fix: raise errors for invalid shears and PixelScale WCS inits
beckermr May 14, 2026
bd0e282
please the dog
beckermr May 14, 2026
a3b7ba4
fix: mock up equinox
beckermr May 14, 2026
5a43c92
test: more array equals
beckermr May 14, 2026
ff18900
doc: update docs for shears
beckermr May 14, 2026
85569db
fix: clarify docs
beckermr May 14, 2026
d8e24ae
fix: raise erorr on failed integrations
beckermr May 14, 2026
81976e6
style: please the dog
beckermr May 14, 2026
d21e383
fix: try code with equinox filter_jit
beckermr May 14, 2026
8805c8c
fix: use standard JIT
beckermr May 14, 2026
4d661a2
doc: update doc strings
beckermr May 14, 2026
1561dc0
fix: only use generic Exception
beckermr May 14, 2026
8c69aac
doc: update docs
beckermr May 14, 2026
6d0a66b
refactor:
beckermr May 14, 2026
2609840
doc: ensure doc string is accurate
beckermr May 14, 2026
ed7c508
fix: enable tests for image gain, area, exptime, and max_extra_noise
beckermr May 15, 2026
831990a
doc: update doc strings
beckermr May 15, 2026
42fb804
doc: update doc string
beckermr May 15, 2026
2ef7195
doc: add doc string for position exceptions
beckermr May 15, 2026
3ee5f33
fix+doc: do more error checking and more docs
beckermr May 15, 2026
48c2569
doc: fix doc string formatting
beckermr May 15, 2026
6e9c4de
Apply suggestion from @beckermr
beckermr May 15, 2026
37310d9
fix: add the rest of the types
beckermr May 15, 2026
9aa63de
Merge branch 'equinox-err-2' of https://github.com/GalSim-developers/…
beckermr May 15, 2026
88c1c7b
fix: use proper array ref
beckermr May 15, 2026
70e4c3f
Apply suggestion from @beckermr
beckermr May 15, 2026
d80f878
Merge branch 'main' into equinox-err-2
beckermr May 15, 2026
164c2c6
fix: docs done wrong
beckermr May 15, 2026
0ec07d8
Merge branch 'main' into equinox-err-2
beckermr May 15, 2026
aafb848
test: update to latest submodule
beckermr May 15, 2026
4ed8c8e
fix: raise for interpolated image init problems
beckermr May 15, 2026
7963c1f
doc: add docs for exceptions
beckermr May 15, 2026
2e2aaca
fix: use array in transform, not image
beckermr May 15, 2026
58196ca
fix: ensure repr of image prints even with tracers
beckermr May 15, 2026
58efd99
fix: raise for invalid beta
beckermr May 15, 2026
21ffe09
fix: need to ensure images always hold jax arrays
beckermr May 15, 2026
c892e04
fix: raise if we do not have array as kwarg
beckermr May 15, 2026
0b366de
fix: raise errors for RNG inits
beckermr May 15, 2026
6811835
fix: ensure errors are raised for random permutations
beckermr May 15, 2026
15e1398
fix: accept integer scalars too
beckermr May 15, 2026
3d645a2
fix: apparently this does not work on tracers
beckermr May 15, 2026
a9f5171
Apply suggestions from code review
beckermr May 15, 2026
c559408
fix: really on this one
beckermr May 15, 2026
df2a3d4
fix: more tests
beckermr May 15, 2026
6d40f5a
fix: ensure we use any or all in calls for errors
beckermr May 15, 2026
1f0d1da
fix: use any everywhere
beckermr May 15, 2026
a7e2300
fix: enable index check for photon arrays
beckermr May 15, 2026
6d970a3
Apply suggestion from @beckermr
beckermr May 15, 2026
f1498e2
fix: enable exceptions for WCS
beckermr May 15, 2026
6c526d4
fix: raise errors for celestial coords
beckermr May 15, 2026
97ed8d8
fix: remove skip statements from celestial coord
beckermr May 15, 2026
7486ffb
test: use ab matrix that is more stable for benchmarks
beckermr May 15, 2026
2610daa
Apply suggestions from code review
beckermr May 18, 2026
29f83f3
fix: add type checking to seed method
beckermr May 18, 2026
24f1ccf
Merge branch 'main' into equinox-err-2
beckermr May 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,30 @@
# SOFTWARE.
import galsim as _galsim
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements
from jax_galsim.core.utils import (
cast_to_float,
ensure_hashable,
has_tracers,
implements,
)

NON_COMPLEX_TYPES = (
Comment thread
ismael-mendoza marked this conversation as resolved.
float,
int,
np.int16,
np.int32,
np.int64,
np.float32,
np.float64,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float32,
jnp.float64,
)


@implements(_galsim.AngleUnit)
Expand Down Expand Up @@ -178,15 +199,23 @@ def __sub__(self, other):
return _Angle(self._rad - other._rad)

def __mul__(self, other):
if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)):
raise TypeError(
"Cannot multiply Angle by %s of type %s" % (other, type(other))
)
return _Angle(self._rad * other)

__rmul__ = __mul__

def __div__(self, other):
if isinstance(other, AngleUnit):
return self._rad / other.value
else:
elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES):
return _Angle(self._rad / other)
else:
raise TypeError(
"Cannot divide Angle by %s of type %s" % (other, type(other))
)

__truediv__ = __div__

Expand Down
48 changes: 17 additions & 31 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
CONST_TYPES,
cast_to_float,
cast_to_int,
cast_to_python_float,
check_is_int_then_cast,
ensure_hashable,
has_tracers,
implements,
)
from jax_galsim.position import Position, PositionD, PositionI

CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64)
CONST_TYPES_WITH_JAX = CONST_TYPES + (
jax.Array,
jnp.array,
jnp.int32,
jnp.int64,
jnp.float32,
jnp.float64,
)

# TODO: write extra docs for JAX changes
BOUNDS_LAX_DESCR = """\
The JAX implementation

- will not always test whether the bounds are valid
- will not always test whether BoundsI is initialized with integers
Comment thread
ismael-mendoza marked this conversation as resolved.

Further, the JAX implementation adds a new method, ``isStatic`` to the
``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance
Expand Down Expand Up @@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs):
f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}."
)

self.deltax = cast_to_python_float(self.deltax)
self.deltay = cast_to_python_float(self.deltay)
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
self.deltax = int(cast_to_int(self.deltax))
self.deltay = int(cast_to_int(self.deltay))

if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
if has_tracers(self._xmin) or has_tracers(self._ymin):
self._isstatic = False

# validate inputs are ints
self._xmin = check_is_int_then_cast(
self._xmin, "BoundsI must be initialized with integer values"
)
self._ymin = check_is_int_then_cast(
self._ymin, "BoundsI must be initialized with integer values"
)

if self.deltax < 1 and self.deltay < 1:
self._isdefined = False

# for simple inputs, we can check if the bounds are valid ints
if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin):
raise TypeError("BoundsI must be initialized with integer values")

if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin):
raise TypeError("BoundsI must be initialized with integer values")

if not has_tracers(self._xmin) and not has_tracers(self._ymin):
self._isstatic = True
self._xmin = int(np.trunc(self._xmin))
self._ymin = int(np.trunc(self._ymin))
else:
self._isstatic = False
self._xmin = cast_to_float(jnp.trunc(self._xmin))
self._ymin = cast_to_float(jnp.trunc(self._ymin))

if force_static and not self._isstatic:
raise RuntimeError(
"BoundsI initialized with non-static "
Expand Down
40 changes: 24 additions & 16 deletions jax_galsim/celestial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from functools import partial

import coord as _coord
import equinox
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians
Expand Down Expand Up @@ -74,6 +76,16 @@ def __init__(self, ra, dec=None):
elif not isinstance(dec, Angle):
raise TypeError("dec must be a galsim.Angle")
else:
if isinstance(dec._rad, (float, int)):
if dec._rad < -np.pi / 2 or dec._rad > np.pi / 2:
raise ValueError("dec must be between -90 deg and +90 deg.")
else:
dec._rad = equinox.error_if(
jnp.array(dec._rad),
jnp.any((dec._rad < -jnp.pi / 2) | (dec._rad > jnp.pi / 2)),
"dec must be between -90 deg and +90 deg.",
)

# Normal case
self._ra = ra
self._dec = dec
Expand Down Expand Up @@ -121,15 +133,14 @@ def get_xyz(self):

@staticmethod
@jax.jit
@implements(
_galsim.celestial.CelestialCoord.from_xyz,
lax_description=(
"The JAX version of this static method does not check that the norm of the input "
"vector is non-zero."
),
)
@implements(_galsim.celestial.CelestialCoord.from_xyz)
def from_xyz(x, y, z):
norm = jnp.sqrt(x * x + y * y + z * z)
norm = equinox.error_if(
norm,
jnp.any(norm == 0),
"CelestialCoord for position (0,0,0) is undefined.",
)
ret = CelestialCoord.__new__(CelestialCoord)
ret._x = x / norm
ret._y = y / norm
Expand Down Expand Up @@ -236,13 +247,7 @@ def distanceTo(self, coord2):

return _Angle(theta)

@implements(
_galsim.celestial.CelestialCoord.greatCirclePoint,
lax_description=(
"The JAX version of this method does not check that coord2 defines a unique great "
"circle with the current coord at angle theta."
),
)
@implements(_galsim.celestial.CelestialCoord.greatCirclePoint)
@jax.jit
def greatCirclePoint(self, coord2, theta):
aux = self._get_aux()
Expand Down Expand Up @@ -280,8 +285,11 @@ def greatCirclePoint(self, coord2, theta):

# Normalize
wr = (wx**2 + wy**2 + wz**2) ** 0.5
# if wr == 0.:
# raise ValueError("coord2 does not define a unique great circle with self.")
wr = equinox.error_if(
wr,
jnp.any(wr == 0),
"coord2 does not define a unique great circle with self.",
)
wx /= wr
wy /= wr
wz /= wr
Expand Down
49 changes: 49 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,60 @@
from functools import partial
from typing import NamedTuple

import equinox
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_flatten

CONST_TYPES = (
float,
int,
np.ndarray,
np.int8,
np.int16,
np.int32,
np.int64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
)
CONST_TYPES_WITH_JAX = CONST_TYPES + (
jax.Array,
jnp.ndarray,
jnp.int8,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float32,
jnp.float64,
jnp.complex64,
jnp.complex128,
)


def check_is_int_then_cast(val, msg):
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
# for simple inputs, we can check direct in python
if isinstance(val, CONST_TYPES) and not has_tracers(val):
val = cast_to_python_float(val)
if val != int(val):
raise TypeError(msg)
val = int(val)
else:
# otherwise we use more opaque checking upon jit via equinox
val = jnp.array(val)
val = equinox.error_if(
val,
np.any(val != jnp.trunc(val)),
msg,
)
val = val.astype(int)

return val


def cast_numpy_array_to_native_byte_order(arr):
"""Cast an array to native byte order."""
Expand Down
11 changes: 5 additions & 6 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os

import equinox
import galsim as _galsim
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -1094,12 +1095,10 @@ def _step(i, args):
unroll=True,
)[0:4]

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
lambda x, y: (x * jnp.nan, y * jnp.nan),
lambda x, y: (x, y),
x,
y,
x, y = equinox.error_if(
(x, y),
jnp.any(jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12),
"Unable to solve for image_pos (max iter reached).",
)

return x, y
Loading
Loading