Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/pysatl_core/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable
from typing import TYPE_CHECKING, Protocol, Self, cast, runtime_checkable

from pysatl_core.types import NumericArray

_KEEP: object = object()


if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
Expand Down Expand Up @@ -71,6 +74,39 @@ def computation_strategy(self) -> ComputationStrategy: ...
@property
def support(self) -> Support | None: ...

def _clone_with_strategies(
self,
*,
sampling_strategy: SamplingStrategy | None | object = _KEEP,
computation_strategy: ComputationStrategy | None | object = _KEEP,
) -> Distribution: ...

def with_sampling_strategy(self, sampling_strategy: SamplingStrategy | None) -> Self:
"""Return a copy of this distribution with an updated sampling strategy."""
return cast(Self, self._clone_with_strategies(sampling_strategy=sampling_strategy))

def with_computation_strategy(self, computation_strategy: ComputationStrategy | None) -> Self:
"""Return a copy of this distribution with an updated computation strategy."""
return cast(
Self,
self._clone_with_strategies(computation_strategy=computation_strategy),
)

def with_strategies(
self,
*,
sampling_strategy: SamplingStrategy | None | object = _KEEP,
computation_strategy: ComputationStrategy | None | object = _KEEP,
) -> Self:
"""Return a copy of this distribution with updated strategies."""
return cast(
Self,
self._clone_with_strategies(
sampling_strategy=sampling_strategy,
computation_strategy=computation_strategy,
),
)

def query_method(
self, characteristic_name: GenericCharacteristicName, **options: Any
) -> Method[Any, Any]:
Expand Down
94 changes: 75 additions & 19 deletions src/pysatl_core/families/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"


from dataclasses import dataclass
from typing import TYPE_CHECKING, cast

from pysatl_core.distributions.distribution import Distribution
from pysatl_core.distributions.distribution import _KEEP, Distribution
from pysatl_core.distributions.strategies import (
DefaultComputationStrategy,
DefaultSamplingUnivariateStrategy,
)
from pysatl_core.families.registry import ParametricFamilyRegister
from pysatl_core.types import NumericArray

Expand All @@ -41,7 +43,6 @@
)


@dataclass(slots=True)
class ParametricFamilyDistribution(Distribution):
"""
A specific distribution instance from a parametric family.
Expand All @@ -53,18 +54,46 @@ class ParametricFamilyDistribution(Distribution):
----------
family_name : str
Name of the distribution family.
_distribution_type : DistributionType
distribution_type : DistributionType
Type of this distribution.
_parametrization : Parametrization
parametrization : Parametrization
Parameter values for this distribution.
_support : Support or None
support : Support or None
Support of this distribution.
sampling_strategy : SamplingStrategy
Strategy for generating random samples.
Such an object is unique for each distribution.
computation_strategy : ComputationStrategy
Strategy for computing characteristics and conversions.
Such an object is unique for each distribution.
"""

family_name: str
_distribution_type: DistributionType
_parametrization: Parametrization
_support: Support | None
def __init__(
self,
family_name: str,
distribution_type: DistributionType,
parametrization: Parametrization,
support: Support | None,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy | None = None,
):
self._distribution_type = distribution_type
self._family_name = family_name
self._parametrization = parametrization
self._support = support

self._computation_strategy = computation_strategy or DefaultComputationStrategy()
self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy()

self._analytical_cache_key: tuple[int, GenericCharacteristicName] | None = None
self._analytical_cache_val: (
Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] | None
) = None

@property
def family_name(self) -> str:
"Get the name of the family this distribution belongs to."
return self._family_name

