|
14 | 14 | __copyright__ = "Copyright (c) 2025 PySATL project" |
15 | 15 | __license__ = "SPDX-License-Identifier: MIT" |
16 | 16 |
|
| 17 | +from collections.abc import Callable |
17 | 18 | from dataclasses import dataclass |
18 | 19 | from math import floor |
19 | | -from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable |
| 20 | +from typing import ( |
| 21 | + TYPE_CHECKING, |
| 22 | + Protocol, |
| 23 | + cast, |
| 24 | + overload, |
| 25 | + runtime_checkable, |
| 26 | +) |
20 | 27 |
|
21 | 28 | import numpy as np |
22 | 29 |
|
23 | | -from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray |
| 30 | +from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray |
24 | 31 |
|
25 | 32 | if TYPE_CHECKING: |
26 | 33 | from collections.abc import Iterable, Iterator |
@@ -49,6 +56,15 @@ class ContinuousSupport(Interval1D, Support): |
49 | 56 | """ |
50 | 57 |
|
51 | 58 |
|
| 59 | +class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc] |
| 60 | + """ |
| 61 | + Support for continuous distributions represented as an array of intervals. |
| 62 | +
|
| 63 | + This class inherits from IntervalND and implements the Support protocol |
| 64 | + for continuous distributions defined on a list of intervals [left, right]. |
| 65 | + """ |
| 66 | + |
| 67 | + |
52 | 68 | @runtime_checkable |
53 | 69 | class DiscreteSupport(Support, Protocol): |
54 | 70 | """ |
@@ -430,10 +446,26 @@ def is_right_bounded(self) -> bool: |
430 | 446 | __iter__ = iter_points |
431 | 447 |
|
432 | 448 |
|
| 449 | +class SupportByPredicate: |
| 450 | + def __init__(self, predicate: Callable[[NumericArray | Number], bool]): |
| 451 | + self._predicate = predicate |
| 452 | + |
| 453 | + def __contains__(self, item: NumericArray | Number) -> bool: |
| 454 | + return self._predicate(item) |
| 455 | + |
| 456 | + |
| 457 | +class SupportByIntervals(SupportByPredicate): |
| 458 | + def __init__(self, support: ContinuousNDSupport): |
| 459 | + SupportByPredicate.__init__(self, lambda x: x in support) |
| 460 | + |
| 461 | + |
433 | 462 | __all__ = [ |
434 | 463 | # Base support protocol |
435 | 464 | "Support", |
436 | 465 | "ContinuousSupport", |
| 466 | + "ContinuousNDSupport", |
| 467 | + "SupportByPredicate", |
| 468 | + "SupportByIntervals", |
437 | 469 | # Discrete support protocol and implementations |
438 | 470 | "DiscreteSupport", |
439 | 471 | "ExplicitTableDiscreteSupport", |
|
0 commit comments