Skip to content

Commit 280efce

Browse files
committed
feat(knowledge base): add optimization with knowledge base for ParametricFamilyDistribution
1 parent 6ac34bf commit 280efce

5 files changed

Lines changed: 299 additions & 5 deletions

File tree

src/pysatl_core/families/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
parametrization,
2323
)
2424
from .registry import ParametricFamilyRegister
25+
from .registry_graph import BinaryOperationType
2526

2627
__all__ = [
2728
"ParametricFamilyRegister",
@@ -32,6 +33,7 @@
3233
"constraint",
3334
"parametrization",
3435
"configure_families_register",
36+
"BinaryOperationType",
3537
# builtins
3638
*_builtins_all,
3739
]

src/pysatl_core/families/distribution.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77

88
from __future__ import annotations
99

10+
import dataclasses
11+
12+
from pysatl_core.families.registry_graph import BinaryOperationType
13+
1014
__author__ = "Leonid Elkin, Mikhail Mikhailov"
1115
__copyright__ = "Copyright (c) 2025 PySATL project"
1216
__license__ = "SPDX-License-Identifier: MIT"
1317

18+
from types import NotImplementedType
1419
from typing import TYPE_CHECKING
1520

1621
from pysatl_core.distributions.distribution import _KEEP, Distribution
@@ -193,3 +198,112 @@ def sample(self, n: int, **options: Any) -> NumericArray:
193198
the sampling strategy.
194199
"""
195200
return self.sampling_strategy.sample(n, distr=self, **options)
201+
202+
def _try_to_transform_with_optimization(
203+
self, other: ParametricFamilyDistribution, kind: BinaryOperationType
204+
) -> None | ParametricFamilyDistribution:
205+
registry = ParametricFamilyRegister()
206+
transform_result = registry.find_binary_transformation(
207+
self.family_name, other.family_name, kind
208+
)
209+
if transform_result is None:
210+
return None
211+
212+
family, transform_parametrization = transform_result
213+
new_parametrization = transform_parametrization(
214+
self.parametrization.transform_to_base_parametrization(),
215+
other.parametrization.transform_to_base_parametrization(),
216+
)
217+
return family(new_parametrization.name, **dataclasses.asdict(new_parametrization)) # type:ignore[call-overload]
218+
219+
def __add__(
220+
self, other: object
221+
) -> ParametricFamilyDistribution | Distribution | NotImplementedType:
222+
"""Return ``self + other`` for scalar or distribution operands."""
223+
if isinstance(other, ParametricFamilyDistribution):
224+
transformation_result = self._try_to_transform_with_optimization(
225+
other, BinaryOperationType.ADD
226+
)
227+
if transformation_result is not None:
228+
return transformation_result
229+
230+
return TransformationOperatorsMixin.__add__(self, other)
231+
232+
def __radd__(
233+
self, other: object
234+
) -> ParametricFamilyDistribution | Distribution | NotImplementedType:
235+
"""Return ``other + self`` for scalar or distribution operands."""
236+
if isinstance(other, ParametricFamilyDistribution):
237+
transformation_result = other._try_to_transform_with_optimization(
238+
self, BinaryOperationType.ADD
239+
)
240+
if transformation_result is not None:
241+
return transformation_result
242+
243+
return TransformationOperatorsMixin.__radd__(self, other)
244+
245+
def __sub__(self, other: object) -> Distribution | NotImplementedType:
246+
"""Return ``self - other`` for scalar or distribution operands."""
247+
if isinstance(other, ParametricFamilyDistribution):
248+
transformation_result = self._try_to_transform_with_optimization(
249+
other, BinaryOperationType.SUB
250+
)
251+
if transformation_result is not None:
252+
return transformation_result
253+
254+
return TransformationOperatorsMixin.__sub__(self, other)
255+
256+
def __rsub__(self, other: object) -> Distribution | NotImplementedType:
257+
"""Return ``other - self`` for scalar or distribution operands."""
258+
if isinstance(other, ParametricFamilyDistribution):
259+
transformation_result = other._try_to_transform_with_optimization(
260+
self, BinaryOperationType.SUB
261+
)
262+
if transformation_result is not None:
263+
return transformation_result
264+
265+
return TransformationOperatorsMixin.__rsub__(self, other)
266+
267+
def __mul__(self, other: object) -> Distribution | NotImplementedType:
268+
"""Return ``self * other`` for scalar or distribution operands."""
269+
if isinstance(other, ParametricFamilyDistribution):
270+
transformation_result = self._try_to_transform_with_optimization(
271+
other, BinaryOperationType.MUL
272+
)
273+
if transformation_result is not None:
274+
return transformation_result
275+
276+
return TransformationOperatorsMixin.__mul__(self, other)
277+
278+
def __rmul__(self, other: object) -> Distribution | NotImplementedType:
279+
"""Return ``other * self`` for scalar or distribution operands."""
280+
if isinstance(other, ParametricFamilyDistribution):
281+
transformation_result = other._try_to_transform_with_optimization(
282+
self, BinaryOperationType.MUL
283+
)
284+
if transformation_result is not None:
285+
return transformation_result
286+
287+
return TransformationOperatorsMixin.__rmul__(self, other)
288+
289+
def __truediv__(self, other: object) -> Distribution | NotImplementedType:
290+
"""Return ``self / other`` for scalar or distribution operands."""
291+
if isinstance(other, ParametricFamilyDistribution):
292+
transformation_result = self._try_to_transform_with_optimization(
293+
other, BinaryOperationType.DIV
294+
)
295+
if transformation_result is not None:
296+
return transformation_result
297+
298+
return TransformationOperatorsMixin.__truediv__(self, other)
299+
300+
def __rtruediv__(self, other: object) -> Distribution | NotImplementedType:
301+
"""Return ``other / self`` for distribution operands."""
302+
if isinstance(other, ParametricFamilyDistribution):
303+
transformation_result = other._try_to_transform_with_optimization(
304+
self, BinaryOperationType.DIV
305+
)
306+
if transformation_result is not None:
307+
return transformation_result
308+
309+
return TransformationOperatorsMixin.__rtruediv__(self, other)

src/pysatl_core/families/registry.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from __future__ import annotations
1010

1111
from pysatl_core.families.parametrizations import Parametrization
12-
from pysatl_core.families.registry_graph import RegistryGraphTransformations
12+
from pysatl_core.families.registry_graph import (
13+
BinaryOperationType,
14+
RegistryGraphTransformations,
15+
)
1316

1417
__author__ = "Leonid Elkin, Mikhail Mikhailov, Fedor Myznikov"
1518
__copyright__ = "Copyright (c) 2025 PySATL project"
@@ -69,6 +72,41 @@ def get_optimal_density(
6972
family_name, transform_function = self._registry_graph.get_optimal_transoformation(name)
7073
return self.get(family_name), transform_function
7174

75+
@classmethod
76+
def add_binary_transformation(
77+
cls,
78+
left: str,
79+
right: str,
80+
result: str,
81+
operation: BinaryOperationType,
82+
parametrization_transformation: Callable[
83+
[Parametrization, Parametrization], Parametrization
84+
],
85+
) -> bool:
86+
self = cls()
87+
return self._registry_graph.add_binary_transformation(
88+
left, right, result, operation, parametrization_transformation
89+
)
90+
91+
@classmethod
92+
def find_binary_transformation(
93+
cls, left: str, right: str, operation: BinaryOperationType
94+
) -> (
95+
None
96+
| tuple[ParametricFamily, Callable[[Parametrization, Parametrization], Parametrization]]
97+
):
98+
self = cls()
99+
transformation = self._registry_graph.find_binary_transform(left, right, operation)
100+
101+
if transformation is None:
102+
return transformation
103+
104+
family = self._registered_families.get(transformation.result, None)
105+
if family is None:
106+
return family
107+
108+
return family, transformation.transformation
109+
72110
@classmethod
73111
def register_parametrization_transformation(
74112
cls,
@@ -145,14 +183,14 @@ def register(cls, family: ParametricFamily, temperature: int = 128) -> None:
145183
If a family with the same name is already registered.
146184
"""
147185
self = cls()
148-
self.change_family_temperature(family.name, temperature)
186+
self._change_family_temperature(family.name, temperature)
149187

