Skip to content

Commit 7b6aecc

Browse files
authored
fix: round properly when rendering int images (#238)
* fix: round properly * fix: do not cast in draewing by k value * fix: more rounding * fix: one more spot * doc: add todo for later * fix: convert to double to then round * test: add test for inverting self and dtypes * test: update submodule to latest test
1 parent 6a8d05a commit 7b6aecc

8 files changed

Lines changed: 130 additions & 16 deletions

File tree

jax_galsim/core/draw.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@ def draw_by_xValue(
3434
)
3535

3636
# Apply the flux scaling
37-
im = (im * flux_scaling).astype(image.dtype)
37+
im *= flux_scaling
38+
39+
# jax-galsim's rounding of float-to-int is platform dependent
40+
# so we explicitly round to ints if needed
41+
if jnp.issubdtype(im.dtype, jnp.floating) and jnp.issubdtype(
42+
image.dtype, jnp.integer
43+
):
44+
im = jnp.around(im)
3845

3946
# Return an image
4047
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
@@ -53,7 +60,7 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)):
5360
im = jax.vmap(lambda *args: gsobject._kValue(PositionD(*args)))(
5461
coords[..., 0], coords[..., 1]
5562
)
56-
im = (im).astype(image.dtype)
63+
im = im.astype(image.dtype)
5764

5865
# Return an image
5966
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)

jax_galsim/gsobject.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,14 @@ def drawReal(self, image, add_to_image=False):
808808
)
809809
im1 = self._drawReal(image)
810810
temp = im1.subImage(image.bounds)
811+
812+
if jnp.issubdtype(temp.array.dtype, jnp.floating) and jnp.issubdtype(
813+
image.array.dtype, jnp.integer
814+
):
815+
# jax-galsim's rounding of float-to-int is platform dependent
816+
# so we explicitly round to ints if needed
817+
temp.array = jnp.around(temp.array)
818+
811819
if add_to_image:
812820
image._array = image._array.at[...].add(temp._array)
813821
else:
@@ -926,6 +934,14 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
926934
real_image = Image(
927935
bounds=breal, array=real_image_arr, dtype=image.dtype, wcs=image.wcs
928936
)
937+
938+
if jnp.issubdtype(real_image.array.dtype, jnp.floating) and jnp.issubdtype(
939+
image.array.dtype, jnp.integer
940+
):
941+
# jax-galsim's rounding of float-to-int is platform dependent
942+
# so we explicitly round to ints if needed
943+
real_image.array = jnp.around(real_image.array)
944+
929945
# Add (a portion of) this to the original image.
930946
temp = real_image.subImage(image.bounds)
931947
if add_to_image:

jax_galsim/image.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,19 @@ def __init__(self, *args, **kwargs):
150150
dtype = array.dtype.type
151151
if dtype in self._alias_dtypes:
152152
dtype = self._alias_dtypes[dtype]
153-
array = array.astype(dtype)
153+
# jax-galsim's rounding of float-to-int is platform dependent
154+
# so we explicitly round to ints if needed
155+
array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype)
154156
elif dtype not in self._valid_dtypes:
155157
raise _galsim.GalSimValueError(
156158
"Invalid dtype of provided array.",
157159
array.dtype,
158160
self._valid_dtypes,
159161
)
160162
else:
161-
array = array.astype(dtype)
163+
# jax-galsim's rounding of float-to-int is platform dependent
164+
# so we explicitly round to ints if needed
165+
array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype)
162166
# Be careful here: we have to watch out for little-endian / big-endian issues.
163167
# The path of least resistance is to check whether the array.dtype is equal to the
164168
# native one (using the dtype.isnative flag), and if not, make a new array that has a
@@ -206,7 +210,7 @@ def __init__(self, *args, **kwargs):
206210
if init_value:
207211
self._array = self._array.at[...].add(init_value)
208212
elif array is not None:
209-
self._array = array.view(dtype=self._dtype)
213+
self._array = array.view()
210214
nrow, ncol = array.shape
211215
if not has_tracers(xmin) and not has_tracers(ymin):
212216
self._bounds = BoundsI(
@@ -239,7 +243,12 @@ def __init__(self, *args, **kwargs):
239243
# e.g. im = ImageF(...)
240244
# im2 = ImageD(im)
241245
self._dtype = dtype
242-
self._array = image.array.astype(self._dtype)
246+
247+
# jax-galsim's rounding of float-to-int is platform dependent
248+
# so we explicitly round to ints if needed
249+
self._array = _safe_cast(
250+
image.array, jnp.issubdtype(self._dtype, jnp.integer), self._dtype
251+
)
243252
else:
244253
self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype)
245254
self._bounds = BoundsI()
@@ -365,6 +374,8 @@ def array(self):
365374

