Skip to content

Commit 7d315a6

Browse files
committed
feat(transformations): transformations base primitives
1 parent 1581b1c commit 7d315a6

26 files changed

Lines changed: 2564 additions & 27 deletions

src/pysatl_core/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from importlib.metadata import version
1212

13+
# isort: off
14+
from .transformations import *
15+
from .transformations import __all__ as _transformations_all
1316
from .distributions import *
1417
from .distributions import __all__ as _distr_all
1518
from .families import *
@@ -18,6 +21,7 @@
1821
from .sampling import __all__ as _sampling_all
1922
from .types import *
2023
from .types import __all__ as _types_all
24+
# isort: on
2125

2226
__version__ = version("pysatl-core")
2327
__all__ = [
@@ -26,9 +30,11 @@
2630
*_family_all,
2731
*_types_all,
2832
*_sampling_all,
33+
*_transformations_all,
2934
]
3035

3136
del _distr_all
3237
del _family_all
3338
del _types_all
39+
del _transformations_all
3440
del _sampling_all

src/pysatl_core/distributions/distribution.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ComputationStrategy,
2121
SamplingStrategy,
2222
)
23+
from pysatl_core.transformations.operators_mixin import TransformationOperatorsMixin
2324
from pysatl_core.types import DEFAULT_ANALYTICAL_COMPUTATION_LABEL, NumericArray
2425

2526
_KEEP: object = object()
@@ -38,7 +39,7 @@
3839
)
3940

4041

