1+ __author__ = "Leonid Elkin"
2+ __copyright__ = "Copyright (c) 2025 PySATL project"
3+ __license__ = "SPDX-License-Identifier: MIT"
4+
5+ import itertools
6+ from collections .abc import Iterable
17from typing import cast
28
39import numpy as np
410import pytest
511import scipy
612from numpy .testing import assert_allclose
713
8- from pysatl_core .distributions .support import ContinuousNDSupport , SupportByIntervals
14+ from pysatl_core .distributions .support import ContinuousNDSupport , SupportByPredicate
915from pysatl_core .families import (
1016 ContinuousExponentialClassFamily ,
1117)
1218from pysatl_core .families .registry import ParametricFamilyRegister
13- from pysatl_core .types import Interval1D , UnivariateContinuous
19+ from pysatl_core .types import CharacteristicName , Interval1D , NumberParameter , UnivariateContinuous
1420
1521
1622def gamma_pdf (alpha : float , beta : float , x : float ) -> float :
@@ -19,13 +25,17 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float:
1925
2026@pytest .fixture (scope = "function" )
2127def conjugate_for_exponential () -> ContinuousExponentialClassFamily :
22- def transform_function (x : list [ float ] | float ) -> list [ float ] | float :
23- if type ( x ) is list :
24- return [- x [0 ]]
25- return - x # type: ignore[operator]
28+ def transform_function (x : NumberParameter ) -> NumberParameter :
29+ if isinstance ( x , Iterable ) :
30+ return np . array ( [- x [0 ]])
31+ return - x
2632
27- support_neg = SupportByIntervals (ContinuousNDSupport (intervals = [Interval1D (- np .inf , 0 )]))
28- support_pos = SupportByIntervals (ContinuousNDSupport (intervals = [Interval1D (0 , np .inf )]))
33+ support_neg = SupportByPredicate (
34+ predicate = lambda x : x in ContinuousNDSupport (intervals = [Interval1D (- np .inf , 0 )])
35+ )
36+ support_pos = SupportByPredicate (
37+ predicate = lambda x : x in ContinuousNDSupport (intervals = [Interval1D (0 , np .inf )])
38+ )
2939 fam = ContinuousExponentialClassFamily (
3040 log_partition = lambda parametrization : np .log (- parametrization ),
3141 sufficient_statistics = lambda x : x ,
@@ -45,8 +55,10 @@ def transform_function(x: list[float] | float) -> list[float] | float:
4555 )
4656
4757
48- @pytest .mark .parametrize ("theta1" , range (2 , 5 ))
49- @pytest .mark .parametrize ("theta2" , range (2 , 5 ))
58+ @pytest .mark .parametrize (
59+ ("theta1" , "theta2" ),
60+ itertools .product (range (2 , 5 ), range (2 , 5 )),
61+ )
5062def test_exponential_pdf (theta1 , theta2 , conjugate_for_exponential ):
5163 gamma_family : ContinuousExponentialClassFamily = conjugate_for_exponential
5264
@@ -61,27 +73,35 @@ def test_exponential_pdf(theta1, theta2, conjugate_for_exponential):
6173 assert_allclose ([pdf (xx ) for xx in x ], [gamma_pdf (alpha , beta , xx ) for xx in x ], rtol = 1e-6 )
6274
6375
64- @pytest .mark .parametrize ("theta1" , range (2 , 5 ))
65- @pytest .mark .parametrize ("theta2" , range (2 , 5 ))
76+ @pytest .mark .parametrize (
77+ ("theta1" , "theta2" ),
78+ itertools .product (range (2 , 5 ), range (2 , 5 )),
79+ )
6680def test_exponential_mean (theta1 , theta2 , conjugate_for_exponential ):
6781 gamma_family : ContinuousExponentialClassFamily = conjugate_for_exponential
6882
6983 alpha = theta2 + 1
7084 beta = theta1
7185
7286 exponential = gamma_family (theta = np .array ([theta1 , theta2 ]), parametrization_name = "theta" )
73- mean = exponential .computation_strategy .query_method ("mean" , distr = exponential )
87+ mean = exponential .computation_strategy .query_method (
88+ CharacteristicName .MEAN_DEFAULT , distr = exponential
89+ )
7490 assert np .isclose (mean (12 ), alpha / beta , rtol = 1e-6 )
7591
7692
77- @pytest .mark .parametrize ("theta1" , range (2 , 5 ))
78- @pytest .mark .parametrize ("theta2" , range (2 , 5 ))
93+ @pytest .mark .parametrize (
94+ ("theta1" , "theta2" ),
95+ itertools .product (range (2 , 5 ), range (2 , 5 )),
96+ )
7997def test_exponential_var (theta1 , theta2 , conjugate_for_exponential ):
8098 gamma_family : ContinuousExponentialClassFamily = conjugate_for_exponential
8199
82100 alpha = theta2 + 1
83101 beta = theta1
84102
85103 exponential = gamma_family (theta = np .array ([theta1 , theta2 ]), parametrization_name = "theta" )
86- var = exponential .computation_strategy .query_method ("var" , distr = exponential )
104+ var = exponential .computation_strategy .query_method (
105+ CharacteristicName .VAR_DEFAULT , distr = exponential
106+ )
87107 assert np .isclose (var (12 ), alpha / beta ** 2 , rtol = 1e-6 )
0 commit comments