1313
1414from collections .abc import Callable , Sequence
1515from dataclasses import dataclass
16- from typing import TYPE_CHECKING , Protocol , overload , runtime_checkable
16+ from typing import TYPE_CHECKING , Protocol , cast , overload , runtime_checkable
1717
1818from pysatl_core .types import ComputationFunc
1919
2525 from pysatl_core .distributions .distribution import Distribution
2626 from pysatl_core .types import GenericCharacteristicName
2727
28+ type Fitter [In , Out ] = Callable [[Distribution , KwArg (Any )], FittedComputationMethod [In , Out ]]
29+ type Evaluator [In , Out ] = (
30+ Callable [[Distribution , KwArg (Any )], Out ] | Callable [[Distribution , In , KwArg (Any )], Out ]
31+ )
32+
2833
2934@runtime_checkable
3035class Computation [In , Out ](Protocol ):
@@ -58,7 +63,7 @@ class AnalyticalComputation[In, Out]:
5863 ----------
5964 target : str
6065 Characteristic name (e.g., "pdf", "cdf").
61- func : Callable[... , Out]
66+ func : ComputationFunc[In , Out]
6267 Analytical function that computes the characteristic.
6368 """
6469
@@ -87,7 +92,7 @@ class FittedComputationMethod[In, Out]:
8792 Destination characteristic name.
8893 sources : Sequence[str]
8994 Source characteristic names (typically length 1 for unary conversions).
90- func : Callable[... , Out]
95+ func : ComputationFunc[In , Out]
9196 Callable implementing the fitted conversion.
9297 """
9398
@@ -120,13 +125,72 @@ class ComputationMethod[In, Out]:
120125 Destination characteristic name.
121126 sources : Sequence[str]
122127 Source characteristic names (typically length 1 for unary conversions).
123- fitter : Callable[[Distribution, **options], FittedComputationMethod]
128+ fitter : Fitter[In, Out] | None
124129 Function that fits the computation method to a distribution.
130+ If provided, the method is considered *cacheable* (fitting may perform
131+ expensive precomputation).
132+ evaluator : Evaluator[In, Out] | None
133+ Direct evaluator that performs the computation in one step, without
134+ a separate fitting stage. If provided, the method is considered
135+ *non-cacheable* at the strategy level.
125136 """
126137
127138 target : GenericCharacteristicName
128139 sources : Sequence [GenericCharacteristicName ]
129- fitter : Callable [[Distribution , KwArg (Any )], FittedComputationMethod [In , Out ]]
140+ fitter : Fitter [In , Out ] | None = None
141+ evaluator : Evaluator [In , Out ] | None = None
142+
143+ def __post_init__ (self ) -> None :
144+ has_fitter = self .fitter is not None
145+ has_eval = self .evaluator is not None
146+ if has_fitter == has_eval :
147+ raise ValueError (
148+ "ComputationMethod must define exactly one of 'fitter' or 'evaluator'."
149+ )
150+
151+ @property
152+ def cacheable (self ) -> bool :
153+ """Whether it makes sense to cache the prepared method at strategy level."""
154+ return self .fitter is not None
155+
156+ def prepare (
157+ self , distribution : Distribution , ** options : Any
158+ ) -> FittedComputationMethod [In , Out ]:
159+ """Prepare a callable method for a specific distribution.
160+
161+ - If ``fitter`` is provided, run the fitting stage and return the fitted method.
162+ - If ``evaluator`` is provided, bind the distribution and return a lightweight
163+ fitted wrapper.
164+ """
165+ if self .fitter is not None :
166+ return self .fitter (distribution , ** options )
167+
168+ def _bound (* args : Any , ** kwargs : Any ) -> Out :
169+ return cast (Evaluator [In , Out ], self .evaluator )(distribution , * args , ** kwargs )
170+
171+ return FittedComputationMethod [In , Out ](
172+ target = self .target ,
173+ sources = self .sources ,
174+ func = _bound ,
175+ )
176+
177+ @overload
178+ def __call__ (self , distribution : Distribution , ** options : Any ) -> Out : ...
179+
180+ @overload
181+ def __call__ (self , distribution : Distribution , data : In , ** options : Any ) -> Out : ...
182+
183+ def __call__ (self , distribution : Distribution , * args : Any , ** options : Any ) -> Out :
184+ """Evaluate *direct* computation methods.
185+
186+ This is only available for methods defined via ``evaluator``.
187+ """
188+ if self .evaluator is None :
189+ raise RuntimeError (
190+ "This ComputationMethod requires fitting. "
191+ "Call .fit(...) / .prepare(...) to obtain a callable."
192+ )
193+ return self .evaluator (distribution , * args , ** options )
130194
131195 def fit (self , distribution : Distribution , ** options : Any ) -> FittedComputationMethod [In , Out ]:
132196 """
@@ -144,6 +208,11 @@ def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMe
144208 FittedComputationMethod
145209 Fitted method ready for evaluation.
146210 """
211+ if self .fitter is None :
212+ raise RuntimeError (
213+ "This ComputationMethod is evaluator-based and does not support .fit(). "
214+ "Use .prepare(...) or call the method directly."
215+ )
147216 return self .fitter (distribution , ** options )
148217
149218
0 commit comments