Skip to content

Commit 254360a

Browse files
committed
refactor(families): preserve partial family support and constraints
1 parent 0e0ff57 commit 254360a

2 files changed

Lines changed: 94 additions & 33 deletions

File tree

src/pysatl_core/families/parametric_family.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from pysatl_core.distributions.computations.computation import AnalyticalComputation
2424
from pysatl_core.families.distribution import ParametricFamilyDistribution
25-
from pysatl_core.families.parametrizations import Parametrization
25+
from pysatl_core.families.parametrizations import Parametrization, ParametrizationConstraint
2626
from pysatl_core.types import (
2727
DEFAULT_ANALYTICAL_COMPUTATION_LABEL,
2828
ComputationFunc,
@@ -546,36 +546,48 @@ def __init__(
546546
"Use `.distribution()` directly."
547547
)
548548

549-
# Generate lightweight parametrization with only free fields
550-
self._free_param_class = self._create_free_param_class()
551-
# Assign __param_name__ and __family__ so that instances have .name and .family
552-
self._free_param_class.__param_name__ = self._fixed_in_param
553-
self._free_param_class.__family__ = self
554-
555549
self._free_parameter_names = tuple(
556550
name
557551
for name in getattr(self._param_class, "__dataclass_fields__", {})
558552
if name not in self._fixed_params
559553
)
560554

555+
# Generate lightweight parametrization with only free fields
556+
self._free_param_class = self._create_free_param_class()
557+
# Assign __param_name__ and __family__ so that instances have .name and .family
558+
self._free_param_class.__param_name__ = self._fixed_in_param
559+
self._free_param_class.__family__ = self
560+
561561
def _view_distr_type(params: Parametrization) -> DistributionType:
562562
canonical = base_family.to_base(params)
563563
return base_family._distr_type(canonical)
564564

565+
def _view_support(params: Parametrization) -> Support | None:
566+
full_params = self._to_full_parametrization(params)
567+
return base_family.support_resolver(full_params)
568+
565569
view_chars = self._build_view_characteristics(base_family)
566570

567571
super().__init__(
568572
name=base_family._name,
569573
distr_type=_view_distr_type,
570574
distr_parametrizations=[self._fixed_in_param],
571575
distr_characteristics=view_chars,
572-
support_by_parametrization=base_family._support_resolver,
576+
support_by_parametrization=_view_support,
573577
base_score=base_family._base_score,
574578
)
575579

576580
# Register the parametrization (needed for parent methods)
577581
self.register_parametrization(self._fixed_in_param, self._free_param_class)
578582

583+
def _to_full_parametrization(self, params: Parametrization) -> Parametrization:
584+
"""Reconstruct the original parametrization by injecting fixed parameters."""
585+
combined = {
586+
**self._fixed_params,
587+
**{name: getattr(params, name) for name in self._free_parameter_names},
588+
}
589+
return self._param_class(**combined)
590+
579591
def _create_free_param_class(self) -> type[Parametrization]:
580592
"""Create a parametrization class containing only the free (unfixed) parameters.
581593
@@ -597,10 +609,10 @@ def _create_free_param_class(self) -> type[Parametrization]:
597609
type[Parametrization]
598610
A lightweight parametrization class with only the unfixed fields.
599611
"""
600-
fixed_params = self._fixed_params
601612
original_class = self._param_class
602613
all_fields = getattr(original_class, "__dataclass_fields__", {})
603-
free_field_names = [name for name in all_fields if name not in fixed_params]
614+
free_field_names = list(self._free_parameter_names)
615+
partial_family = self
604616

605617
def __init__(self: Parametrization, **kwargs: Any) -> None:
606618
unexpected = set(kwargs) - set(free_field_names)
@@ -620,20 +632,12 @@ def __init__(self: Parametrization, **kwargs: Any) -> None:
620632

621633
def transform_to_base(self: Parametrization) -> Parametrization:
622634
"""Substitute fixed values and delegate to the original parametrization."""
623-
combined = {
624-
**fixed_params,
625-
**{f: getattr(self, f) for f in free_field_names},
626-
}
627-
original_instance = original_class(**combined)
628-
return original_instance.transform_to_base_parametrization()
635+
full_params = partial_family._to_full_parametrization(self)
636+
return full_params.transform_to_base_parametrization()
629637

630638
def validate(self: Parametrization) -> None:
631639
"""Validate by combining fixed and free parameters, then delegating."""
632-
combined = {
633-
**fixed_params,
634-
**{f: getattr(self, f) for f in free_field_names},
635-
}
636-
original_class(**combined).validate()
640+
partial_family._to_full_parametrization(self).validate()
637641

638642
def gradient_transform(self: Parametrization, base_grad: NumericArray) -> NumericArray:
639643
"""Map a gradient from the base parametrization to free-parameter space.
@@ -642,16 +646,28 @@ def gradient_transform(self: Parametrization, base_grad: NumericArray) -> Numeri
642646
parametrization. Components that correspond to fixed parameters are
643647
then discarded, keeping only the directions of the free parameters.
644648
"""
645-
combined = {
646-
**fixed_params,
647-
**{f: getattr(self, f) for f in free_field_names},
648-
}
649-
full_instance = original_class(**combined)
649+
full_instance = partial_family._to_full_parametrization(self)
650650
full_grad = full_instance.gradient_transform(base_grad)
651651
all_field_names = list(all_fields.keys())
652652
free_indices = [i for i, name in enumerate(all_field_names) if name in free_field_names]
653653
return full_grad[..., free_indices]
654654

655+
def adapt_constraint(
656+
original_constraint: ParametrizationConstraint,
657+
) -> ParametrizationConstraint:
658+
def check(params: Parametrization) -> bool:
659+
return original_constraint.check(partial_family._to_full_parametrization(params))
660+
661+
return ParametrizationConstraint(
662+
description=original_constraint.description,
663+
check=check,
664+
)
665+
666+
adapted_constraints = [
667+
adapt_constraint(constraint)
668+
for constraint in getattr(original_class, "_constraints", [])
669+
]
670+
655671
new_class = type(
656672
f"{original_class.__name__}Free",
657673
(Parametrization,),
@@ -662,6 +678,7 @@ def gradient_transform(self: Parametrization, base_grad: NumericArray) -> Numeri
662678
"gradient_transform": gradient_transform,
663679
"__dataclass_fields__": {name: all_fields[name] for name in free_field_names},
664680
"__annotations__": {name: all_fields[name].type for name in free_field_names},
681+
"_constraints": adapted_constraints,
665682
},
666683
)
667684
return new_class
@@ -728,7 +745,7 @@ def to_base(self, parameters: Parametrization) -> Parametrization:
728745
The view's own base is the lightweight class, but the true base is the
729746
original family's base. We always transform through the full parametrization.
730747
"""
731-
return parameters.transform_to_base_parametrization()
748+
return self._base_family.to_base(self._to_full_parametrization(parameters))
732749

733750
def _build_view_characteristics(self, base_family: ParametricFamily) -> dict[str, Any]:
734751
view_chars = {}
@@ -748,11 +765,7 @@ def wrap_fixed_provider(
748765
provider: ParametricFamilyCharacteristic[Any, Any],
749766
) -> ParametricFamilyCharacteristic[Any, Any]:
750767
def wrapped(params: Parametrization, *args: Any, **kwargs: Any) -> Any:
751-
combined = {
752-
**self._fixed_params,
753-
**{f: getattr(params, f) for f in self.free_parameter_names},
754-
}
755-
full_params = self._param_class(**combined)
768+
full_params = self._to_full_parametrization(params)
756769
bound = ParametricFamily._bind_parametrization(provider, full_params)
757770
return bound(*args, **kwargs)
758771

tests/unit/families/test_partial_family.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import pytest
1212

1313
from pysatl_core.distributions.strategies import DefaultComputationStrategy
14+
from pysatl_core.distributions.support import ContinuousSupport
1415
from pysatl_core.families.parametric_family import ParametricFamily, PartialParametricFamily
15-
from pysatl_core.families.parametrizations import Parametrization
16+
from pysatl_core.families.parametrizations import Parametrization, constraint
1617
from pysatl_core.sampling.default import DefaultSamplingUnivariateStrategy
1718
from pysatl_core.types import (
1819
CharacteristicName,
@@ -176,6 +177,53 @@ def test_distribution_passes_sampling_and_computation_strategies(self) -> None:
176177
assert dist.sampling_strategy is sampling
177178
assert dist.computation_strategy is computation
178179

180+
def test_support_resolver_receives_full_original_parametrization(self) -> None:
181+
def support(params: Parametrization) -> ContinuousSupport:
182+
full_params = cast(TwoParam, params)
183+
return ContinuousSupport(full_params.a, full_params.b)
184+
185+
fam = ParametricFamily(
186+
name="SupportDependsOnParams",
187+
distr_type=UnivariateContinuous,
188+
distr_parametrizations=["base"],
189+
distr_characteristics={
190+
CharacteristicName.PDF: {"base": {"default": lambda p, x: p.a + p.b}},
191+
},
192+
support_by_parametrization=support,
193+
)
194+
fam.register_parametrization("base", TwoParam)
195+
196+
dist = fam.view(a=1.0).distribution(b=3.0)
197+
198+
assert dist.support == ContinuousSupport(1.0, 3.0)
199+
200+
def test_partial_distribution_preserves_adapted_constraints(self) -> None:
201+
fam = ParametricFamily(
202+
name="ConstrainedFamily",
203+
distr_type=UnivariateContinuous,
204+
distr_parametrizations=["base"],
205+
distr_characteristics={
206+
CharacteristicName.PDF: {"base": {"default": lambda p, x: p.a + p.b}},
207+
},
208+
)
209+
210+
@fam.parametrization(name="base")
211+
class OrderedParams(Parametrization):
212+
a: float
213+
b: float
214+
215+
@constraint(description="a < b")
216+
def check_order(self) -> bool:
217+
return self.a < self.b
218+
219+
dist = fam.view(a=1.0).distribution(b=3.0)
220+
221+
assert [c.description for c in dist.parameters_constraints] == ["a < b"]
222+
assert all(c.check(dist.parametrization) for c in dist.parameters_constraints)
223+
224+
with pytest.raises(ValueError, match='Constraint "a < b" does not hold'):
225+
fam.view(a=3.0).distribution(b=1.0)
226+
179227
# Chaining view
180228
def test_view_on_view_creates_new_view_until_complete(self) -> None:
181229
fam = self._make_two_param_family()

0 commit comments

Comments
 (0)