41-
class Distribution(ABC):
42+
class Distribution(TransformationOperatorsMixin, ABC):
4243
"""
4344
Protocol defining the interface for probability distributions.
4445
@@ -57,7 +58,10 @@ class Distribution(ABC):
5758
| Mapping[LabelName, AnalyticalComputation[Any, Any]]
5859
),
5960
]
60-
Direct analytical computations provided by the distribution.
61+
Distribution-provided characteristic methods.
62+
For non-transformed distributions every method in this mapping is
63+
fully analytical, so this mapping matches the set of loops with
64+
``is_analytical=True`` in the graph view.
6165
sampling_strategy : SamplingStrategy
6266
Strategy for generating random samples.
6367
computation_strategy : ComputationStrategy
@@ -92,7 +96,9 @@ def __init__(
9296
| Mapping[LabelName, AnalyticalComputation[Any, Any]]
9397
),
9498
]
95-
Analytical computations provided by the distribution.
99+
Distribution-provided characteristic methods.
100+
For non-transformed distributions these methods are fully
101+
analytical.
96102
support : Support or None, default=None
97103
Support of the distribution.
98104
sampling_strategy : SamplingStrategy or None, default=None
@@ -141,9 +147,45 @@ def distribution_type(self) -> DistributionType:
141147
def analytical_computations(
142148
self,
143149
) -> Mapping[GenericCharacteristicName, Mapping[LabelName, AnalyticalComputation[Any, Any]]]:
144-
"""Return analytical computations provided directly by this distribution."""
150+
"""
151+
Return distribution-provided characteristic methods.
152+
153+
For non-transformed distributions this mapping coincides with
154+
graph loops marked as ``is_analytical=True``.
155+
"""
145156
return self._analytical_computations
146157

158+
def loop_is_analytical(
159+
self,
160+
characteristic_name: GenericCharacteristicName,
161+
label_name: LabelName,
162+
) -> bool:
163+
"""
164+
Tell whether a self-loop method is fully analytical in the graph.
165+
166+
Parameters
167+
----------
168+
characteristic_name : GenericCharacteristicName
169+
Characteristic name of the self-loop.
170+
label_name : LabelName
171+
Label of the analytical computation variant.
172+
173+
Returns
174+
-------
175+
bool
176+
``True`` when every required predecessor in the transformation
177+
chain is analytical.
178+
179+
Notes
180+
-----
181+
Presence in ``analytical_computations`` means that a characteristic
182+
has at least one analytical ancestor in its derivation chain.
183+
For non-transformed distributions these notions coincide, therefore
184+
this method always returns ``True``.
185+
"""
186+
_ = characteristic_name, label_name
187+
return True
188+
147189
@property
148190
def sampling_strategy(self) -> SamplingStrategy:
149191
"""Return the currently attached sampling strategy."""

src/pysatl_core/distributions/registry/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ComputationEdgeMeta,
3030
EdgeMeta,
3131
GraphInvariantError,
32+
TransformationLoopEdgeMeta,
3233
)
3334

3435
__all__ = [
@@ -37,6 +38,7 @@
3738
"EdgeMeta",
3839
"ComputationEdgeMeta",
3940
"AnalyticalLoopEdgeMeta",
41+
"TransformationLoopEdgeMeta",
4042
"GraphInvariantError",
4143
# Constraint types
4244
"Constraint",

src/pysatl_core/distributions/registry/graph.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ComputationEdgeMeta,
3434
EdgeMeta,
3535
GraphInvariantError,
36+
TransformationLoopEdgeMeta,
3637
)
3738

3839
if TYPE_CHECKING:
@@ -304,6 +305,15 @@ def _compute_definitive_nodes(self, distr: Distribution) -> set[GenericCharacter
304305
definitive.add(name)
305306
return definitive
306307

308+
@staticmethod
309+
def _loop_is_analytical(
310+
distr: Distribution,
311+
characteristic_name: GenericCharacteristicName,
312+
label_name: LabelName,
313+
) -> bool:
314+
"""Resolve loop analytical flag for a distribution-provided method."""
315+
return distr.loop_is_analytical(characteristic_name, label_name)
316+
307317
@staticmethod
308318
def _attach_analytical_loops(
309319
adj: dict[
@@ -314,12 +324,14 @@ def _attach_analytical_loops(
314324
present_nodes: set[GenericCharacteristicName],
315325
) -> None:
316326
"""
317-
Attach analytical self-loops for distribution-provided computations.
327+
Attach distribution-provided self-loops to the view graph.
318328
319329
Notes
320330
-----
321-
Analytical loops are only added for characteristics present in this view.
322-
Each labeled analytical computation becomes one loop edge ``char -> char``.
331+
Loops are only added for characteristics present in this view.
332+
Each labeled computation in ``analytical_computations`` becomes one
333+
loop edge ``char -> char``. The loop class is selected via
334+
``distr.loop_is_analytical(...)``.
323335
"""
324336
for characteristic_name, labeled_methods in distr.analytical_computations.items():
325337
if characteristic_name not in present_nodes:
@@ -329,7 +341,15 @@ def _attach_analytical_loops(
329341
characteristic_name, {}
330342
)
331343
for label_name, analytical_method in labeled_methods.items():
332-
loop_variants[label_name] = AnalyticalLoopEdgeMeta(method=analytical_method)
344+
loop_variants[label_name] = (
345+
AnalyticalLoopEdgeMeta(method=analytical_method)
346+
if CharacteristicRegistry._loop_is_analytical(
347+
distr,
348+
characteristic_name,
349+
label_name,
350+
)
351+
else TransformationLoopEdgeMeta(method=analytical_method)
352+
)
333353

334354
def view(self, distr: Distribution) -> RegistryView:
335355
"""

src/pysatl_core/distributions/registry/graph_primitives.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,23 @@ def edge_kind(self) -> str:
7777
return "analytical_loop"
7878

7979

