Skip to content

Commit 320aeac

Browse files
committed
refactor(exponential): fix pr issues
1 parent cfe1562 commit 320aeac

5 files changed

Lines changed: 86 additions & 41 deletions

File tree

src/pysatl_core/distributions/support.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@
2727

2828
import numpy as np
2929

30-
from pysatl_core.types import BoolArray, Interval1D, IntervalND, Number, NumericArray
30+
from pysatl_core.types import (
31+
BoolArray,
32+
Interval1D,
33+
IntervalND,
34+
Number,
35+
NumberParameter,
36+
NumericArray,
37+
)
3138

3239
if TYPE_CHECKING:
3340
from collections.abc import Iterable, Iterator
@@ -56,7 +63,7 @@ class ContinuousSupport(Interval1D, Support):
5663
"""
5764

5865

59-
class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc]
66+
class ContinuousNDSupport(IntervalND, Support):
6067
"""
6168
Support for continuous distributions represented as an array of intervals.
6269
@@ -446,17 +453,20 @@ def is_right_bounded(self) -> bool:
446453
__iter__ = iter_points
447454

448455

449-
class SupportByPredicate:
450-
def __init__(self, predicate: Callable[[NumericArray | Number], bool]):
451-
self._predicate = predicate
456+
@dataclass(slots=True)
457+
class SupportByPredicate(Support):
458+
predicate: Callable[[NumberParameter], bool]
452459

453-
def __contains__(self, item: NumericArray | Number) -> bool:
454-
return self._predicate(item)
460+
@overload
461+
def contains(self, x: Number) -> bool: ...
462+
@overload
463+
def contains(self, x: NumericArray) -> BoolArray: ...
455464

465+
def contains(self, x: NumberParameter) -> bool | BoolArray:
466+
return self.predicate(x)
456467

457-
class SupportByIntervals(SupportByPredicate):
458-
def __init__(self, support: ContinuousNDSupport):
459-
SupportByPredicate.__init__(self, lambda x: x in support)
468+
def __contains__(self, item: object) -> bool | BoolArray:
469+
return self.contains(cast(NumberParameter, item))
460470

461471

462472
__all__ = [
@@ -465,7 +475,6 @@ def __init__(self, support: ContinuousNDSupport):
465475
"ContinuousSupport",
466476
"ContinuousNDSupport",
467477
"SupportByPredicate",
468-
"SupportByIntervals",
469478
# Discrete support protocol and implementations
470479
"DiscreteSupport",
471480
"ExplicitTableDiscreteSupport",

src/pysatl_core/families/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
"ContinuousExponentialClassFamily",
4444
"ExponentialFamilyParametrization",
4545
"ExponentialConjugateHyperparameters",
46-
# "CanonicalContinuousExponentialClassFamily",
4746
]
4847

4948
del _builtins_all

src/pysatl_core/families/exponential_family.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
3+
__author__ = "Vinogradov Ilya"
4+
__copyright__ = "Copyright (c) 2025 PySATL project"
5+
__license__ = "SPDX-License-Identifier: MIT"
6+
7+
8+
from collections.abc import Callable, Iterable
49
from dataclasses import dataclass
510
from typing import TYPE_CHECKING, Any, cast
611

@@ -25,11 +30,10 @@
2530

2631
if TYPE_CHECKING:
2732
from pysatl_core.distributions.support import Support
28-
from pysatl_core.types import Number, NumericArray
33+
from pysatl_core.types import Number, NumberParameter, NumericArray
2934

3035
type ParametrizedFunction = Callable[[Parametrization, Any], Any]
3136
type SupportArg = Callable[[Parametrization], Support | None] | None
32-
type NumberParameter = Number | NumericArray
3337

3438

3539
@dataclass
@@ -77,6 +81,7 @@ def __init__(
7781
distr_type: DistributionType | Callable[[Parametrization], DistributionType],
7882
distr_parametrizations: list[ParametrizationName],
7983
support_by_parametrization: SupportArg = None,
84+
base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None,
8085
):
8186
self._sufficient = sufficient_statistics
8287
self._log_partition = log_partition
@@ -91,8 +96,8 @@ def __init__(
9196
dict[ParametrizationName, ParametrizedFunction] | ParametrizedFunction,
9297
] = {
9398
CharacteristicName.PDF: self.density,
94-
CharacteristicName.MEAN: self._mean,
95-
CharacteristicName.VAR: self._var,
99+
CharacteristicName.MEAN_DEFAULT: self._mean,
100+
CharacteristicName.VAR_DEFAULT: self._var,
96101
}
97102

