Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
68c6938
fix: redo bounds again for dynamic usages
beckermr May 20, 2026
9d9dc80
style: please the dog
beckermr May 20, 2026
956f778
fix: only apply ~ to array bool
beckermr May 20, 2026
03c7869
fix: try dynamic children
beckermr May 20, 2026
231fdb2
fix: ensure bounds iuncludes tests are done properly
beckermr May 20, 2026
97b8823
fix: cast in a different way
beckermr May 20, 2026
997dadf
fix: finish dynamic bounds impl
beckermr May 20, 2026
0774b6a
fix: be sure to test everything
beckermr May 20, 2026
80f7327
fix: get bools right
beckermr May 20, 2026
9012a3e
fix: wrong tracing branch; remove test
beckermr May 20, 2026
1abee57
refactor: only need one of these
beckermr May 20, 2026
6718b24
Apply suggestion from @beckermr
beckermr May 20, 2026
f8d2c00
Apply suggestion from @beckermr
beckermr May 20, 2026
2b7634f
Apply suggestion from @beckermr
beckermr May 20, 2026
e2a4fac
test: update to latest test submodule
beckermr May 20, 2026
93eb30f
fix: ensure we handle branches on bounds eq properly
beckermr May 20, 2026
e44d8df
fix: this needs to be a float
beckermr May 20, 2026
ef586a3
fix: ensure we can FFT OK
beckermr May 20, 2026
618e137
Update jax_galsim/bounds.py
beckermr May 20, 2026
01f893c
Update jax_galsim/bounds.py
beckermr May 20, 2026
9d6f2fa
fix: use latest submodule
beckermr May 20, 2026
13e0266
test: update to latest submodule
beckermr May 20, 2026
2a9c0e7
fix: use to_galsim for fpacking
beckermr May 20, 2026
ac4e000
fix: do not convert all bounds props to arrays
beckermr May 21, 2026
004ada8
fix: put back variable pytree def
beckermr May 21, 2026
7f97f11
fix: make sure to send fits headers to galsim
beckermr May 21, 2026
a5fc11e
Apply suggestion from @beckermr
beckermr May 21, 2026
a0d0ba6
Merge branch 'typing-inits-cleanup' into bounds-fix-again-all-static
beckermr May 21, 2026
996a5fb
Merge branch 'typing-inits-cleanup' into bounds-fix-again-all-static
beckermr May 21, 2026
2143616
Merge branch 'main' into bounds-fix-again-all-static
beckermr May 21, 2026
fea3aec
fix: add back isStatic method
beckermr May 21, 2026
1a8f579
test: ensure api tests are correct
beckermr May 21, 2026
053de86
style: remove extra blank space changes
beckermr May 21, 2026
51efa74
Apply suggestion from @beckermr
beckermr May 21, 2026
cb9d084
Apply suggestion from @beckermr
beckermr May 21, 2026
63fe4b3
Apply suggestion from @beckermr
beckermr May 21, 2026
d0aab7e
fix: allow python bool for static bounds eq
beckermr May 21, 2026
9f99b92
fix: start on eq using jax bool values in most cases
beckermr May 21, 2026
0485a14
Apply suggestion from @beckermr
beckermr May 21, 2026
f642c26
fix: remove extra prints
beckermr May 21, 2026
f3d8e25
Merge branch 'bounds-fix-again-all-static' of https://github.com/GalS…
beckermr May 21, 2026
7d176a7
fix: return JAX bools for rest of things
beckermr May 21, 2026
1bf4a5d
fix: bool conversion in pos comp
beckermr May 21, 2026
4703099
fix: bool conversion in pos comp
beckermr May 21, 2026
de2bb72
test: update to latest submodule
beckermr May 21, 2026
29df9c3
doc: add docs
beckermr May 21, 2026
2d53ab0
doc: be a bit more specific
beckermr May 21, 2026
1c94cf3
fix: dead code
beckermr May 21, 2026
95f0ae9
test: add test of bool eq api
beckermr May 21, 2026
0d07661
test: add test of bool eq api
beckermr May 21, 2026
dea6915
test: add tests of bounds and vmap
beckermr May 21, 2026
1671b37
test: add tests of bounds and vmap
beckermr May 21, 2026
946958c
Apply suggestion from @beckermr
beckermr May 22, 2026
860a419
test: more tests for bounds and vmap
beckermr May 22, 2026
f312d87
Apply suggestion from @beckermr
beckermr May 22, 2026
8e7280f
test: ensure we have proper logic here
beckermr May 22, 2026
a9ef07f
Apply suggestions from code review
beckermr May 22, 2026
401e540
doc: clarify
beckermr May 22, 2026
0a1ddd1
test: more tests for bnds int and vmap
beckermr May 22, 2026
aa8d587
test: add tests for type conversion
beckermr May 22, 2026
4922f21
test: more tests for bounds set
beckermr May 22, 2026
2d48748
test: fix bounds api tests
beckermr May 22, 2026
64af33a
fix: sharpen tests
beckermr May 22, 2026
a7efe2d
doc: add comment
beckermr May 22, 2026
eb2748a
fix: rename for clarity
beckermr May 22, 2026
215b412
test: add tests of includes and simpler bounds init
beckermr May 23, 2026
863a99e
test: run faster?
beckermr May 23, 2026
137edd1
test: use more splits
beckermr May 23, 2026
6a5a37d
test: fewer splits, more workers
beckermr May 23, 2026
c7348ef
test: less logging
beckermr May 23, 2026
7b35ae8
test: try four xdist only
beckermr May 23, 2026
b71db8b
Apply suggestion from @beckermr
beckermr May 23, 2026
642b4ae
Apply suggestion from @beckermr
beckermr May 23, 2026
caa1e42
fix: implement _BoundsD/I
beckermr May 23, 2026
2a1a36d
Merge branch 'bounds-fix-again-all-static' of https://github.com/GalS…
beckermr May 23, 2026
d9af535
test: just a bit faster
beckermr May 23, 2026
afe868f
fix: do not cat this giant file
beckermr May 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.12"]
group: [1, 2, 3, 4]
group: [1, 2, 3, 4, 5, 6]
env:
NUM_SPLITS: 4
NUM_SPLITS: 6

