1818from typing import TYPE_CHECKING , Any , cast , dataclass_transform
1919
2020from pysatl_core .distributions .computation import AnalyticalComputation
21- from pysatl_core .distributions .strategies import (
22- DefaultComputationStrategy ,
23- DefaultSamplingUnivariateStrategy ,
24- )
2521from pysatl_core .families .distribution import ParametricFamilyDistribution
2622from pysatl_core .types import ComputationFunc , DistributionType
2723
@@ -83,10 +79,6 @@ class ParametricFamily:
8379 - pointwise characteristics (e.g., pdf, cdf, ppf): provider(params, x, **kwargs) -> Any
8480
8581 If a single callable is provided, it is treated as defined for the base parametrization.
86- sampling_strategy : SamplingStrategy, optional
87- Strategy for sampling from distributions.
88- computation_strategy : ComputationStrategy, optional
89- Strategy for computing distribution characteristics.
9082 support_by_parametrization : Callable or None, optional
9183 Function that returns support for given parameters.
9284 """
@@ -97,37 +89,27 @@ def __init__(
9789 distr_type : DistributionType | Callable [[Parametrization ], DistributionType ],
9890 distr_parametrizations : list [ParametrizationName ],
9991 distr_characteristics : CharacteristicsMap ,
100- sampling_strategy : SamplingStrategy | None = None ,
101- computation_strategy : ComputationStrategy | None = None ,
10292 support_by_parametrization : SupportArg = None ,
10393 ):
94+ if not distr_parametrizations :
95+ raise ValueError (
96+ "distr_parametrizations must be non-empty (base parametrization is required)."
97+ )
98+
10499 self ._name = name
100+ # Ordered names; the first one is the base parametrization name
101+ self .parametrization_names = distr_parametrizations
102+ self .base_parametrization_name = self .parametrization_names [0 ]
105103 self ._distr_type : Callable [[Parametrization ], DistributionType ] = (
106104 (lambda params : distr_type ) if isinstance (distr_type , DistributionType ) else distr_type
107105 )
108106
109- self .computation_strategy = (
110- DefaultComputationStrategy () if computation_strategy is None else computation_strategy
111- )
112-
113- if support_by_parametrization is None :
114- self ._support_resolver : SupportResolver
115- self ._support_resolver = lambda _params : None
116- else :
117- self ._support_resolver = support_by_parametrization
118-
119- # Ordered names; the first one is the base parametrization name
120- self .parametrization_names : list [ParametrizationName ] = distr_parametrizations
121- self .base_parametrization_name : ParametrizationName = self .parametrization_names [0 ]
107+ self ._support_resolver : SupportResolver = support_by_parametrization or (lambda _p : None )
122108
123109 # Runtime registry of parametrization classes
124110 self ._parametrizations : dict [ParametrizationName , type [Parametrization ]] = {}
125111
126- self .sampling_strategy = (
127- DefaultSamplingUnivariateStrategy () if sampling_strategy is None else sampling_strategy
128- )
129-
130- def _process_char_val (
112+ def _normalize_characteristic (
131113 value : Mapping [ParametrizationName , CharacteristicFunction [Any , Any ]]
132114 | CharacteristicFunction [Any , Any ],
133115 ) -> dict [ParametrizationName , CharacteristicFunction [Any , Any ]]:
@@ -139,21 +121,33 @@ def _process_char_val(
139121
140122 self .distr_characteristics : dict [
141123 GenericCharacteristicName , dict [ParametrizationName , CharacteristicFunction [Any , Any ]]
142- ] = {key : _process_char_val (val ) for key , val in distr_characteristics .items ()}
143-
144- # Precompute analytical plan
124+ ] = {k : _normalize_characteristic (v ) for k , v in distr_characteristics .items ()}
125+
126+ # Validate characteristic providers
127+ valid_names = set (self .parametrization_names )
128+ for char_name , forms in self .distr_characteristics .items ():
129+ unknown = set (forms ) - valid_names
130+ if unknown :
131+ raise ValueError (
132+ f"Characteristic '{ char_name } ' has providers for unknown parametrizations: "
133+ f"{ sorted (unknown )} ."
134+ )
135+ if self .base_parametrization_name not in forms and len (forms ) == 0 :
136+ raise ValueError (f"Characteristic '{ char_name } ' has no providers." )
137+
138+ # Precompute analytical plan: for each parametrization pick provider (self or base)
145139 self ._analytical_plan : dict [
146140 ParametrizationName , dict [GenericCharacteristicName , ParametrizationName ]
147141 ] = {}
148- base_name = self .base_parametrization_name
142+ base = self .base_parametrization_name
149143 for pname in self .parametrization_names :
150- plan_for_p : dict [GenericCharacteristicName , ParametrizationName ] = {}
144+ plan : dict [GenericCharacteristicName , ParametrizationName ] = {}
151145 for characteristic , forms in self .distr_characteristics .items ():
152146 if pname in forms :
153- plan_for_p [characteristic ] = pname
154- elif base_name in forms :
155- plan_for_p [characteristic ] = base_name
156- self ._analytical_plan [pname ] = plan_for_p
147+ plan [characteristic ] = pname
148+ elif base in forms :
149+ plan [characteristic ] = base
150+ self ._analytical_plan [pname ] = plan
157151
158152 @property
159153 def name (self ) -> str :
@@ -184,7 +178,7 @@ def base(self) -> type[Parametrization]:
184178
185179 @property
186180 def support_resolver (self ) -> SupportResolver :
187- """Get the support resolver function ."""
181+ """Support resolver callable ."""
188182 return self ._support_resolver
189183
190184 def register_parametrization (
@@ -269,7 +263,7 @@ def _bind_parametrization[In, Out](
269263 else func ,
270264 )
271265
272- def _build_analytical_computations (
266+ def build_analytical_computations (
273267 self , parameters : Parametrization
274268 ) -> dict [GenericCharacteristicName , AnalyticalComputation [Any , Any ]]:
275269 """
@@ -285,8 +279,7 @@ def _build_analytical_computations(
285279 if provider_name == parameters .name :
286280 params_obj = parameters
287281 else :
288- if base_params is None :
289- base_params = self .to_base (parameters )
282+ base_params = base_params or self .to_base (parameters )
290283 params_obj = base_params
291284
292285 func_factory = self .distr_characteristics [characteristic ][provider_name ]
@@ -299,16 +292,23 @@ def _build_analytical_computations(
299292
300293 def distribution (
301294 self ,
302- parametrization_name : str | None = None ,
295+ parametrization_name : ParametrizationName | None = None ,
296+ sampling_strategy : SamplingStrategy | None = None ,
297+ computation_strategy : ComputationStrategy | None = None ,
303298 ** parameters_values : Any ,
304299 ) -> ParametricFamilyDistribution :
305300 """
306301 Create a distribution instance with given parameters.
307302
308303 Parameters
309304 ----------
310- parametrization_name : str , optional
305+ parametrization_name : ParametrizationName | None , optional
311306 Name of parametrization to use (defaults to base).
307+ sampling_strategy : SamplingStrategy
308+ Strategy for generating random samples. Such an object is unique for each distribution.
309+ computation_strategy : ComputationStrategy
310+ Strategy for computing characteristics and conversions.
311+ Such an object is unique for each distribution.
312312 **parameters_values
313313 Parameter values for the distribution.
314314
@@ -324,22 +324,28 @@ def distribution(
324324 ValueError
325325 If parameters don't satisfy constraints.
326326 """
327- if parametrization_name is None :
328- parametrization_class = self .base
329- else :
330- parametrization_class = self ._parametrizations [parametrization_name ]
327+ parametrization_class = (
328+ self .base
329+ if parametrization_name is None
330+ else self ._parametrizations [parametrization_name ]
331+ )
331332
332333 parameters = parametrization_class (** parameters_values )
333334 parameters .validate ()
334335 base_parameters = self .to_base (parameters )
335336 distribution_type = self ._distr_type (base_parameters )
336337 return ParametricFamilyDistribution (
337- self .name , distribution_type , parameters , self .support_resolver (parameters )
338+ family_name = self .name ,
339+ distribution_type = distribution_type ,
340+ parametrization = parameters ,
341+ support = self .support_resolver (parameters ),
342+ sampling_strategy = sampling_strategy ,
343+ computation_strategy = computation_strategy ,
338344 )
339345
340346 @dataclass_transform ()
341347 def parametrization (
342- self , * , name : str
348+ self , * , name : ParametrizationName
343349 ) -> Callable [[type [Parametrization ]], type [Parametrization ]]:
344350 """
345351 Create a class decorator that registers a parametrization.
0 commit comments