Skip to content

Commit 8ba5e15

Browse files
committed
refactor: strategies now is one per distribution
1 parent 5417c30 commit 8ba5e15

5 files changed

Lines changed: 191 additions & 74 deletions

File tree

src/pysatl_core/distributions/distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ class Distribution(Protocol):
4848
analytical_computations : Mapping[str, AnalyticalComputation]
4949
Direct analytical computations provided by the distribution.
5050
sampling_strategy : SamplingStrategy
51-
Strategy for generating random samples.
51+
Strategy for generating random samples. Such an object is unique for each distribution.
5252
computation_strategy : ComputationStrategy
5353
Strategy for computing characteristics and conversions.
54+
Such an object is unique for each distribution.
5455
support : Support or None
5556
Support of the distribution, if defined.
5657
"""

src/pysatl_core/families/distribution.py

Lines changed: 133 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
__copyright__ = "Copyright (c) 2025 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

14-
15-
from dataclasses import dataclass
1614
from typing import TYPE_CHECKING, cast
1715

1816
from pysatl_core.distributions.distribution import Distribution
17+
from pysatl_core.distributions.strategies import (
18+
DefaultComputationStrategy,
19+
DefaultSamplingUnivariateStrategy,
20+
)
1921
from pysatl_core.families.registry import ParametricFamilyRegister
2022
from pysatl_core.types import NumericArray
2123

@@ -40,8 +42,9 @@
4042
ParametrizationName,
4143
)
4244

45+
_KEEP: object = object()
46+
4347

44-
@dataclass(slots=True)
4548
class ParametricFamilyDistribution(Distribution):
4649
"""
4750
A specific distribution instance from a parametric family.
@@ -53,18 +56,45 @@ class ParametricFamilyDistribution(Distribution):
5356
----------
5457
family_name : str
5558
Name of the distribution family.
56-
_distribution_type : DistributionType
59+
distribution_type : DistributionType
5760
Type of this distribution.
58-
_parametrization : Parametrization
61+
parametrization : Parametrization
5962
Parameter values for this distribution.
60-
_support : Support or None
63+
support : Support or None
6164
Support of this distribution.
65+
sampling_strategy : SamplingStrategy
66+
Strategy for generating random samples. Such an object is unique for each distribution.
67+
computation_strategy : ComputationStrategy
68+
Strategy for computing characteristics and conversions.
69+
Such an object is unique for each distribution.
6270
"""
6371

64-
family_name: str
65-
_distribution_type: DistributionType
66-
_parametrization: Parametrization
67-
_support: Support | None
72+
def __init__(
73+
self,
74+
family_name: str,
75+
distribution_type: DistributionType,
76+
parametrization: Parametrization,
77+
support: Support | None,
78+
sampling_strategy: SamplingStrategy | None = None,
79+
computation_strategy: ComputationStrategy[Any, Any] | None = None,
80+
):
81+
self._distribution_type = distribution_type
82+
self._family_name = family_name
83+
self._parametrization = parametrization
84+
self._support = support
85+
86+
self._computation_strategy = computation_strategy or DefaultComputationStrategy()
87+
self._sampling_strategy = sampling_strategy or DefaultSamplingUnivariateStrategy()
88+
89+
self._analytical_cache_key: tuple[int, GenericCharacteristicName] | None = None
90+
self._analytical_cache_val: (
91+
Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]] | None
92+
) = None
93+
94+
@property
95+
def family_name(self) -> str:
96+
"Get the name of the family this distribution belongs to."
97+
return self._family_name
6898

6999
@property
70100
def distribution_type(self) -> DistributionType:
@@ -142,25 +172,110 @@ def analytical_computations(
142172
parametrization object or name changes.
143173
"""
144174
key = (id(self.parametrization), self.parametrization_name)
145-
cache_key = getattr(self, "_analytical_cache_key", None)
146-
cache_val = getattr(self, "_analytical_cache_val", None)
147175

