Skip to content

Commit 83f5584

Browse files
committed
feat(exponential): remove NumberParameter logic
1 parent 1d58077 commit 83f5584

4 files changed

Lines changed: 58 additions & 61 deletions

File tree

src/pysatl_core/distributions/support.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
Interval1D,
3333
IntervalND,
3434
Number,
35-
NumberParameter,
3635
NumericArray,
3736
)
3837

@@ -63,7 +62,8 @@ class ContinuousSupport(Interval1D, Support):
6362
"""
6463

6564

66-
class ContinuousNDSupport(IntervalND, Support):
65+
# Support want to have Number as a parameter of contains, but we decided that we should avoid this
66+
class ContinuousNDSupport(IntervalND, Support): # type: ignore[misc]
6767
"""
6868
Support for continuous distributions represented as an array of intervals.
6969
@@ -455,18 +455,18 @@ def is_right_bounded(self) -> bool:
455455

456456
@dataclass(slots=True)
457457
class SupportByPredicate(Support):
458-
predicate: Callable[[NumberParameter], bool]
458+
predicate: Callable[[NumericArray], bool]
459459

460460
@overload
461461
def contains(self, x: Number) -> bool: ...
462462
@overload
463463
def contains(self, x: NumericArray) -> BoolArray: ...
464464

465-
def contains(self, x: NumberParameter) -> bool | BoolArray:
465+
def contains(self, x: NumericArray) -> bool | BoolArray: # type: ignore[misc]
466466
return self.predicate(x)
467467

468468
def __contains__(self, item: object) -> bool | BoolArray:
469-
return self.contains(cast(NumberParameter, item))
469+
return self.contains(cast(NumericArray, item))
470470

471471

472472
__all__ = [

src/pysatl_core/families/exponential_family.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from pysatl_core.distributions.support import Support
39-
from pysatl_core.types import Number, NumberParameter, NumericArray
39+
from pysatl_core.types import Number, NumericArray
4040

4141
type ParametrizedFunction = Callable[[Parametrization, Any], Any]
4242
type SupportArg = Callable[[Parametrization], Support | None] | None
@@ -52,10 +52,10 @@ class ExponentialFamilyParametrization(Parametrization):
5252
f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
5353
5454
Attributes:
55-
theta (NumberParameter): Natural parameter vector (can be a scalar or array)
55+
theta (NumericArray): Natural parameter vector (can be a scalar or array)
5656
"""
5757

58-
theta: NumberParameter
58+
theta: NumericArray
5959

6060
def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
6161
"""Return the base parametrization (identity transform for canonical form)."""
@@ -120,9 +120,9 @@ class ContinuousExponentialClassFamily(ParametricFamily):
120120
def __init__(
121121
self,
122122
*,
123-
log_partition: Callable[[NumberParameter], NumberParameter],
124-
sufficient_statistics: Callable[[NumberParameter], NumberParameter],
125-
normalization_constant: Callable[[NumberParameter], NumberParameter],
123+
log_partition: Callable[[NumericArray], NumericArray],
124+
sufficient_statistics: Callable[[NumericArray], NumericArray],
125+
normalization_constant: Callable[[NumericArray], Number],
126126
support: SupportByPredicate,
127127
parameter_space: SupportByPredicate,
128128
sufficient_statistics_values: SupportByPredicate,
@@ -185,13 +185,13 @@ def log_density(self) -> ParametrizedFunction:
185185
and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186186
187187
Returns:
188-
Callable[[Parametrization, NumberParameter], Number]
188+
Callable[[Parametrization, NumericArray], Number]
189189
"""
190190

191-
def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number:
191+
def log_density_func(parametrization: Parametrization, x: NumericArray) -> Number:
192192
parametrization = cast(ExponentialFamilyParametrization, parametrization)
193193
parametrization = parametrization.transform_to_base_parametrization()
194-
if x not in self._support:
194+
if np.array([x]) not in self._support:
195195
return -np.inf
196196

197197
theta = parametrization.theta
@@ -211,7 +211,7 @@ def density(self) -> ParametrizedFunction:
211211
Density function (exponentiated log‑density).
212212
213213
Returns:
214-
Callable[[Parametrization, NumberParameter], Number]
214+
Callable[[Parametrization, NumericArray], Number]
215215
"""
216216
return lambda parametrization, x: np.exp(self.log_density(parametrization, x))
217217

@@ -230,8 +230,8 @@ def conjugate_prior_family(self) -> ContinuousExponentialClassFamily:
230230
"""
231231

232232
def conjugate_sufficient(
233-
theta: NumberParameter,
234-
) -> NumberParameter:
233+
theta: NumericArray,
234+
) -> NumericArray:
235235
if not hasattr(theta, "__len__"):
236236
theta = np.array([theta])
237237

