@@ -150,15 +150,19 @@ def __init__(self, *args, **kwargs):
150150 dtype = array .dtype .type
151151 if dtype in self ._alias_dtypes :
152152 dtype = self ._alias_dtypes [dtype ]
153- array = array .astype (dtype )
153+ # jax-galsim's rounding of float-to-int is platform dependent
154+ # so we explicitly round to ints if needed
155+ array = _safe_cast (array , jnp .issubdtype (dtype , jnp .integer ), dtype )
154156 elif dtype not in self ._valid_dtypes :
155157 raise _galsim .GalSimValueError (
156158 "Invalid dtype of provided array." ,
157159 array .dtype ,
158160 self ._valid_dtypes ,
159161 )
160162 else :
161- array = array .astype (dtype )
163+ # jax-galsim's rounding of float-to-int is platform dependent
164+ # so we explicitly round to ints if needed
165+ array = _safe_cast (array , jnp .issubdtype (dtype , jnp .integer ), dtype )
162166 # Be careful here: we have to watch out for little-endian / big-endian issues.
163167 # The path of least resistance is to check whether the array.dtype is equal to the
164168 # native one (using the dtype.isnative flag), and if not, make a new array that has a
@@ -206,7 +210,7 @@ def __init__(self, *args, **kwargs):
206210 if init_value :
207211 self ._array = self ._array .at [...].add (init_value )
208212 elif array is not None :
209- self ._array = array .view (dtype = self . _dtype )
213+ self ._array = array .view ()
210214 nrow , ncol = array .shape
211215 if not has_tracers (xmin ) and not has_tracers (ymin ):
212216 self ._bounds = BoundsI (
@@ -239,7 +243,12 @@ def __init__(self, *args, **kwargs):
239243 # e.g. im = ImageF(...)
240244 # im2 = ImageD(im)
241245 self ._dtype = dtype
242- self ._array = image .array .astype (self ._dtype )
246+
247+ # jax-galsim's rounding of float-to-int is platform dependent
248+ # so we explicitly round to ints if needed
249+ self ._array = _safe_cast (
250+ image .array , jnp .issubdtype (self ._dtype , jnp .integer ), self ._dtype
251+ )
243252 else :
244253 self ._array = jnp .zeros (shape = (1 , 1 ), dtype = self ._dtype )
245254 self ._bounds = BoundsI ()
@@ -365,6 +374,8 @@ def array(self):
365374
366375 @array .setter
367376 def array (self , other ):
377+ # jax-galsim's rounding of float-to-int is platform dependent
378+ # so we explicitly round to ints if needed
368379 self ._array = self ._array .at [...].set (
369380 _safe_cast (other , self .isinteger , self .array .dtype )
370381 )
@@ -590,8 +601,12 @@ def setSubImage(self, bounds, rhs):
590601 i2 = bounds .ymax - self .ymin + 1
591602 j1 = bounds .xmin - self .xmin
592603 j2 = bounds .xmax - self .xmin + 1
604+ # jax-galsim's rounding of float-to-int is platform dependent
605+ # so we explicitly round to ints if needed
593606 self ._array = self ._array .at [i1 :i2 , j1 :j2 ].set (
594- jnp .astype (rhs .array , self .dtype )
607+ _safe_cast (
608+ rhs .array , jnp .issubdtype (self .dtype , jnp .integer ), self .dtype
609+ )
595610 )
596611 else :
597612 start_inds = (
@@ -600,7 +615,11 @@ def setSubImage(self, bounds, rhs):
600615 )
601616 self ._array = jax .lax .dynamic_update_slice (
602617 self .array ,
603- jnp .astype (rhs .array , self .dtype ),
618+ # jax-galsim's rounding of float-to-int is platform dependent
619+ # so we explicitly round to ints if needed
620+ _safe_cast (
621+ rhs .array , jnp .issubdtype (self .dtype , jnp .integer ), self .dtype
622+ ),
604623 start_inds ,
605624 )
606625
@@ -904,6 +923,8 @@ def copyFrom(self, rhs):
904923 def _copyFrom (self , rhs ):
905924 """Same as copyFrom, but no sanity checks."""
906925 self ._array = self ._array .at [...].set (
926+ # jax-galsim's rounding of float-to-int is platform dependent
927+ # so we explicitly round to ints if needed
907928 _safe_cast (rhs ._array , self .isinteger , self .array .dtype )
908929 )
909930
@@ -947,7 +968,9 @@ def view(
947968
948969 # Recast the array type if necessary
949970 if dtype != self .array .dtype :
950- array = self .array .astype (dtype )
971+ # jax-galsim's rounding of float-to-int is platform dependent
972+ # so we explicitly round to ints if needed
973+ array = _safe_cast (self .array , jnp .issubdtype (dtype , jnp .integer ), dtype )
951974 elif contiguous :
952975 # this is a noop since all jax arrays are contiguous
953976 pass
@@ -1109,7 +1132,13 @@ def _invertSelf(self):
11091132 self ._array ,
11101133 )
11111134 self ._array = self ._array .at [...].set (
1112- (jnp .where (msk , 0.0 , 1.0 / safe_array )).astype (self ._array .dtype )
1135+ # jax-galsim's rounding of float-to-int is platform dependent
1136+ # so we explicitly round to ints if needed
1137+ _safe_cast (
1138+ (jnp .where (msk , 0.0 , 1.0 / safe_array )),
1139+ jnp .issubdtype (self ._array .dtype , jnp .integer ),
1140+ self ._array .dtype ,
1141+ )
11131142 )
11141143
11151144 @implements (_galsim .Image .replaceNegative )
@@ -1289,7 +1318,9 @@ def _Image(array, bounds, wcs):
12891318 ret ._dtype = array .dtype .type
12901319 if ret ._dtype in Image ._alias_dtypes :
12911320 ret ._dtype = Image ._alias_dtypes [ret ._dtype ]
1292- array = array .astype (ret ._dtype )
1321+ # jax-galsim's rounding of float-to-int is platform dependent
1322+ # so we explicitly round to ints if needed
1323+ array = _safe_cast (array , jnp .issubdtype (ret ._dtype , jnp .integer ), ret ._dtype )
12931324 ret ._array = array
12941325 ret ._bounds = bounds
12951326 return ret
@@ -1428,6 +1459,8 @@ def Image_iadd(self, other):
14281459 if dt == self .array .dtype :
14291460 self ._array = self .array .at [...].add (a )
14301461 else :
1462+ # jax-galsim's rounding of float-to-int is platform dependent
1463+ # so we explicitly round to ints if needed
14311464 self ._array = self .array .at [...].set (
14321465 _safe_cast (self .array + a , self .isinteger , self .array .dtype )
14331466 )
@@ -1458,6 +1491,8 @@ def Image_isub(self, other):
14581491 if dt == self .array .dtype :
14591492 self ._array = self .array .at [...].subtract (a )
14601493 else :
1494+ # jax-galsim's rounding of float-to-int is platform dependent
1495+ # so we explicitly round to ints if needed
14611496 self ._array = self .array .at [...].set (
14621497 _safe_cast (self .array - a , self .isinteger , self .array .dtype )
14631498 )
@@ -1484,6 +1519,8 @@ def Image_imul(self, other):
14841519 if dt == self .array .dtype :
14851520 self ._array = self .array .at [...].multiply (a )
14861521 else :
1522+ # jax-galsim's rounding of float-to-int is platform dependent
1523+ # so we explicitly round to ints if needed
14871524 self ._array = self .array .at [...].set (
14881525 _safe_cast (self .array * a , self .isinteger , self .array .dtype )
14891526 )
@@ -1516,6 +1553,8 @@ def Image_idiv(self, other):
15161553 # back to an integer array. So for integers (or mixed types), don't use /=.
15171554 self ._array = self .array .at [...].divide (a )
15181555 else :
1556+ # jax-galsim's rounding of float-to-int is platform dependent
1557+ # so we explicitly round to ints if needed
15191558 self ._array = self .array .at [...].set (
15201559 _safe_cast (self .array / a , self .isinteger , self .array .dtype )
15211560 )
@@ -1547,6 +1586,8 @@ def Image_ifloordiv(self, other):
15471586 if dt == self .array .dtype :
15481587 self ._array = self .array .at [...].set (self .array // a )
15491588 else :
1589+ # jax-galsim's rounding of float-to-int is platform dependent
1590+ # so we explicitly round to ints if needed
15501591 self ._array = self .array .at [...].set (
15511592 _safe_cast (self .array // a , self .isinteger , self .array .dtype )
15521593 )
@@ -1578,6 +1619,8 @@ def Image_imod(self, other):
15781619 if dt == self .array .dtype :
15791620 self ._array = self .array .at [...].set (self .array % a )
15801621 else :
1622+ # jax-galsim's rounding of float-to-int is platform dependent
1623+ # so we explicitly round to ints if needed
15811624 self ._array = self .array .at [...].set (
15821625 _safe_cast (self .array % a , self .isinteger , self .array .dtype )
15831626 )
@@ -1595,6 +1638,8 @@ def Image_ipow(self, other):
15951638 if not self .isinteger or isinstance (other , int ):
15961639 self ._array = self .array .at [...].power (other )
15971640 else :
1641+ # jax-galsim's rounding of float-to-int is platform dependent
1642+ # so we explicitly round to ints if needed
15981643 self ._array = self .array .at [...].set (
15991644 _safe_cast (self .array ** other , self .isinteger , self .array .dtype )
16001645 )
0 commit comments