-
Notifications
You must be signed in to change notification settings - Fork 9
fix: clean up type handling #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4077a18
5f17c95
bb482c3
434f800
76c5097
3a6e4e0
cb1387f
2ea2e88
19fa527
edad337
f7aa2a6
caecf66
d5ee292
bb313a4
da4db03
afea1e6
7bda500
d54e844
c40a6aa
9e45b84
15b4e22
402eb30
b185d6d
d6633c6
eceafb9
9e668fa
1a1f18b
c2c6097
d28f0b9
250560f
e686a8a
db83109
83df9cb
0677f1a
b848ae0
be78240
f12c9f5
cacca74
f9f0613
b1329d7
3624566
1aeb448
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,14 @@ | ||
| import galsim as _galsim | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from jax.tree_util import register_pytree_node_class | ||
|
|
||
| from jax_galsim.core.utils import ( | ||
| CONST_TYPES, | ||
| cast_to_float, | ||
| cast_to_int, | ||
| cast_to_python_float, | ||
| check_is_int_then_cast, | ||
| ensure_hashable, | ||
| has_tracers, | ||
| implements, | ||
| ) | ||
| from jax_galsim.position import Position, PositionD, PositionI | ||
|
|
@@ -23,10 +21,7 @@ | |
| Further, the JAX implementation adds a new method, ``isStatic`` to the | ||
| ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance | ||
| has been instantiated with static, known values, ``isStatic()`` will | ||
| return ``True``. You can indicate to JAX-GalSim that a ``BoundsI`` | ||
| instance should be static via initializing it with the ``static`` | ||
| keyword set to the ``True``. If the object detects that it is being | ||
| initialized with non-static data, an error will be raised. | ||
| return ``True``. | ||
|
|
||
| ``BoundsI`` objects in JAX-Galsim support an additional initialization | ||
| call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, | ||
|
|
@@ -139,8 +134,8 @@ def _parse_args(self, *args, **kwargs): | |
| else: | ||
| max_delta = 1 | ||
| if ( | ||
| isinstance(self.deltax, CONST_TYPES) | ||
| and isinstance(self.deltay, CONST_TYPES) | ||
| isinstance(self.deltax, (int, float, np.integer, np.floating)) | ||
| and isinstance(self.deltay, (int, float, np.integer, np.floating)) | ||
| and (self.deltax < max_delta or self.deltay < max_delta) | ||
| ): | ||
| self._isdefined = False | ||
|
|
@@ -364,10 +359,8 @@ def from_galsim(cls, galsim_bounds): | |
| """Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object.""" | ||
| if isinstance(galsim_bounds, _galsim.BoundsD): | ||
| _cls = BoundsD | ||
| kwargs = {} | ||
| elif isinstance(galsim_bounds, _galsim.BoundsI): | ||
| _cls = BoundsI | ||
| kwargs = {"static": True} | ||
| else: | ||
| raise TypeError( | ||
| "galsim_bounds must be either a %s or a %s" | ||
|
|
@@ -379,7 +372,6 @@ def from_galsim(cls, galsim_bounds): | |
| galsim_bounds.xmax, | ||
| galsim_bounds.ymin, | ||
| galsim_bounds.ymax, | ||
| **kwargs, | ||
| ) | ||
| else: | ||
| return _cls() | ||
|
|
@@ -505,24 +497,25 @@ def __init__(self, *args, **kwargs): | |
| # initial setting to let stuff pass through freely | ||
| self._isstatic = True | ||
|
|
||
| force_static = kwargs.pop("static", False) | ||
|
|
||
| self._parse_args(*args, **kwargs) | ||
|
|
||
| if has_tracers(self.deltax) or has_tracers(self.deltay): | ||
| raise RuntimeError( | ||
| "Jax-GalSim BoundsI instances must have a fixed width! " | ||
| f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." | ||
| ) | ||
|
|
||
| self.deltax = cast_to_python_float(self.deltax) | ||
| self.deltay = cast_to_python_float(self.deltay) | ||
| self.deltax = cast_to_float(self.deltax) | ||
| self.deltay = cast_to_float(self.deltay) | ||
| if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): | ||
| raise TypeError("BoundsI must be initialized with integer values") | ||
| self.deltax = int(cast_to_int(self.deltax)) | ||
| self.deltay = int(cast_to_int(self.deltay)) | ||
| self.deltax = cast_to_int(self.deltax) | ||
| self.deltay = cast_to_int(self.deltay) | ||
|
Comment on lines
+502
to
+507
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so if I understand correctly, if someone tries to have non static deltax or deltay they will now get a
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that is the intention. I am further cleaning this up in #250. |
||
|
|
||
| if has_tracers(self._xmin) or has_tracers(self._ymin): | ||
| if not ( | ||
| isinstance( | ||
| self._xmin, | ||
| (int, float, np.floating, np.integer), | ||
| ) | ||
| and isinstance( | ||
| self._ymin, | ||
| (int, float, np.floating, np.integer), | ||
| ) | ||
| ): | ||
| self._isstatic = False | ||
|
|
||
| # validate inputs are ints | ||
|
|
@@ -536,13 +529,6 @@ def __init__(self, *args, **kwargs): | |
| if self.deltax < 1 and self.deltay < 1: | ||
| self._isdefined = False | ||
|
|
||
| if force_static and not self._isstatic: | ||
| raise RuntimeError( | ||
| "BoundsI initialized with non-static " | ||
| f"data (xmin,ymin = {self._xmin},{self._yminb}) " | ||
| "when static data was explicitly requested." | ||
| ) | ||
|
|
||
| def _check_scalar(self, x, name): | ||
| try: | ||
| if ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why this is the check here? I guess more specifically, why not ensure that
otheris scalar instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the goal in this PR is to let arrays and tracers pass through without trying to detect which is which. We can do that for the most part (except for a few cases like BoundsI). So the goal of the check here is guard against common errors of handling Angles and AngleUnits.