From d6bbd80680a6249247136fd4e737e6de19b78eda Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 16 Apr 2026 13:11:54 -0500 Subject: [PATCH 01/26] test: update tests to branch to ensure they pass --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index cf0c4b19..7c66c94b 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit cf0c4b196dcbd789cc550b8fe7d0643d6e508db5 +Subproject commit 7c66c94b2ebeeb91edcc43c815ce4e9e5e01098c From 4b2f72313c8d77445165f1bdb930177273aec413 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 15:27:52 -0500 Subject: [PATCH 02/26] fix: use in-place ops properly --- jax_galsim/convolve.py | 20 +++--- jax_galsim/gsobject.py | 6 +- jax_galsim/image.py | 140 +++++++++++++++++++++++++------------ jax_galsim/noise.py | 44 ++++++------ jax_galsim/photon_array.py | 2 +- jax_galsim/wcs.py | 6 +- 6 files changed, 140 insertions(+), 78 deletions(-) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..e6df0769 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -486,10 +486,12 @@ def _kValue(self, pos): def _drawKImage(self, image, jac=None): image = self.orig_obj._drawKImage(image, jac) - image._array = jnp.where( - jnp.abs(image.array) > self._min_acc_kvalue, - 1.0 / image.array, - self._inv_min_acc_kvalue, + image._array = image._array.at[...].set( + jnp.where( + jnp.abs(image.array) > self._min_acc_kvalue, + 1.0 / image.array, + self._inv_min_acc_kvalue, + ) ) kx, ky = image.get_pixel_centers() _jac = jnp.eye(2) if jac is None else jac @@ -500,10 +502,12 @@ def _drawKImage(self, image, jac=None): ) ksq = (kx**2 + ky**2) * image.scale**2 # Set to zero outside of nominal maxk so as not to amplify high frequencies. - image._array = jnp.where( - ksq > self.maxk**2, - 0.0, - image.array, + image._array = image._array.at[...].set( + jnp.where( + ksq > self.maxk**2, + 0.0, + image.array, + ) ) return image diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 64cc9ec7..915583f7 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -809,7 +809,7 @@ def drawReal(self, image, add_to_image=False): im1 = self._drawReal(image) temp = im1.subImage(image.bounds) if add_to_image: - image._array = image._array + temp._array + image._array = image._array.at[...].add(temp._array) else: image._array = temp._array @@ -929,7 +929,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image): # Add (a portion of) this to the original image. temp = real_image.subImage(image.bounds) if add_to_image: - image._array = image._array + temp._array + image._array = image._array.at[...].add(temp._array) else: image._array = temp._array @@ -1043,7 +1043,7 @@ def drawKImage( if not add_to_image: image._array = im2._array else: - image._array = im2._array + image._array + image._array = image._array.at[...].add(im2._array) image_in._array = image._array image_in._bounds = image._bounds diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 4f9acc4a..11f8bb37 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -191,16 +191,16 @@ def __init__(self, *args, **kwargs): else: self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: - self._array = self._array + init_value + self._array = self._array.at[...].add(init_value) elif bounds is not None: if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype) self._bounds = bounds if init_value: - self._array = self._array + init_value + self._array = self._array.at[...].add(init_value) elif array is not None: - self._array = array.view(dtype=self._dtype) + self._array = self._array.at[...].set(array.view(dtype=self._dtype)) nrow, ncol = array.shape if not has_tracers(xmin) and not has_tracers(ymin): self._bounds = BoundsI( @@ -233,7 +233,7 @@ def __init__(self, *args, **kwargs): # e.g. im = ImageF(...) # im2 = ImageD(im) self._dtype = dtype - self._array = image.array.astype(self._dtype) + self._array = self._array.at[...].set(image.array.astype(self._dtype)) else: self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype) self._bounds = BoundsI() @@ -357,6 +357,12 @@ def bounds(self): def array(self): return self._array + @array.setter + def array(self, other): + self._array = self._array.at[...].set( + _safe_cast(other, self.isinteger, self.array.dtype) + ) + @property @implements(_galsim.Image.nrow) def nrow(self): @@ -688,37 +694,43 @@ def _wrap(self, bounds, hermx, hermy, hermitian_wrap_size): if not hermx and not hermy: from jax_galsim.core.wrap_image import wrap_nonhermitian - self._array = wrap_nonhermitian( - self._array, - # zero indexed location of subimage - bounds.xmin - self.xmin, - bounds.ymin - self.ymin, - bounds.deltax, - bounds.deltay, + self._array = self._array.at[...].set( + wrap_nonhermitian( + self._array, + # zero indexed location of subimage + bounds.xmin - self.xmin, + bounds.ymin - self.ymin, + bounds.deltax, + bounds.deltay, + ) ) elif hermx and not hermy: from jax_galsim.core.wrap_image import wrap_hermitian_x - self._array = wrap_hermitian_x( - self._array, - -self.xmax, - self.ymin, - -bounds.xmax + 1, - bounds.ymin, - hermitian_wrap_size, - bounds.deltay, + self._array = self._array.at[...].set( + wrap_hermitian_x( + self._array, + -self.xmax, + self.ymin, + -bounds.xmax + 1, + bounds.ymin, + hermitian_wrap_size, + bounds.deltay, + ) ) elif not hermx and hermy: from jax_galsim.core.wrap_image import wrap_hermitian_y - self._array = wrap_hermitian_y( - self._array, - self.xmin, - -self.ymax, - bounds.xmin, - -bounds.ymax + 1, - bounds.deltax, - hermitian_wrap_size, + self._array = self._array.at[...].set( + wrap_hermitian_y( + self._array, + self.xmin, + -self.ymax, + bounds.xmin, + -bounds.ymax + 1, + bounds.deltax, + hermitian_wrap_size, + ) ) return self.subImage(bounds) @@ -776,8 +788,8 @@ def calculate_fft(self): ) # we shift the image before and after the FFT to match the layout of the modes # used by GalSim - out._array = jnp.fft.fftshift( - jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 + out._array = out._array.at[...].set( + jnp.fft.fftshift(jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0) ) out *= dx * dx @@ -838,8 +850,8 @@ def calculate_inverse_fft(self): scale=dx, ) # we shift the image before and after the FFT to match the layout used by galsim - out_extra._array = jnp.fft.fftshift( - jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) + out_extra._array = out_extra._array.at[...].set( + jnp.fft.fftshift(jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0))) ) # Now cut off the bit we don't need. out = out_extra.subImage( @@ -879,7 +891,13 @@ def copyFrom(self, rhs): self_image=self, rhs=rhs, ) - self._array = rhs._array + self._copyFrom(rhs) + + def _copyFrom(self, rhs): + """Same as copyFrom, but no sanity checks.""" + self._array = self._array.at[...].set( + _safe_cast(rhs._array, self.isinteger, self.array.dtype) + ) @implements( _galsim.Image.view, @@ -1055,7 +1073,7 @@ def fill(self, value): @implements(_galsim.Image._fill) def _fill(self, value): - self._array = jnp.zeros_like(self._array) + value + self._array = self._array.at[...].set(value) @implements(_galsim.Image.setZero) def setZero(self): @@ -1075,9 +1093,15 @@ def invertSelf(self): @implements(_galsim.Image._invertSelf) def _invertSelf(self): - array = 1.0 / self._array - array = array.at[jnp.isinf(array)].set(0.0) - self._array = array.astype(self._array.dtype) + msk = self._array == 0 + safe_array = jnp.where( + msk, + 1.0, + self._array, + ) + self._array = self._array.at[...].set( + (jnp.where(msk, 0.0, 1.0 / safe_array)).astype(self._array.dtype) + ) @implements(_galsim.Image.replaceNegative) def replaceNegative(self, replace_value=0): @@ -1348,6 +1372,18 @@ def ImageCD(*args, **kwargs): # +def _safe_cast(array, target_isinteger, target_dtype): + # code snippet pulled from upstream GalSim and turned into a general purpose + # function + # + # Assign the given array to self.array, safely casting it to the required type. + # Most important is to make sure integer types round first before casting, since + # numpy's astype doesn't do any rounding. + if target_isinteger: + array = jnp.around(array) + return array.astype(target_dtype) + + # Define a utility function to be used by the arithmetic functions below def check_image_consistency(im1, im2, integer=False): if integer and not im1.isinteger: @@ -1381,7 +1417,9 @@ def Image_iadd(self, other): if dt == self.array.dtype: self._array = self.array.at[...].add(a) else: - self._array = self.array.at[...].set((self.array + a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array + a, self.isinteger, self.array.dtype) + ) return self @@ -1409,7 +1447,9 @@ def Image_isub(self, other): if dt == self.array.dtype: self._array = self.array.at[...].subtract(a) else: - self._array = self.array.at[...].set((self.array - a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array - a, self.isinteger, self.array.dtype) + ) return self @@ -1433,7 +1473,9 @@ def Image_imul(self, other): if dt == self.array.dtype: self._array = self.array.at[...].multiply(a) else: - self._array = self.array.at[...].set((self.array * a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array * a, self.isinteger, self.array.dtype) + ) return self @@ -1463,7 +1505,9 @@ def Image_idiv(self, other): # back to an integer array. So for integers (or mixed types), don't use /=. self._array = self.array.at[...].divide(a) else: - self._array = self.array.at[...].set((self.array / a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array / a, self.isinteger, self.array.dtype) + ) return self @@ -1492,7 +1536,9 @@ def Image_ifloordiv(self, other): if dt == self.array.dtype: self._array = self.array.at[...].set(self.array // a) else: - self._array = self.array.at[...].set((self.array // a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array // a, self.isinteger, self.array.dtype) + ) return self @@ -1521,7 +1567,9 @@ def Image_imod(self, other): if dt == self.array.dtype: self._array = self.array.at[...].set(self.array % a) else: - self._array = self.array.at[...].set((self.array % a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array % a, self.isinteger, self.array.dtype) + ) return self @@ -1532,7 +1580,13 @@ def Image_pow(self, other): def Image_ipow(self, other): if not isinstance(other, int) and not isinstance(other, float): raise TypeError("Can only raise an image to a float or int power!") - self._array = self.array.at[...].power(other) + + if not self.isinteger or isinstance(other, int): + self._array = self.array.at[...].power(other) + else: + self._array = self.array.at[...].set( + _safe_cast(self.array**other, self.isinteger, self.array.dtype) + ) return self diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index f49ffa50..3956cf6a 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -142,8 +142,8 @@ def sigma(self): return self._sigma def _applyTo(self, image): - image._array = (image._array + self._rng.generate(image._array)).astype( - image.dtype + image._array = image._array.at[...].add( + self._rng.generate(image._array).astype(image.dtype) ) def _getVariance(self): @@ -231,14 +231,15 @@ def _applyTo(self, image): frac_sky, ) # Noise array is now the correct value for each pixel. - image._array = noise_array.astype(image.dtype) - image._array = jax.lax.cond( - int_sky != 0.0, - lambda na, ints: (na - ints).astype(float), - lambda na, ints: na.astype(float), - image._array, - int_sky, - ).astype(image.dtype) + image._array = image._array.at[...].set( + jax.lax.cond( + int_sky != 0.0, + lambda na, ints: (na - ints).astype(float), + lambda na, ints: na.astype(float), + noise_array.astype(image.dtype), + int_sky, + ).astype(image.dtype) + ) def _getVariance(self): return self.sky_level @@ -362,14 +363,15 @@ def _applyTo(self, image): frac_sky, ) # Noise array is now the correct value for each pixel. - image._array = noise_array.astype(image.dtype) - image._array = jax.lax.cond( - int_sky != 0.0, - lambda na, ints: (na - ints).astype(float), - lambda na, ints: na.astype(float), - image._array, - int_sky, - ).astype(image.dtype) + image._array = image._array.at[...].set( + jax.lax.cond( + int_sky != 0.0, + lambda na, ints: (na - ints).astype(float), + lambda na, ints: na.astype(float), + noise_array.astype(image.dtype), + int_sky, + ).astype(image.dtype) + ) def _getVariance(self): return jax.lax.cond( @@ -454,8 +456,8 @@ def __init__(self, dev): super().__init__(dev) def _applyTo(self, image): - image._array = (image._array + self._rng.generate(image._array)).astype( - image.dtype + image._array = image._array.at[...].add( + (self._rng.generate(image._array)).astype(image.dtype) ) def _getVariance(self): @@ -533,7 +535,7 @@ def applyTo(self, image): def _applyTo(self, image): # jax galsim never fills an image so this is safe noise_array = self._rng.generate_from_variance(self.var_image.array) - image._array = image._array + noise_array.astype(image.dtype) + image._array = image._array.at[...].add(noise_array.astype(image.dtype)) @implements( _galsim.noise.VariableGaussianNoise.copy, diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index a0cade02..cfa90831 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -863,7 +863,7 @@ def addTo(self, image): image.bounds.ymin, image._array, ) - image._array = _arr + image._array = image.array.at[...].set(_arr) return _flux_sum diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index ee18ef33..5eeccae5 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -516,7 +516,7 @@ def _makeSkyImage(self, image, sky_level, color): dvdy = 0.5 * (v[2:ny, 1 : nx - 1] - v[0 : ny - 2, 1 : nx - 1]) area = jnp.abs(dudx * dvdy - dvdx * dudy) - image._array = (area * sky_level).astype(image.dtype) + image._array = image._array.at[...].set((area * sky_level).astype(image.dtype)) # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): @@ -750,7 +750,9 @@ def _makeSkyImage(self, image, sky_level, color): area = jnp.abs(dudx * dvdy - dvdx * dudy) factor = radians / arcsec - image._array = area * sky_level * factor**2 + image._array = image._array.at[...].set( + (area * sky_level * factor**2).astype(image.dtype) + ) # Simple. Just call _radec. def _posToWorld(self, image_pos, color, project_center=None, projection="gnomonic"): From 6e9a10c584cd50514e2d4dc68b62e9897621931a Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 15:40:12 -0500 Subject: [PATCH 03/26] fix: buggy thing --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 11f8bb37..b957df13 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -200,7 +200,7 @@ def __init__(self, *args, **kwargs): if init_value: self._array = self._array.at[...].add(init_value) elif array is not None: - self._array = self._array.at[...].set(array.view(dtype=self._dtype)) + self._array = array.view(dtype=self._dtype) nrow, ncol = array.shape if not has_tracers(xmin) and not has_tracers(ymin): self._bounds = BoundsI( From 50f4d5e6e561f10cc723c37325d6eeb07ab7f4b9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 16:49:28 -0500 Subject: [PATCH 04/26] fix: correct FFT shape for inplace --- jax_galsim/image.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index b957df13..ff0028a0 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -233,7 +233,7 @@ def __init__(self, *args, **kwargs): # e.g. im = ImageF(...) # im2 = ImageD(im) self._dtype = dtype - self._array = self._array.at[...].set(image.array.astype(self._dtype)) + self._array = image.array.astype(self._dtype) else: self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype) self._bounds = BoundsI() @@ -832,9 +832,9 @@ def calculate_inverse_fft(self): kimage = Image(full_bounds, dtype=self.dtype, init_value=0) posx_bounds = BoundsI( xmin=0, - xmax=self.bounds.xmax, + deltax=self.bounds.deltax, ymin=self.bounds.ymin, - ymax=self.bounds.ymax, + deltay=self.bounds.deltay, ) kimage[posx_bounds] = self[posx_bounds] kimage = kimage._wrap(target_bounds, True, False, 2 * No2) @@ -843,20 +843,22 @@ def calculate_inverse_fft(self): # dx = 2pi / (N dk) dx = jnp.pi / (No2 * dk) - # For the inverse, we need a bit of extra space for the fft. - out_extra = Image( - BoundsI(xmin=-No2, deltax=2 * No2 + 2, ymin=-No2, deltay=2 * No2), + # In GalSim, they use inplace FFTW transforms which require the + # array that holds the input/output to have extra padding on the + # x dimension. + # jax-galsim does not need the padding since it does not use an + # inplace FFT. Thus we do not use the + # padding. + + out = Image( + BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2), dtype=float, scale=dx, ) # we shift the image before and after the FFT to match the layout used by galsim - out_extra._array = out_extra._array.at[...].set( + out._array = out._array.at[...].set( jnp.fft.fftshift(jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0))) ) - # Now cut off the bit we don't need. - out = out_extra.subImage( - BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) - ) out *= (dk * No2 / jnp.pi) ** 2 out.setCenter(0, 0) return out From aa9b46fb3dc27a04c1746a03a31a1b70aa50895b Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 16:52:38 -0500 Subject: [PATCH 05/26] fix: avoid multiple extra calls to set things inplace --- jax_galsim/image.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index ff0028a0..4acffc10 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -851,15 +851,15 @@ def calculate_inverse_fft(self): # padding. out = Image( - BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2), + bounds=BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2), dtype=float, scale=dx, + # we shift the image before and after the FFT to match the layout used by galsim + array=jnp.fft.fftshift( + jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) + ) + * (dk * No2 / jnp.pi) ** 2, ) - # we shift the image before and after the FFT to match the layout used by galsim - out._array = out._array.at[...].set( - jnp.fft.fftshift(jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0))) - ) - out *= (dk * No2 / jnp.pi) ** 2 out.setCenter(0, 0) return out From 579ceba153f09e58190ccfe9fc88518111cfd663 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 17:02:46 -0500 Subject: [PATCH 06/26] test: fix more test warnings --- jax_galsim/__init__.py | 2 +- jax_galsim/errors.py | 1 + jax_galsim/image.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 8e586d32..a90a5b6c 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -8,7 +8,7 @@ from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError -from .errors import GalSimFFTSizeError +from .errors import GalSimFFTSizeError, GalSimFFTSizeWarning from .errors import GalSimConfigError, GalSimConfigValueError from .errors import GalSimWarning, GalSimDeprecationWarning diff --git a/jax_galsim/errors.py b/jax_galsim/errors.py index f08e8fb1..5f98e67a 100644 --- a/jax_galsim/errors.py +++ b/jax_galsim/errors.py @@ -5,6 +5,7 @@ GalSimDeprecationWarning, GalSimError, GalSimFFTSizeError, + GalSimFFTSizeWarning, GalSimHSMError, GalSimImmutableError, GalSimIncompatibleValuesError, diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 4acffc10..16d032bf 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -832,9 +832,9 @@ def calculate_inverse_fft(self): kimage = Image(full_bounds, dtype=self.dtype, init_value=0) posx_bounds = BoundsI( xmin=0, - deltax=self.bounds.deltax, + xmax=self.bounds.xmax, ymin=self.bounds.ymin, - deltay=self.bounds.deltay, + ymax=self.bounds.ymax, ) kimage[posx_bounds] = self[posx_bounds] kimage = kimage._wrap(target_bounds, True, False, 2 * No2) From 6a246a97ed4be53794c3c0bd6659f5d2192c6937 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 17:12:19 -0500 Subject: [PATCH 07/26] test: fix more test warnings --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 7c66c94b..fc76358a 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 7c66c94b2ebeeb91edcc43c815ce4e9e5e01098c +Subproject commit fc76358a40c70d08ee1fa5e7c7bfd723758a6cd5 From 31bb208096687503dd5ffaa676979ffea7b51422 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 17:32:29 -0500 Subject: [PATCH 08/26] test: latest commit --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index fc76358a..5a343559 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit fc76358a40c70d08ee1fa5e7c7bfd723758a6cd5 +Subproject commit 5a3435598344e41a1d8f542b64b43c661b9fc2b5 From f993804ef2da08a38ad85627469cd5ec231e0fbd Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 17:51:44 -0500 Subject: [PATCH 09/26] fix: tests for implements maybe --- tests/jax/test_implements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_implements.py b/tests/jax/test_implements.py index ab01887a..afe27990 100644 --- a/tests/jax/test_implements.py +++ b/tests/jax/test_implements.py @@ -37,7 +37,7 @@ def test_implements(): p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\n cool." + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" not in p.front_matter assert p.sections == {} @@ -45,7 +45,7 @@ def test_implements(): docstring = LAXTestImplements.__doc__ p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\n cool." + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" in p.front_matter assert p.sections == {} From fb173500113c8c1a38de3bf8ec9128f4fd6c1147 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 17 Apr 2026 18:04:00 -0500 Subject: [PATCH 10/26] test: make tests robust across versions --- tests/jax/test_implements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_implements.py b/tests/jax/test_implements.py index afe27990..24d7ac13 100644 --- a/tests/jax/test_implements.py +++ b/tests/jax/test_implements.py @@ -37,7 +37,7 @@ def test_implements(): p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\ncool." + assert p.summary in ["The summary is\ncool.", "The summary is\n cool."] assert "This is front matter." in p.front_matter assert "LAX" not in p.front_matter assert p.sections == {} @@ -45,7 +45,7 @@ def test_implements(): docstring = LAXTestImplements.__doc__ p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\ncool." + assert p.summary in ["The summary is\ncool.", "The summary is\n cool."] assert "This is front matter." in p.front_matter assert "LAX" in p.front_matter assert p.sections == {} From d6624a4571f9b2c063f7338fd52f5f0592c7c823 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 06:09:55 -0500 Subject: [PATCH 11/26] fix: ensure implements removes leading spaces in older pythons --- jax_galsim/core/utils.py | 13 +++++++++++-- tests/jax/test_implements.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c4cc5413..4247f465 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -296,7 +296,7 @@ class ParsedDoc(NamedTuple): sections: dict[str, str] = {} -def _break_off_body_section_by_newline(body): +def _break_off_body_section_by_newline(body, double_check_first_indent=False): first_lines = [] body_lines = [] found_first_break = False @@ -314,7 +314,14 @@ def _break_off_body_section_by_newline(body): else: first_lines.append(line) + if double_check_first_indent and len(first_lines) > 1: + len_first_indent = len(first_lines[1]) - len(first_lines[1].lstrip()) + if len_first_indent > 0: + first_indent = first_lines[1][:len_first_indent] + first_lines[0] = first_indent + first_lines[0].lstrip() + firstline = "\n".join(first_lines) + firstline = textwrap.dedent(firstline) body = "\n".join(body_lines) body = textwrap.dedent(body.lstrip("\n")) @@ -337,7 +344,9 @@ def _parse_galsimdoc(docstr): signature, body = "", docstr - firstline, body = _break_off_body_section_by_newline(body) + firstline, body = _break_off_body_section_by_newline( + body, double_check_first_indent=True + ) summary = firstline if not summary: diff --git a/tests/jax/test_implements.py b/tests/jax/test_implements.py index 24d7ac13..afe27990 100644 --- a/tests/jax/test_implements.py +++ b/tests/jax/test_implements.py @@ -37,7 +37,7 @@ def test_implements(): p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary in ["The summary is\ncool.", "The summary is\n cool."] + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" not in p.front_matter assert p.sections == {} @@ -45,7 +45,7 @@ def test_implements(): docstring = LAXTestImplements.__doc__ p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary in ["The summary is\ncool.", "The summary is\n cool."] + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" in p.front_matter assert p.sections == {} From 585ecce238ba4ee3f11bf581a8d34be2b829a5f2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 06:27:59 -0500 Subject: [PATCH 12/26] fix: make sure coeffs is cast to an array --- jax_galsim/image.py | 5 +++-- jax_galsim/utilities.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 16d032bf..4995ecbd 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -315,7 +315,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds if self.bounds.isDefined(): - s += ", array=\n%r" % np.array(self.array) + s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: s += ", make_const=True" @@ -943,7 +943,8 @@ def view( if dtype != self.array.dtype: array = self.array.astype(dtype) elif contiguous: - array = np.ascontiguousarray(self.array) + # this is a noop since all jax arrays are contiguous + pass else: array = self.array diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 374a6a16..8f575c69 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -144,7 +144,7 @@ def unweighted_shape(arg): @functools.partial(jax.jit, static_argnames=("dtype",)) def horner(x, coef, dtype=None): x = jnp.array(x) - coef = jnp.atleast_1d(coef) + coef = jnp.atleast_1d(jnp.asarray(coef)) if dtype is None: res_dtype = jnp.result_type(x, coef) else: @@ -172,7 +172,7 @@ def horner(x, coef, dtype=None): def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) - coefs = jnp.atleast_1d(coefs) + coefs = jnp.atleast_1d(jnp.asarray(coefs)) if dtype is None: res_dtype = jnp.result_type(x, coefs) else: From 23ea81bcf28330f98ae8c7991b1748d1ca463a54 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 06:30:24 -0500 Subject: [PATCH 13/26] debug: stop on first failure for now --- .github/workflows/python_package.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index c1abea64..1c53409d 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -64,7 +64,7 @@ jobs: - name: Test with pytest run: | pytest \ - -vv \ + -vvsx \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ From e55dcfa157820304fc9576762a72441dbba2c2c1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 06:46:56 -0500 Subject: [PATCH 14/26] test: use latest testing branch --- .github/workflows/python_package.yaml | 2 +- tests/GalSim | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 1c53409d..ca933397 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -64,7 +64,7 @@ jobs: - name: Test with pytest run: | pytest \ - -vvsx \ + -vvx \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ diff --git a/tests/GalSim b/tests/GalSim index 5a343559..89f2aad6 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5a3435598344e41a1d8f542b64b43c661b9fc2b5 +Subproject commit 89f2aad6faf0c5fb6f034b1b91824199103d3e77 From 3a425978c901c379c85c04c2573d78670e95a3af Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:06:49 -0500 Subject: [PATCH 15/26] fix: ensure we cast to native byte order --- jax_galsim/core/utils.py | 11 +++++++++++ jax_galsim/photon_array.py | 40 +++++++++++++++++++++++++++----------- tests/jax/test_utils.py | 20 +++++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 tests/jax/test_utils.py diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 4247f465..5c49cdb2 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -9,6 +9,17 @@ from jax.tree_util import tree_flatten +def cast_numpy_array_to_native_byte_order(arr): + """Cast an array to native byte order.""" + if not isinstance(arr, np.ndarray): + return arr + + if arr.dtype.isnative: + return arr + + return arr.byteswap().view(arr.dtype.newbyteorder("=")) + + def has_tracers(x): """Return True if the input item is a JAX tracer or object, False otherwise.""" for item in tree_flatten(x)[0]: diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index cfa90831..f6fffe51 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -6,7 +6,11 @@ import jax.random as jrng from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_python_int, implements +from jax_galsim.core.utils import ( + cast_numpy_array_to_native_byte_order, + cast_to_python_int, + implements, +) from jax_galsim.errors import ( GalSimIncompatibleValuesError, GalSimRangeError, @@ -952,21 +956,35 @@ def read(cls, file_name): photons = cls( N, - x=jnp.array(data["x"]), - y=jnp.array(data["y"]), - flux=jnp.array(data["flux"]), + x=jnp.array(cast_numpy_array_to_native_byte_order(data["x"])), + y=jnp.array(cast_numpy_array_to_native_byte_order(data["y"])), + flux=jnp.array(cast_numpy_array_to_native_byte_order(data["flux"])), + ) + photons._nokeep = jnp.array( + cast_numpy_array_to_native_byte_order(data["_nokeep"]) ) - photons._nokeep = jnp.array(data["_nokeep"]) if "dxdz" in names: - photons.dxdz = jnp.array(data["dxdz"]) - photons.dydz = jnp.array(data["dydz"]) + photons.dxdz = jnp.array( + cast_numpy_array_to_native_byte_order(data["dxdz"]) + ) + photons.dydz = jnp.array( + cast_numpy_array_to_native_byte_order(data["dydz"]) + ) if "wavelength" in names: - photons.wavelength = jnp.array(data["wavelength"]) + photons.wavelength = jnp.array( + cast_numpy_array_to_native_byte_order(data["wavelength"]) + ) if "pupil_u" in names: - photons.pupil_u = jnp.array(data["pupil_u"]) - photons.pupil_v = jnp.array(data["pupil_v"]) + photons.pupil_u = jnp.array( + cast_numpy_array_to_native_byte_order(data["pupil_u"]) + ) + photons.pupil_v = jnp.array( + cast_numpy_array_to_native_byte_order(data["pupil_v"]) + ) if "time" in names: - photons.time = jnp.array(data["time"]) + photons.time = jnp.array( + cast_numpy_array_to_native_byte_order(data["time"]) + ) return photons diff --git a/tests/jax/test_utils.py b/tests/jax/test_utils.py new file mode 100644 index 00000000..4220f67d --- /dev/null +++ b/tests/jax/test_utils.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from jax_galsim.core.utils import cast_numpy_array_to_native_byte_order + + +@pytest.mark.parametrize( + "arr", + [ + np.arange(10), + np.arange(10, dtype=" Date: Sat, 18 Apr 2026 07:08:31 -0500 Subject: [PATCH 16/26] test: use latest testing code --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 89f2aad6..191e8786 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 89f2aad6faf0c5fb6f034b1b91824199103d3e77 +Subproject commit 191e8786ea40515ab466248c0f607ee846eee085 From 7f85cb029048e02d6b47702b14ed3c68c92b633d Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:14:16 -0500 Subject: [PATCH 17/26] fix: more native byte order casts --- jax_galsim/image.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 4995ecbd..efce0e10 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -5,7 +5,12 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.bounds import Bounds, BoundsD, BoundsI -from jax_galsim.core.utils import ensure_hashable, has_tracers, implements +from jax_galsim.core.utils import ( + cast_numpy_array_to_native_byte_order, + ensure_hashable, + has_tracers, + implements, +) from jax_galsim.errors import GalSimImmutableError from jax_galsim.position import PositionI from jax_galsim.utilities import parse_pos_args @@ -75,7 +80,7 @@ def __init__(self, *args, **kwargs): ymin = kwargs.pop("ymin", 1) elif len(args) == 1: if isinstance(args[0], np.ndarray): - array = jnp.array(args[0]) + array = jnp.array(cast_numpy_array_to_native_byte_order(args[0])) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) @@ -1237,7 +1242,9 @@ def from_galsim(cls, galsim_image): else None ) im = cls( - array=jnp.asarray(galsim_image.array), + array=jnp.asarray( + cast_numpy_array_to_native_byte_order(galsim_image.array) + ), wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) From 067cc85c26c4b63b4251c912e62b5f559e1441b2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:15:54 -0500 Subject: [PATCH 18/26] fix: make a copy --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 5c49cdb2..eac1cf09 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -17,7 +17,7 @@ def cast_numpy_array_to_native_byte_order(arr): if arr.dtype.isnative: return arr - return arr.byteswap().view(arr.dtype.newbyteorder("=")) + return arr.astype(arr.dtype.newbyteorder('=')) def has_tracers(x): From c07ba67be34e9ec6a165ce333734eb8426c2d01f Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:16:40 -0500 Subject: [PATCH 19/26] style: pre-commit --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index eac1cf09..3fcbf46d 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -17,7 +17,7 @@ def cast_numpy_array_to_native_byte_order(arr): if arr.dtype.isnative: return arr - return arr.astype(arr.dtype.newbyteorder('=')) + return arr.astype(arr.dtype.newbyteorder("=")) def has_tracers(x): From 2812c6760cfc55cd32a38a1350a8785b326d9df3 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Sat, 18 Apr 2026 07:34:23 -0500 Subject: [PATCH 20/26] Apply suggestion from @beckermr --- .github/workflows/python_package.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index ca933397..c1abea64 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -64,7 +64,7 @@ jobs: - name: Test with pytest run: | pytest \ - -vvx \ + -vv \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ From 8664851ea556ecc4d128ef9a53a5ee2a4c86be8b Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:38:45 -0500 Subject: [PATCH 21/26] test: try a new test that might be more robust --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 191e8786..07b6344a 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 191e8786ea40515ab466248c0f607ee846eee085 +Subproject commit 07b6344a8ac44b189eeb75591b606b539928e695 From ea1f9443ae21cfbdc6c977f53f271506216f5809 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:40:40 -0500 Subject: [PATCH 22/26] test: run it all via split --- .github/workflows/python_package.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index c1abea64..d56025de 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -75,10 +75,17 @@ jobs: --retries 1 - name: Test with pytest in float32 - if: ${{ matrix.group == '1' }} run: | pytest \ -vv \ + --durations=100 \ + --randomly-seed=42 \ + --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ + --store-durations \ + --durations-path=.test_durations.${{ matrix.group }} \ + --splitting-algorithm least_duration \ + --clean-durations \ + --retries 1 \ --test-in-float32 - name: Upload test durations From 71ebad17b8bc0d8725a6d48a91935125e3a5169f Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 07:45:48 -0500 Subject: [PATCH 23/26] test: try this test --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 07b6344a..61c6ce19 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 07b6344a8ac44b189eeb75591b606b539928e695 +Subproject commit 61c6ce193569d78d8575cb149800dc6cf423729a From dd3e982450419e17c465fc35f7c6d10ae824deed Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 08:30:02 -0500 Subject: [PATCH 24/26] test: use latest changes --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 61c6ce19..bcc1ee10 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 61c6ce193569d78d8575cb149800dc6cf423729a +Subproject commit bcc1ee102fd68ea5e9d5f635b79e5939b9ec3be2 From 302e110aca96a76943a11db80810ffb36b378da9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 08:45:32 -0500 Subject: [PATCH 25/26] fix: do not store durations for float32 tests --- .github/workflows/python_package.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index d56025de..c2e3de13 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -61,20 +61,18 @@ jobs: cat .test_durations* fi - - name: Test with pytest + - name: Test with pytest in float32 run: | pytest \ -vv \ --durations=100 \ --randomly-seed=42 \ --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ - --store-durations \ - --durations-path=.test_durations.${{ matrix.group }} \ --splitting-algorithm least_duration \ - --clean-durations \ - --retries 1 + --retries 1 \ + --test-in-float32 - - name: Test with pytest in float32 + - name: Test with pytest run: | pytest \ -vv \ @@ -85,8 +83,7 @@ jobs: --durations-path=.test_durations.${{ matrix.group }} \ --splitting-algorithm least_duration \ --clean-durations \ - --retries 1 \ - --test-in-float32 + --retries 1 - name: Upload test durations uses: actions/upload-artifact@v7 From 11e59bbe0a472695ac0e25da8c0d42dab352b44c Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 18 Apr 2026 08:48:25 -0500 Subject: [PATCH 26/26] test: update tests submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index bcc1ee10..1a490c3b 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit bcc1ee102fd68ea5e9d5f635b79e5939b9ec3be2 +Subproject commit 1a490c3b558fddf2cab1fc0e6d449b73fa3b4eda