Skip to content

Commit d31fc54

Browse files
author
domosedy
committed
feat(knowledge base): add prototype of knowledge graph
1 parent 937137c commit d31fc54

4 files changed

Lines changed: 116 additions & 0 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from collections.abc import Callable
2+
3+
from pysatl_core.families.parametrizations import Parametrization
4+
5+
6+
class EdgeWithFixedParametrization:
7+
def __init__(
8+
self,
9+
tail_name: str,
10+
transform_constraint: Callable[[Parametrization], bool],
11+
transform_function: Callable[[Parametrization], Parametrization],
12+
):
13+
self._transform_function = transform_function
14+
self._transform_constraint = transform_constraint
15+
self._tail_name = tail_name
16+
17+
def is_transoform_possible(self, parametrization: Parametrization) -> bool:
18+
return self._transform_constraint(parametrization)
19+
20+
def transform_parametrization(self, parametrization: Parametrization) -> Parametrization:
21+
return self._transform_function(parametrization)
22+
23+
@property
24+
def tail_name(self) -> str:
25+
return self._tail_name

src/pysatl_core/families/parametric_family.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from pysatl_core.distributions.computation import AnalyticalComputation
2323
from pysatl_core.families.distribution import ParametricFamilyDistribution
24+
from pysatl_core.families.registry import ParametricFamilyRegister
2425
from pysatl_core.types import (
2526
DEFAULT_ANALYTICAL_COMPUTATION_LABEL,
2627
ComputationFunc,
@@ -386,6 +387,19 @@ def distribution(
386387
parameters = parametrization_class(**parameters_values)
387388
parameters.validate()
388389
base_parameters = self.to_base(parameters)
390+
registry = ParametricFamilyRegister()
391+
optimized_family = registry.get_optimal_family(self.name, parameters)
392+
if optimized_family is not None:
393+
edge, new_family = optimized_family
394+
new_parametrization = edge.transform_parametrization(parameters)
395+
396+
return new_family.distribution(
397+
parametrization_name=None,
398+
sampling_strategy=sampling_strategy,
399+
computation_strategy=computation_strategy,
400+
**new_parametrization.parameters,
401+
)
402+
389403
distribution_type = self._distr_type(base_parameters)
390404
analytical_computations = self._build_analytical_computations(parameters)
391405
return ParametricFamilyDistribution(

src/pysatl_core/families/registry.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
from __future__ import annotations
1010

11+
from pysatl_core.families.fixed_parametrization_edge import EdgeWithFixedParametrization
12+
from pysatl_core.families.parametrizations import Parametrization
13+
1114
__author__ = "Leonid Elkin, Mikhail Mikhailov, Fedor Myznikov"
1215
__copyright__ = "Copyright (c) 2025 PySATL project"
1316
__license__ = "SPDX-License-Identifier: MIT"
1417

18+
from collections.abc import Callable
1519
from typing import TYPE_CHECKING
1620

1721
if TYPE_CHECKING:
@@ -30,14 +34,49 @@ class ParametricFamilyRegister:
3034

3135
_instance: ClassVar[ParametricFamilyRegister | None] = None
3236
_registered_families: dict[str, ParametricFamily]
37+
_registered_transformations: dict[str, list[EdgeWithFixedParametrization]]
3338

3439
def __new__(cls) -> ParametricFamilyRegister:
3540
"""Create or return the singleton instance."""
3641
if cls._instance is None:
3742
cls._instance = super().__new__(cls)
3843
cls._instance._registered_families = {}
44+
cls._instance._registered_transformations = {}
3945
return cls._instance
4046

47+
@classmethod
48+
def get_optimal_family(
49+
cls, name: str, parametrization: Parametrization
50+
) -> None | tuple[EdgeWithFixedParametrization, ParametricFamily]:
51+
self = cls()
52+
53+
for optimization_edge in self._registered_transformations.get(name, []):
54+
if (
55+
optimization_edge.tail_name in self._registered_families
56+
and optimization_edge.is_transoform_possible(parametrization)
57+
):
58+
return optimization_edge, self._registered_families[optimization_edge.tail_name]
59+
60+
return None
61+
62+
@classmethod
63+
def add_optimization_edge(
64+
cls,
65+
head_name: str,
66+
tail_name: str,
67+
transform_constraint: Callable[[Parametrization], bool],
68+
transform_function: Callable[[Parametrization], Parametrization],
69+
) -> bool:
70+
self = cls()
71+
72+
if head_name in self._registered_families:
73+
self._registered_transformations.setdefault(head_name, []).append(
74+
EdgeWithFixedParametrization(tail_name, transform_constraint, transform_function)
75+
)
76+
return True
77+
78+
return False
79+
4180
@classmethod
4281
def get(cls, name: str) -> ParametricFamily:
4382
"""
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pysatl_core.families.configuration import configure_families_register
2+
from pysatl_core.families.parametrizations import Parametrization
3+
from pysatl_core.types import CharacteristicName, FamilyName
4+
5+
6+
def test_optimization():
7+
registry = configure_families_register()
8+
exponential_fam = registry.get(FamilyName.EXPONENTIAL)
9+
uniform_fam = registry.get(FamilyName.CONTINUOUS_UNIFORM)
10+
11+
def transform_function(head_param: Parametrization) -> Parametrization:
12+
tail_param_type = uniform_fam.get_parametrization(uniform_fam.base_parametrization_name)
13+
14+
return tail_param_type(lower_bound=0, upper_bound=1) # type: ignore[call-arg]
15+
16+
def transform_constraint(head_param: Parametrization) -> bool:
17+
exponential_fam.get_parametrization(exponential_fam.base_parametrization_name)
18+
19+
return head_param.lambda_ == 1.0 # type: ignore[attr-defined]
20+
21+
registry.add_optimization_edge(
22+
FamilyName.EXPONENTIAL,
23+
FamilyName.CONTINUOUS_UNIFORM,
24+
transform_constraint,
25+
transform_function,
26+
)
27+
28+
exponential = exponential_fam(lambda_=1.0)
29+
exponential_parametrization = exponential.parametrization
30+
31+
assert exponential.family_name == FamilyName.CONTINUOUS_UNIFORM
32+
assert exponential_parametrization.lower_bound == 0 # type: ignore[attr-defined]
33+
assert exponential_parametrization.upper_bound == 1 # type: ignore[attr-defined]
34+
35+
pdf = exponential.query_method(CharacteristicName.PDF)
36+
37+
assert pdf(0.5) == 1
38+
assert pdf(10) == 0

0 commit comments

Comments
 (0)