@@ -240,9 +240,9 @@ def conjugate_sufficient(
240240
return np.append(theta, self._log_partition(theta))
241241

242242
def conjugate_log_partition(
243-
parametrization: NumberParameter,
244-
) -> NumberParameter:
245-
def pdf(theta: NumberParameter) -> NumberParameter:
243+
parametrization: NumericArray,
244+
) -> NumericArray:
245+
def pdf(theta: NumericArray) -> Number:
246246
if not hasattr(theta, "__len__"):
247247
theta = np.array([theta])
248248
return cast(
@@ -259,23 +259,25 @@ def pdf(theta: NumberParameter) -> NumberParameter:
259259
lambda x: pdf(x) if x in self._parameter_space else 0, # type: ignore[arg-type]
260260
[(float("-inf"), float("+inf"))],
261261
)[0]
262-
return cast(np.float64, -np.log(all_value))
262+
return np.array([cast(np.float64, -np.log(all_value))])
263263

264264
def conjugate_sufficient_accepts(
265265
theta: NumericArray,
266266
) -> bool:
267267
xi = theta[:-1]
268268
nu = theta[-1]
269269

270-
return xi in self._sufficient_statistics_values and nu in ContinuousSupport(0, np.inf)
270+
return xi in self._sufficient_statistics_values and np.array([nu]) in ContinuousSupport(
271+
0, np.inf
272+
)
271273

