|
1 | 1 | import galsim as _galsim |
2 | 2 | import jax |
3 | 3 | import jax.numpy as jnp |
4 | | -import numpy as np |
5 | 4 | from jax.tree_util import register_pytree_node_class |
6 | 5 |
|
7 | 6 | from jax_galsim.core.utils import ( |
| 7 | + CONST_TYPES, |
8 | 8 | cast_to_float, |
9 | 9 | cast_to_int, |
| 10 | + cast_to_python_float, |
| 11 | + check_is_int_then_cast, |
10 | 12 | ensure_hashable, |
11 | 13 | has_tracers, |
12 | 14 | implements, |
13 | 15 | ) |
14 | 16 | from jax_galsim.position import Position, PositionD, PositionI |
15 | 17 |
|
16 | | -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) |
17 | | -CONST_TYPES_WITH_JAX = CONST_TYPES + ( |
18 | | - jax.Array, |
19 | | - jnp.array, |
20 | | - jnp.int32, |
21 | | - jnp.int64, |
22 | | - jnp.float32, |
23 | | - jnp.float64, |
24 | | -) |
25 | | - |
26 | | -# TODO: write extra docs for JAX changes |
27 | 18 | BOUNDS_LAX_DESCR = """\ |
28 | 19 | The JAX implementation |
29 | 20 |
|
30 | 21 | - will not always test whether the bounds are valid |
31 | | -- will not always test whether BoundsI is initialized with integers |
32 | 22 |
|
33 | 23 | Further, the JAX implementation adds a new method, ``isStatic`` to the |
34 | 24 | ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance |
@@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs): |
525 | 515 | f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." |
526 | 516 | ) |
527 | 517 |
|
| 518 | + self.deltax = cast_to_python_float(self.deltax) |
| 519 | + self.deltay = cast_to_python_float(self.deltay) |
| 520 | + if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): |
| 521 | + raise TypeError("BoundsI must be initialized with integer values") |
528 | 522 | self.deltax = int(cast_to_int(self.deltax)) |
529 | 523 | self.deltay = int(cast_to_int(self.deltay)) |
530 | 524 |
|
531 | | - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): |
532 | | - raise TypeError("BoundsI must be initialized with integer values") |
| 525 | + if has_tracers(self._xmin) or has_tracers(self._ymin): |
| 526 | + self._isstatic = False |
| 527 | + |
| 528 | + # validate inputs are ints |
| 529 | + self._xmin = check_is_int_then_cast( |
| 530 | + self._xmin, "BoundsI must be initialized with integer values" |
| 531 | + ) |
| 532 | + self._ymin = check_is_int_then_cast( |
| 533 | + self._ymin, "BoundsI must be initialized with integer values" |
| 534 | + ) |
533 | 535 |
|
534 | 536 | if self.deltax < 1 and self.deltay < 1: |
535 | 537 | self._isdefined = False |
536 | 538 |
|
537 | | - # for simple inputs, we can check if the bounds are valid ints |
538 | | - if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): |
539 | | - raise TypeError("BoundsI must be initialized with integer values") |
540 | | - |
541 | | - if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): |
542 | | - raise TypeError("BoundsI must be initialized with integer values") |
543 | | - |
544 | | - if not has_tracers(self._xmin) and not has_tracers(self._ymin): |
545 | | - self._isstatic = True |
546 | | - self._xmin = int(np.trunc(self._xmin)) |
547 | | - self._ymin = int(np.trunc(self._ymin)) |
548 | | - else: |
549 | | - self._isstatic = False |
550 | | - self._xmin = cast_to_float(jnp.trunc(self._xmin)) |
551 | | - self._ymin = cast_to_float(jnp.trunc(self._ymin)) |
552 | | - |
553 | 539 | if force_static and not self._isstatic: |
554 | 540 | raise RuntimeError( |
555 | 541 | "BoundsI initialized with non-static " |
|
0 commit comments