steps:
- uses: actions/checkout@v6
Expand Down Expand Up @@ -57,24 +57,24 @@ jobs:
cp .test_durations .test_durations.${{ matrix.group }}
ls -lah .test_durations*
echo " "
cat .test_durations*
fi

- name: Test with pytest in float32
run: |
pytest \
-vv \
-v \
--durations=100 \
--randomly-seed=42 \
--splits ${NUM_SPLITS} --group ${{ matrix.group }} \
--splitting-algorithm least_duration \
--retries 1 \
--test-in-float32
--test-in-float32 \
-n 4

- name: Test with pytest
run: |
pytest \
-vv \
-v \
--durations=100 \
--randomly-seed=42 \
--splits ${NUM_SPLITS} --group ${{ matrix.group }} \
Expand All @@ -83,7 +83,7 @@ jobs:
--splitting-algorithm least_duration \
--clean-durations \
--retries 1 \
-n 2
-n 4

- name: Upload test durations
uses: actions/upload-artifact@v7
Expand Down
35 changes: 33 additions & 2 deletions docs/sharp-bits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,21 @@ does not affect the original.
# JAX-GalSim — real_part is a copy
real_part = complex_image.real # independent array
Scalar Types, Array Types, and Casting
--------------------------------------
Fixed Array Shapes in JAX Function Transformations
--------------------------------------------------

JAX function transformations (e.g., ``jax.jit``, ``jax.vmap``, etc.) require statically known
array shapes in order to support tracing. To support this, the JAX-GalSim ``BoundsI`` class must
have a statically known shape. Further this class can be instantiated via the syntax
``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)`` where ``deltax/y`` are the statically defined
shape. ``BoundsI`` classes may have dynamically set ``x/ymin`` values. However, in this case the ``&``
and ``+`` operations, which can change the shape of the ``BoundsI`` instance are not allowed in
JAX-traced code. ``BoundsI`` instances have a special method ``isStatic()`` which returns ``True``
if the object was instantiated with statically know ``x/ymin`` values. A static ``BoundsI`` class
cannot be converted to a dynamic one via assignment and an attempt to do so will raise an exception.

Scalar Types, Array Types, and Type Casting
-------------------------------------------

With the use of JAX, there are now many possible types for numeric data. These include

Expand All @@ -89,6 +102,24 @@ These rules allow JAX-GalSim to transparently handle JAX's tracing operations, b
the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in
some cases.

Object Comparison with the ``==`` Operator
------------------------------------------

In JAX-GalSim, all objects which define arrays to be traced by JAX will return JAX boolean
array scalars (i.e., ``jax.numpy.array(True)`` or ``jax.numpy.array(False)``) as the result
of the ``==`` operator. Otherwise the return value is a Python boolean. Important cases of this
rule are static ``BoundsI`` objects, ``Interpolant`` objects (and their subclasses), and ``GSParams``
objects, all of which return Python boolean values (i.e. ``True`` and ``False``). These difference
can be a source of subtle bugs since the negation of JAX array boolean values is typically done
with ``~``, while for Python boolean values it is done with ``not``. Mixing these two forms can
cause unexpected and incorrect results since

.. code-block:: python
>>> ~True is False
<python-input-0>:1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="?
False
Random Number Generation
------------------------

Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

# Basic building blocks
from .bounds import Bounds, BoundsD, BoundsI
from .bounds import Bounds, BoundsD, BoundsI, _BoundsD, _BoundsI
from .gsparams import GSParams
from .position import Position, PositionD, PositionI
from .angle import Angle, AngleUnit, _Angle, radians, hours, degrees, arcmin, arcsec
Expand Down
14 changes: 10 additions & 4 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ def __repr__(self):
return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),)

def __eq__(self, other):
return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value)
if not isinstance(other, AngleUnit):
return jnp.array(False)
else:
return jnp.array_equal(self.value, other.value)

def __ne__(self, other):
return not self.__eq__(other)
return ~self.__eq__(other)

def __hash__(self):
return hash(("galsim.AngleUnit", ensure_hashable(self.value)))
Expand Down Expand Up @@ -253,10 +256,13 @@ def __repr__(self):
return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),)

def __eq__(self, other):
return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad)
if not isinstance(other, Angle):
return jnp.array(False)
else:
return jnp.array_equal(self.rad, other.rad)

def __ne__(self, other):
return not self.__eq__(other)
return ~self.__eq__(other)

def __le__(self, other):
if not isinstance(other, Angle):
Expand Down
Loading
Loading