272274
return ContinuousExponentialClassFamily(
273275
log_partition=conjugate_log_partition,
274276
sufficient_statistics=conjugate_sufficient,
275277
normalization_constant=lambda _: 1,
276278
support=self._parameter_space,
277-
sufficient_statistics_values=self._parameter_space, # TODO: write convex hull for this
278-
parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts), # type: ignore[arg-type]
279+
sufficient_statistics_values=self._parameter_space,
280+
parameter_space=SupportByPredicate(predicate=conjugate_sufficient_accepts),
279281
name=self.name,
280282
distr_type=self._distr_type,
281283
distr_parametrizations=self.parametrization_names,
@@ -284,7 +286,7 @@ def conjugate_sufficient_accepts(
284286

285287
def transform(
286288
self,
287-
transform_function: Callable[[NumberParameter], NumberParameter],
289+
transform_function: Callable[[NumericArray], NumericArray],
288290
) -> ContinuousExponentialClassFamily:
289291
"""
290292
Transform the random variable by a monotonic, differentiable function.
@@ -301,20 +303,20 @@ def transform(
301303
ContinuousExponentialClassFamily: A new family for the transformed variable.
302304
"""
303305

304-
def calculate_jacobian(x: NumberParameter) -> NumberParameter:
306+
def calculate_jacobian(x: NumericArray) -> NumericArray:
305307
if not isinstance(x, Iterable):
306308
x = np.array([x])
307309

308310
return np.abs(det(jacobian(transform_function, x).df))
309311

310-
def new_support(x: NumberParameter) -> bool:
312+
def new_support(x: NumericArray) -> bool:
311313
return transform_function(x) in self._support
312314

313-
def new_sufficient(x: NumberParameter) -> NumberParameter:
315+
def new_sufficient(x: NumericArray) -> NumericArray:
314316
return self._sufficient(transform_function(x))
315317

316-
def new_normalization(x: NumberParameter) -> NumberParameter:
317-
return self._normalization(x) * calculate_jacobian(x)
318+
def new_normalization(x: NumericArray) -> Number:
319+
return cast(np.float64, self._normalization(x) * calculate_jacobian(x))
318320

319321
return ContinuousExponentialClassFamily(
320322
log_partition=self._log_partition,
@@ -339,7 +341,11 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
339341
if hasattr(x, "__len__"):
340342
dimension_size = len(x)
341343
return nquad(
342-
lambda x: np.dot(x, self.density(parametrization, x)) if x in self._support else 0,
344+
lambda x: (
345+
np.dot(x, self.density(parametrization, x))
346+
if np.array([x]) in self._support
347+
else 0
348+
),
343349
[(float("-inf"), float("inf"))] * dimension_size,
344350
)[0]
345351

@@ -355,7 +361,9 @@ def func(parametrization: Parametrization, x: Any) -> Any:
355361
if hasattr(x, "__len__"):
356362
dimension_size = len(x)
357363
return nquad(
358-
lambda x: x**2 * self.density(parametrization, x) if x in self._support else 0,
364+
lambda x: (
365+
x**2 * self.density(parametrization, x) if np.array([x]) in self._support else 0
366+
),
359367
[(float("-inf"), float("inf"))] * dimension_size,
360368
)[0]
361369

@@ -394,7 +402,7 @@ def posterior_hyperparameters(
394402
posterior_effective_sample_size = parametrizaiton.effective_sample_size
395403
if hasattr(sample, "__iter__") and not isinstance(sample, str):
396404
posterior_effective_suff_stat_value += np.sum(
397-
[self._sufficient(x) for x in sample], # type: ignore[arg-type]
405+
[self._sufficient(x) for x in sample],
398406
axis=0,
399407
)
400408
posterior_effective_sample_size += len(sample)
@@ -424,13 +432,13 @@ def posterior_predictive(self) -> ParametricFamily:
424432

425433
def conjugate_log_partition(
426434
parametrization: ExponentialConjugateHyperparameters,
427-
) -> NumberParameter:
435+
) -> NumericArray:
428436
conjugate_value = self.conjugate_prior_family._log_partition(
429437
parametrization.transform_to_base_parametrization().theta
430438
)
431439
return np.exp(conjugate_value)
432440

433-
def posterior_density(parametrization: Parametrization, x: NumberParameter) -> Number:
441+
def posterior_density(parametrization: Parametrization, x: NumericArray) -> Number:
434442
parametrization = cast(ExponentialConjugateHyperparameters, parametrization)
435443
return cast(
436444
np.float32,

src/pysatl_core/types.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,6 @@ class EuclideanDistributionType(DistributionType):
121121
type BoolArray = NDArray[np.bool_]
122122
"""Type alias for boolean arrays."""
123123

124-
type NumberParameter = Number | NumericArray
125-
"""Type alias for numeric or list parameter"""
126-
127124

128125
class ContinuousSupportShape1D(Enum):
129126
"""
@@ -258,27 +255,22 @@ def shape(self) -> ContinuousSupportShape1D:
258255
class IntervalND:
259256
intervals: list[Interval1D]
260257

261-
@overload
262-
def contains(self, x: Number) -> bool: ...
258+
def contains(self, x: NumericArray) -> bool | BoolArray:
259+
def contains_for_point(point: NumericArray) -> bool:
260+
assert len(point) == len(self.intervals)
261+
return all(
262+
x_coordinate in interval
263+
for interval, x_coordinate in zip(self.intervals, point, strict=True)
264+
)
263265

264-
@overload
265-
def contains(self, x: NumericArray) -> BoolArray: ...
266+
if len(x.shape) == 1:
267+
return contains_for_point(x)
266268

267-
def contains(self, x: Number | NumericArray) -> bool | BoolArray:
268-
if not hasattr(x, "__iter__"):
269-
x = np.array([x])
270-
271-
x = np.array(x)
272-
assert len(x) == len(self.intervals)
273-
274-
return all(
275-
x_coordinate in interval
276-
for interval, x_coordinate in zip(self.intervals, x, strict=True)
277-
)
269+
return np.array([contains_for_point(point) for point in x])
278270

279271
def __contains__(self, x: object) -> bool:
280272
"""Check if a single point is in the interval."""
281-
return bool(self.contains(cast(Number, x)))
273+
return bool(self.contains(cast(NumericArray, x)))
282274

283275

284276
type GenericCharacteristicName = str

tests/unit/families/test_exponential_family.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
__license__ = "SPDX-License-Identifier: MIT"
44

55
import itertools
6-
from collections.abc import Iterable
76
from typing import cast
87

98
import numpy as np
@@ -16,7 +15,7 @@
1615
ContinuousExponentialClassFamily,
1716
)
1817
from pysatl_core.families.registry import ParametricFamilyRegister
19-
from pysatl_core.types import CharacteristicName, Interval1D, NumberParameter, UnivariateContinuous
18+
from pysatl_core.types import CharacteristicName, Interval1D, NumericArray, UnivariateContinuous
2019

2120

2221
def gamma_pdf(alpha: float, beta: float, x: float) -> float:
@@ -25,16 +24,14 @@ def gamma_pdf(alpha: float, beta: float, x: float) -> float:
2524

2625
@pytest.fixture(scope="function")
2726
def conjugate_for_exponential() -> ContinuousExponentialClassFamily:
28-
def transform_function(x: NumberParameter) -> NumberParameter:
29-
if isinstance(x, Iterable):
30-
return np.array([-x[0]])
27+
def transform_function(x: NumericArray) -> NumericArray:
3128
return -x
3229

3330
support_neg = SupportByPredicate(
34-
predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])
31+
predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(-np.inf, 0)])
3532
)
3633
support_pos = SupportByPredicate(
37-
predicate=lambda x: x in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])
34+
predicate=lambda x: np.array([x]) in ContinuousNDSupport(intervals=[Interval1D(0, np.inf)])
3835
)
3936
fam = ContinuousExponentialClassFamily(
4037
log_partition=lambda parametrization: np.log(-parametrization),

0 commit comments

Comments
 (0)