Skip to content

Commit ec832d5

Browse files
committed
refactor(exponential): update support protocol and fix some minor problems
1 parent 2500678 commit ec832d5

8 files changed

Lines changed: 309 additions & 150 deletions

File tree

src/pysatl_core/distributions/support.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
TYPE_CHECKING,
2222
Protocol,
2323
cast,
24-
overload,
2524
runtime_checkable,
2625
)
2726

@@ -47,10 +46,7 @@ class Support(Protocol):
4746
Support defines the set of values where a distribution is defined.
4847
"""
4948

50-
@overload
51-
def contains(self, x: Number) -> bool: ...
52-
@overload
53-
def contains(self, x: NumericArray) -> BoolArray: ...
49+
def contains(self, x: NumericArray) -> bool | BoolArray: ...
5450

5551

5652
class ContinuousSupport(Interval1D, Support):
@@ -62,8 +58,7 @@ class ContinuousSupport(Interval1D, Support):
6258
"""
6359

6460

65-
# Support want to have Number as a parameter of contains, but we decided that we should avoid this
66-
class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc]
61+
class ContinuousNDSupport(IntervalND, Support):
6762
"""
6863
Support for continuous distributions represented as an array of intervals.
6964
@@ -151,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non
151146

152147
self._points = arr[unique_mask]
153148

154-
@overload
155-
def contains(self, x: Number) -> bool: ...
156-
@overload
157-
def contains(self, x: NumericArray) -> BoolArray: ...
158-
159-
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
149+
def contains(self, x: NumericArray) -> bool | BoolArray:
160150
"""
161151
Check if point(s) are in the support.
162152
@@ -185,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
185175
return bool(result)
186176
return cast(BoolArray, result)
187177

188-
def __contains__(self, x: object) -> bool:
189-
"""Check if a point is in the support."""
190-
return bool(self.contains(cast(Number, x)))
191-
192178
def iter_points(self) -> Iterator[Number]:
193179
"""Iterate through all points in the support."""
194180
return iter(self._points)
@@ -275,12 +261,7 @@ def __post_init__(self) -> None:
275261
if self.modulus <= 0:
276262
raise ValueError("modulus must be a positive integer.")
277263

278-
@overload
279-
def contains(self, x: Number) -> bool: ...
280-
@overload
281-
def contains(self, x: NumericArray) -> BoolArray: ...
282-
283-
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
264+
def contains(self, x: NumericArray) -> bool | BoolArray:
284265
"""
285266
Check if point(s) are in the integer lattice support.
286267
@@ -306,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
306287
return bool(result)
307288
return cast(BoolArray, result)
308289

309-
def __contains__(self, x: object) -> bool:
310-
"""Check if a point is in the integer lattice support."""
311-
return bool(self.contains(cast(Number, x)))
312-
313290
def iter_points(self) -> Iterator[int]:
314291
"""
315292
Iterate through all points in the integer lattice support.
@@ -453,28 +430,20 @@ def is_right_bounded(self) -> bool:
453430
__iter__ = iter_points
454431

455432

456-
@dataclass(slots=True)
457-
class SupportByPredicate(Support):
458-
predicate: Callable[[NumericArray], bool]
459-
460-
@overload
461-
def contains(self, x: Number) -> bool: ...
462-
@overload
463-
def contains(self, x: NumericArray) -> BoolArray: ...
433+
@dataclass(frozen=True, slots=True)
434+
class PredicateSupport(Support):
435+
predicate: Callable[[NumericArray], bool | BoolArray]
464436

465-
def contains(self, x: NumericArray) -> bool | BoolArray: # type: ignore[misc]
437+
def contains(self, x: NumericArray) -> bool | BoolArray:
466438
return self.predicate(x)
467439

468-
def __contains__(self, item: object) -> bool | BoolArray:
469-
return self.contains(cast(NumericArray, item))
470-
471440

472441
__all__ = [
473442
# Base support protocol
474443
"Support",
475444
"ContinuousSupport",
476445
"ContinuousNDSupport",
477-
"SupportByPredicate",
446+
"PredicateSupport",
478447
# Discrete support protocol and implementations
479448
"DiscreteSupport",
480449
"ExplicitTableDiscreteSupport",

0 commit comments

Comments
 (0)