Skip to content

fix: clean up type handling#248

Merged
beckermr merged 42 commits into
mainfrom
typing-inits-cleanup
May 21, 2026
Merged

fix: clean up type handling#248
beckermr merged 42 commits into
mainfrom
typing-inits-cleanup

Conversation

@beckermr
Copy link
Copy Markdown
Collaborator

@beckermr beckermr commented May 18, 2026

This PR cleans up the type handling.

The goal is to have objects do one of three things:

  1. If it is a python numeric type (i.e., int or float) or a numpy scalar type (i.e., isinstance(x, np.number), isinstance(x, np.integer), etc.; see https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types), convert it to a python type of the appropriate kind.
  2. Otherwise, cast to the correct type via jax.numpy.astype(x, ...).
  3. For outputs in FITS headers, we allow conversions of numpy/jax arrays to python numeric types as long as there is one element in the array (i.e., it is a numpy scalar type, an array scalar, or a 1D array with one element).

This set of rules is simple to understand, consistently handles array scalars (i.e., arrays with zero dimensions) in that any numpy array scalar is converted to a jax array scalar, and transparently handles JAX tracing.

The diff on the PR results in fewer lines of code, which is a nice side effect!

TODO:

  • instead of duck-typing via try...except blocks, explicitly test for types
  • eliminate the has_tracers function in favor of transparent handling
  • eliminate use of cast_to_python_float / cast_to_python_int outside of FITS handling
  • document the ideas above in docs

closes #246

@beckermr beckermr changed the base branch from main to equinox-err-2 May 18, 2026 11:03
@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented May 18, 2026

Merging this PR will not alter performance

⚠️ Unknown Walltime execution environment detected

Using the Walltime instrument on standard Hosted Runners will lead to inconsistent data.

For the most accurate results, we recommend using CodSpeed Macro Runners: bare-metal machines fine-tuned for performance measurement consistency.

⚠️ Different runtime environments detected

Some benchmarks with significant performance changes were compared across different runtime environments,
which may affect the accuracy of the results.

Open the report in CodSpeed to investigate

⚡ 4 improved benchmarks
❌ 3 regressed benchmarks
✅ 29 untouched benchmarks

Warning

Please fix the performance issues or acknowledge them on CodSpeed.

Performance Changes

Mode Benchmark BASE HEAD Efficiency
WallTime test_benchmarks_lanczos_interp[xval-conserve_dc-run] 126.5 µs 95.8 µs +32.01%
WallTime test_benchmarks_lanczos_interp[xval-no_conserve_dc-run] 123.7 µs 92.6 µs +33.52%
WallTime test_benchmarks_lanczos_interp[kval-no_conserve_dc-run] 57.9 µs 43.7 µs +32.7%
WallTime test_benchmarks_metacal[run] 20.2 ms 34.8 ms -41.84%
WallTime test_benchmark_spergel_conv[run] 169.2 ms 242.8 ms -30.3%
WallTime test_benchmark_moffat_init[run] 103 µs 58.9 µs +75.04%
WallTime test_benchmark_moffat_conv[run] 195.3 ms 285.9 ms -31.7%

Tip

Investigate this regression by commenting @codspeedbot fix this regression on this PR, or directly use the CodSpeed MCP with your agent.


Comparing typing-inits-cleanup (1aeb448) with main (6ebfcb8)

Open in CodSpeed

Comment thread jax_galsim/photon_array.py Outdated
Comment thread jax_galsim/photon_array.py Outdated
Comment thread jax_galsim/angle.py Outdated
@beckermr beckermr requested a review from ismael-mendoza May 18, 2026 21:11
@beckermr
Copy link
Copy Markdown
Collaborator Author

@ismael-mendoza This PR is based on top of #243. I plan to merge it after #243. Comments welcome!

Base automatically changed from equinox-err-2 to main May 18, 2026 21:36
Comment thread docs/sharp-bits.rst Outdated
Comment thread jax_galsim/image.py
Comment thread docs/sharp-bits.rst Outdated
Comment thread jax_galsim/image.py Outdated
Copy link
Copy Markdown
Collaborator

@ismael-mendoza ismael-mendoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Matt! I just have a few minor comments/questions and perhaps one small bug in _cast_to_static_numeric_scalar?

Comment thread docs/sharp-bits.rst Outdated
Comment thread docs/sharp-bits.rst Outdated
Comment thread docs/sharp-bits.rst Outdated
Comment thread docs/sharp-bits.rst Outdated
Comment thread jax_galsim/angle.py

def __mul__(self, other):
if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)):
if isinstance(other, (Angle, AngleUnit)):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is the check here? I guess more specifically, why not ensure that other is scalar instead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the goal in this PR is to let arrays and tracers pass through without trying to detect which is which. We can do that for the most part (except for a few cases like BoundsI). So the goal of the check here is guard against common errors of handling Angles and AngleUnits.

Comment thread jax_galsim/bounds.py
Comment on lines +502 to +507
self.deltax = cast_to_float(self.deltax)
self.deltay = cast_to_float(self.deltay)
if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)):
raise TypeError("BoundsI must be initialized with integer values")
self.deltax = int(cast_to_int(self.deltax))
self.deltay = int(cast_to_int(self.deltay))
self.deltax = cast_to_int(self.deltax)
self.deltay = cast_to_int(self.deltay)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if I understand correctly, if someone tries to have non static deltax or deltay they will now get a ConcretizationTypeError, with no error raised from JAX Galsim explicitly. Is that the intention since we warn the user in the LAX docs and we want to avoid using has_tracers?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is the intention. I am further cleaning this up in #250.

Comment thread jax_galsim/image.py Outdated
Comment thread jax_galsim/image.py Outdated
Comment thread jax_galsim/photon_array.py
Comment thread jax_galsim/wcs.py Outdated
beckermr and others added 9 commits May 21, 2026 06:52
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Clarified descriptions of numeric data types in JAX.
Add error handling for photon array size exceeding limit
@beckermr
Copy link
Copy Markdown
Collaborator Author

pre-commit.ci autofix

@beckermr beckermr requested a review from ismael-mendoza May 21, 2026 12:14
@beckermr
Copy link
Copy Markdown
Collaborator Author

OK @ismael-mendoza this one is ready for another look!

Copy link
Copy Markdown
Collaborator

@ismael-mendoza ismael-mendoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@beckermr beckermr merged commit e2b1029 into main May 21, 2026
10 checks passed
@beckermr beckermr deleted the typing-inits-cleanup branch May 21, 2026 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

clean up and unify type handling

2 participants