Skip to content

Commit a8ed14f

Browse files
committed
test(distributions): add tests for execution plans and option-aware caching
1 parent 08a75af commit a8ed14f

23 files changed

Lines changed: 1218 additions & 518 deletions

tests/unit/distributions/computations/test_base.py

Lines changed: 0 additions & 461 deletions
This file was deleted.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Tests for CharacteristicOption descriptor.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
__author__ = "Irina Sergeeva"
8+
__copyright__ = "Copyright (c) 2025 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
11+
from typing import Any
12+
13+
import pytest
14+
15+
from pysatl_core.distributions.computations.base import CharacteristicOption
16+
17+
18+
class TestCharacteristicOption:
19+
"""Tests for the CharacteristicOption dataclass."""
20+
21+
def test_resolve_returns_default_when_key_absent(self) -> None:
22+
opt = CharacteristicOption(name="eps", type=float, default=1e-6)
23+
kwargs: dict[str, Any] = {}
24+
assert opt.resolve(kwargs) == 1e-6
25+
26+
def test_resolve_returns_caller_value_when_present(self) -> None:
27+
opt = CharacteristicOption(name="eps", type=float, default=1e-6)
28+
kwargs: dict[str, Any] = {"eps": 1e-3}
29+
assert opt.resolve(kwargs) == pytest.approx(1e-3)
30+
31+
def test_resolve_pops_key_from_kwargs(self) -> None:
32+
opt = CharacteristicOption(name="eps", type=float, default=1e-6)
33+
kwargs: dict[str, Any] = {"eps": 1e-3, "other": 42}
34+
opt.resolve(kwargs)
35+
assert "eps" not in kwargs
36+
assert "other" in kwargs
37+
38+
def test_resolve_casts_to_declared_type(self) -> None:
39+
opt = CharacteristicOption(name="x0", type=float, default=0.0)
40+
kwargs: dict[str, Any] = {"x0": 1}
41+
result = opt.resolve(kwargs)
42+
assert isinstance(result, float)
43+
assert result == 1.0
44+
45+
def test_resolve_raises_type_error_on_bad_cast(self) -> None:
46+
opt = CharacteristicOption(name="eps", type=float, default=1e-6)
47+
kwargs: dict[str, Any] = {"eps": "not_a_number"}
48+
with pytest.raises(TypeError, match="cannot convert"):
49+
opt.resolve(kwargs)
50+
51+
def test_resolve_raises_value_error_on_failed_validation(self) -> None:
52+
opt = CharacteristicOption(
53+
name="eps", type=float, default=1e-6, validate=lambda v: 0 < v < 0.5
54+
)
55+
kwargs: dict[str, Any] = {"eps": -1.0}
56+
with pytest.raises(ValueError, match="failed validation"):
57+
opt.resolve(kwargs)
58+
59+
def test_resolve_passes_validation_when_valid(self) -> None:
60+
opt = CharacteristicOption(
61+
name="eps", type=float, default=1e-6, validate=lambda v: 0 < v < 0.5
62+
)
63+
kwargs: dict[str, Any] = {"eps": 0.1}
64+
assert opt.resolve(kwargs) == pytest.approx(0.1)
65+
66+
def test_frozen_dataclass(self) -> None:
67+
opt = CharacteristicOption(name="eps", type=float, default=1e-6)
68+
with pytest.raises(AttributeError):
69+
opt.name = "other" # type: ignore[misc]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Tests for ComputationOption descriptor.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
__author__ = "Irina Sergeeva"
8+
__copyright__ = "Copyright (c) 2025 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
11+
from typing import Any
12+
13+
import pytest
14+
15+
from pysatl_core.distributions.computations.base import ComputationOption
16+
17+
18+
class TestComputationOption:
19+
"""Tests for the ComputationOption dataclass."""
20+
21+
def test_resolve_returns_default_when_key_absent(self) -> None:
22+
opt = ComputationOption(name="limit", type=int, default=200)
23+
kwargs: dict[str, Any] = {}
24+
assert opt.resolve(kwargs) == 200
25+
26+
def test_resolve_returns_caller_value_when_present(self) -> None:
27+
opt = ComputationOption(name="limit", type=int, default=200)
28+
kwargs: dict[str, Any] = {"limit": 500}
29+
assert opt.resolve(kwargs) == 500
30+
31+
def test_resolve_pops_key_from_kwargs(self) -> None:
32+
opt = ComputationOption(name="limit", type=int, default=200)
33+
kwargs: dict[str, Any] = {"limit": 500, "other": 42}
34+
opt.resolve(kwargs)
35+
assert "limit" not in kwargs
36+
assert "other" in kwargs
37+
38+
def test_resolve_casts_to_declared_type(self) -> None:
39+
opt = ComputationOption(name="h", type=float, default=1e-5)
40+
kwargs: dict[str, Any] = {"h": 1}
41+
result = opt.resolve(kwargs)
42+
assert isinstance(result, float)
43+
assert result == 1.0
44+
45+
def test_resolve_raises_type_error_on_bad_cast(self) -> None:
46+
opt = ComputationOption(name="limit", type=int, default=200)
47+
kwargs: dict[str, Any] = {"limit": "not_a_number"}
48+
with pytest.raises(TypeError, match="cannot convert"):
49+
opt.resolve(kwargs)
50+
51+
def test_resolve_raises_value_error_on_failed_validation(self) -> None:
52+
opt = ComputationOption(name="limit", type=int, default=200, validate=lambda v: v > 0)
53+
kwargs: dict[str, Any] = {"limit": -1}
54+
with pytest.raises(ValueError, match="failed validation"):
55+
opt.resolve(kwargs)
56+
57+
def test_resolve_passes_validation_when_valid(self) -> None:
58+
opt = ComputationOption(name="limit", type=int, default=200, validate=lambda v: v > 0)
59+
kwargs: dict[str, Any] = {"limit": 100}
60+
assert opt.resolve(kwargs) == 100
61+
62+
def test_resolve_no_validation_when_none(self) -> None:
63+
opt = ComputationOption(name="x", type=float, default=0.0, validate=None)
64+
kwargs: dict[str, Any] = {"x": -999.0}
65+
assert opt.resolve(kwargs) == -999.0
66+
67+
def test_frozen_dataclass(self) -> None:
68+
opt = ComputationOption(name="x", type=float, default=0.0)
69+
with pytest.raises(AttributeError):
70+
opt.name = "y" # type: ignore[misc]

