2121 TYPE_CHECKING ,
2222 Protocol ,
2323 cast ,
24- overload ,
2524 runtime_checkable ,
2625)
2726
@@ -47,10 +46,7 @@ class Support(Protocol):
4746 Support defines the set of values where a distribution is defined.
4847 """
4948
50- @overload
51- def contains (self , x : Number ) -> bool : ...
52- @overload
53- def contains (self , x : NumericArray ) -> BoolArray : ...
49+ def contains (self , x : NumericArray ) -> bool | BoolArray : ...
5450
5551
5652class ContinuousSupport (Interval1D , Support ):
@@ -62,8 +58,7 @@ class ContinuousSupport(Interval1D, Support):
6258 """
6359
6460
65- # Support want to have Number as a parameter of contains, but we decided that we should avoid this
66- class ContinuousNDSupport (IntervalND , Support ): # type: ignore[misc]
61+ class ContinuousNDSupport (IntervalND , Support ):
6762 """
6863 Support for continuous distributions represented as an array of intervals.
6964
@@ -151,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non
151146
152147 self ._points = arr [unique_mask ]
153148
154- @overload
155- def contains (self , x : Number ) -> bool : ...
156- @overload
157- def contains (self , x : NumericArray ) -> BoolArray : ...
158-
159- def contains (self , x : Number | NumericArray ) -> bool | BoolArray :
149+ def contains (self , x : NumericArray ) -> bool | BoolArray :
160150 """
161151 Check if point(s) are in the support.
162152
@@ -185,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
185175 return bool (result )
186176 return cast (BoolArray , result )
187177
188- def __contains__ (self , x : object ) -> bool :
189- """Check if a point is in the support."""
190- return bool (self .contains (cast (Number , x )))
191-
192178 def iter_points (self ) -> Iterator [Number ]:
193179 """Iterate through all points in the support."""
194180 return iter (self ._points )
@@ -275,12 +261,7 @@ def __post_init__(self) -> None:
275261 if self .modulus <= 0 :
276262 raise ValueError ("modulus must be a positive integer." )
277263
278- @overload
279- def contains (self , x : Number ) -> bool : ...
280- @overload
281- def contains (self , x : NumericArray ) -> BoolArray : ...
282-
283- def contains (self , x : Number | NumericArray ) -> bool | BoolArray :
264+ def contains (self , x : NumericArray ) -> bool | BoolArray :
284265 """
285266 Check if point(s) are in the integer lattice support.
286267
@@ -306,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
306287 return bool (result )
307288 return cast (BoolArray , result )
308289
309- def __contains__ (self , x : object ) -> bool :
310- """Check if a point is in the integer lattice support."""
311- return bool (self .contains (cast (Number , x )))
312-
313290 def iter_points (self ) -> Iterator [int ]:
314291 """
315292 Iterate through all points in the integer lattice support.
@@ -453,28 +430,20 @@ def is_right_bounded(self) -> bool:
453430 __iter__ = iter_points
454431
455432
456- @dataclass (slots = True )
457- class SupportByPredicate (Support ):
458- predicate : Callable [[NumericArray ], bool ]
459-
460- @overload
461- def contains (self , x : Number ) -> bool : ...
462- @overload
463- def contains (self , x : NumericArray ) -> BoolArray : ...
433+ @dataclass (frozen = True , slots = True )
434+ class PredicateSupport (Support ):
435+ predicate : Callable [[NumericArray ], bool | BoolArray ]
464436
465- def contains (self , x : NumericArray ) -> bool | BoolArray : # type: ignore[misc]
437+ def contains (self , x : NumericArray ) -> bool | BoolArray :
466438 return self .predicate (x )
467439
468- def __contains__ (self , item : object ) -> bool | BoolArray :
469- return self .contains (cast (NumericArray , item ))
470-
471440
472441__all__ = [
473442 # Base support protocol
474443 "Support" ,
475444 "ContinuousSupport" ,
476445 "ContinuousNDSupport" ,
477- "SupportByPredicate " ,
446+ "PredicateSupport " ,
478447 # Discrete support protocol and implementations
479448 "DiscreteSupport" ,
480449 "ExplicitTableDiscreteSupport" ,
0 commit comments