@@ -106,11 +106,15 @@ def _applyTo(self, image):
106106 raise NotImplementedError ("Cannot call applyTo on a pure BaseNoise object" )
107107
108108 def __eq__ (self , other ):
109- # Quick and dirty. Just check reprs are equal.
110- return self is other or repr (self ) == repr (other )
109+ if self is other :
110+ return jnp .array (True )
111+ elif isinstance (other , BaseNoise ):
112+ return jnp .array (self ._rng == other ._rng )
113+ else :
114+ return jnp .array (False )
111115
112116 def __ne__ (self , other ):
113- return not self .__eq__ (other )
117+ return ~ self .__eq__ (other )
114118
115119 __hash__ = None
116120
@@ -173,6 +177,16 @@ def __repr__(self):
173177 def __str__ (self ):
174178 return "galsim.GaussianNoise(sigma=%s)" % (ensure_hashable (self .sigma ),)
175179
180+ def __eq__ (self , other ):
181+ if self is other :
182+ return jnp .array (True )
183+ elif isinstance (other , self .__class__ ):
184+ return jnp .array (self ._rng == other ._rng ) & jnp .array_equal (
185+ self ._sigma , other ._sigma
186+ )
187+ else :
188+ return jnp .array (False )
189+
176190 def tree_flatten (self ):
177191 """This function flattens the GaussianNoise into a list of children
178192 nodes that will be traced by JAX and auxiliary static data."""
@@ -265,6 +279,16 @@ def __repr__(self):
265279 def __str__ (self ):
266280 return "galsim.PoissonNoise(sky_level=%s)" % (self .sky_level )
267281
282+ def __eq__ (self , other ):
283+ if self is other :
284+ return jnp .array (True )
285+ elif isinstance (other , self .__class__ ):
286+ return jnp .array (self ._rng == other ._rng ) & jnp .array_equal (
287+ self ._sky_level , other ._sky_level
288+ )
289+ else :
290+ return jnp .array (False )
291+
268292 def tree_flatten (self ):
269293 """This function flattens the PoissonNoise into a list of children
270294 nodes that will be traced by JAX and auxiliary static data."""
@@ -429,6 +453,19 @@ def __str__(self):
429453 self .read_noise ,
430454 )
431455
456+ def __eq__ (self , other ):
457+ if self is other :
458+ return jnp .array (True )
459+ elif isinstance (other , self .__class__ ):
460+ return (
461+ jnp .array (self ._rng == other ._rng )
462+ & jnp .array_equal (self ._sky_level , other ._sky_level )
463+ & jnp .array_equal (self ._gain , other ._gain )
464+ & jnp .array_equal (self ._read_noise , other ._read_noise )
465+ )
466+ else :
467+ return jnp .array (False )
468+
432469 def tree_flatten (self ):
433470 """This function flattens the CCDNoise into a list of children
434471 nodes that will be traced by JAX and auxiliary static data."""
@@ -570,6 +607,16 @@ def __repr__(self):
570607 def __str__ (self ):
571608 return "galsim.VariableGaussianNoise(var_image=%s)" % (self .var_image )
572609
610+ def __eq__ (self , other ):
611+ if self is other :
612+ return jnp .array (True )
613+ elif isinstance (other , self .__class__ ):
614+ return jnp .array (self ._rng == other ._rng ) & jnp .array (
615+ self ._var_image == other ._var_image
616+ )
617+ else :
618+ return jnp .array (False )
619+
573620 def tree_flatten (self ):
574621 """This function flattens the VariableGaussianNoise into a list of children
575622 nodes that will be traced by JAX and auxiliary static data."""
0 commit comments