Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 105 additions & 47 deletions src/pysatl_core/distributions/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,25 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable

from pysatl_core.types import ComputationFunc

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from typing import Any

from mypy_extensions import KwArg

from pysatl_core.distributions.distribution import Distribution
from pysatl_core.types import GenericCharacteristicName

type Fitter[In, Out] = Callable[[Distribution, KwArg(Any)], FittedComputationMethod[In, Out]]
type Evaluator[In, Out] = (
Callable[[Distribution, KwArg(Any)], Out] | Callable[[Distribution, In, KwArg(Any)], Out]
)


@runtime_checkable
class Computation[In, Out](Protocol):
Expand All @@ -37,42 +44,14 @@ class Computation[In, Out](Protocol):

@property
def target(self) -> GenericCharacteristicName: ...
def __call__(self, data: In, **options: Any) -> Out: ...


@runtime_checkable
class FittedComputationMethodProtocol[In, Out](Protocol):
"""
Protocol for fitted computation methods ready for evaluation.

Attributes
----------
target : str
Destination characteristic name.
sources : Sequence[str]
Source characteristic names this method depends on.
"""

@property
def target(self) -> GenericCharacteristicName: ...
@property
def sources(self) -> Sequence[GenericCharacteristicName]: ...
def __call__(self, data: In, **options: Any) -> Out: ...

@overload
def __call__(self, **kwargs: Any) -> Out: ...

@runtime_checkable
class ComputationMethodProtocol[In, Out](Protocol):
"""
Protocol for computation method factories that can be fitted to distributions.
"""
@overload
def __call__(self, x: In, **kwargs: Any) -> Out: ...

@property
def target(self) -> GenericCharacteristicName: ...
@property
def sources(self) -> Sequence[GenericCharacteristicName]: ...
def fit(
self, distribution: Distribution, **options: Any
) -> FittedComputationMethodProtocol[In, Out]: ...
def __call__(self, *args: Any, **kwargs: Any) -> Out: ...


@dataclass(frozen=True, slots=True)
Expand All @@ -84,16 +63,22 @@ class AnalyticalComputation[In, Out]:
----------
target : str
Characteristic name (e.g., "pdf", "cdf").
func : Callable[[In, KwArg(Any)], Out]
func : ComputationFunc[In, Out]
Analytical function that computes the characteristic.
"""

target: GenericCharacteristicName
func: Callable[[In, KwArg(Any)], Out]
func: ComputationFunc[In, Out]

@overload
def __call__(self, **options: Any) -> Out: ...

def __call__(self, data: In, **options: Any) -> Out:
"""Evaluate the analytical function at the given data."""
return self.func(data, **options)
@overload
def __call__(self, data: In, **options: Any) -> Out: ...

def __call__(self, *args: Any, **options: Any) -> Out:
"""Evaluate the analytical function."""
return self.func(*args, **options)


@dataclass(frozen=True, slots=True)
Expand All @@ -107,17 +92,23 @@ class FittedComputationMethod[In, Out]:
Destination characteristic name.
sources : Sequence[str]
Source characteristic names (typically length 1 for unary conversions).
func : Callable[[In, KwArg(Any)], Out]
func : ComputationFunc[In, Out]
Callable implementing the fitted conversion.
"""

target: GenericCharacteristicName
sources: Sequence[GenericCharacteristicName]
func: Callable[[In, KwArg(Any)], Out]
func: ComputationFunc[In, Out]

@overload
def __call__(self, **options: Any) -> Out: ...

def __call__(self, data: In, **options: Any) -> Out:
"""Evaluate the fitted conversion at the given data."""
return self.func(data, **options)
@overload
def __call__(self, data: In, **options: Any) -> Out: ...

def __call__(self, *args: Any, **options: Any) -> Out:
"""Evaluate the fitted conversion."""
return self.func(*args, **options)


@dataclass(frozen=True, slots=True)
Expand All @@ -134,13 +125,72 @@ class ComputationMethod[In, Out]:
Destination characteristic name.
sources : Sequence[str]
Source characteristic names (typically length 1 for unary conversions).
fitter : Callable[[Distribution, **options], FittedComputationMethod]
fitter : Fitter[In, Out] | None
Function that fits the computation method to a distribution.
If provided, the method is considered *cacheable* (fitting may perform
expensive precomputation).
evaluator : Evaluator[In, Out] | None
Direct evaluator that performs the computation in one step, without
a separate fitting stage. If provided, the method is considered
*non-cacheable* at the strategy level.
"""

target: GenericCharacteristicName
sources: Sequence[GenericCharacteristicName]
fitter: Callable[[Distribution, KwArg(Any)], FittedComputationMethod[In, Out]]
fitter: Fitter[In, Out] | None = None
evaluator: Evaluator[In, Out] | None = None

def __post_init__(self) -> None:
has_fitter = self.fitter is not None
has_eval = self.evaluator is not None
if has_fitter == has_eval:
raise ValueError(
"ComputationMethod must define exactly one of 'fitter' or 'evaluator'."
)

@property
def cacheable(self) -> bool:
"""Whether it makes sense to cache the prepared method at strategy level."""
return self.fitter is not None

def prepare(
self, distribution: Distribution, **options: Any
) -> FittedComputationMethod[In, Out]:
"""Prepare a callable method for a specific distribution.

