Skip to content

Commit 248c8d2

Browse files
committed
refactor: generics now used more carefully. Exactly where they are should be
1 parent b868794 commit 248c8d2

9 files changed

Lines changed: 45 additions & 42 deletions

File tree

src/pysatl_core/distributions/computation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,6 @@ def fit(self, distribution: Distribution, **options: Any) -> FittedComputationMe
145145
Fitted method ready for evaluation.
146146
"""
147147
return self.fitter(distribution, **options)
148+
149+
150+
type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]

src/pysatl_core/distributions/distribution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from collections.abc import Mapping
2020
from typing import Any
2121

22-
from pysatl_core.distributions.computation import AnalyticalComputation
22+
from pysatl_core.distributions.computation import AnalyticalComputation, Method
2323
from pysatl_core.distributions.strategies import (
2424
ComputationStrategy,
25-
Method,
2625
SamplingStrategy,
2726
)
2827
from pysatl_core.distributions.support import Support
@@ -67,7 +66,7 @@ def analytical_computations(
6766
def sampling_strategy(self) -> SamplingStrategy: ...
6867

6968
@property
70-
def computation_strategy(self) -> ComputationStrategy[Any, Any]: ...
69+
def computation_strategy(self) -> ComputationStrategy: ...
7170

7271
@property
7372
def support(self) -> Support | None: ...

src/pysatl_core/distributions/fitters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
from typing import Any
2828

2929
from pysatl_core.distributions.distribution import Distribution
30-
from pysatl_core.types import GenericCharacteristicName, ScalarFunc
30+
from pysatl_core.types import GenericCharacteristicName
31+
32+
type ScalarFunc = Callable[[float], float]
3133

3234

3335
def _resolve(distribution: Distribution, name: GenericCharacteristicName) -> ScalarFunc:

src/pysatl_core/distributions/strategies.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121
if TYPE_CHECKING:
2222
from typing import Any
2323

24-
from pysatl_core.distributions.computation import AnalyticalComputation, FittedComputationMethod
24+
from pysatl_core.distributions.computation import FittedComputationMethod, Method
2525
from pysatl_core.distributions.distribution import Distribution
2626
from pysatl_core.types import GenericCharacteristicName
2727

28-
type Method[In, Out] = AnalyticalComputation[In, Out] | FittedComputationMethod[In, Out]
2928

30-
31-
class ComputationStrategy[In, Out](Protocol):
29+
class ComputationStrategy(Protocol):
3230
"""
3331
Protocol for strategies that resolve computation methods for characteristics.
3432
@@ -42,10 +40,10 @@ class ComputationStrategy[In, Out](Protocol):
4240

4341
def query_method(
4442
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
45-
) -> Method[In, Out]: ...
43+
) -> Method[Any, Any]: ...
4644

4745

