Skip to content

Commit 4797489

Browse files
committed
refactor: strategies now is one per distribution
1 parent 659e91f commit 4797489

6 files changed

Lines changed: 177 additions & 73 deletions

File tree

src/pysatl_core/distributions/distribution.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
__copyright__ = "Copyright (c) 2025 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

14-
from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable
14+
from typing import TYPE_CHECKING, Protocol, Self, cast, runtime_checkable
1515

1616
from pysatl_core.types import NumericArray
1717

18+
_KEEP: object = object()
19+
20+
1821
if TYPE_CHECKING:
1922
from collections.abc import Mapping
2023
from typing import Any
@@ -71,6 +74,39 @@ def computation_strategy(self) -> ComputationStrategy: ...
7174
@property
7275
def support(self) -> Support | None: ...
7376

77+
def _clone_with_strategies(
78+
self,
79+
*,
80+
sampling_strategy: SamplingStrategy | None | object = _KEEP,
81+
computation_strategy: ComputationStrategy | None | object = _KEEP,
82+
) -> Distribution: ...
83+
84+
def with_sampling_strategy(self, sampling_strategy: SamplingStrategy | None) -> Self:
85+
"""Return a copy of this distribution with an updated sampling strategy."""
86+
return cast(Self, self._clone_with_strategies(sampling_strategy=sampling_strategy))
87+
88+
def with_computation_strategy(self, computation_strategy: ComputationStrategy | None) -> Self:
89+
"""Return a copy of this distribution with an updated computation strategy."""
90+
return cast(
91+
Self,
92+
self._clone_with_strategies(computation_strategy=computation_strategy),
93+
)
94+
95+
def with_strategies(
96+
self,
97+
*,
98+
sampling_strategy: SamplingStrategy | None | object = _KEEP,
99+
computation_strategy: ComputationStrategy | None | object = _KEEP,
100+
) -> Self:
101+
"""Return a copy of this distribution with updated strategies."""
102+
return cast(
103+
Self,
104+
self._clone_with_strategies(
105+
sampling_strategy=sampling_strategy,
106+
computation_strategy=computation_strategy,
107+
),
108+
)
109+
74110
def query_method(
75111
self, characteristic_name: GenericCharacteristicName, **options: Any
76112
) -> Method[Any, Any]:

src/pysatl_core/families/distribution.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
__copyright__ = "Copyright (c) 2025 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

14-
15-
from dataclasses import dataclass
1614
from typing import TYPE_CHECKING, cast
1715

18-
from pysatl_core.distributions.distribution import Distribution
16+
from pysatl_core.distributions.distribution import _KEEP, Distribution
17+
from pysatl_core.distributions.strategies import (
18+
DefaultComputationStrategy,
19+
DefaultSamplingUnivariateStrategy,
20+
)
1921
from pysatl_core.families.registry import ParametricFamilyRegister
2022
from pysatl_core.types import NumericArray
2123

@@ -41,7 +43,6 @@
4143
)
4244

4345

44-
@dataclass(slots=True)
4546
class ParametricFamilyDistribution(Distribution):
4647
"""
4748
A specific distribution instance from a parametric family.
@@ -53,18 +54,46 @@ class ParametricFamilyDistribution(Distribution):
5354
----------
5455
family_name : str
5556
Name of the distribution family.
56-
_distribution_type : DistributionType
57+
distribution_type : DistributionType
5758
Type of this distribution.
58-
_parametrization : Parametrization
59+
parametrization : Parametrization
5960
Parameter values for this distribution.
60-
_support : Support or None
61+
support : Support or None
6162
Support of this distribution.
63+
sampling_strategy : SamplingStrategy
64+
Strategy for generating random samples.
65+
Such an object is unique for each distribution.
66+
computation_strategy : ComputationStrategy
67+
Strategy for computing characteristics and conversions.
68+
Such an object is unique for each distribution.
6269
"""
6370

