Skip to content

Commit 779438f

Browse files
author
domosedy
committed
[refactor] finally (i hope) fixed mr issues
1 parent c1c0054 commit 779438f

5 files changed

Lines changed: 156 additions & 213 deletions

File tree

src/pysatl_core/distributions/support.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414
__copyright__ = "Copyright (c) 2025 PySATL project"
1515
__license__ = "SPDX-License-Identifier: MIT"
1616

17+
from collections.abc import Callable
1718
from dataclasses import dataclass
1819
from math import floor
19-
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable
20+
from typing import (
21+
TYPE_CHECKING,
22+
Protocol,
23+
cast,
24+
overload,
25+
runtime_checkable,
26+
)
2027

2128
import numpy as np
2229

23-
from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray
30+
from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray
2431

2532
if TYPE_CHECKING:
2633
from collections.abc import Iterable, Iterator
@@ -49,6 +56,15 @@ class ContinuousSupport(Interval1D, Support):
4956
"""
5057

5158

59+
class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc]
60+
"""
61+
Support for continuous distributions represented as an array of intervals.
62+
63+
This class inherits from IntervalND and implements the Support protocol
64+
for continuous distributions defined on a list of intervals [left, right].
65+
"""
66+
67+
5268
@runtime_checkable
5369
class DiscreteSupport(Support, Protocol):
5470
"""
@@ -430,10 +446,26 @@ def is_right_bounded(self) -> bool:
430446
__iter__ = iter_points
431447

432448

449+
class SupportByPredicate:
450+
def __init__(self, predicate: Callable[[NumericArray | Number], bool]):
451+
self._predicate = predicate
452+
453+
def __contains__(self, item: NumericArray | Number) -> bool:
454+
return self._predicate(item)
455+
456+
457+
class SupportByIntervals(SupportByPredicate):
458+
def __init__(self, support: ContinuousNDSupport):
459+
SupportByPredicate.__init__(self, lambda x: x in support)
460+
461+
433462
__all__ = [
434463
# Base support protocol
435464
"Support",
436465
"ContinuousSupport",
466+
"ContinuousNDSupport",
467+
"SupportByPredicate",
468+
"SupportByIntervals",
437469
# Discrete support protocol and implementations
438470
"DiscreteSupport",
439471
"ExplicitTableDiscreteSupport",

src/pysatl_core/families/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
from .configuration import configure_families_register
1616
from .distribution import ParametricFamilyDistribution
1717
from .exponential_family import (
18+
# CanonicalContinuousExponentialClassFamily,
19+
ContinuousExponentialClassFamily,
1820
ExponentialConjugateHyperparameters,
19-
ExponentialFamily,
2021
ExponentialFamilyParametrization,
21-
NaturalExponentialFamily,
22-
SpacePredicate,
23-
SpacePredicateArray,
2422
)
2523
from .parametric_family import ParametricFamily
2624
from .parametrizations import (
@@ -42,12 +40,10 @@
4240
"configure_families_register",
4341
# builtins
4442
*_builtins_all,
45-
"ExponentialFamily",
43+
"ContinuousExponentialClassFamily",
4644
"ExponentialFamilyParametrization",
4745
"ExponentialConjugateHyperparameters",
48-
"SpacePredicate",
49-
"SpacePredicateArray",
50-
"NaturalExponentialFamily",
46+
# "CanonicalContinuousExponentialClassFamily",
5147
]
5248

5349
del _builtins_all

0 commit comments

Comments
 (0)