48-
class DefaultComputationStrategy[In, Out]:
46+
class DefaultComputationStrategy:
4947
"""
5048
Default strategy for resolving characteristic computation methods.
5149
@@ -71,7 +69,7 @@ class DefaultComputationStrategy[In, Out]:
7169

7270
def __init__(self, enable_caching: bool = False) -> None:
7371
self.enable_caching = enable_caching
74-
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[In, Out]] = {}
72+
self._cache: dict[GenericCharacteristicName, FittedComputationMethod[Any, Any]] = {}
7573
self._resolving: dict[int, set[GenericCharacteristicName]] = {}
7674

7775
def _push_guard(self, distr: Distribution, state: GenericCharacteristicName) -> None:
@@ -103,7 +101,7 @@ def _pop_guard(self, distr: Distribution, state: GenericCharacteristicName) -> N
103101

104102
def query_method(
105103
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
106-
) -> Method[In, Out]:
104+
) -> Method[Any, Any]:
107105
"""
108106
Resolve a computation method for the target characteristic.
109107
@@ -164,7 +162,7 @@ def query_method(
164162
continue
165163

166164
# Fit each edge along the path
167-
last_fitted: FittedComputationMethod[In, Out] | None = None
165+
last_fitted: FittedComputationMethod[Any, Any] | None = None
168166
for edge in path:
169167
fitted = edge.fit(distr, **options)
170168
if self.enable_caching:

src/pysatl_core/families/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def sampling_strategy(self) -> SamplingStrategy:
158158
return self.family.sampling_strategy
159159

160160
@property
161-
def computation_strategy(self) -> ComputationStrategy[Any, Any]:
161+
def computation_strategy(self) -> ComputationStrategy:
162162
"""Get the computation strategy for this distribution."""
163163
return self.family.computation_strategy
164164

src/pysatl_core/families/parametric_family.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,19 @@
4141
type SupportArg = Callable[[Parametrization], Support | None] | None
4242
type SupportResolver = Callable[[Parametrization], Support | None]
4343
type CharacteristicProvider = (
44-
Mapping[ParametrizationName, CharacteristicFunction] | CharacteristicFunction
44+
Mapping[ParametrizationName, CharacteristicFunction[Any, Any]]
45+
| CharacteristicFunction[Any, Any]
4546
)
4647
type CharacteristicsMap = Mapping[GenericCharacteristicName, CharacteristicProvider]
48+
type NonParametrizedCharacteristic[In, Out] = Callable[[], Out]
49+
type CharacteristicFunction[In, Out] = (
50+
NonParametrizedCharacteristic[In, Out] | ParametrizedCharacteristic[In, Out]
51+
)
52+
4753

48-
type NonParametrizedCharacteristic = Callable[[], Any]
49-
type ParametrizedCharacteristic = (
50-
Callable[[Parametrization, Any], Any] | Callable[[Parametrization], Any]
54+
type ParametrizedCharacteristic[In, Out] = (
55+
Callable[[Parametrization, In], Out] | Callable[[Parametrization], Out]
5156
)
52-
type CharacteristicFunction = NonParametrizedCharacteristic | ParametrizedCharacteristic
5357

5458

5559
class ParametricFamily:
@@ -94,7 +98,7 @@ def __init__(
9498
distr_parametrizations: list[ParametrizationName],
9599
distr_characteristics: CharacteristicsMap,
96100
sampling_strategy: SamplingStrategy | None = None,
97-
computation_strategy: ComputationStrategy[Any, Any] | None = None,
101+
computation_strategy: ComputationStrategy | None = None,
98102
support_by_parametrization: SupportArg = None,
99103
):
100104
self._name = name
@@ -124,16 +128,17 @@ def __init__(
124128
)
125129

126130
def _process_char_val(
127-
value: Mapping[ParametrizationName, CharacteristicFunction] | CharacteristicFunction,
128-
) -> dict[ParametrizationName, CharacteristicFunction]:
131+
value: Mapping[ParametrizationName, CharacteristicFunction[Any, Any]]
132+
| CharacteristicFunction[Any, Any],
133+
) -> dict[ParametrizationName, CharacteristicFunction[Any, Any]]:
129134
return (
130135
dict(value)
131136
if isinstance(value, Mapping)
132137
else {self.base_parametrization_name: value}
133138
)
134139

135140
self.distr_characteristics: dict[
136-
GenericCharacteristicName, dict[ParametrizationName, CharacteristicFunction]
141+
GenericCharacteristicName, dict[ParametrizationName, CharacteristicFunction[Any, Any]]
137142
] = {key: _process_char_val(val) for key, val in distr_characteristics.items()}
138143

139144
# Precompute analytical plan
@@ -236,9 +241,9 @@ def to_base(self, parameters: Parametrization) -> Parametrization:
236241
return parameters.transform_to_base_parametrization()
237242

238243
@staticmethod
239-
def _bind_parametrization(
240-
func: CharacteristicFunction, params_obj: Parametrization
241-
) -> ComputationFunc[Any, Any]:
244+
def _bind_parametrization[In, Out](
245+
func: CharacteristicFunction[In, Out], params_obj: Parametrization
246+
) -> ComputationFunc[In, Out]:
242247
"""Bind ``params_obj`` to ``func`` only when ``func`` can accept positional arguments.
243248
244249
This allows parametrization-independent analytical providers to be written without
@@ -258,8 +263,8 @@ def _bind_parametrization(
258263
)
259264

260265
return cast(
261-
ComputationFunc[Any, Any],
262-
partial(cast(ParametrizedCharacteristic, func), params_obj)
266+
ComputationFunc[In, Out],
267+
partial(cast(ParametrizedCharacteristic[In, Out], func), params_obj)
263268
if accepts_first_positional
264269
else func,
265270
)

src/pysatl_core/types.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,24 +89,24 @@ class EuclideanDistributionType(DistributionType):
8989

9090

9191
UnivariateContinuous = EuclideanDistributionType(kind=Kind.CONTINUOUS, dimension=1)
92-
"""Type for univariate continuous distributions."""
92+
"""Predefined DistributionType object for univariate continuous distributions."""
9393

9494
UnivariateDiscrete = EuclideanDistributionType(kind=Kind.DISCRETE, dimension=1)
95-
"""Type for univariate discrete distributions."""
95+
"""Predefined DistributionType object for univariate discrete distributions."""
9696

97-
NumPyNumber = np.floating[Any] | np.integer[Any]
97+
type NumPyNumber = np.floating[Any] | np.integer[Any]
9898
"""Type alias for NumPy numeric types."""
9999

100-
Number = NumPyNumber | int | float
100+
type Number = NumPyNumber | int | float
101101
"""Type alias for all numeric types."""
102102

103-
NumericArray = NDArray[NumPyNumber]
103+
type NumericArray = NDArray[NumPyNumber]
104104
"""Type alias for numeric arrays."""
105105

106-
ComplexArray = NDArray[np.complexfloating[Any]]
106+
type ComplexArray = NDArray[np.complexfloating[Any]]
107107
"""Type alias for complex arrays."""
108108

109-
BoolArray = NDArray[np.bool_]
109+
type BoolArray = NDArray[np.bool_]
110110
"""Type alias for boolean arrays."""
111111

112112

@@ -253,9 +253,6 @@ def shape(self) -> ContinuousSupportShape1D:
253253
``**options`` dynamically.
254254
"""
255255

256-
ScalarFunc = Callable[[float], float]
257-
"""Type alias for scalar functions (float -> float)."""
258-
259256

260257
class CharacteristicName(StrEnum):
261258
"""
@@ -301,7 +298,6 @@ class FamilyName(StrEnum):
301298
"ParametrizationName",
302299
"ComputationFunc",
303300
"DistributionType",
304-
"ScalarFunc",
305301
"Interval1D",
306302
"ContinuousSupportShape1D",
307303
"BoolArray",

tests/unit/distributions/test_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_configuration_continuous_presence_and_connectivity(self) -> None:
4444
assert view.find_path(CharacteristicName.PPF, CharacteristicName.CDF) is not None
4545

4646
# Strategy resolves and roundtrips: CDF(PPF(q)) ~ q
47-
strategy = DefaultComputationStrategy[float, float](enable_caching=False)
47+
strategy = DefaultComputationStrategy(enable_caching=False)
4848
ppf = strategy.query_method(CharacteristicName.PPF, distr)
4949
cdf = strategy.query_method(CharacteristicName.CDF, distr)
5050
qs = np.linspace(1e-6, 1.0 - 1e-6, 7)
@@ -76,7 +76,7 @@ def test_configuration_discrete_requires_support_then_ok(self) -> None:
7676
assert view.find_path(CharacteristicName.CDF, CharacteristicName.PPF) is not None
7777
assert view.find_path(CharacteristicName.PPF, CharacteristicName.CDF) is not None
7878

79-
strategy = DefaultComputationStrategy[float, float](enable_caching=False)
79+
strategy = DefaultComputationStrategy(enable_caching=False)
8080
cdf = strategy.query_method(CharacteristicName.CDF, distr)
8181
assert cdf(0.0) == pytest.approx(0.2, abs=1e-10)
8282
assert cdf(1.0) == pytest.approx(0.7, abs=1e-10)

tests/utils/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def sampling_strategy(self) -> SamplingStrategy:
8181
return DefaultSamplingUnivariateStrategy()
8282

8383
@property
84-
def computation_strategy(self) -> ComputationStrategy[Any, Any]:
84+
def computation_strategy(self) -> ComputationStrategy:
8585
"""Computation strategy instance."""
8686
return DefaultComputationStrategy()
8787

0 commit comments

Comments
 (0)