64-
family_name: str
65-
_distribution_type: DistributionType
66-
_parametrization: Parametrization
67-
_support: Support | None
71+
def __init__(
72+
self,
73+
family_name: str,
74+
distribution_type: DistributionType,
75+
parametrization: Parametrization,
76+
support: Support | None,
77+
sampling_strategy: SamplingStrategy | None = None,
78+
computation_strategy: ComputationStrategy | None = None,
79+
):
80+
self._distribution_type = distribution_type
81+
self._family_name = family_name
82+
self._parametrization = parametrization
83+
self._support = support
84+
85+
self._computation_strategy = computation_strategy or DefaultComputationStrategy()
86+
self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy()
87+
88+
self._analytical_cache_key: tuple[int, GenericCharacteristicName] | None = None
89+
self._analytical_cache_val: (
90+
Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] | None
91+
) = None
92+
93+
@property
94+
def family_name(self) -> str:
95+
"Get the name of the family this distribution belongs to."
96+
return self._family_name
6897

6998
@property
7099
def distribution_type(self) -> DistributionType:
@@ -142,25 +171,52 @@ def analytical_computations(
142171
parametrization object or name changes.
143172
"""
144173
key = (id(self.parametrization), self.parametrization_name)
145-
cache_key = getattr(self, "_analytical_cache_key", None)
146-
cache_val = getattr(self, "_analytical_cache_val", None)
147174

148-
if cache_key != key or cache_val is None:
149-
cache_val = self.family._build_analytical_computations(self.parametrization)
175+
if self._analytical_cache_key != key or self._analytical_cache_val is None:
176+
self._analytical_cache_val = self.family.build_analytical_computations(
177+
self.parametrization
178+
)
150179
self._analytical_cache_key = key
151-
self._analytical_cache_val = cache_val
152180

153-
return cache_val
181+
return self._analytical_cache_val
154182

155183
@property
156184
def sampling_strategy(self) -> SamplingStrategy:
157185
"""Get the sampling strategy for this distribution."""
158-
return self.family.sampling_strategy
186+
return self._sampling_strategy
159187

160188
@property
161189
def computation_strategy(self) -> ComputationStrategy:
162190
"""Get the computation strategy for this distribution."""
163-
return self.family.computation_strategy
191+
return self._computation_strategy
192+
193+
def _clone_with_strategies(
194+
self,
195+
*,
196+
sampling_strategy: SamplingStrategy | None | object = _KEEP,
197+
computation_strategy: ComputationStrategy | None | object = _KEEP,
198+
) -> ParametricFamilyDistribution:
199+
"""Return a copy of this distribution with updated strategies."""
200+
new_sampling: SamplingStrategy | None = (
201+
self._sampling_strategy
202+
if sampling_strategy is _KEEP
203+
else cast(SamplingStrategy | None, sampling_strategy)
204+
)
205+
206+
new_computation: ComputationStrategy | None = (
207+
self._computation_strategy
208+
if computation_strategy is _KEEP
209+
else cast(ComputationStrategy | None, computation_strategy)
210+
)
211+
212+
return ParametricFamilyDistribution(
213+
family_name=self._family_name,
214+
distribution_type=self._distribution_type,
215+
parametrization=self._parametrization,
216+
support=self._support,
217+
sampling_strategy=new_sampling,
218+
computation_strategy=new_computation,
219+
)
164220

165221
@property
166222
def support(self) -> Support | None:

src/pysatl_core/families/parametric_family.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
from typing import TYPE_CHECKING, Any, cast, dataclass_transform
1919

2020
from pysatl_core.distributions.computation import AnalyticalComputation
21-
from pysatl_core.distributions.strategies import (
22-
DefaultComputationStrategy,
23-
DefaultSamplingUnivariateStrategy,
24-
)
2521
from pysatl_core.families.distribution import ParametricFamilyDistribution
2622
from pysatl_core.types import ComputationFunc, DistributionType
2723

@@ -83,10 +79,6 @@ class ParametricFamily:
8379
- pointwise characteristics (e.g., pdf, cdf, ppf): provider(params, x, **kwargs) -> Any
8480
8581
If a single callable is provided, it is treated as defined for the base parametrization.
86-
sampling_strategy : SamplingStrategy, optional
87-
Strategy for sampling from distributions.
88-
computation_strategy : ComputationStrategy, optional
89-
Strategy for computing distribution characteristics.
9082
support_by_parametrization : Callable or None, optional
9183
Function that returns support for given parameters.
9284
"""
@@ -97,37 +89,27 @@ def __init__(
9789
distr_type: DistributionType | Callable[[Parametrization], DistributionType],
9890
distr_parametrizations: list[ParametrizationName],
9991
distr_characteristics: CharacteristicsMap,
100-
sampling_strategy: SamplingStrategy | None = None,
101-
computation_strategy: ComputationStrategy | None = None,
10292
support_by_parametrization: SupportArg = None,
10393
):
94+
if not distr_parametrizations:
95+
raise ValueError(
96+
"distr_parametrizations must be non-empty (base parametrization is required)."
97+
)
98+
10499
self._name = name
100+
# Ordered names; the first one is the base parametrization name
101+
self.parametrization_names = distr_parametrizations
102+
self.base_parametrization_name = self.parametrization_names[0]
105103
self._distr_type: Callable[[Parametrization], DistributionType] = (
106104
(lambda params: distr_type) if isinstance(distr_type, DistributionType) else distr_type
107105
)
108106

109-
self.computation_strategy = (
110-
DefaultComputationStrategy() if computation_strategy is None else computation_strategy
111-
)
112-
113-
if support_by_parametrization is None:
114-
self._support_resolver: SupportResolver
115-
self._support_resolver = lambda _params: None
116-
else:
117-
self._support_resolver = support_by_parametrization
118-
119-
# Ordered names; the first one is the base parametrization name
120-
self.parametrization_names: list[ParametrizationName] = distr_parametrizations
121-
self.base_parametrization_name: ParametrizationName = self.parametrization_names[0]
107+
self._support_resolver: SupportResolver = support_by_parametrization or (lambda _p: None)
122108

123109
# Runtime registry of parametrization classes
124110
self._parametrizations: dict[ParametrizationName, type[Parametrization]] = {}
125111

126-
self.sampling_strategy = (
127-
DefaultSamplingUnivariateStrategy() if sampling_strategy is None else sampling_strategy
128-
)
129-
130-
def _process_char_val(
112+
def _normalize_characteristic(
131113
value: Mapping[ParametrizationName, CharacteristicFunction[Any, Any]]
132114
| CharacteristicFunction[Any, Any],
133115
) -> dict[ParametrizationName, CharacteristicFunction[Any, Any]]:
@@ -139,21 +121,33 @@ def _process_char_val(
139121

140122
self.distr_characteristics: dict[
141123
GenericCharacteristicName, dict[ParametrizationName, CharacteristicFunction[Any, Any]]
142-
] = {key: _process_char_val(val) for key, val in distr_characteristics.items()}
143-
144-
# Precompute analytical plan
124+
] = {k: _normalize_characteristic(v) for k, v in distr_characteristics.items()}
125+
126+
# Validate characteristic providers
127+
valid_names = set(self.parametrization_names)
128+
for char_name, forms in self.distr_characteristics.items():
129+
unknown = set(forms) - valid_names
130+
if unknown:
131+
raise ValueError(
132+
f"Characteristic '{char_name}' has providers for unknown parametrizations: "
133+
f"{sorted(unknown)}."
134+
)
135+
if self.base_parametrization_name not in forms and len(forms) == 0:
136+
raise ValueError(f"Characteristic '{char_name}' has no providers.")
137+
138+
# Precompute analytical plan: for each parametrization pick provider (self or base)
145139
self._analytical_plan: dict[
146140
ParametrizationName, dict[GenericCharacteristicName, ParametrizationName]
147141
] = {}
148-
base_name = self.base_parametrization_name
142+
base = self.base_parametrization_name
149143
for pname in self.parametrization_names:
150-
plan_for_p: dict[GenericCharacteristicName, ParametrizationName] = {}
144+
plan: dict[GenericCharacteristicName, ParametrizationName] = {}
151145
for characteristic, forms in self.distr_characteristics.items():
152146
if pname in forms:
153-
plan_for_p[characteristic] = pname
154-
elif base_name in forms:
155-
plan_for_p[characteristic] = base_name
156-
self._analytical_plan[pname] = plan_for_p
147+
plan[characteristic] = pname
148+
elif base in forms:
149+
plan[characteristic] = base
150+
self._analytical_plan[pname] = plan
157151

158152
@property
159153
def name(self) -> str:
@@ -184,7 +178,7 @@ def base(self) -> type[Parametrization]:
184178

185179
@property
186180
def support_resolver(self) -> SupportResolver:
187-
"""Get the support resolver function."""
181+
"""Support resolver callable."""
188182
return self._support_resolver
189183

190184
def register_parametrization(
@@ -269,7 +263,7 @@ def _bind_parametrization[In, Out](
269263
else func,
270264
)
271265

272-
def _build_analytical_computations(
266+
def build_analytical_computations(
273267
self, parameters: Parametrization
274268
) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]:
275269
"""
@@ -285,8 +279,7 @@ def _build_analytical_computations(
285279
if provider_name == parameters.name:
286280
params_obj = parameters
287281
else:
288-
if base_params is None:
289-
base_params = self.to_base(parameters)
282+
base_params = base_params or self.to_base(parameters)
290283
params_obj = base_params
291284

292285
func_factory = self.distr_characteristics[characteristic][provider_name]
@@ -299,16 +292,23 @@ def _build_analytical_computations(
299292

300293
def distribution(
301294
self,
302-
parametrization_name: str | None = None,
295+
parametrization_name: ParametrizationName | None = None,
296+
sampling_strategy: SamplingStrategy | None = None,
297+
computation_strategy: ComputationStrategy | None = None,
303298
**parameters_values: Any,
304299
) -> ParametricFamilyDistribution:
305300
"""
306301
Create a distribution instance with given parameters.
307302
308303
Parameters
309304
----------
310-
parametrization_name : str, optional
305+
parametrization_name : ParametrizationName | None, optional
311306
Name of parametrization to use (defaults to base).
307+
sampling_strategy : SamplingStrategy
308+
Strategy for generating random samples. Such an object is unique for each distribution.
309+
computation_strategy : ComputationStrategy
310+
Strategy for computing characteristics and conversions.
311+
Such an object is unique for each distribution.
312312
**parameters_values
313313
Parameter values for the distribution.
314314
@@ -324,22 +324,28 @@ def distribution(
324324
ValueError
325325
If parameters don't satisfy constraints.
326326
"""
327-
if parametrization_name is None:
328-
parametrization_class = self.base
329-
else:
330-
parametrization_class = self._parametrizations[parametrization_name]
327+
parametrization_class = (
328+
self.base
329+
if parametrization_name is None
330+
else self._parametrizations[parametrization_name]
331+
)
331332

332333
parameters = parametrization_class(**parameters_values)
333334
parameters.validate()
334335
base_parameters = self.to_base(parameters)
335336
distribution_type = self._distr_type(base_parameters)
336337
return ParametricFamilyDistribution(
337-
self.name, distribution_type, parameters, self.support_resolver(parameters)
338+
family_name=self.name,
339+
distribution_type=distribution_type,
340+
parametrization=parameters,
341+
support=self.support_resolver(parameters),
342+
sampling_strategy=sampling_strategy,
343+
computation_strategy=computation_strategy,
338344
)
339345

340346
@dataclass_transform()
341347
def parametrization(
342-
self, *, name: str
348+
self, *, name: ParametrizationName
343349
) -> Callable[[type[Parametrization]], type[Parametrization]]:
344350
"""
345351
Create a class decorator that registers a parametrization.

0 commit comments

Comments
 (0)