diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index c1abea64..c2e3de13 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -61,6 +61,17 @@ jobs: cat .test_durations* fi + - name: Test with pytest in float32 + run: | + pytest \ + -vv \ + --durations=100 \ + --randomly-seed=42 \ + --splits ${NUM_SPLITS} --group ${{ matrix.group }} \ + --splitting-algorithm least_duration \ + --retries 1 \ + --test-in-float32 + - name: Test with pytest run: | pytest \ @@ -74,13 +85,6 @@ jobs: --clean-durations \ --retries 1 - - name: Test with pytest in float32 - if: ${{ matrix.group == '1' }} - run: | - pytest \ - -vv \ - --test-in-float32 - - name: Upload test durations uses: actions/upload-artifact@v7 with: diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 8e586d32..a90a5b6c 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -8,7 +8,7 @@ from .errors import GalSimKeyError, GalSimIndexError, GalSimNotImplementedError from .errors import GalSimBoundsError, GalSimUndefinedBoundsError, GalSimImmutableError from .errors import GalSimIncompatibleValuesError, GalSimSEDError, GalSimHSMError -from .errors import GalSimFFTSizeError +from .errors import GalSimFFTSizeError, GalSimFFTSizeWarning from .errors import GalSimConfigError, GalSimConfigValueError from .errors import GalSimWarning, GalSimDeprecationWarning diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..e6df0769 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -486,10 +486,12 @@ def _kValue(self, pos): def _drawKImage(self, image, jac=None): image = self.orig_obj._drawKImage(image, jac) - image._array = jnp.where( - jnp.abs(image.array) > self._min_acc_kvalue, - 1.0 / image.array, - self._inv_min_acc_kvalue, + image._array = image._array.at[...].set( + jnp.where( + jnp.abs(image.array) > self._min_acc_kvalue, + 1.0 / image.array, + self._inv_min_acc_kvalue, + ) ) kx, ky = image.get_pixel_centers() _jac = jnp.eye(2) if jac is None else jac @@ -500,10 +502,12 @@ def _drawKImage(self, image, jac=None): ) ksq = (kx**2 + ky**2) * image.scale**2 # Set to zero outside of nominal maxk so as not to amplify high frequencies. - image._array = jnp.where( - ksq > self.maxk**2, - 0.0, - image.array, + image._array = image._array.at[...].set( + jnp.where( + ksq > self.maxk**2, + 0.0, + image.array, + ) ) return image diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c4cc5413..3fcbf46d 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -9,6 +9,17 @@ from jax.tree_util import tree_flatten +def cast_numpy_array_to_native_byte_order(arr): + """Cast an array to native byte order.""" + if not isinstance(arr, np.ndarray): + return arr + + if arr.dtype.isnative: + return arr + + return arr.astype(arr.dtype.newbyteorder("=")) + + def has_tracers(x): """Return True if the input item is a JAX tracer or object, False otherwise.""" for item in tree_flatten(x)[0]: @@ -296,7 +307,7 @@ class ParsedDoc(NamedTuple): sections: dict[str, str] = {} -def _break_off_body_section_by_newline(body): +def _break_off_body_section_by_newline(body, double_check_first_indent=False): first_lines = [] body_lines = [] found_first_break = False @@ -314,7 +325,14 @@ def _break_off_body_section_by_newline(body): else: first_lines.append(line) + if double_check_first_indent and len(first_lines) > 1: + len_first_indent = len(first_lines[1]) - len(first_lines[1].lstrip()) + if len_first_indent > 0: + first_indent = first_lines[1][:len_first_indent] + first_lines[0] = first_indent + first_lines[0].lstrip() + firstline = "\n".join(first_lines) + firstline = textwrap.dedent(firstline) body = "\n".join(body_lines) body = textwrap.dedent(body.lstrip("\n")) @@ -337,7 +355,9 @@ def _parse_galsimdoc(docstr): signature, body = "", docstr - firstline, body = _break_off_body_section_by_newline(body) + firstline, body = _break_off_body_section_by_newline( + body, double_check_first_indent=True + ) summary = firstline if not summary: diff --git a/jax_galsim/errors.py b/jax_galsim/errors.py index f08e8fb1..5f98e67a 100644 --- a/jax_galsim/errors.py +++ b/jax_galsim/errors.py @@ -5,6 +5,7 @@ GalSimDeprecationWarning, GalSimError, GalSimFFTSizeError, + GalSimFFTSizeWarning, GalSimHSMError, GalSimImmutableError, GalSimIncompatibleValuesError, diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 64cc9ec7..915583f7 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -809,7 +809,7 @@ def drawReal(self, image, add_to_image=False): im1 = self._drawReal(image) temp = im1.subImage(image.bounds) if add_to_image: - image._array = image._array + temp._array + image._array = image._array.at[...].add(temp._array) else: image._array = temp._array @@ -929,7 +929,7 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image): # Add (a portion of) this to the original image. temp = real_image.subImage(image.bounds) if add_to_image: - image._array = image._array + temp._array + image._array = image._array.at[...].add(temp._array) else: image._array = temp._array @@ -1043,7 +1043,7 @@ def drawKImage( if not add_to_image: image._array = im2._array else: - image._array = im2._array + image._array + image._array = image._array.at[...].add(im2._array) image_in._array = image._array image_in._bounds = image._bounds diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 4f9acc4a..efce0e10 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -5,7 +5,12 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.bounds import Bounds, BoundsD, BoundsI -from jax_galsim.core.utils import ensure_hashable, has_tracers, implements +from jax_galsim.core.utils import ( + cast_numpy_array_to_native_byte_order, + ensure_hashable, + has_tracers, + implements, +) from jax_galsim.errors import GalSimImmutableError from jax_galsim.position import PositionI from jax_galsim.utilities import parse_pos_args @@ -75,7 +80,7 @@ def __init__(self, *args, **kwargs): ymin = kwargs.pop("ymin", 1) elif len(args) == 1: if isinstance(args[0], np.ndarray): - array = jnp.array(args[0]) + array = jnp.array(cast_numpy_array_to_native_byte_order(args[0])) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) @@ -191,14 +196,14 @@ def __init__(self, *args, **kwargs): else: self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: - self._array = self._array + init_value + self._array = self._array.at[...].add(init_value) elif bounds is not None: if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype) self._bounds = bounds if init_value: - self._array = self._array + init_value + self._array = self._array.at[...].add(init_value) elif array is not None: self._array = array.view(dtype=self._dtype) nrow, ncol = array.shape @@ -315,7 +320,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds if self.bounds.isDefined(): - s += ", array=\n%r" % np.array(self.array) + s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: s += ", make_const=True" @@ -357,6 +362,12 @@ def bounds(self): def array(self): return self._array + @array.setter + def array(self, other): + self._array = self._array.at[...].set( + _safe_cast(other, self.isinteger, self.array.dtype) + ) + @property @implements(_galsim.Image.nrow) def nrow(self): @@ -688,37 +699,43 @@ def _wrap(self, bounds, hermx, hermy, hermitian_wrap_size): if not hermx and not hermy: from jax_galsim.core.wrap_image import wrap_nonhermitian - self._array = wrap_nonhermitian( - self._array, - # zero indexed location of subimage - bounds.xmin - self.xmin, - bounds.ymin - self.ymin, - bounds.deltax, - bounds.deltay, + self._array = self._array.at[...].set( + wrap_nonhermitian( + self._array, + # zero indexed location of subimage + bounds.xmin - self.xmin, + bounds.ymin - self.ymin, + bounds.deltax, + bounds.deltay, + ) ) elif hermx and not hermy: from jax_galsim.core.wrap_image import wrap_hermitian_x - self._array = wrap_hermitian_x( - self._array, - -self.xmax, - self.ymin, - -bounds.xmax + 1, - bounds.ymin, - hermitian_wrap_size, - bounds.deltay, + self._array = self._array.at[...].set( + wrap_hermitian_x( + self._array, + -self.xmax, + self.ymin, + -bounds.xmax + 1, + bounds.ymin, + hermitian_wrap_size, + bounds.deltay, + ) ) elif not hermx and hermy: from jax_galsim.core.wrap_image import wrap_hermitian_y - self._array = wrap_hermitian_y( - self._array, - self.xmin, - -self.ymax, - bounds.xmin, - -bounds.ymax + 1, - bounds.deltax, - hermitian_wrap_size, + self._array = self._array.at[...].set( + wrap_hermitian_y( + self._array, + self.xmin, + -self.ymax, + bounds.xmin, + -bounds.ymax + 1, + bounds.deltax, + hermitian_wrap_size, + ) ) return self.subImage(bounds) @@ -776,8 +793,8 @@ def calculate_fft(self): ) # we shift the image before and after the FFT to match the layout of the modes # used by GalSim - out._array = jnp.fft.fftshift( - jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0 + out._array = out._array.at[...].set( + jnp.fft.fftshift(jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0) ) out *= dx * dx @@ -831,21 +848,23 @@ def calculate_inverse_fft(self): # dx = 2pi / (N dk) dx = jnp.pi / (No2 * dk) - # For the inverse, we need a bit of extra space for the fft. - out_extra = Image( - BoundsI(xmin=-No2, deltax=2 * No2 + 2, ymin=-No2, deltay=2 * No2), + # In GalSim, they use inplace FFTW transforms which require the + # array that holds the input/output to have extra padding on the + # x dimension. + # jax-galsim does not need the padding since it does not use an + # inplace FFT. Thus we do not use the + # padding. + + out = Image( + bounds=BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2), dtype=float, scale=dx, + # we shift the image before and after the FFT to match the layout used by galsim + array=jnp.fft.fftshift( + jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) + ) + * (dk * No2 / jnp.pi) ** 2, ) - # we shift the image before and after the FFT to match the layout used by galsim - out_extra._array = jnp.fft.fftshift( - jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) - ) - # Now cut off the bit we don't need. - out = out_extra.subImage( - BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) - ) - out *= (dk * No2 / jnp.pi) ** 2 out.setCenter(0, 0) return out @@ -879,7 +898,13 @@ def copyFrom(self, rhs): self_image=self, rhs=rhs, ) - self._array = rhs._array + self._copyFrom(rhs) + + def _copyFrom(self, rhs): + """Same as copyFrom, but no sanity checks.""" + self._array = self._array.at[...].set( + _safe_cast(rhs._array, self.isinteger, self.array.dtype) + ) @implements( _galsim.Image.view, @@ -923,7 +948,8 @@ def view( if dtype != self.array.dtype: array = self.array.astype(dtype) elif contiguous: - array = np.ascontiguousarray(self.array) + # this is a noop since all jax arrays are contiguous + pass else: array = self.array @@ -1055,7 +1081,7 @@ def fill(self, value): @implements(_galsim.Image._fill) def _fill(self, value): - self._array = jnp.zeros_like(self._array) + value + self._array = self._array.at[...].set(value) @implements(_galsim.Image.setZero) def setZero(self): @@ -1075,9 +1101,15 @@ def invertSelf(self): @implements(_galsim.Image._invertSelf) def _invertSelf(self): - array = 1.0 / self._array - array = array.at[jnp.isinf(array)].set(0.0) - self._array = array.astype(self._array.dtype) + msk = self._array == 0 + safe_array = jnp.where( + msk, + 1.0, + self._array, + ) + self._array = self._array.at[...].set( + (jnp.where(msk, 0.0, 1.0 / safe_array)).astype(self._array.dtype) + ) @implements(_galsim.Image.replaceNegative) def replaceNegative(self, replace_value=0): @@ -1210,7 +1242,9 @@ def from_galsim(cls, galsim_image): else None ) im = cls( - array=jnp.asarray(galsim_image.array), + array=jnp.asarray( + cast_numpy_array_to_native_byte_order(galsim_image.array) + ), wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) @@ -1348,6 +1382,18 @@ def ImageCD(*args, **kwargs): # +def _safe_cast(array, target_isinteger, target_dtype): + # code snippet pulled from upstream GalSim and turned into a general purpose + # function + # + # Assign the given array to self.array, safely casting it to the required type. + # Most important is to make sure integer types round first before casting, since + # numpy's astype doesn't do any rounding. + if target_isinteger: + array = jnp.around(array) + return array.astype(target_dtype) + + # Define a utility function to be used by the arithmetic functions below def check_image_consistency(im1, im2, integer=False): if integer and not im1.isinteger: @@ -1381,7 +1427,9 @@ def Image_iadd(self, other): if dt == self.array.dtype: self._array = self.array.at[...].add(a) else: - self._array = self.array.at[...].set((self.array + a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array + a, self.isinteger, self.array.dtype) + ) return self @@ -1409,7 +1457,9 @@ def Image_isub(self, other): if dt == self.array.dtype: self._array = self.array.at[...].subtract(a) else: - self._array = self.array.at[...].set((self.array - a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array - a, self.isinteger, self.array.dtype) + ) return self @@ -1433,7 +1483,9 @@ def Image_imul(self, other): if dt == self.array.dtype: self._array = self.array.at[...].multiply(a) else: - self._array = self.array.at[...].set((self.array * a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array * a, self.isinteger, self.array.dtype) + ) return self @@ -1463,7 +1515,9 @@ def Image_idiv(self, other): # back to an integer array. So for integers (or mixed types), don't use /=. self._array = self.array.at[...].divide(a) else: - self._array = self.array.at[...].set((self.array / a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array / a, self.isinteger, self.array.dtype) + ) return self @@ -1492,7 +1546,9 @@ def Image_ifloordiv(self, other): if dt == self.array.dtype: self._array = self.array.at[...].set(self.array // a) else: - self._array = self.array.at[...].set((self.array // a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array // a, self.isinteger, self.array.dtype) + ) return self @@ -1521,7 +1577,9 @@ def Image_imod(self, other): if dt == self.array.dtype: self._array = self.array.at[...].set(self.array % a) else: - self._array = self.array.at[...].set((self.array % a).astype(self.array.dtype)) + self._array = self.array.at[...].set( + _safe_cast(self.array % a, self.isinteger, self.array.dtype) + ) return self @@ -1532,7 +1590,13 @@ def Image_pow(self, other): def Image_ipow(self, other): if not isinstance(other, int) and not isinstance(other, float): raise TypeError("Can only raise an image to a float or int power!") - self._array = self.array.at[...].power(other) + + if not self.isinteger or isinstance(other, int): + self._array = self.array.at[...].power(other) + else: + self._array = self.array.at[...].set( + _safe_cast(self.array**other, self.isinteger, self.array.dtype) + ) return self diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index f49ffa50..3956cf6a 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -142,8 +142,8 @@ def sigma(self): return self._sigma def _applyTo(self, image): - image._array = (image._array + self._rng.generate(image._array)).astype( - image.dtype + image._array = image._array.at[...].add( + self._rng.generate(image._array).astype(image.dtype) ) def _getVariance(self): @@ -231,14 +231,15 @@ def _applyTo(self, image): frac_sky, ) # Noise array is now the correct value for each pixel. - image._array = noise_array.astype(image.dtype) - image._array = jax.lax.cond( - int_sky != 0.0, - lambda na, ints: (na - ints).astype(float), - lambda na, ints: na.astype(float), - image._array, - int_sky, - ).astype(image.dtype) + image._array = image._array.at[...].set( + jax.lax.cond( + int_sky != 0.0, + lambda na, ints: (na - ints).astype(float), + lambda na, ints: na.astype(float), + noise_array.astype(image.dtype), + int_sky, + ).astype(image.dtype) + ) def _getVariance(self): return self.sky_level @@ -362,14 +363,15 @@ def _applyTo(self, image): frac_sky, ) # Noise array is now the correct value for each pixel. - image._array = noise_array.astype(image.dtype) - image._array = jax.lax.cond( - int_sky != 0.0, - lambda na, ints: (na - ints).astype(float), - lambda na, ints: na.astype(float), - image._array, - int_sky, - ).astype(image.dtype) + image._array = image._array.at[...].set( + jax.lax.cond( + int_sky != 0.0, + lambda na, ints: (na - ints).astype(float), + lambda na, ints: na.astype(float), + noise_array.astype(image.dtype), + int_sky, + ).astype(image.dtype) + ) def _getVariance(self): return jax.lax.cond( @@ -454,8 +456,8 @@ def __init__(self, dev): super().__init__(dev) def _applyTo(self, image): - image._array = (image._array + self._rng.generate(image._array)).astype( - image.dtype + image._array = image._array.at[...].add( + (self._rng.generate(image._array)).astype(image.dtype) ) def _getVariance(self): @@ -533,7 +535,7 @@ def applyTo(self, image): def _applyTo(self, image): # jax galsim never fills an image so this is safe noise_array = self._rng.generate_from_variance(self.var_image.array) - image._array = image._array + noise_array.astype(image.dtype) + image._array = image._array.at[...].add(noise_array.astype(image.dtype)) @implements( _galsim.noise.VariableGaussianNoise.copy, diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index a0cade02..f6fffe51 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -6,7 +6,11 @@ import jax.random as jrng from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_python_int, implements +from jax_galsim.core.utils import ( + cast_numpy_array_to_native_byte_order, + cast_to_python_int, + implements, +) from jax_galsim.errors import ( GalSimIncompatibleValuesError, GalSimRangeError, @@ -863,7 +867,7 @@ def addTo(self, image): image.bounds.ymin, image._array, ) - image._array = _arr + image._array = image.array.at[...].set(_arr) return _flux_sum @@ -952,21 +956,35 @@ def read(cls, file_name): photons = cls( N, - x=jnp.array(data["x"]), - y=jnp.array(data["y"]), - flux=jnp.array(data["flux"]), + x=jnp.array(cast_numpy_array_to_native_byte_order(data["x"])), + y=jnp.array(cast_numpy_array_to_native_byte_order(data["y"])), + flux=jnp.array(cast_numpy_array_to_native_byte_order(data["flux"])), + ) + photons._nokeep = jnp.array( + cast_numpy_array_to_native_byte_order(data["_nokeep"]) ) - photons._nokeep = jnp.array(data["_nokeep"]) if "dxdz" in names: - photons.dxdz = jnp.array(data["dxdz"]) - photons.dydz = jnp.array(data["dydz"]) + photons.dxdz = jnp.array( + cast_numpy_array_to_native_byte_order(data["dxdz"]) + ) + photons.dydz = jnp.array( + cast_numpy_array_to_native_byte_order(data["dydz"]) + ) if "wavelength" in names: - photons.wavelength = jnp.array(data["wavelength"]) + photons.wavelength = jnp.array( + cast_numpy_array_to_native_byte_order(data["wavelength"]) + ) if "pupil_u" in names: - photons.pupil_u = jnp.array(data["pupil_u"]) - photons.pupil_v = jnp.array(data["pupil_v"]) + photons.pupil_u = jnp.array( + cast_numpy_array_to_native_byte_order(data["pupil_u"]) + ) + photons.pupil_v = jnp.array( + cast_numpy_array_to_native_byte_order(data["pupil_v"]) + ) if "time" in names: - photons.time = jnp.array(data["time"]) + photons.time = jnp.array( + cast_numpy_array_to_native_byte_order(data["time"]) + ) return photons diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 374a6a16..8f575c69 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -144,7 +144,7 @@ def unweighted_shape(arg): @functools.partial(jax.jit, static_argnames=("dtype",)) def horner(x, coef, dtype=None): x = jnp.array(x) - coef = jnp.atleast_1d(coef) + coef = jnp.atleast_1d(jnp.asarray(coef)) if dtype is None: res_dtype = jnp.result_type(x, coef) else: @@ -172,7 +172,7 @@ def horner(x, coef, dtype=None): def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) - coefs = jnp.atleast_1d(coefs) + coefs = jnp.atleast_1d(jnp.asarray(coefs)) if dtype is None: res_dtype = jnp.result_type(x, coefs) else: diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index ee18ef33..5eeccae5 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -516,7 +516,7 @@ def _makeSkyImage(self, image, sky_level, color): dvdy = 0.5 * (v[2:ny, 1 : nx - 1] - v[0 : ny - 2, 1 : nx - 1]) area = jnp.abs(dudx * dvdy - dvdx * dudy) - image._array = (area * sky_level).astype(image.dtype) + image._array = image._array.at[...].set((area * sky_level).astype(image.dtype)) # Each class should define the __eq__ function. Then __ne__ is obvious. def __ne__(self, other): @@ -750,7 +750,9 @@ def _makeSkyImage(self, image, sky_level, color): area = jnp.abs(dudx * dvdy - dvdx * dudy) factor = radians / arcsec - image._array = area * sky_level * factor**2 + image._array = image._array.at[...].set( + (area * sky_level * factor**2).astype(image.dtype) + ) # Simple. Just call _radec. def _posToWorld(self, image_pos, color, project_center=None, projection="gnomonic"): diff --git a/tests/GalSim b/tests/GalSim index cf0c4b19..1a490c3b 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit cf0c4b196dcbd789cc550b8fe7d0643d6e508db5 +Subproject commit 1a490c3b558fddf2cab1fc0e6d449b73fa3b4eda diff --git a/tests/jax/test_implements.py b/tests/jax/test_implements.py index ab01887a..afe27990 100644 --- a/tests/jax/test_implements.py +++ b/tests/jax/test_implements.py @@ -37,7 +37,7 @@ def test_implements(): p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\n cool." + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" not in p.front_matter assert p.sections == {} @@ -45,7 +45,7 @@ def test_implements(): docstring = LAXTestImplements.__doc__ p = _parse_galsimdoc(docstring) assert p.signature == "" - assert p.summary == "The summary is\n cool." + assert p.summary == "The summary is\ncool." assert "This is front matter." in p.front_matter assert "LAX" in p.front_matter assert p.sections == {} diff --git a/tests/jax/test_utils.py b/tests/jax/test_utils.py new file mode 100644 index 00000000..4220f67d --- /dev/null +++ b/tests/jax/test_utils.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from jax_galsim.core.utils import cast_numpy_array_to_native_byte_order + + +@pytest.mark.parametrize( + "arr", + [ + np.arange(10), + np.arange(10, dtype="