Skip to content

Commit 948d92b

Browse files
authored
Merge pull request #243 from GalSim-developers/equinox-err-2
2 parents f6ed18a + 24f1ccf commit 948d92b

23 files changed

Lines changed: 453 additions & 102 deletions

jax_galsim/angle.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,30 @@
2121
# SOFTWARE.
2222
import galsim as _galsim
2323
import jax.numpy as jnp
24+
import numpy as np
2425
from jax.tree_util import register_pytree_node_class
2526

26-
from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements
27+
from jax_galsim.core.utils import (
28+
cast_to_float,
29+
ensure_hashable,
30+
has_tracers,
31+
implements,
32+
)
33+
34+
NON_COMPLEX_TYPES = (
35+
float,
36+
int,
37+
np.int16,
38+
np.int32,
39+
np.int64,
40+
np.float32,
41+
np.float64,
42+
jnp.int16,
43+
jnp.int32,
44+
jnp.int64,
45+
jnp.float32,
46+
jnp.float64,
47+
)
2748

2849

2950
@implements(_galsim.AngleUnit)
@@ -178,15 +199,23 @@ def __sub__(self, other):
178199
return _Angle(self._rad - other._rad)
179200

180201
def __mul__(self, other):
202+
if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)):
203+
raise TypeError(
204+
"Cannot multiply Angle by %s of type %s" % (other, type(other))
205+
)
181206
return _Angle(self._rad * other)
182207

183208
__rmul__ = __mul__
184209

185210
def __div__(self, other):
186211
if isinstance(other, AngleUnit):
187212
return self._rad / other.value
188-
else:
213+
elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES):
189214
return _Angle(self._rad / other)
215+
else:
216+
raise TypeError(
217+
"Cannot divide Angle by %s of type %s" % (other, type(other))
218+
)
190219

191220
__truediv__ = __div__
192221

jax_galsim/bounds.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,24 @@
11
import galsim as _galsim
22
import jax
33
import jax.numpy as jnp
4-
import numpy as np
54
from jax.tree_util import register_pytree_node_class
65

76
from jax_galsim.core.utils import (
7+
CONST_TYPES,
88
cast_to_float,
99
cast_to_int,
10+
cast_to_python_float,
11+
check_is_int_then_cast,
1012
ensure_hashable,
1113
has_tracers,
1214
implements,
1315
)
1416
from jax_galsim.position import Position, PositionD, PositionI
1517

