Skip to content

Commit 96faae2

Browse files
committed
refactor(distributions): now the singleton is really a singleton, some checks have been added + test covering
1 parent 8b58c4e commit 96faae2

5 files changed

Lines changed: 73 additions & 18 deletions

File tree

src/pysatl_core/distributions/computations/continuous.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,15 +346,26 @@ def _fit_ppf_to_cdf_1C(
346346
lower bound of the CDF domain approximation.
347347
q_highest : float, default 1 - 1e-12
348348
*(Characteristic option)* Right bracket for root search. Defines the
349-
upper bound of the CDF domain approximation.
349+
upper bound of the CDF domain approximation. Must be strictly greater
350+
than *q_lowest*.
350351
max_iter : int, default 256
351352
*(Computation option)* Maximum brentq iterations per point.
352353
353354
Returns
354355
-------
355356
FittedComputationMethod[NumericArray, NumericArray]
356357
Array-semantic ``cdf`` callable.
358+
359+
Raises
360+
------
361+
ValueError
362+
If ``q_highest <= q_lowest``.
357363
"""
364+
if q_highest <= q_lowest:
365+
raise ValueError(
366+
f"q_highest must be greater than q_lowest, got q_lowest={q_lowest!r}, "
367+
f"q_highest={q_highest!r}."
368+
)
358369
ppf_func = resolve(distribution, CharacteristicName.PPF)
359370

360371
def _cdf(x: NumericArray, **options: Any) -> NumericArray:

src/pysatl_core/distributions/computations/registry.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
__license__ = "SPDX-License-Identifier: MIT"
1818

1919
from functools import lru_cache
20-
from typing import TYPE_CHECKING
20+
from typing import TYPE_CHECKING, Any, ClassVar, Self
2121

2222
if TYPE_CHECKING:
2323
from collections.abc import Sequence
@@ -30,19 +30,50 @@ class FitterRegistry:
3030
"""
3131
Registry that stores fitter descriptors and selects the best match.
3232
33+
This class is a singleton: every call to ``FitterRegistry()`` returns the
34+
same instance. Use ``FitterRegistry._reset()`` in tests to clear state.
35+
3336
Examples
3437
--------
3538
>>> registry = FitterRegistry()
3639
>>> registry.register(some_descriptor)
3740
>>> desc = registry.find("cdf", ["pdf"], required_tags={"continuous", "univariate"})
3841
"""
3942

43+
_instance: ClassVar[FitterRegistry | None] = None
44+
45+
def __new__(cls) -> Self:
46+
if cls._instance is None:
47+
cls._instance = super().__new__(cls)
48+
return cls._instance # type: ignore[return-value]
49+
4050
def __init__(self) -> None:
51+
if getattr(self, "_initialized", False):
52+
return
4153
self._by_key: dict[
4254
tuple[GenericCharacteristicName, tuple[GenericCharacteristicName, ...]],
4355
list[FitterDescriptor],
4456
] = {}
4557
self._all: list[FitterDescriptor] = []
58+
self._initialized = True
59+
60+
def __copy__(self) -> Self:
61+
"""Singleton copy returns the same instance."""
62+
return self
63+
64+
def __deepcopy__(self, memo: dict[Any, Any]) -> Self:
65+
"""Singleton deepcopy returns the same instance."""
66+
return self
67+
68+
@classmethod
69+
def _reset(cls) -> None:
70+
"""
71+
Clear the singleton instance (for testing purposes only).
72+
73+
Resets all registered descriptors and allows the next ``FitterRegistry()``
74+
call to create a fresh instance.
75+
"""
76+
cls._instance = None
4677

4778
def register(self, descriptor: FitterDescriptor) -> None:
4879
"""
@@ -153,8 +184,9 @@ def fitter_registry() -> FitterRegistry:
153184
-----
154185
- Descriptors are **not** created at import time; they are built here on
155186
first access, keeping module-level side-effects to a minimum.
156-
- Users who need a custom registry should instantiate ``FitterRegistry()``
157-
directly and populate it themselves.
187+
- ``FitterRegistry`` is a singleton; ``FitterRegistry() is fitter_registry()``
188+
is always ``True`` after the first call.
189+
- To reset the singleton (e.g. in tests), call ``reset_fitter_registry()``.
158190
"""
159191
from pysatl_core.distributions.computations.continuous import (
160192
_build_continuous_descriptors,
@@ -172,6 +204,7 @@ def fitter_registry() -> FitterRegistry:
172204
def reset_fitter_registry() -> None:
173205
"""Reset the cached fitter registry (useful in tests)."""
174206
fitter_registry.cache_clear()
207+
FitterRegistry._reset()
175208

176209

177210
__all__ = [

src/pysatl_core/distributions/registry/configuration.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import lru_cache
2727
from typing import TYPE_CHECKING
2828

29-
from pysatl_core.distributions.computations.registry import FitterRegistry, fitter_registry
29+
from pysatl_core.distributions.computations.registry import fitter_registry
3030
from pysatl_core.distributions.registry.constraint import (
3131
GraphPrimitiveConstraint,
3232
NonNullConstraint,
@@ -48,7 +48,6 @@
4848

4949
def _add_edges(
5050
reg: CharacteristicRegistry,
51-
fitter_reg: FitterRegistry,
5251
pairs: Iterable[tuple[GenericCharacteristicName, GenericCharacteristicName]],
5352
*,
5453
tags: frozenset[str],
@@ -62,8 +61,6 @@ def _add_edges(
6261
----------
6362
reg : CharacteristicRegistry
6463
Target characteristic graph.
65-
fitter_reg : FitterRegistry
66-
Index of fitter descriptors to query.
6764
pairs : Iterable[tuple[str, str]]
6865
``(source, target)`` pairs to register.
6966
tags : frozenset[str]
@@ -76,6 +73,7 @@ def _add_edges(
7673
RuntimeError
7774
If no descriptor matches one of the requested ``(source, target, tags)``.
7875
"""
76+
fitter_reg = fitter_registry()
7977
for src, tgt in pairs:
8078
descriptor = fitter_reg.find(tgt, [src], required_tags=tags)
8179
if descriptor is None:
@@ -92,8 +90,6 @@ def _add_edges(
9290

9391
def _configure(reg: CharacteristicRegistry) -> None:
9492
"""Default PySATL configuration for characteristic registry."""
95-
fitter_reg = fitter_registry()
96-
9793
dim1_constraint = NumericConstraint(allowed=frozenset({1}))
9894
kind_continuous = SetConstraint(allowed=frozenset({Kind.CONTINUOUS}))
9995
kind_discrete = SetConstraint(allowed=frozenset({Kind.DISCRETE}))
@@ -144,7 +140,6 @@ def _configure(reg: CharacteristicRegistry) -> None:
144140

145141
_add_edges(
146142
reg,
147-
fitter_reg,
148143
pairs=(
149144
(CharacteristicName.PDF, CharacteristicName.CDF),
150145
(CharacteristicName.CDF, CharacteristicName.PDF),
@@ -157,7 +152,6 @@ def _configure(reg: CharacteristicRegistry) -> None:
157152

158153
_add_edges(
159154
reg,
160-
fitter_reg,
161155
pairs=(
162156
(CharacteristicName.PMF, CharacteristicName.CDF),
163157
(CharacteristicName.CDF, CharacteristicName.PMF),

src/pysatl_core/distributions/strategies.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,9 @@ def query_method(
727727

728728
self._push_guard(distr, state)
729729
self._char_options_stack.append(dict(effective_char_options))
730+
injected_keys: list[tuple[int, GenericCharacteristicName]] = []
730731
try:
731732
last_fitted: FittedComputationMethod[Any, Any] | None = None
732-
injected_keys: list[tuple[int, GenericCharacteristicName]] = []
733733
for step_idx, edge in enumerate(cached_plan.edges):
734734
method = edge.method
735735

@@ -783,14 +783,14 @@ def query_method(
783783
)
784784
injected_keys.append(intermediate_key)
785785

786-
# Remove the temporary loop plans injected for intermediate targets.
787-
for key in injected_keys:
788-
self._path_cache.pop(key, None)
789-
790786
if last_fitted is None:
791787
raise RuntimeError(f"Empty path when resolving '{state}'.")
792788
return last_fitted
793789
finally:
790+
# Remove the temporary loop plans injected for intermediate targets.
791+
# Placed in finally to ensure cleanup even if fitting raises.
792+
for key in injected_keys:
793+
self._path_cache.pop(key, None)
794794
self._char_options_stack.pop()
795795
self._pop_guard(distr, state)
796796

tests/unit/distributions/computations/test_registry.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from typing import Any
1212

13+
import pytest
14+
1315
from pysatl_core.distributions.computations.descriptors import FitterDescriptor
14-
from pysatl_core.distributions.computations.registry import FitterRegistry
16+
from pysatl_core.distributions.computations.registry import FitterRegistry, reset_fitter_registry
1517
from pysatl_core.types import CharacteristicName
1618

1719

@@ -37,6 +39,11 @@ def _make_desc(
3739
class TestFitterRegistry:
3840
"""Tests for the FitterRegistry class."""
3941

42+
@pytest.fixture(autouse=True)
43+
def reset_registry(self) -> None:
44+
"""Reset the FitterRegistry singleton before each test."""
45+
reset_fitter_registry()
46+
4047
def test_register_and_find(self) -> None:
4148
reg = FitterRegistry()
4249
desc = _make_desc("pdf_to_cdf")
@@ -140,6 +147,16 @@ def test_contains(self) -> None:
140147
assert "pdf_to_cdf" in reg
141148
assert "nonexistent" not in reg
142149

150+
def test_singleton_returns_same_instance(self) -> None:
151+
reg1 = FitterRegistry()
152+
reg2 = FitterRegistry()
153+
assert reg1 is reg2
154+
155+
def test_fitter_registry_function_returns_same_instance(self) -> None:
156+
from pysatl_core.distributions.computations.registry import fitter_registry
157+
158+
assert FitterRegistry() is fitter_registry()
159+
143160
def test_different_source_target_pairs_are_independent(self) -> None:
144161
reg = FitterRegistry()
145162
pdf_cdf = _make_desc(

0 commit comments

Comments
 (0)