366375
@array.setter
367376
def array(self, other):
377+
# jax-galsim's rounding of float-to-int is platform dependent
378+
# so we explicitly round to ints if needed
368379
self._array = self._array.at[...].set(
369380
_safe_cast(other, self.isinteger, self.array.dtype)
370381
)
@@ -590,8 +601,12 @@ def setSubImage(self, bounds, rhs):
590601
i2 = bounds.ymax - self.ymin + 1
591602
j1 = bounds.xmin - self.xmin
592603
j2 = bounds.xmax - self.xmin + 1
604+
# jax-galsim's rounding of float-to-int is platform dependent
605+
# so we explicitly round to ints if needed
593606
self._array = self._array.at[i1:i2, j1:j2].set(
594-
jnp.astype(rhs.array, self.dtype)
607+
_safe_cast(
608+
rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype
609+
)
595610
)
596611
else:
597612
start_inds = (
@@ -600,7 +615,11 @@ def setSubImage(self, bounds, rhs):
600615
)
601616
self._array = jax.lax.dynamic_update_slice(
602617
self.array,
603-
jnp.astype(rhs.array, self.dtype),
618+
# jax-galsim's rounding of float-to-int is platform dependent
619+
# so we explicitly round to ints if needed
620+
_safe_cast(
621+
rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype
622+
),
604623
start_inds,
605624
)
606625

@@ -904,6 +923,8 @@ def copyFrom(self, rhs):
904923
def _copyFrom(self, rhs):
905924
"""Same as copyFrom, but no sanity checks."""
906925
self._array = self._array.at[...].set(
926+
# jax-galsim's rounding of float-to-int is platform dependent
927+
# so we explicitly round to ints if needed
907928
_safe_cast(rhs._array, self.isinteger, self.array.dtype)
908929
)
909930

@@ -947,7 +968,9 @@ def view(
947968

948969
# Recast the array type if necessary
949970
if dtype != self.array.dtype:
950-
array = self.array.astype(dtype)
971+
# jax-galsim's rounding of float-to-int is platform dependent
972+
# so we explicitly round to ints if needed
973+
array = _safe_cast(self.array, jnp.issubdtype(dtype, jnp.integer), dtype)
951974
elif contiguous:
952975
# this is a noop since all jax arrays are contiguous
953976
pass
@@ -1109,7 +1132,13 @@ def _invertSelf(self):
11091132
self._array,
11101133
)
11111134
self._array = self._array.at[...].set(
1112-
(jnp.where(msk, 0.0, 1.0 / safe_array)).astype(self._array.dtype)
1135+
# jax-galsim's rounding of float-to-int is platform dependent
1136+
# so we explicitly round to ints if needed
1137+
_safe_cast(
1138+
(jnp.where(msk, 0.0, 1.0 / safe_array)),
1139+
jnp.issubdtype(self._array.dtype, jnp.integer),
1140+
self._array.dtype,
1141+
)
11131142
)
11141143

11151144
@implements(_galsim.Image.replaceNegative)
@@ -1289,7 +1318,9 @@ def _Image(array, bounds, wcs):
12891318
ret._dtype = array.dtype.type
12901319
if ret._dtype in Image._alias_dtypes:
12911320
ret._dtype = Image._alias_dtypes[ret._dtype]
1292-
array = array.astype(ret._dtype)
1321+
# jax-galsim's rounding of float-to-int is platform dependent
1322+
# so we explicitly round to ints if needed
1323+
array = _safe_cast(array, jnp.issubdtype(ret._dtype, jnp.integer), ret._dtype)
12931324
ret._array = array
12941325
ret._bounds = bounds
12951326
return ret
@@ -1428,6 +1459,8 @@ def Image_iadd(self, other):
14281459
if dt == self.array.dtype:
14291460
self._array = self.array.at[...].add(a)
14301461
else:
1462+
# jax-galsim's rounding of float-to-int is platform dependent
1463+
# so we explicitly round to ints if needed
14311464
self._array = self.array.at[...].set(
14321465
_safe_cast(self.array + a, self.isinteger, self.array.dtype)
14331466
)
@@ -1458,6 +1491,8 @@ def Image_isub(self, other):
14581491
if dt == self.array.dtype:
14591492
self._array = self.array.at[...].subtract(a)
14601493
else:
1494+
# jax-galsim's rounding of float-to-int is platform dependent
1495+
# so we explicitly round to ints if needed
14611496
self._array = self.array.at[...].set(
14621497
_safe_cast(self.array - a, self.isinteger, self.array.dtype)
14631498
)
@@ -1484,6 +1519,8 @@ def Image_imul(self, other):
14841519
if dt == self.array.dtype:
14851520
self._array = self.array.at[...].multiply(a)
14861521
else:
1522+
# jax-galsim's rounding of float-to-int is platform dependent
1523+
# so we explicitly round to ints if needed
14871524
self._array = self.array.at[...].set(
14881525
_safe_cast(self.array * a, self.isinteger, self.array.dtype)
14891526
)
@@ -1516,6 +1553,8 @@ def Image_idiv(self, other):
15161553
# back to an integer array. So for integers (or mixed types), don't use /=.
15171554
self._array = self.array.at[...].divide(a)
15181555
else:
1556+
# jax-galsim's rounding of float-to-int is platform dependent
1557+
# so we explicitly round to ints if needed
15191558
self._array = self.array.at[...].set(
15201559
_safe_cast(self.array / a, self.isinteger, self.array.dtype)
15211560
)
@@ -1547,6 +1586,8 @@ def Image_ifloordiv(self, other):
15471586
if dt == self.array.dtype:
15481587
self._array = self.array.at[...].set(self.array // a)
15491588
else:
1589+
# jax-galsim's rounding of float-to-int is platform dependent
1590+
# so we explicitly round to ints if needed
15501591
self._array = self.array.at[...].set(
15511592
_safe_cast(self.array // a, self.isinteger, self.array.dtype)
15521593
)
@@ -1578,6 +1619,8 @@ def Image_imod(self, other):
15781619
if dt == self.array.dtype:
15791620
self._array = self.array.at[...].set(self.array % a)
15801621
else:
1622+
# jax-galsim's rounding of float-to-int is platform dependent
1623+
# so we explicitly round to ints if needed
15811624
self._array = self.array.at[...].set(
15821625
_safe_cast(self.array % a, self.isinteger, self.array.dtype)
15831626
)
@@ -1595,6 +1638,8 @@ def Image_ipow(self, other):
15951638
if not self.isinteger or isinstance(other, int):
15961639
self._array = self.array.at[...].power(other)
15971640
else:
1641+
# jax-galsim's rounding of float-to-int is platform dependent
1642+
# so we explicitly round to ints if needed
15981643
self._array = self.array.at[...].set(
15991644
_safe_cast(self.array**other, self.isinteger, self.array.dtype)
16001645
)

jax_galsim/interpolatedimage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,14 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
792792
)
793793

794794
# Apply the flux scaling
795-
im = (im * flux_scaling).astype(image.dtype)
795+
im *= flux_scaling
796+
797+
# jax-galsim's rounding of float-to-int is platform dependent
798+
# so we explicitly round to ints if needed
799+
if jnp.issubdtype(im.dtype, jnp.floating) and jnp.issubdtype(
800+
image.dtype, jnp.integer
801+
):
802+
im = jnp.around(im)
796803

797804
# Return an image
798805
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
@@ -817,7 +824,7 @@ def _drawKImage(self, image, jac=None):
817824
self._x_interpolant,
818825
self._k_interpolant,
819826
)
820-
im = (im).astype(image.dtype)
827+
im = im.astype(image.dtype)
821828

