Skip to content

Commit a4a7046

Browse files
committed
refactor: fit-evaluate strategy now work more clearly
1 parent 248c8d2 commit a4a7046

2 files changed

Lines changed: 76 additions & 7 deletions

File tree

src/pysatl_core/distributions/computation.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from collections.abc import Callable, Sequence
1515
from dataclasses import dataclass
16-
from typing import TYPE_CHECKING, Protocol, overload, runtime_checkable
16+
from typing import TYPE_CHECKING, Protocol, cast, overload, runtime_checkable
1717

1818
from pysatl_core.types import ComputationFunc
1919

@@ -25,6 +25,11 @@
2525
from pysatl_core.distributions.distribution import Distribution
2626
from pysatl_core.types import GenericCharacteristicName
2727

28+
type Fitter[In, Out] = Callable[[Distribution, KwArg(Any)], FittedComputationMethod[In, Out]]
29+
type Evaluator[In, Out] = (
30+
Callable[[Distribution, KwArg(Any)], Out] | Callable[[Distribution, In, KwArg(Any)], Out]
31+
)
32+
2833

2934
@runtime_checkable
3035
class Computation[In, Out](Protocol):
@@ -58,7 +63,7 @@ class AnalyticalComputation[In, Out]:
5863
----------
5964
target : str
6065
Characteristic name (e.g., "pdf", "cdf").
61-
func : Callable[..., Out]
66+
func : ComputationFunc[In, Out]
6267
Analytical function that computes the characteristic.
6368
"""
6469

@@ -87,7 +92,7 @@ class FittedComputationMethod[In, Out]:
8792
Destination characteristic name.
8893
sources : Sequence[str]
8994
Source characteristic names (typically length 1 for unary conversions).
90-
func : Callable[..., Out]
95+
func : ComputationFunc[In, Out]
9196
Callable implementing the fitted conversion.
9297
"""
9398

@@ -120,13 +125,72 @@ class ComputationMethod[In, Out]:
120125
Destination characteristic name.
121126
sources : Sequence[str]
122127
Source characteristic names (typically length 1 for unary conversions).
123-
fitter : Callable[[Distribution, **options], FittedComputationMethod]
128+
fitter : Fitter[In, Out] | None
124129
Function that fits the computation method to a distribution.
130+
If provided, the method is considered *cacheable* (fitting may perform
131+
expensive precomputation).
132+
evaluator : Evaluator[In, Out] | None
133+
Direct evaluator that performs the computation in one step, without
134+
a separate fitting stage. If provided, the method is considered
135+
*non-cacheable* at the strategy level.
125136
"""
126137

127138
target: GenericCharacteristicName
128139
sources: Sequence[GenericCharacteristicName]
129-
fitter: Callable[[Distribution, KwArg(Any)], FittedComputationMethod[In, Out]]
140+
fitter: Fitter[In, Out] | None = None
141+
evaluator: Evaluator[In, Out] | None = None
142+
143+
def __post_init__(self) -> None:
144+
has_fitter = self.fitter is not None
145+
has_eval = self.evaluator is not None
146+
if has_fitter == has_eval:
147+
raise ValueError(
148+
"ComputationMethod must define exactly one of 'fitter' or 'evaluator'."
149+
)
150+
151+
@property
152+
def cacheable(self) -> bool:
153+
"""Whether it makes sense to cache the prepared method at strategy level."""
154+
return self.fitter is not None
155+
156+
def prepare(
157+
self, distribution: Distribution, **options: Any
158+
) -> FittedComputationMethod[In, Out]:
159+
"""Prepare a callable method for a specific distribution.
160+
161+
- If ``fitter`` is provided, run the fitting stage and return the fitted method.
162+
- If ``evaluator`` is provided, bind the distribution and return a lightweight
163+
fitted wrapper.
164+
"""
165+
if self.fitter is not None:
166+
return self.fitter(distribution, **options)
167+
168+
def _bound(*args: Any, **kwargs: Any) -> Out:
169+
return cast(Evaluator[In, Out], self.evaluator)(distribution, *args, **kwargs)
170+
171+
return FittedComputationMethod[In, Out](
172+
target=self.target,
173+
sources=self.sources,
174+
func=_bound,
175+
)
176+
177+
@overload
178+
def __call__(self, distribution: Distribution, **options: Any) -> Out: ...
179+
180+
@overload
181+
def __call__(self, distribution: Distribution, data: In, **options: Any) -> Out: ...
182+
183+
def __call__(self, distribution: Distribution, *args: Any, **options: Any) -> Out:
184+
"""Evaluate *direct* computation methods.
185+
186+
This is only available for methods defined via ``evaluator``.
187+
"""
188+
if self.evaluator is None:
189+
raise RuntimeError(
190+
"This ComputationMethod requires fitting. "
191+
"Call .fit(...) / .prepare(...) to obtain a callable."
192+
)
193+
return self.evaluator(distribution, *args, **options)
130194

131195
def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMethod[In, Out]:
132196
"""
@@ -144,6 +208,11 @@ def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMe
144208
FittedComputationMethod
145209
Fitted method ready for evaluation.
146210
"""
211+
if self.fitter is None:
212+
raise RuntimeError(
213+
"This ComputationMethod is evaluator-based and does not support .fit(). "
214+
"Use .prepare(...) or call the method directly."
215+
)
147216
return self.fitter(distribution, **options)
148217

149218

src/pysatl_core/distributions/strategies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def query_method(
164164
# Fit each edge along the path
165165
last_fitted: FittedComputationMethod[Any, Any] | None = None
166166
for edge in path:
167-
fitted = edge.fit(distr, **options)
168-
if self.enable_caching:
167+
fitted = edge.prepare(distr, **options)
168+
if self.enable_caching and edge.cacheable:
169169
self._cache[edge.target] = fitted
170170
last_fitted = fitted
171171

0 commit comments

Comments
 (0)