148-
if cache_key != key or cache_val is None:
149-
cache_val = self.family._build_analytical_computations(self.parametrization)
176+
if self._analytical_cache_key != key or self._analytical_cache_val is None:
177+
self._analytical_cache_val = self.family.build_analytical_computations(
178+
self.parametrization
179+
)
150180
self._analytical_cache_key = key
151-
self._analytical_cache_val = cache_val
152181

153-
return cache_val
182+
return self._analytical_cache_val
154183

155184
@property
156185
def sampling_strategy(self) -> SamplingStrategy:
157186
"""Get the sampling strategy for this distribution."""
158-
return self.family.sampling_strategy
187+
return self._sampling_strategy
159188

160189
@property
161190
def computation_strategy(self) -> ComputationStrategy[Any, Any]:
162191
"""Get the computation strategy for this distribution."""
163-
return self.family.computation_strategy
192+
return self._computation_strategy
193+
194+
def with_sampling_strategy(
195+
self, sampling_strategy: SamplingStrategy | None
196+
) -> ParametricFamilyDistribution:
197+
"""
198+
Return a copy of this distribution with an updated sampling strategy.
199+
200+
Parameters
201+
----------
202+
sampling_strategy : SamplingStrategy | None
203+
New sampling strategy. If ``None``, the default sampling strategy is used.
204+
205+
Returns
206+
-------
207+
ParametricFamilyDistribution
208+
New distribution instance with the same parameters and updated strategy.
209+
"""
210+
return ParametricFamilyDistribution(
211+
family_name=self._family_name,
212+
distribution_type=self._distribution_type,
213+
parametrization=self._parametrization,
214+
support=self._support,
215+
sampling_strategy=sampling_strategy,
216+
computation_strategy=self._computation_strategy,
217+
)
218+
219+
def with_computation_strategy(
220+
self, computation_strategy: ComputationStrategy[Any, Any] | None
221+
) -> ParametricFamilyDistribution:
222+
"""
223+
Return a copy of this distribution with an updated computation strategy.
224+
225+
Parameters
226+
----------
227+
computation_strategy : ComputationStrategy[Any, Any] | None
228+
New computation strategy. If ``None``, the default computation strategy is used.
229+
230+
Returns
231+
-------
232+
ParametricFamilyDistribution
233+
New distribution instance with the same parameters and updated strategy.
234+
"""
235+
return ParametricFamilyDistribution(
236+
family_name=self._family_name,
237+
distribution_type=self._distribution_type,
238+
parametrization=self._parametrization,
239+
support=self._support,
240+
sampling_strategy=self._sampling_strategy,
241+
computation_strategy=computation_strategy,
242+
)
243+
244+
def with_strategies(
245+
self,
246+
*,
247+
sampling_strategy: SamplingStrategy | None = None,
248+
computation_strategy: ComputationStrategy[Any, Any] | None = None,
249+
) -> ParametricFamilyDistribution:
250+
"""
251+
Return a copy of this distribution with updated strategies.
252+
253+
Parameters
254+
----------
255+
sampling_strategy : SamplingStrategy | None | object, optional
256+
New sampling strategy. If not provided, the current strategy is preserved.
257+
If explicitly set to ``None``, the default sampling strategy is used.
258+
computation_strategy : ComputationStrategy[Any, Any] | None | object, optional
259+
New computation strategy. If not provided, the current strategy is preserved.
260+
If explicitly set to ``None``, the default computation strategy is used.
261+
262+
Returns
263+
-------
264+
ParametricFamilyDistribution
265+
New distribution instance with the same parameters and updated strategies.
266+
"""
267+
new_sampling = self._sampling_strategy if sampling_strategy is _KEEP else sampling_strategy
268+
new_computation = (
269+
self._computation_strategy if computation_strategy is _KEEP else computation_strategy
270+
)
271+
return ParametricFamilyDistribution(
272+
family_name=self._family_name,
273+
distribution_type=self._distribution_type,
274+
parametrization=self._parametrization,
275+
support=self._support,
276+
sampling_strategy=new_sampling,
277+
computation_strategy=new_computation,
278+
)
164279

165280
@property
166281
def support(self) -> Support | None:

0 commit comments

Comments
 (0)