98103
ParametricFamily.__init__(
@@ -102,6 +107,7 @@ def __init__(
102107
distr_parametrizations=distr_parametrizations,
103108
distr_characteristics=distr_characteristics,
104109
support_by_parametrization=support_by_parametrization,
110+
base_score=base_score,
105111
)
106112
parametrization(family=self, name="theta")(ExponentialFamilyParametrization)
107113

@@ -176,7 +182,7 @@ def conjugate_sufficient_accepts(
176182
normalization_constant=lambda _: 1,
177183
support=self._parameter_space,
178184
sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this
179-
parameter_space=SupportByPredicate(conjugate_sufficient_accepts), # type: ignore[arg-type]
185+
parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), # type: ignore[arg-type]
180186
name=self.name,
181187
distr_type=self._distr_type,
182188
distr_parametrizations=self.parametrization_names,
@@ -185,28 +191,28 @@ def conjugate_sufficient_accepts(
185191

186192
def transform(
187193
self,
188-
transform_function: Callable[[Any], Any],
194+
transform_function: Callable[[NumberParameter], NumberParameter],
189195
) -> ContinuousExponentialClassFamily:
190-
def calculate_jacobian(x: Any) -> Any:
191-
if type(x) is not list:
196+
def calculate_jacobian(x: NumberParameter) -> NumberParameter:
197+
if not isinstance(x, Iterable):
192198
x = np.array([x])
193199

194200
return np.abs(det(jacobian(transform_function, x).df))
195201

196-
def new_support(x: Any) -> bool:
202+
def new_support(x: NumberParameter) -> bool:
197203
return transform_function(x) in self._support
198204

199-
def new_sufficient(x: Any) -> Any:
205+
def new_sufficient(x: NumberParameter) -> NumberParameter:
200206
return self._sufficient(transform_function(x))
201207

202-
def new_normalization(x: Any) -> Any:
208+
def new_normalization(x: NumberParameter) -> NumberParameter:
203209
return self._normalization(x) * calculate_jacobian(x)
204210

205211
return ContinuousExponentialClassFamily(
206212
log_partition=self._log_partition,
207213
sufficient_statistics=new_sufficient,
208214
normalization_constant=new_normalization,
209-
support=SupportByPredicate(new_support),
215+
support=SupportByPredicate(predicate=new_support),
210216
parameter_space=self._parameter_space,
211217
sufficient_statistics_values=self._sufficient_statistics_values,
212218
name=f"Transformed{self._name}",

src/pysatl_core/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ class EuclideanDistributionType(DistributionType):
112112
type BoolArray = NDArray[np.bool_]
113113
"""Type alias for boolean arrays."""
114114

115+
type NumberParameter = Number | NumericArray
116+
"""Type alias for numeric or list parameter"""
117+
115118

116119
class ContinuousSupportShape1D(Enum):
117120
"""
@@ -246,6 +249,12 @@ def shape(self) -> ContinuousSupportShape1D:
246249
class IntervalND:
247250
intervals: list[Interval1D]
248251

252+
@overload
253+
def contains(self, x: Number) -> bool: ...
254+
255+
@overload
256+
def contains(self, x: NumericArray) -> BoolArray: ...
257+
249258
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
250259
if not hasattr(x, "__iter__"):
251260
x = np.array([x])
@@ -307,6 +316,8 @@ class CharacteristicName(StrEnum):
307316
CDF = "cdf"
308317
PPF = "ppf"
309318
PMF = "pmf"
319+
MEAN_DEFAULT = "MEAN_DEFAULT" # defined in class implementation of mean
320+
VAR_DEFAULT = "VAR_DEFAULT" # defined in class implementation of var
310321
LPDF = "lpdf" # unimplemented in graph yet
311322
CF = "cf" # unimplemented in graph yet
312323
SF = "sf" # unimplemented in graph yet

tests/unit/families/test_exponential_family.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1+
__author__ = "Leonid Elkin"
2+
__copyright__ = "Copyright (c) 2025 PySATL project"
3+
__license__ = "SPDX-License-Identifier: MIT"
4+
5+
import itertools
6+
from collections.abc import Iterable
17
from typing import cast
28

39
import numpy as np
410
import pytest
511
import scipy
612
from numpy.testing import assert_allclose
713

8-
from pysatl_core.distributions.support import ContinuousNDSupport, SupportByIntervals
14+
from pysatl_core.distributions.support import ContinuousNDSupport, SupportByPredicate
915
from pysatl_core.families import (
1016
ContinuousExponentialClassFamily,
1117
)
1218
from pysatl_core.families.registry import ParametricFamilyRegister
13-
from pysatl_core.types import Interval1D, UnivariateContinuous
19+
from pysatl_core.types import CharacteristicName, Interval1D, NumberParameter, UnivariateContinuous
1420

1521

1622
def gamma_pdf(alpha: float, beta: float, x: float) -> float:
@@ -19,13 +25,17 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float:
1925

2026
@pytest.fixture(scope="function")
2127
def conjugate_for_exponential() -> ContinuousExponentialClassFamily:
22-
def transform_function(x: list[float] | float) -> list[float] | float:
23-
if type(x) is list:
24-
return [-x[0]]
25-
return -x # type: ignore[operator]
28+
def transform_function(x: NumberParameter) -> NumberParameter:
29+
if isinstance(x, Iterable):
30+
return np.array([-x[0]])
31+
return -x
2632

27-
support_neg = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)]))
28-
support_pos = SupportByIntervals(ContinuousNDSupport(intervals=[Interval1D(0, np.inf)]))
33+
support_neg = SupportByPredicate(
34+
predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])
35+
)
36+
support_pos = SupportByPredicate(
37+
predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])
38+
)
2939
fam = ContinuousExponentialClassFamily(
3040
log_partition=lambda parametrization: np.log(-parametrization),
3141
sufficient_statistics=lambda x: x,
@@ -45,8 +55,10 @@ def transform_function(x: list[float] | float) -> list[float] | float:
4555
)
4656

4757

48-
@pytest.mark.parametrize("theta1", range(2, 5))
49-
@pytest.mark.parametrize("theta2", range(2, 5))
58+
@pytest.mark.parametrize(
59+
("theta1", "theta2"),
60+
itertools.product(range(2, 5), range(2, 5)),
61+
)
5062
def test_exponential_pdf(theta1, theta2, conjugate_for_exponential):
5163
gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential
5264

@@ -61,27 +73,35 @@ def test_exponential_pdf(theta1, theta2, conjugate_for_exponential):
6173
assert_allclose([pdf(xx) for xx in x], [gamma_pdf(alpha, beta, xx) for xx in x], rtol=1e-6)
6274

6375

64-
@pytest.mark.parametrize("theta1", range(2, 5))
65-
@pytest.mark.parametrize("theta2", range(2, 5))
76+
@pytest.mark.parametrize(
77+
("theta1", "theta2"),
78+
itertools.product(range(2, 5), range(2, 5)),
79+
)
6680
def test_exponential_mean(theta1, theta2, conjugate_for_exponential):
6781
gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential
6882

6983
alpha = theta2 + 1
7084
beta = theta1
7185

7286
exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta")
73-
mean = exponential.computation_strategy.query_method("mean", distr=exponential)
87+
mean = exponential.computation_strategy.query_method(
88+
CharacteristicName.MEAN_DEFAULT, distr=exponential
89+
)
7490
assert np.isclose(mean(12), alpha / beta, rtol=1e-6)
7591

7692

77-
@pytest.mark.parametrize("theta1", range(2, 5))
78-
@pytest.mark.parametrize("theta2", range(2, 5))
93+
@pytest.mark.parametrize(
94+
("theta1", "theta2"),
95+
itertools.product(range(2, 5), range(2, 5)),
96+
)
7997
def test_exponential_var(theta1, theta2, conjugate_for_exponential):
8098
gamma_family: ContinuousExponentialClassFamily = conjugate_for_exponential
8199

82100
alpha = theta2 + 1
83101
beta = theta1
84102

85103
exponential = gamma_family(theta=np.array([theta1, theta2]), parametrization_name="theta")
86-
var = exponential.computation_strategy.query_method("var", distr=exponential)
104+
var = exponential.computation_strategy.query_method(
105+
CharacteristicName.VAR_DEFAULT, distr=exponential
106+
)
87107
assert np.isclose(var(12), alpha / beta**2, rtol=1e-6)

0 commit comments

Comments
 (0)