150188
if family.name in self._registered_families:
151189
raise ValueError(f"Family {family.name} already found in register")
152190
self._registered_families[family.name] = family
153191

154192
@classmethod
155-
def change_family_temperature(cls, family_name: str, new_temperature: int) -> None:
193+
def _change_family_temperature(cls, family_name: str, new_temperature: int) -> None:
156194
self = cls()
157195
self._registry_graph.register_family_temperature(family_name, new_temperature)
158196

src/pysatl_core/families/registry_graph.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4+
from enum import StrEnum
45
from queue import Queue
56
from typing import TYPE_CHECKING, cast
67

@@ -60,20 +61,122 @@ def transform_density(self, argument: Number | NumericArray) -> Number | Numeric
6061
return self._transform_function(argument)
6162

6263

64+
class BinaryOperationType(StrEnum):
65+
ADD = "add"
66+
SUB = "sub"
67+
MUL = "multiply"
68+
DIV = "divide"
69+
70+
71+
class FamilyBinaryOperationTransformationRecord:
72+
"""
73+
Record for a bainray transformation between 2 families
74+
It doesn't accept constraints on transformation, just for now
75+
"""
76+
77+
def __init__(
78+
self,
79+
left: str,
80+
right: str,
81+
result: str,
82+
kind: BinaryOperationType,
83+
parametrization_transformation: Callable[
84+
[Parametrization, Parametrization], Parametrization
85+
],
86+
):
87+
self._left = left
88+
self._right = right
89+
self._kind = kind
90+
self._result = result
91+
self._transformation = parametrization_transformation
92+
93+
def accepts(self, left: str, right: str, kind: BinaryOperationType) -> bool:
94+
return (self._left, self._right, self._kind) == (left, right, kind)
95+
96+
def transform(
97+
self, left_parametrization: Parametrization, right_parametrization: Parametrization
98+
) -> tuple[str, Parametrization]:
99+
return self._result, self._transformation(left_parametrization, right_parametrization)
100+
101+
def __eq__(self, other: object) -> bool:
102+
if not isinstance(other, FamilyBinaryOperationTransformationRecord):
103+
return False
104+
105+
return self.accepts(other._left, other._right, other._kind)
106+
107+
@property
108+
def result(self) -> str:
109+
return self._result
110+
111+
@property
112+
def transformation(self) -> Callable[[Parametrization, Parametrization], Parametrization]:
113+
return self._transformation
114+
115+
def __hash__(self) -> int:
116+
return (
117+
self._left.__hash__()
118+
^ self._right.__hash__()
119+
^ self._kind.__hash__()
120+
^ self._result.__hash__()
121+
)
122+
123+
63124
class RegistryGraphTransformations:
125+
"""
126+
Registry with transformations, such as transformation, when densitry equals or smth like this
127+
TODO: find a way to merge ways between density and parametrization transformations
128+
"""
129+
64130
_instance: ClassVar[RegistryGraphTransformations | None] = None
65131
_registered_families_temperature: dict[str, int]
66132
_registered_parametrzation_transformations: dict[str, list[TransformatedParametrizationEdge]]
67133
_registered_transformations: dict[str, list[TransformatedDensityEdge]]
134+
_binary_transformations: list[FamilyBinaryOperationTransformationRecord]
68135

