Skip to content

Commit 95f0ae9

Browse files
committed
test: add test of bool eq api
1 parent 1c94cf3 commit 95f0ae9

2 files changed

Lines changed: 78 additions & 3 deletions

File tree

jax_galsim/noise.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,15 @@ def _applyTo(self, image):
106106
raise NotImplementedError("Cannot call applyTo on a pure BaseNoise object")
107107

108108
def __eq__(self, other):
109-
# Quick and dirty. Just check reprs are equal.
110-
return self is other or repr(self) == repr(other)
109+
if self is other:
110+
return jnp.array(True)
111+
elif isinstance(other, BaseNoise):
112+
return jnp.array(self._rng == other._rng)
113+
else:
114+
return jnp.array(False)
111115

112116
def __ne__(self, other):
113-
return not self.__eq__(other)
117+
return ~self.__eq__(other)
114118

115119
__hash__ = None
116120

@@ -173,6 +177,16 @@ def __repr__(self):
173177
def __str__(self):
174178
return "galsim.GaussianNoise(sigma=%s)" % (ensure_hashable(self.sigma),)
175179

180+
def __eq__(self, other):
181+
if self is other:
182+
return jnp.array(True)
183+
elif isinstance(other, self.__class__):
184+
return jnp.array(self._rng == other._rng) & jnp.array_equal(
185+
self._sigma, other._sigma
186+
)
187+
else:
188+
return jnp.array(False)
189+
176190
def tree_flatten(self):
177191
"""This function flattens the GaussianNoise into a list of children
178192
nodes that will be traced by JAX and auxiliary static data."""
@@ -265,6 +279,16 @@ def __repr__(self):
265279
def __str__(self):
266280
return "galsim.PoissonNoise(sky_level=%s)" % (self.sky_level)
267281

282+
def __eq__(self, other):
283+
if self is other:
284+
return jnp.array(True)
285+
elif isinstance(other, self.__class__):
286+
return jnp.array(self._rng == other._rng) & jnp.array_equal(
287+
self._sky_level, other._sky_level
288+
)
289+
else:
290+
return jnp.array(False)
291+
268292
def tree_flatten(self):
269293
"""This function flattens the PoissonNoise into a list of children
270294
nodes that will be traced by JAX and auxiliary static data."""
@@ -429,6 +453,19 @@ def __str__(self):
429453
self.read_noise,
430454
)
431455

456+
def __eq__(self, other):
457+
if self is other:
458+
return jnp.array(True)
459+
elif isinstance(other, self.__class__):
460+
return (
461+
jnp.array(self._rng == other._rng)
462+
& jnp.array_equal(self._sky_level, other._sky_level)
463+
& jnp.array_equal(self._gain, other._gain)
464+
& jnp.array_equal(self._read_noise, other._read_noise)
465+
)
466+
else:
467+
return jnp.array(False)
468+
432469
def tree_flatten(self):
433470
"""This function flattens the CCDNoise into a list of children
434471
nodes that will be traced by JAX and auxiliary static data."""
@@ -570,6 +607,16 @@ def __repr__(self):
570607
def __str__(self):
571608
return "galsim.VariableGaussianNoise(var_image=%s)" % (self.var_image)
572609

610+
def __eq__(self, other):
611+
if self is other:
612+
return jnp.array(True)
613+
elif isinstance(other, self.__class__):
614+
return jnp.array(self._rng == other._rng) & jnp.array(
615+
self._var_image == other._var_image
616+
)
617+
else:
618+
return jnp.array(False)
619+
573620
def tree_flatten(self):
574621
"""This function flattens the VariableGaussianNoise into a list of children
575622
nodes that will be traced by JAX and auxiliary static data."""

tests/jax/test_api.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,16 @@ def _run_object_checks(obj, cls, kind):
126126

127127
# check that we can hash the object
128128
hash(obj)
129+
130+
# check that val jax array
131+
if (
132+
hasattr(obj, "isStatic")
133+
and obj.isStatic()
134+
or isinstance(obj, jax_galsim.Sensor)
135+
):
136+
assert isinstance(eval(repr(obj)) == obj, bool)
137+
else:
138+
assert isinstance(eval(repr(obj)) == obj, jnp.ndarray)
129139
elif kind == "to-from-galsim":
130140
gs_obj = obj.to_galsim()
131141
jgs_obj = obj.from_galsim(gs_obj)
@@ -141,6 +151,14 @@ def _run_object_checks(obj, cls, kind):
141151

142152
# check that we cannot hash the object
143153
assert obj.__hash__ is None
154+
155+
# check that val jax array
156+
if (hasattr(obj, "isStatic") and obj.isStatic()) or isinstance(
157+
obj, jax_galsim.Sensor
158+
):
159+
assert isinstance(eval(repr(obj)) == obj, bool)
160+
else:
161+
assert isinstance(eval(repr(obj)) == obj, jnp.ndarray)
144162
elif kind == "pickle-eval-repr-wcs":
145163
import jax_galsim as galsim # noqa: F401
146164

@@ -152,6 +170,16 @@ def _run_object_checks(obj, cls, kind):
152170

153171
# check that we cannot hash the object
154172
hash(obj)
173+
174+
# check that val jax array
175+
if (
176+
hasattr(obj, "isStatic")
177+
and obj.isStatic()
178+
or isinstance(obj, jax_galsim.Sensor)
179+
):
180+
assert isinstance(eval(repr(obj)) == obj, bool)
181+
else:
182+
assert isinstance(eval(repr(obj)) == obj, jnp.ndarray)
155183
elif kind == "jax-compatible":
156184
# JAX tracing should be an identity
157185
assert cls.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj

0 commit comments

Comments
 (0)