tests/unit/distributions/fitters/test_continuous.py renamed to tests/unit/distributions/computations/test_continuous.py

File renamed without changes.

tests/unit/distributions/fitters/test_discrete.py renamed to tests/unit/distributions/computations/test_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytest
2323
from mypy_extensions import KwArg
2424

25-
from pysatl_core.distributions.computation import AnalyticalComputation
25+
from pysatl_core.distributions.computations.computation import AnalyticalComputation
2626
from pysatl_core.distributions.computations.discrete import (
2727
FITTER_CDF_TO_PMF_1D,
2828
FITTER_CDF_TO_PPF_1D,
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Tests for EvaluatorDescriptor.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
__author__ = "Irina Sergeeva"
8+
__copyright__ = "Copyright (c) 2025 PySATL project"
9+
__license__ = "SPDX-License-Identifier: MIT"
10+
11+
from typing import Any
12+
13+
import numpy as np
14+
import pytest
15+
16+
from pysatl_core.distributions.computations.base import (
17+
CharacteristicOption,
18+
ComputationOption,
19+
EvaluatorDescriptor,
20+
)
21+
from pysatl_core.types import CharacteristicName, NumericArray
22+
23+
24+
class TestEvaluatorDescriptor:
25+
"""Tests for the EvaluatorDescriptor dataclass (non-cacheable evaluators)."""
26+
27+
@staticmethod
28+
def _dummy_evaluator(distribution: Any, x: NumericArray, /, **kwargs: Any) -> NumericArray:
29+
return np.zeros_like(np.asarray(x, dtype=float))
30+
31+
def _make_descriptor(self, **overrides: Any) -> EvaluatorDescriptor:
32+
defaults: dict[str, Any] = {
33+
"name": "test_evaluator",
34+
"target": CharacteristicName.PDF,
35+
"sources": [CharacteristicName.CDF],
36+
"evaluator": self._dummy_evaluator,
37+
"characteristic_options": (
38+
CharacteristicOption(
39+
name="tol", type=float, default=1e-8, validate=lambda v: v > 0
40+
),
41+
),
42+
"computation_options": (
43+
ComputationOption(name="max_iter", type=int, default=50, validate=lambda v: v > 0),
44+
),
45+
"constraint_tags": frozenset({"continuous", "univariate"}),
46+
"description": "Test evaluator.",
47+
}
48+
defaults.update(overrides)
49+
return EvaluatorDescriptor(**defaults)
50+
51+
def test_options_property_combines_both_kinds(self) -> None:
52+
desc = self._make_descriptor()
53+
names = tuple(o.name for o in desc.options)
54+
# characteristic options come first
55+
assert names == ("tol", "max_iter")
56+
57+
def test_resolve_characteristic_options_returns_defaults(self) -> None:
58+
desc = self._make_descriptor()
59+
kwargs: dict[str, Any] = {}
60+
opts = desc.resolve_characteristic_options(kwargs)
61+
assert opts == {"tol": 1e-8}
62+
63+
def test_resolve_characteristic_options_does_not_consume_computation_keys(self) -> None:
64+
desc = self._make_descriptor()
65+
kwargs: dict[str, Any] = {"tol": 1e-6, "max_iter": 100}
66+
desc.resolve_characteristic_options(kwargs)
67+
assert "max_iter" in kwargs
68+
assert "tol" not in kwargs
69+
70+
def test_resolve_computation_options_returns_defaults(self) -> None:
71+
desc = self._make_descriptor()
72+
kwargs: dict[str, Any] = {}
73+
opts = desc.resolve_computation_options(kwargs)
74+
assert opts == {"max_iter": 50}
75+
76+
def test_resolve_computation_options_does_not_consume_characteristic_keys(self) -> None:
77+
desc = self._make_descriptor()
78+
kwargs: dict[str, Any] = {"tol": 1e-6, "max_iter": 100}
79+
desc.resolve_computation_options(kwargs)
80+
assert "tol" in kwargs
81+
assert "max_iter" not in kwargs
82+
83+
def test_resolve_options_returns_all_defaults(self) -> None:
84+
desc = self._make_descriptor()
85+
kwargs: dict[str, Any] = {}
86+
opts = desc.resolve_options(kwargs)
87+
assert opts == {"tol": 1e-8, "max_iter": 50}
88+
89+
def test_resolve_options_uses_caller_values(self) -> None:
90+
desc = self._make_descriptor()
91+
kwargs: dict[str, Any] = {"tol": 1e-6, "max_iter": 100}
92+
opts = desc.resolve_options(kwargs)
93+
assert opts["tol"] == pytest.approx(1e-6)
94+
assert opts["max_iter"] == 100
95+
96+
def test_option_names_returns_all(self) -> None:
97+
desc = self._make_descriptor()
98+
assert desc.option_names() == ("tol", "max_iter")
99+
100+
def test_characteristic_option_names(self) -> None:
101+
desc = self._make_descriptor()
102+
assert desc.characteristic_option_names() == ("tol",)
103+
104+
def test_computation_option_names(self) -> None:
105+
desc = self._make_descriptor()
106+
assert desc.computation_option_names() == ("max_iter",)
107+
108+
def test_option_defaults_returns_all(self) -> None:
109+
desc = self._make_descriptor()
110+
assert desc.option_defaults() == {"tol": 1e-8, "max_iter": 50}
111+
112+
def test_frozen_dataclass(self) -> None:
113+
desc = self._make_descriptor()
114+
with pytest.raises(AttributeError):
115+
desc.name = "other" # type: ignore[misc]
116+
117+
def test_to_computation_method_returns_evaluator_method(self) -> None:
118+
from pysatl_core.distributions.computations.computation import EvaluatorMethod
119+
120+
desc = self._make_descriptor()
121+
cm = desc.to_computation_method()
122+
assert isinstance(cm, EvaluatorMethod)
123+
assert cm.evaluator is not None
124+
assert cm.target == CharacteristicName.PDF
125+
assert list(cm.sources) == [CharacteristicName.CDF]
126+
127+
def test_empty_options(self) -> None:
128+
desc = self._make_descriptor(characteristic_options=(), computation_options=())
129+
assert desc.option_names() == ()
130+
assert desc.option_defaults() == {}
131+
assert desc.resolve_options({}) == {}
132+
assert desc.resolve_characteristic_options({}) == {}
133+
assert desc.resolve_computation_options({}) == {}

0 commit comments

Comments
 (0)