69136
def __new__(cls) -> RegistryGraphTransformations:
70137
if cls._instance is None:
71138
cls._instance = super().__new__(cls)
72139
cls._registered_parametrzation_transformations = {}
73140
cls._registered_families_temperature = {}
74141
cls._registered_transformations = {}
142+
cls._binary_transformations = []
75143
return cls._instance
76144

145+
@classmethod
146+
def find_binary_transform(
147+
cls, left: str, right: str, operation: BinaryOperationType
148+
) -> None | FamilyBinaryOperationTransformationRecord:
149+
self = cls()
150+
151+
for transformation in self._binary_transformations:
152+
if transformation.accepts(left, right, operation):
153+
return transformation
154+
155+
return None
156+
157+
@classmethod
158+
def add_binary_transformation(
159+
cls,
160+
left: str,
161+
right: str,
162+
result: str,
163+
operation: BinaryOperationType,
164+
parametrization_transformation: Callable[
165+
[Parametrization, Parametrization], Parametrization
166+
],
167+
) -> bool:
168+
transformation = FamilyBinaryOperationTransformationRecord(
169+
left, right, result, operation, parametrization_transformation
170+
)
171+
self = cls()
172+
173+
for registered_transformation in self._binary_transformations:
174+
if transformation == registered_transformation:
175+
return False
176+
177+
self._binary_transformations.append(transformation)
178+
return True
179+
77180
@classmethod
78181
def _run_bfs(
79182
cls,

0 commit comments

Comments
 (0)