Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions src/pysatl_core/distributions/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from collections.abc import Callable
from dataclasses import dataclass
from math import floor
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable
from typing import (
TYPE_CHECKING,
Protocol,
cast,
runtime_checkable,
)

import numpy as np

from pysatl_core.types import BoolArray, Interval1D, Number, NumericArray
from pysatl_core.types import (
BoolArray,
Interval1D,
IntervalND,
Number,
NumericArray,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand All @@ -34,10 +46,7 @@ class Support(Protocol):
Support defines the set of values where a distribution is defined.
"""

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...
def contains(self, x: NumericArray) -> bool | BoolArray: ...


class ContinuousSupport(Interval1D, Support):
Expand All @@ -49,6 +58,15 @@ class ContinuousSupport(Interval1D, Support):
"""


class ContinuousNDSupport(IntervalND, Support):
"""
Support for continuous distributions represented as an array of intervals.

This class inherits from IntervalND and implements the Support protocol
for continuous distributions defined on a list of intervals [left, right].
"""


@runtime_checkable
class DiscreteSupport(Support, Protocol):
"""
Expand Down Expand Up @@ -128,12 +146,7 @@ def __init__(self, points: Iterable[Number], assume_sorted: bool = False) -> Non

self._points = arr[unique_mask]

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...

def contains(self, x: Number | NumericArray) -> bool | BoolArray:
def contains(self, x: NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the support.

Expand Down Expand Up @@ -162,10 +175,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
return bool(result)
return cast(BoolArray, result)

def __contains__(self, x: object) -> bool:
"""Check if a point is in the support."""
return bool(self.contains(cast(Number, x)))

def iter_points(self) -> Iterator[Number]:
"""Iterate through all points in the support."""
return iter(self._points)
Expand Down Expand Up @@ -252,12 +261,7 @@ def __post_init__(self) -> None:
if self.modulus <= 0:
raise ValueError("modulus must be a positive integer.")

@overload
def contains(self, x: Number) -> bool: ...
@overload
def contains(self, x: NumericArray) -> BoolArray: ...

def contains(self, x: Number | NumericArray) -> bool | BoolArray:
def contains(self, x: NumericArray) -> bool | BoolArray:
"""
Check if point(s) are in the integer lattice support.

Expand All @@ -283,10 +287,6 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
return bool(result)
return cast(BoolArray, result)

def __contains__(self, x: object) -> bool:
"""Check if a point is in the integer lattice support."""
return bool(self.contains(cast(Number, x)))

def iter_points(self) -> Iterator[int]:
"""
Iterate through all points in the integer lattice support.
Expand Down Expand Up @@ -430,10 +430,20 @@ def is_right_bounded(self) -> bool:
__iter__ = iter_points


@dataclass(frozen=True, slots=True)
class PredicateSupport(Support):
predicate: Callable[[NumericArray], bool | BoolArray]

def contains(self, x: NumericArray) -> bool | BoolArray:
return self.predicate(x)


__all__ = [
# Base support protocol
"Support",
"ContinuousSupport",
"ContinuousNDSupport",
"PredicateSupport",
# Discrete support protocol and implementations
"DiscreteSupport",
"ExplicitTableDiscreteSupport",
Expand Down
9 changes: 9 additions & 0 deletions src/pysatl_core/families/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from .builtins import __all__ as _builtins_all
from .configuration import configure_families_register
from .distribution import ParametricFamilyDistribution
from .exponential_family import (
# CanonicalContinuousExponentialClassFamily,
ContinuousExponentialClassFamily,
ExponentialConjugateHyperparameters,
ExponentialFamilyParametrization,
)
from .parametric_family import ParametricFamily
from .parametrizations import (
Parametrization,
Expand All @@ -34,6 +40,9 @@
"configure_families_register",
# builtins
*_builtins_all,
"ContinuousExponentialClassFamily",
"ExponentialFamilyParametrization",
"ExponentialConjugateHyperparameters",
]

del _builtins_all
2 changes: 2 additions & 0 deletions src/pysatl_core/families/builtins/continuous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

from pysatl_core.families.builtins.continuous.exponential import configure_exponential_family
from pysatl_core.families.builtins.continuous.normal import configure_normal_family
from pysatl_core.families.builtins.continuous.pareto import configure_pareto_family
from pysatl_core.families.builtins.continuous.uniform import configure_uniform_family

__all__ = [
"configure_normal_family",
"configure_uniform_family",
"configure_exponential_family",
"configure_pareto_family",
]
109 changes: 109 additions & 0 deletions src/pysatl_core/families/builtins/continuous/pareto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Pareto distribution family implementation.

This module provides a Pareto Type I distribution with fixed known minimum
``x_m = 1`` as a continuous exponential family.
"""

from __future__ import annotations

__author__ = "Vinogradov Ilya"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


import numpy as np

from pysatl_core.distributions.support import ContinuousSupport
from pysatl_core.families.exponential_family import (
ContinuousExponentialClassFamily,
ExponentialFamilyParametrization,
)
from pysatl_core.families.parametrizations import Parametrization, constraint, parametrization
from pysatl_core.families.registry import ParametricFamilyRegister
from pysatl_core.types import FamilyName, NumericArray, UnivariateContinuous

_MINIMUM = 1.0


def configure_pareto_family() -> None:
"""
Configure and register the Pareto distribution family with fixed minimum 1.
"""

if ParametricFamilyRegister.contains(FamilyName.PARETO):
return

PARETO_DOC = """
Pareto Type I distribution with known minimum x_m = 1.

The distribution is parameterized by the shape parameter ``alpha > 0``.

Probability density function:
f(x) = alpha / x^(alpha + 1), x >= 1

In natural form this is written as
f(x | theta) = exp(theta log(x) + B(theta)), theta < -1
where ``theta = -(alpha + 1)``.

With the more common convention
f(x | theta) = exp(theta log(x) - A(theta)),
we have
A(theta) = -log(alpha * x_m^alpha)
and therefore
B(theta) = -A(theta) = log(alpha * x_m^alpha).
"""

def _theta_to_alpha(theta: NumericArray) -> float:
theta_arr = np.atleast_1d(np.asarray(theta, dtype=float))
return float(-theta_arr[0] - 1.0)

def log_partition(theta: NumericArray) -> NumericArray:
alpha = _theta_to_alpha(theta)
return np.array([np.log(alpha) + alpha * np.log(_MINIMUM)])

def sufficient_statistics(x: NumericArray) -> NumericArray:
x_arr = np.asarray(x, dtype=float)
return np.array([np.log(x_arr).item()])

def normalization_constant(_: NumericArray) -> float:
return 1.0

def _support(_: Parametrization) -> ContinuousSupport:
return ContinuousSupport(left=_MINIMUM)

pareto_family = ContinuousExponentialClassFamily(
name=FamilyName.PARETO,
log_partition=log_partition,
sufficient_statistics=sufficient_statistics,
normalization_constant=normalization_constant,
support=ContinuousSupport(left=_MINIMUM),
parameter_space=ContinuousSupport(right=-1.0, right_closed=False),
sufficient_statistics_values=ContinuousSupport(left=np.log(_MINIMUM)),
distr_type=UnivariateContinuous,
distr_parametrizations=["theta", "shape"],
support_by_parametrization=_support,
)
pareto_family.__doc__ = PARETO_DOC

@parametrization(family=pareto_family, name="shape")
class _Shape(Parametrization):
"""
Shape parametrization of Pareto distribution.

Parameters
----------
alpha : float
Shape parameter of the distribution.
"""

alpha: float

@constraint(description="alpha > 0")
def check_alpha_positive(self) -> bool:
return self.alpha > 0

def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
return ExponentialFamilyParametrization(theta=np.array([-(self.alpha + 1.0)]))

ParametricFamilyRegister.register(pareto_family)
2 changes: 2 additions & 0 deletions src/pysatl_core/families/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pysatl_core.families.builtins import (
configure_exponential_family,
configure_normal_family,
configure_pareto_family,
configure_uniform_family,
)
from pysatl_core.families.registry import ParametricFamilyRegister
Expand All @@ -46,6 +47,7 @@ def configure_families_register() -> ParametricFamilyRegister:
The global registry of parametric families.
"""
configure_exponential_family()
configure_pareto_family()
configure_uniform_family()
configure_normal_family()
return ParametricFamilyRegister()
Expand Down
Loading
Loading