16-
CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64)
17-
CONST_TYPES_WITH_JAX = CONST_TYPES + (
18-
jax.Array,
19-
jnp.array,
20-
jnp.int32,
21-
jnp.int64,
22-
jnp.float32,
23-
jnp.float64,
24-
)
25-
26-
# TODO: write extra docs for JAX changes
2718
BOUNDS_LAX_DESCR = """\
2819
The JAX implementation
2920
3021
- will not always test whether the bounds are valid
31-
- will not always test whether BoundsI is initialized with integers
3222
3323
Further, the JAX implementation adds a new method, ``isStatic`` to the
3424
``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance
@@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs):
525515
f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}."
526516
)
527517

518+
self.deltax = cast_to_python_float(self.deltax)
519+
self.deltay = cast_to_python_float(self.deltay)
520+
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
521+
raise TypeError("BoundsI must be initialized with integer values")
528522
self.deltax = int(cast_to_int(self.deltax))
529523
self.deltay = int(cast_to_int(self.deltay))
530524

531-
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
532-
raise TypeError("BoundsI must be initialized with integer values")
525+
if has_tracers(self._xmin) or has_tracers(self._ymin):
526+
self._isstatic = False
527+
528+
# validate inputs are ints
529+
self._xmin = check_is_int_then_cast(
530+
self._xmin, "BoundsI must be initialized with integer values"
531+
)
532+
self._ymin = check_is_int_then_cast(
533+
self._ymin, "BoundsI must be initialized with integer values"
534+
)
533535

534536
if self.deltax < 1 and self.deltay < 1:
535537
self._isdefined = False
536538

537-
# for simple inputs, we can check if the bounds are valid ints
538-
if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin):
539-
raise TypeError("BoundsI must be initialized with integer values")
540-
541-
if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin):
542-
raise TypeError("BoundsI must be initialized with integer values")
543-
544-
if not has_tracers(self._xmin) and not has_tracers(self._ymin):
545-
self._isstatic = True
546-
self._xmin = int(np.trunc(self._xmin))
547-
self._ymin = int(np.trunc(self._ymin))
548-
else:
549-
self._isstatic = False
550-
self._xmin = cast_to_float(jnp.trunc(self._xmin))
551-
self._ymin = cast_to_float(jnp.trunc(self._ymin))
552-
553539
if force_static and not self._isstatic:
554540
raise RuntimeError(
555541
"BoundsI initialized with non-static "

jax_galsim/celestial.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from functools import partial
2424

2525
import coord as _coord
26+
import equinox
2627
import galsim as _galsim
2728
import jax
2829
import jax.numpy as jnp
30+
import numpy as np
2931
from jax.tree_util import register_pytree_node_class
3032

3133
from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians
@@ -74,6 +76,16 @@ def __init__(self, ra, dec=None):
7476
elif not isinstance(dec, Angle):
7577
raise TypeError("dec must be a galsim.Angle")
7678
else:
79+
if isinstance(dec._rad, (float, int)):
80+
if dec._rad < -np.pi / 2 or dec._rad > np.pi / 2:
81+
raise ValueError("dec must be between -90 deg and +90 deg.")
82+
else:
83+
dec._rad = equinox.error_if(
84+
jnp.array(dec._rad),
85+
jnp.any((dec._rad < -jnp.pi / 2) | (dec._rad > jnp.pi / 2)),
86+
"dec must be between -90 deg and +90 deg.",
87+
)
88+
7789
# Normal case
7890
self._ra = ra
7991
self._dec = dec
@@ -121,15 +133,14 @@ def get_xyz(self):
121133

122134
@staticmethod
123135
@jax.jit
124-
@implements(
125-
_galsim.celestial.CelestialCoord.from_xyz,
126-
lax_description=(
127-
"The JAX version of this static method does not check that the norm of the input "
128-
"vector is non-zero."
129-
),
130-
)
136+
@implements(_galsim.celestial.CelestialCoord.from_xyz)
131137
def from_xyz(x, y, z):
132138
norm = jnp.sqrt(x * x + y * y + z * z)
139+
norm = equinox.error_if(
140+
norm,
141+
jnp.any(norm == 0),
142+
"CelestialCoord for position (0,0,0) is undefined.",
143+
)
133144
ret = CelestialCoord.__new__(CelestialCoord)
134145
ret._x = x / norm
135146
ret._y = y / norm
@@ -236,13 +247,7 @@ def distanceTo(self, coord2):
236247

237248
return _Angle(theta)
238249

239-
@implements(
240-
_galsim.celestial.CelestialCoord.greatCirclePoint,
241-
lax_description=(
242-
"The JAX version of this method does not check that coord2 defines a unique great "
243-
"circle with the current coord at angle theta."
244-
),
245-
)
250+
@implements(_galsim.celestial.CelestialCoord.greatCirclePoint)
246251
@jax.jit
247252
def greatCirclePoint(self, coord2, theta):
248253
aux = self._get_aux()
@@ -280,8 +285,11 @@ def greatCirclePoint(self, coord2, theta):
280285

281286
# Normalize
282287
wr = (wx**2 + wy**2 + wz**2) ** 0.5
283-
# if wr == 0.:
284-
# raise ValueError("coord2 does not define a unique great circle with self.")
288+
wr = equinox.error_if(
289+
wr,
290+
jnp.any(wr == 0),
291+
"coord2 does not define a unique great circle with self.",
292+
)
285293
wx /= wr
286294
wy /= wr
287295
wz /= wr

jax_galsim/core/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,60 @@
33
from functools import partial
44
from typing import NamedTuple
55

6+
import equinox
67
import jax
78
import jax.numpy as jnp
89
import numpy as np
910
from jax.tree_util import tree_flatten
1011

12+
CONST_TYPES = (
13+
float,
14+
int,
15+
np.ndarray,
16+
np.int8,
17+
np.int16,
18+
np.int32,
19+
np.int64,
20+
np.float16,
21+
np.float32,
22+
np.float64,
23+
np.complex64,
24+
np.complex128,
25+
)
26+
CONST_TYPES_WITH_JAX = CONST_TYPES + (
27+
jax.Array,
28+
jnp.ndarray,
29+
jnp.int8,
30+
jnp.int16,
31+
jnp.int32,
32+
jnp.int64,
33+
jnp.float32,
34+
jnp.float64,
35+
jnp.complex64,
36+
jnp.complex128,
37+
)
38+
39+
40+
def check_is_int_then_cast(val, msg):
41+
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
42+
# for simple inputs, we can check direct in python
43+
if isinstance(val, CONST_TYPES) and not has_tracers(val):
44+
val = cast_to_python_float(val)
45+
if val != int(val):
46+
raise TypeError(msg)
47+
val = int(val)
48+
else:
49+
# otherwise we use more opaque checking upon jit via equinox
50+
val = jnp.array(val)
51+
val = equinox.error_if(
52+
val,
53+
np.any(val != jnp.trunc(val)),
54+
msg,
55+
)
56+
val = val.astype(int)
57+
58+
return val
59+
1160

1261
def cast_numpy_array_to_native_byte_order(arr):
1362
"""Cast an array to native byte order."""

jax_galsim/fitswcs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import os
33

4+
import equinox
45
import galsim as _galsim
56
import jax
67
import jax.numpy as jnp
@@ -1094,12 +1095,10 @@ def _step(i, args):
10941095
unroll=True,
10951096
)[0:4]
10961097

1097-
x, y = jax.lax.cond(
1098-
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
1099-
lambda x, y: (x * jnp.nan, y * jnp.nan),
1100-
lambda x, y: (x, y),
1101-
x,
1102-
y,
1098+
x, y = equinox.error_if(
1099+
(x, y),
1100+
jnp.any(jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12),
1101+
"Unable to solve for image_pos (max iter reached).",
11031102
)
11041103

11051104
return x, y

0 commit comments

Comments
 (0)