Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d6bbd80
test: update tests to branch to ensure they pass
beckermr Apr 16, 2026
4b2f723
fix: use in-place ops properly
beckermr Apr 17, 2026
6e9a10c
fix: buggy thing
beckermr Apr 17, 2026
50f4d5e
fix: correct FFT shape for inplace
beckermr Apr 17, 2026
aa9b46f
fix: avoid multiple extra calls to set things inplace
beckermr Apr 17, 2026
579ceba
test: fix more test warnings
beckermr Apr 17, 2026
6a246a9
test: fix more test warnings
beckermr Apr 17, 2026
31bb208
test: latest commit
beckermr Apr 17, 2026
f993804
fix: tests for implements maybe
beckermr Apr 17, 2026
fb17350
test: make tests robust across versions
beckermr Apr 17, 2026
d6624a4
fix: ensure implements removes leading spaces in older pythons
beckermr Apr 18, 2026
585ecce
fix: make sure coeffs is cast to an array
beckermr Apr 18, 2026
23ea81b
debug: stop on first failure for now
beckermr Apr 18, 2026
e55dcfa
test: use latest testing branch
beckermr Apr 18, 2026
3a42597
fix: ensure we cast to native byte order
beckermr Apr 18, 2026
b27a12a
test: use latest testing code
beckermr Apr 18, 2026
7f85cb0
fix: more native byte order casts
beckermr Apr 18, 2026
067cc85
fix: make a copy
beckermr Apr 18, 2026
c07ba67
style: pre-commit
beckermr Apr 18, 2026
2812c67
Apply suggestion from @beckermr
beckermr Apr 18, 2026
8664851
test: try a new test that might be more robust
beckermr Apr 18, 2026
0b01023
Merge branch 'update-tests-galsim-2026-04-18' of https://github.com/G…
beckermr Apr 18, 2026
ea1f944
test: run it all via split
beckermr Apr 18, 2026
71ebad1
test: try this test
beckermr Apr 18, 2026
dd3e982
test: use latest changes
beckermr Apr 18, 2026
302e110
fix: do not store durations for float32 tests
beckermr Apr 18, 2026
11e59bb
test: update tests submodule
beckermr Apr 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ jobs:
cat .test_durations*
fi

- name: Test with pytest in float32
run: |
pytest \
-vv \
--durations=100 \
--randomly-seed=42 \
--splits ${NUM_SPLITS} --group ${{ matrix.group }} \
--splitting-algorithm least_duration \
--retries 1 \
--test-in-float32

- name: Test with pytest
run: |
pytest \
Expand All @@ -74,13 +85,6 @@ jobs:
--clean-durations \
--retries 1

- name: Test with pytest in float32
if: ${{ matrix.group == '1' }}
run: |
pytest \
-vv \
--test-in-float32

- name: Upload test durations
uses: actions/upload-artifact@v7
with:
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError
from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError
from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError
from .errors import GalSimFFTSizeError
from .errors import GalSimFFTSizeError, GalSimFFTSizeWarning
from .errors import GalSimConfigError, GalSimConfigValueError
from .errors import GalSimWarning, GalSimDeprecationWarning

Expand Down
20 changes: 12 additions & 8 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,12 @@ def _kValue(self, pos):

def _drawKImage(self, image, jac=None):
image = self.orig_obj._drawKImage(image, jac)
image._array = jnp.where(
jnp.abs(image.array) > self._min_acc_kvalue,
1.0 / image.array,
self._inv_min_acc_kvalue,
image._array = image._array.at[...].set(
jnp.where(
jnp.abs(image.array) > self._min_acc_kvalue,
1.0 / image.array,
self._inv_min_acc_kvalue,
)
)
kx, ky = image.get_pixel_centers()
_jac = jnp.eye(2) if jac is None else jac
Expand All @@ -500,10 +502,12 @@ def _drawKImage(self, image, jac=None):
)
ksq = (kx**2 + ky**2) * image.scale**2
# Set to zero outside of nominal maxk so as not to amplify high frequencies.
image._array = jnp.where(
ksq > self.maxk**2,
0.0,
image.array,
image._array = image._array.at[...].set(
jnp.where(
ksq > self.maxk**2,
0.0,
image.array,
)
)
return image

Expand Down
24 changes: 22 additions & 2 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from jax.tree_util import tree_flatten


def cast_numpy_array_to_native_byte_order(arr):
"""Cast an array to native byte order."""
if not isinstance(arr, np.ndarray):
return arr

if arr.dtype.isnative:
return arr

return arr.astype(arr.dtype.newbyteorder("="))


def has_tracers(x):
"""Return True if the input item is a JAX tracer or object, False otherwise."""
for item in tree_flatten(x)[0]:
Expand Down Expand Up @@ -296,7 +307,7 @@ class ParsedDoc(NamedTuple):
sections: dict[str, str] = {}


def _break_off_body_section_by_newline(body):
def _break_off_body_section_by_newline(body, double_check_first_indent=False):
first_lines = []
body_lines = []
found_first_break = False
Expand All @@ -314,7 +325,14 @@ def _break_off_body_section_by_newline(body):
else:
first_lines.append(line)

if double_check_first_indent and len(first_lines) > 1:
len_first_indent = len(first_lines[1]) - len(first_lines[1].lstrip())
if len_first_indent > 0:
first_indent = first_lines[1][:len_first_indent]
first_lines[0] = first_indent + first_lines[0].lstrip()

firstline = "\n".join(first_lines)
firstline = textwrap.dedent(firstline)
body = "\n".join(body_lines)
body = textwrap.dedent(body.lstrip("\n"))

Expand All @@ -337,7 +355,9 @@ def _parse_galsimdoc(docstr):

signature, body = "", docstr

firstline, body = _break_off_body_section_by_newline(body)
firstline, body = _break_off_body_section_by_newline(
body, double_check_first_indent=True
)

summary = firstline
if not summary:
Expand Down
1 change: 1 addition & 0 deletions jax_galsim/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
GalSimDeprecationWarning,
GalSimError,
GalSimFFTSizeError,
GalSimFFTSizeWarning,
GalSimHSMError,
GalSimImmutableError,
GalSimIncompatibleValuesError,
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def drawReal(self, image, add_to_image=False):
im1 = self._drawReal(image)
temp = im1.subImage(image.bounds)
if add_to_image:
image._array = image._array + temp._array
image._array = image._array.at[...].add(temp._array)
else:
image._array = temp._array

Expand Down Expand Up @@ -929,7 +929,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
# Add (a portion of) this to the original image.
temp = real_image.subImage(image.bounds)
if add_to_image:
image._array = image._array + temp._array
image._array = image._array.at[...].add(temp._array)
else:
image._array = temp._array

Expand Down Expand Up @@ -1043,7 +1043,7 @@ def drawKImage(
if not add_to_image:
image._array = im2._array
else:
image._array = im2._array + image._array
image._array = image._array.at[...].add(im2._array)

image_in._array = image._array
image_in._bounds = image._bounds
Expand Down
Loading