@property
def distribution_type(self) -> DistributionType:
Expand Down Expand Up @@ -142,25 +171,52 @@ def analytical_computations(
parametrization object or name changes.
"""
key = (id(self.parametrization), self.parametrization_name)
cache_key = getattr(self, "_analytical_cache_key", None)
cache_val = getattr(self, "_analytical_cache_val", None)

if cache_key != key or cache_val is None:
cache_val = self.family._build_analytical_computations(self.parametrization)
if self._analytical_cache_key != key or self._analytical_cache_val is None:
self._analytical_cache_val = self.family.build_analytical_computations(
self.parametrization
)
self._analytical_cache_key = key
self._analytical_cache_val = cache_val

return cache_val
return self._analytical_cache_val

@property
def sampling_strategy(self) -> SamplingStrategy:
"""Get the sampling strategy for this distribution."""
return self.family.sampling_strategy
return self._sampling_strategy

@property
def computation_strategy(self) -> ComputationStrategy:
"""Get the computation strategy for this distribution."""
return self.family.computation_strategy
return self._computation_strategy

def _clone_with_strategies(
self,
*,
sampling_strategy: SamplingStrategy | None | object = _KEEP,
computation_strategy: ComputationStrategy | None | object = _KEEP,
) -> ParametricFamilyDistribution:
"""Return a copy of this distribution with updated strategies."""
new_sampling: SamplingStrategy | None = (
self._sampling_strategy
if sampling_strategy is _KEEP
else cast(SamplingStrategy | None, sampling_strategy)
)

new_computation: ComputationStrategy | None = (
self._computation_strategy
if computation_strategy is _KEEP
else cast(ComputationStrategy | None, computation_strategy)
)

return ParametricFamilyDistribution(
family_name=self._family_name,
distribution_type=self._distribution_type,
parametrization=self._parametrization,
support=self._support,
sampling_strategy=new_sampling,
computation_strategy=new_computation,
)

@property
def support(self) -> Support | None:
Expand Down
104 changes: 55 additions & 49 deletions src/pysatl_core/families/parametric_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
from typing import TYPE_CHECKING, Any, cast, dataclass_transform

from pysatl_core.distributions.computation import AnalyticalComputation
from pysatl_core.distributions.strategies import (
DefaultComputationStrategy,
DefaultSamplingUnivariateStrategy,
)
from pysatl_core.families.distribution import ParametricFamilyDistribution
from pysatl_core.types import ComputationFunc, DistributionType

Expand Down Expand Up @@ -83,10 +79,6 @@ class ParametricFamily:
- pointwise characteristics (e.g., pdf, cdf, ppf): provider(params, x, **kwargs) -> Any

If a single callable is provided, it is treated as defined for the base parametrization.
sampling_strategy : SamplingStrategy, optional
Strategy for sampling from distributions.
computation_strategy : ComputationStrategy, optional
Strategy for computing distribution characteristics.
support_by_parametrization : Callable or None, optional
Function that returns support for given parameters.
"""
Expand All @@ -97,37 +89,27 @@ def __init__(
distr_type: DistributionType | Callable[[Parametrization], DistributionType],
distr_parametrizations: list[ParametrizationName],
distr_characteristics: CharacteristicsMap,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy | None = None,
support_by_parametrization: SupportArg = None,
):
if not distr_parametrizations:
raise ValueError(
"distr_parametrizations must be non-empty (base parametrization is required)."
)

self._name = name
# Ordered names; the first one is the base parametrization name
self.parametrization_names = distr_parametrizations
self.base_parametrization_name = self.parametrization_names[0]
self._distr_type: Callable[[Parametrization], DistributionType] = (
(lambda params: distr_type) if isinstance(distr_type, DistributionType) else distr_type
)

self.computation_strategy = (
DefaultComputationStrategy() if computation_strategy is None else computation_strategy
)

if support_by_parametrization is None:
self._support_resolver: SupportResolver
self._support_resolver = lambda _params: None
else:
self._support_resolver = support_by_parametrization

# Ordered names; the first one is the base parametrization name
self.parametrization_names: list[ParametrizationName] = distr_parametrizations
self.base_parametrization_name: ParametrizationName = self.parametrization_names[0]
self._support_resolver: SupportResolver = support_by_parametrization or (lambda _p: None)

# Runtime registry of parametrization classes
self._parametrizations: dict[ParametrizationName, type[Parametrization]] = {}

self.sampling_strategy = (
DefaultSamplingUnivariateStrategy() if sampling_strategy is None else sampling_strategy
)

def _process_char_val(
def _normalize_characteristic(
value: Mapping[ParametrizationName, CharacteristicFunction[Any, Any]]
| CharacteristicFunction[Any, Any],
) -> dict[ParametrizationName, CharacteristicFunction[Any, Any]]:
Expand All @@ -139,21 +121,33 @@ def _process_char_val(

self.distr_characteristics: dict[
GenericCharacteristicName, dict[ParametrizationName, CharacteristicFunction[Any, Any]]
] = {key: _process_char_val(val) for key, val in distr_characteristics.items()}

# Precompute analytical plan
] = {k: _normalize_characteristic(v) for k, v in distr_characteristics.items()}

# Validate characteristic providers
valid_names = set(self.parametrization_names)
for char_name, forms in self.distr_characteristics.items():
unknown = set(forms) - valid_names
if unknown:
raise ValueError(
f"Characteristic '{char_name}' has providers for unknown parametrizations: "
f"{sorted(unknown)}."
)
if self.base_parametrization_name not in forms and len(forms) == 0:
raise ValueError(f"Characteristic '{char_name}' has no providers.")

# Precompute analytical plan: for each parametrization pick provider (self or base)
self._analytical_plan: dict[
ParametrizationName, dict[GenericCharacteristicName, ParametrizationName]
] = {}
base_name = self.base_parametrization_name
base = self.base_parametrization_name
for pname in self.parametrization_names:
plan_for_p: dict[GenericCharacteristicName, ParametrizationName] = {}
plan: dict[GenericCharacteristicName, ParametrizationName] = {}
for characteristic, forms in self.distr_characteristics.items():
if pname in forms:
plan_for_p[characteristic] = pname
elif base_name in forms:
plan_for_p[characteristic] = base_name
self._analytical_plan[pname] = plan_for_p
plan[characteristic] = pname
elif base in forms:
plan[characteristic] = base
self._analytical_plan[pname] = plan

@property
def name(self) -> str:
Expand Down Expand Up @@ -184,7 +178,7 @@ def base(self) -> type[Parametrization]:

@property
def support_resolver(self) -> SupportResolver:
"""Get the support resolver function."""
"""Support resolver callable."""
return self._support_resolver

def register_parametrization(
Expand Down Expand Up @@ -269,7 +263,7 @@ def _bind_parametrization[In, Out](
else func,
)

def _build_analytical_computations(
def build_analytical_computations(
self, parameters: Parametrization
) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]:
"""
Expand All @@ -285,8 +279,7 @@ def _build_analytical_computations(
if provider_name == parameters.name:
params_obj = parameters
else:
if base_params is None:
base_params = self.to_base(parameters)
base_params = base_params or self.to_base(parameters)
params_obj = base_params

func_factory = self.distr_characteristics[characteristic][provider_name]
Expand All @@ -299,16 +292,23 @@ def _build_analytical_computations(

def distribution(
self,
parametrization_name: str | None = None,
parametrization_name: ParametrizationName | None = None,
sampling_strategy: SamplingStrategy | None = None,
computation_strategy: ComputationStrategy | None = None,
**parameters_values: Any,
) -> ParametricFamilyDistribution:
"""
Create a distribution instance with given parameters.

Parameters
----------
parametrization_name : str, optional
parametrization_name : ParametrizationName | None, optional
Name of parametrization to use (defaults to base).
sampling_strategy : SamplingStrategy
Strategy for generating random samples. Such an object is unique for each distribution.
computation_strategy : ComputationStrategy
Strategy for computing characteristics and conversions.
Such an object is unique for each distribution.
**parameters_values
Parameter values for the distribution.

Expand All @@ -324,22 +324,28 @@ def distribution(
ValueError
If parameters don't satisfy constraints.
"""
if parametrization_name is None:
parametrization_class = self.base
else:
parametrization_class = self._parametrizations[parametrization_name]
parametrization_class = (
self.base
if parametrization_name is None
else self._parametrizations[parametrization_name]
)

parameters = parametrization_class(**parameters_values)
parameters.validate()
base_parameters = self.to_base(parameters)
distribution_type = self._distr_type(base_parameters)
return ParametricFamilyDistribution(
self.name, distribution_type, parameters, self.support_resolver(parameters)
family_name=self.name,
distribution_type=distribution_type,
parametrization=parameters,
support=self.support_resolver(parameters),
sampling_strategy=sampling_strategy,
computation_strategy=computation_strategy,
)

@dataclass_transform()
def parametrization(
self, *, name: str
self, *, name: ParametrizationName
) -> Callable[[type[Parametrization]], type[Parametrization]]:
"""
Create a class decorator that registers a parametrization.
Expand Down
Loading