822829
# Return an image
823830
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)

jax_galsim/photon_array.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,14 @@ def _add_photons_to_image(x, y, flux, xmin, ymin, arr):
10001000
# dropped and negative indices wrap around
10011001
good = (xinds >= 0) & (xinds < arr.shape[1]) & (yinds >= 0) & (yinds < arr.shape[0])
10021002
_flux = jnp.where(good, flux, 0.0)
1003-
_arr = arr.at[yinds, xinds].add(_flux.astype(arr.dtype))
1003+
1004+
# jax-galsim's rounding of float-to-int is platform dependent
1005+
# so we explicitly round to ints if needed
1006+
if jnp.issubdtype(arr.dtype, jnp.integer):
1007+
_arr = arr.astype(float).at[yinds, xinds].add(_flux.astype(float))
1008+
_arr = jnp.around(_arr).astype(arr.dtype)
1009+
else:
1010+
_arr = arr.at[yinds, xinds].add(_flux.astype(arr.dtype))
10041011

10051012
return _arr, _flux.sum()
10061013

jax_galsim/wcs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,16 @@ def _makeSkyImage(self, image, sky_level, color):
516516
dvdy = 0.5 * (v[2:ny, 1 : nx - 1] - v[0 : ny - 2, 1 : nx - 1])
517517

518518
area = jnp.abs(dudx * dvdy - dvdx * dudy)
519-
image._array = image._array.at[...].set((area * sky_level).astype(image.dtype))
519+
im = area * sky_level
520+
521+
# jax-galsim's rounding of float-to-int is platform dependent
522+
# so we explicitly round to ints if needed
523+
if jnp.issubdtype(im.dtype, jnp.floating) and jnp.issubdtype(
524+
image.dtype, jnp.integer
525+
):
526+
im = jnp.around(im)
527+
528+
image._array = image._array.at[...].set(im)
520529

521530
# Each class should define the __eq__ function. Then __ne__ is obvious.
522531
def __ne__(self, other):

tests/GalSim

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import galsim as _galsim
2+
import jax.numpy as jnp
3+
import numpy as np
4+
5+
import jax_galsim
6+
7+
8+
def test_int_float_dtype_handling_invertSelf():
9+
gim = _galsim.Image(np.arange(20).reshape(4, 5), dtype=int)
10+
gim.invertSelf()
11+
12+
assert gim[1, 1] == 0
13+
assert gim[2, 1] == 1
14+
assert gim[4, 4] == 0
15+
16+
jgim = jax_galsim.Image(jnp.arange(20).reshape(4, 5), dtype=int)
17+
jgim.invertSelf()
18+
19+
assert jgim[1, 1] == 0
20+
assert jgim[2, 1] == 1
21+
assert jgim[4, 4] == 0
22+
23+
np.testing.assert_array_equal(gim.array, jgim.array)

0 commit comments

Comments
 (0)