80+
@dataclass(frozen=True, slots=True)
81+
class TransformationLoopEdgeMeta(EdgeMeta):
82+
"""
83+
Edge metadata for transformation-provided self-loops.
84+
85+
Such loops are attached from ``analytical_computations`` as regular
86+
stopping points for the strategy, but they are not considered fully
87+
analytical by the graph semantics.
88+
"""
89+
90+
method: AnalyticalComputation[Any, Any]
91+
is_analytical: bool = field(default=False)
92+
93+
def edge_kind(self) -> str:
94+
return "transformation_loop"
95+
96+
8097
class GraphInvariantError(RuntimeError):
8198
"""
8299
Raised when characteristic graph invariants are violated.

src/pysatl_core/distributions/strategies.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
__copyright__ = "Copyright (c) 2025 PySATL project"
1212
__license__ = "SPDX-License-Identifier: MIT"
1313

14-
from typing import TYPE_CHECKING, Protocol, cast
14+
from typing import TYPE_CHECKING, Any, Protocol, cast
1515

1616
import numpy as np
1717

@@ -20,14 +20,13 @@
2020

2121
if TYPE_CHECKING:
2222
from collections.abc import Mapping
23-
from typing import Any
2423

2524
from pysatl_core.distributions.computation import (
2625
AnalyticalComputation,
2726
FittedComputationMethod,
28-
Method,
2927
)
3028
from pysatl_core.distributions.distribution import Distribution
29+
from pysatl_core.distributions.registry.graph import RegistryView
3130
from pysatl_core.types import GenericCharacteristicName, LabelName
3231

3332

@@ -126,16 +125,30 @@ def _pick_analytical_method(
126125
f"Characteristic '{state}' provides no labeled analytical computations."
127126
) from exc
128127

128+
@staticmethod
129+
def _pick_loop_method(
130+
state: GenericCharacteristicName,
131+
view: RegistryView,
132+
) -> Method[Any, Any] | None:
133+
"""
134+
Pick the first available self-loop method for a characteristic in a view.
135+
"""
136+
loops = view.variants(state, state)
137+
if not loops:
138+
return None
139+
return cast(Method[Any, Any], next(iter(loops.values())).method)
140+
129141
def query_method(
130142
self, state: GenericCharacteristicName, distr: Distribution, **options: Any
131143
) -> Method[Any, Any]:
132144
"""
133145
Resolve a computation method for the target characteristic.
134146
135147
Resolution order:
136-
1. Analytical implementation from the distribution
137-
2. Cached fitted method (if caching enabled)
138-
3. Conversion path from an analytical characteristic via the graph
148+
1. Cached fitted method (if caching enabled)
149+
2. Analytical implementation for non-registry characteristics
150+
3. First self-loop from the registry view
151+
4. Conversion path from loop characteristics via the graph
139152
140153
Parameters
141154
----------
@@ -157,34 +170,46 @@ def query_method(
157170
If no analytical base exists, no conversion path is found,
158171
or a cycle is detected.
159172
"""
160-
# 1. Check for analytical implementation
161-
if state in distr.analytical_computations:
162-
return self._pick_analytical_method(state, distr.analytical_computations[state])
163-
164-
# 2. Check cache if enabled
173+
# 1. Check cache if enabled
165174
if self._enable_caching:
166175
cached = self._cache.get(state)
167176
if cached is not None:
168177
return cached
169178

170-
# 3. Require at least one analytical characteristic
179+
# 2. Require at least one analytical characteristic
171180
if not distr.analytical_computations:
172181
raise RuntimeError(
173182
"Distribution provides no analytical computations to ground conversions."
174183
)
175184

176-
# 4. Get filtered graph view for this distribution
177-
reg = characteristic_registry().view(distr)
185+
# 3. Non-registry characteristics are resolved directly.
186+
# It covers the situation where user is providing their analytical computation which isn't
187+
# in the graph
188+
registry = characteristic_registry()
189+
if state not in registry.declared_characteristics:
190+
if state in distr.analytical_computations:
191+
return self._pick_analytical_method(state, distr.analytical_computations[state])
192+
raise RuntimeError(
193+
f"Characteristic '{state}' is not declared in the registry and has no "
194+
"analytical implementation in the distribution."
195+
)
196+
197+
# 4. Get filtered graph view for this distribution.
198+
view = registry.view(distr)
178199

179200
self._push_guard(distr, state)
180201
try:
181-
# 5. Try each analytical characteristic as a source
202+
loop_method = self._pick_loop_method(state, view)
203+
if loop_method is not None:
204+
return loop_method
205+
206+
# 5. Try each loop characteristic as a source
182207
for src in distr.analytical_computations:
183-
if src == state:
184-
return self._pick_analytical_method(src, distr.analytical_computations[src])
208+
if not view.variants(src, src):
209+
continue
185210

186211
# Find conversion path in the graph
187-
path = reg.find_path(src, state)
212+
path = view.find_path(src, state)
188213
if not path:
189214
continue
190215

@@ -201,7 +226,8 @@ def query_method(
201226
return last_fitted
202227

203228
raise RuntimeError(
204-
f"No conversion path from any analytical characteristic to '{state}'."
229+
"No conversion path from any characteristic in "
230+
f"analytical_computations to '{state}'."
205231
)
206232
finally:
207233
self._pop_guard(distr, state)

0 commit comments

Comments
 (0)