Skip to content

Commit 5c3c777

Browse files
committed
refactor: distribution is abstract not the protocol
1 parent 1533dd2 commit 5c3c777

4 files changed

Lines changed: 35 additions & 25 deletions

File tree

src/pysatl_core/distributions/computation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@ def _bound(*args: Any, **kwargs: Any) -> Out:
175175
)
176176

177177
@overload
178-
def __call__(self, distribution: Distribution, **options: Any) -> Out: ...
178+
def evaluate(self, distribution: Distribution, **options: Any) -> Out: ...
179179

180180
@overload
181-
def __call__(self, distribution: Distribution, data: In, **options: Any) -> Out: ...
181+
def evaluate(self, distribution: Distribution, data: In, **options: Any) -> Out: ...
182182

183-
def __call__(self, distribution: Distribution, *args: Any, **options: Any) -> Out:
183+
def evaluate(self, distribution: Distribution, *args: Any, **options: Any) -> Out:
184184
"""Evaluate *direct* computation methods.
185185
186186
This is only available for methods defined via ``evaluator``.
@@ -215,5 +215,15 @@ def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMe
215215
)
216216
return self.fitter(distribution, **options)
217217

218+
@overload
219+
def __call__(self, distribution: Distribution, **options: Any) -> Out: ...
220+
221+
@overload
222+
def __call__(self, distribution: Distribution, data: In, **options: Any) -> Out: ...
223+
224+
def __call__(self, distribution: Distribution, *args: Any, **options: Any) -> Out:
225+
"""Fit if possible and then evaluate"""
226+
return self.prepare(distribution, **options)(*args)
227+
218228

219229
type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]

src/pysatl_core/distributions/distribution.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
__copyright__ = "Copyright (c) 2025 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

14-
from typing import TYPE_CHECKING, Protocol, Self, cast, runtime_checkable
14+
from abc import ABC, abstractmethod
15+
from typing import TYPE_CHECKING, Self, cast
1516

1617
from pysatl_core.types import NumericArray
1718

@@ -34,8 +35,7 @@
3435
)
3536

3637

37-
@runtime_checkable
38-
class Distribution(Protocol):
38+
class Distribution(ABC):
3939
"""
4040
Protocol defining the interface for probability distributions.
4141
@@ -58,22 +58,28 @@ class Distribution(Protocol):
5858
"""
5959

6060
@property
61+
@abstractmethod
6162
def distribution_type(self) -> DistributionType: ...
6263

6364
@property
65+
@abstractmethod
6466
def analytical_computations(
6567
self,
6668
) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: ...
6769

6870
@property
71+
@abstractmethod
6972
def sampling_strategy(self) -> SamplingStrategy: ...
7073

7174
@property
75+
@abstractmethod
7276
def computation_strategy(self) -> ComputationStrategy: ...
7377

7478
@property
79+
@abstractmethod
7580
def support(self) -> Support | None: ...
7681

82+
@abstractmethod
7783
def _clone_with_strategies(
7884
self,
7985
*,

src/pysatl_core/distributions/strategies.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class ComputationStrategy(Protocol):
3636
Whether to cache fitted computation methods.
3737
"""
3838

39-
enable_caching: bool
40-
4139
def query_method(
4240
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
4341
) -> Method[Any, Any]: ...
@@ -59,7 +57,7 @@ class DefaultComputationStrategy:
5957
6058
Attributes
6159
----------
62-
enable_caching : bool
60+
_enable_caching : bool
6361
Whether caching is enabled.
6462
_cache : dict[str, FittedComputationMethod]
6563
Cache of fitted computation methods.
@@ -68,10 +66,14 @@ class DefaultComputationStrategy:
6866
"""
6967

7068
def __init__(self, enable_caching: bool = False) -> None:
71-
self.enable_caching = enable_caching
69+
self._enable_caching = enable_caching
7270
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[Any, Any]] = {}
7371
self._resolving: dict[int, set[GenericCharacteristicName]] = {}
7472

73+
@property
74+
def is_caching_enabled(self) -> bool:
75+
return self._enable_caching
76+
7577
def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
7678
"""
7779
Push a characteristic onto the resolution stack to detect cycles.
@@ -135,7 +137,7 @@ def query_method(
135137
return distr.analytical_computations[state]
136138

137139
# 2. Check cache if enabled
138-
if self.enable_caching:
140+
if self._enable_caching:
139141
cached = self._cache.get(state)
140142
if cached is not None:
141143
return cached
@@ -165,7 +167,7 @@ def query_method(
165167
last_fitted: FittedComputationMethod[Any, Any] | None = None
166168
for edge in path:
167169
fitted = edge.prepare(distr, **options)
168-
if self.enable_caching and edge.cacheable:
170+
if self._enable_caching and edge.cacheable:
169171
self._cache[edge.target] = fitted
170172
last_fitted = fitted
171173

src/pysatl_core/families/parametrizations.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
from abc import ABC
1616
from dataclasses import dataclass, is_dataclass
17-
from functools import wraps
1817
from inspect import isfunction
19-
from typing import TYPE_CHECKING, ParamSpec, dataclass_transform
18+
from typing import TYPE_CHECKING, dataclass_transform
2019

2120
from pysatl_core.types import ParametrizationName
2221

@@ -107,10 +106,7 @@ def transform_to_base_parametrization(self) -> Parametrization:
107106
return self
108107

109108

110-
P = ParamSpec("P")
111-
112-
113-
def constraint(description: str) -> Callable[[Callable[P, bool]], Callable[P, bool]]:
109+
def constraint[**P](description: str) -> Callable[[Callable[P, bool]], Callable[P, bool]]:
114110
"""
115111
Decorator to mark an instance method as a parameter constraint.
116112
@@ -133,13 +129,9 @@ def constraint(description: str) -> Callable[[Callable[P, bool]], Callable[P, bo
133129
"""
134130

135131
def decorator(func: Callable[P, bool]) -> Callable[P, bool]:
136-
@wraps(func)
137-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bool:
138-
return func(*args, **kwargs)
139-
140-
setattr(wrapper, "__is_constraint", True)
141-
setattr(wrapper, "__constraint_description", description)
142-
return wrapper
132+
setattr(func, "__is_constraint", True)
133+
setattr(func, "__constraint_description", description)
134+
return func
143135

144136
return decorator
145137

0 commit comments

Comments
 (0)