- If ``fitter`` is provided, run the fitting stage and return the fitted method.
- If ``evaluator`` is provided, bind the distribution and return a lightweight
fitted wrapper.
"""
if self.fitter is not None:
return self.fitter(distribution, **options)

def _bound(*args: Any, **kwargs: Any) -> Out:
return cast(Evaluator[In, Out], self.evaluator)(distribution, *args, **kwargs)

return FittedComputationMethod[In, Out](
target=self.target,
sources=self.sources,
func=_bound,
)

@overload
def __call__(self, distribution: Distribution, **options: Any) -> Out: ...

@overload
def __call__(self, distribution: Distribution, data: In, **options: Any) -> Out: ...

def __call__(self, distribution: Distribution, *args: Any, **options: Any) -> Out:
"""Evaluate *direct* computation methods.

This is only available for methods defined via ``evaluator``.
"""
if self.evaluator is None:
raise RuntimeError(
"This ComputationMethod requires fitting. "
"Call .fit(...) / .prepare(...) to obtain a callable."
)
return self.evaluator(distribution, *args, **options)

def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMethod[In, Out]:
"""
Expand All @@ -158,4 +208,12 @@ def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMe
FittedComputationMethod
Fitted method ready for evaluation.
"""
if self.fitter is None:
raise RuntimeError(
"This ComputationMethod is evaluator-based and does not support .fit(). "
"Use .prepare(...) or call the method directly."
)
return self.fitter(distribution, **options)


type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]
5 changes: 2 additions & 3 deletions src/pysatl_core/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from collections.abc import Mapping
from typing import Any

from pysatl_core.distributions.computation import AnalyticalComputation
from pysatl_core.distributions.computation import AnalyticalComputation, Method
from pysatl_core.distributions.strategies import (
ComputationStrategy,
Method,
SamplingStrategy,
)
from pysatl_core.distributions.support import Support
Expand Down Expand Up @@ -67,7 +66,7 @@ def analytical_computations(
def sampling_strategy(self) -> SamplingStrategy: ...

@property
def computation_strategy(self) -> ComputationStrategy[Any, Any]: ...
def computation_strategy(self) -> ComputationStrategy: ...

@property
def support(self) -> Support | None: ...
Expand Down
4 changes: 3 additions & 1 deletion src/pysatl_core/distributions/fitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from typing import Any

from pysatl_core.distributions.distribution import Distribution
from pysatl_core.types import GenericCharacteristicName, ScalarFunc
from pysatl_core.types import GenericCharacteristicName

type ScalarFunc = Callable[[float], float]


def _resolve(distribution: Distribution, name: GenericCharacteristicName) -> ScalarFunc:
Expand Down
20 changes: 9 additions & 11 deletions src/pysatl_core/distributions/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
if TYPE_CHECKING:
from typing import Any

from pysatl_core.distributions.computation import AnalyticalComputation, FittedComputationMethod
from pysatl_core.distributions.computation import FittedComputationMethod, Method
from pysatl_core.distributions.distribution import Distribution
from pysatl_core.types import GenericCharacteristicName

type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]


class ComputationStrategy[In, Out](Protocol):
class ComputationStrategy(Protocol):
"""
Protocol for strategies that resolve computation methods for characteristics.

Expand All @@ -42,10 +40,10 @@ class ComputationStrategy[In, Out](Protocol):

def query_method(
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
) -> Method[In, Out]: ...
) -> Method[Any, Any]: ...


class DefaultComputationStrategy[In, Out]:
class DefaultComputationStrategy:
"""
Default strategy for resolving characteristic computation methods.

Expand All @@ -71,7 +69,7 @@ class DefaultComputationStrategy[In, Out]:

def __init__(self, enable_caching: bool = False) -> None:
self.enable_caching = enable_caching
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[In, Out]] = {}
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[Any, Any]] = {}
self._resolving: dict[int, set[GenericCharacteristicName]] = {}

def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
Expand Down Expand Up @@ -103,7 +101,7 @@ def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> N

def query_method(
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
) -> Method[In, Out]:
) -> Method[Any, Any]:
"""
Resolve a computation method for the target characteristic.

Expand Down Expand Up @@ -164,10 +162,10 @@ def query_method(
continue

# Fit each edge along the path
last_fitted: FittedComputationMethod[In, Out] | None = None
last_fitted: FittedComputationMethod[Any, Any] | None = None
for edge in path:
fitted = edge.fit(distr, **options)
if self.enable_caching:
fitted = edge.prepare(distr, **options)
if self.enable_caching and edge.cacheable:
self._cache[edge.target] = fitted
last_fitted = fitted

Expand Down
13 changes: 5 additions & 8 deletions src/pysatl_core/families/builtins/continuous/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from typing import TYPE_CHECKING, cast
from typing import cast

import numpy as np

Expand All @@ -30,9 +30,6 @@
UnivariateContinuous,
)

if TYPE_CHECKING:
from typing import Any


def configure_exponential_family() -> None:
"""
Expand Down Expand Up @@ -165,21 +162,21 @@ def char_func(parameters: Parametrization, t: NumericArray) -> ComplexArray:
)
return cast(ComplexArray, result)

def mean_func(parameters: Parametrization, _: Any) -> float:
def mean_func(parameters: Parametrization) -> float:
"""Mean of exponential distribution."""
parameters = cast(_Rate, parameters)
return 1.0 / parameters.lambda_

def var_func(parameters: Parametrization, _: Any) -> float:
def var_func(parameters: Parametrization) -> float:
"""Variance of exponential distribution."""
parameters = cast(_Rate, parameters)
return 1.0 / (parameters.lambda_**2)

def skew_func(_1: Parametrization, _2: Any) -> float:
def skew_func() -> float:
"""Skewness of exponential distribution (always 2)."""
return 2.0

def kurt_func(_1: Parametrization, _2: Any, excess: bool = False) -> float:
def kurt_func(*, excess: bool = False) -> float:
"""Raw or excess kurtosis of exponential distribution.

Parameters
Expand Down
Loading