2222
2323from pysatl_core .distributions .computations .computation import AnalyticalComputation
2424from pysatl_core .families .distribution import ParametricFamilyDistribution
25- from pysatl_core .families .parametrizations import Parametrization
25+ from pysatl_core .families .parametrizations import Parametrization , ParametrizationConstraint
2626from pysatl_core .types import (
2727 DEFAULT_ANALYTICAL_COMPUTATION_LABEL ,
2828 ComputationFunc ,
@@ -546,36 +546,48 @@ def __init__(
546546 "Use `.distribution()` directly."
547547 )
548548
549- # Generate lightweight parametrization with only free fields
550- self ._free_param_class = self ._create_free_param_class ()
551- # Assign __param_name__ and __family__ so that instances have .name and .family
552- self ._free_param_class .__param_name__ = self ._fixed_in_param
553- self ._free_param_class .__family__ = self
554-
555549 self ._free_parameter_names = tuple (
556550 name
557551 for name in getattr (self ._param_class , "__dataclass_fields__" , {})
558552 if name not in self ._fixed_params
559553 )
560554
555+ # Generate lightweight parametrization with only free fields
556+ self ._free_param_class = self ._create_free_param_class ()
557+ # Assign __param_name__ and __family__ so that instances have .name and .family
558+ self ._free_param_class .__param_name__ = self ._fixed_in_param
559+ self ._free_param_class .__family__ = self
560+
561561 def _view_distr_type (params : Parametrization ) -> DistributionType :
562562 canonical = base_family .to_base (params )
563563 return base_family ._distr_type (canonical )
564564
565+ def _view_support (params : Parametrization ) -> Support | None :
566+ full_params = self ._to_full_parametrization (params )
567+ return base_family .support_resolver (full_params )
568+
565569 view_chars = self ._build_view_characteristics (base_family )
566570
567571 super ().__init__ (
568572 name = base_family ._name ,
569573 distr_type = _view_distr_type ,
570574 distr_parametrizations = [self ._fixed_in_param ],
571575 distr_characteristics = view_chars ,
572- support_by_parametrization = base_family . _support_resolver ,
576+ support_by_parametrization = _view_support ,
573577 base_score = base_family ._base_score ,
574578 )
575579
576580 # Register the parametrization (needed for parent methods)
577581 self .register_parametrization (self ._fixed_in_param , self ._free_param_class )
578582
583+ def _to_full_parametrization (self , params : Parametrization ) -> Parametrization :
584+ """Reconstruct the original parametrization by injecting fixed parameters."""
585+ combined = {
586+ ** self ._fixed_params ,
587+ ** {name : getattr (params , name ) for name in self ._free_parameter_names },
588+ }
589+ return self ._param_class (** combined )
590+
579591 def _create_free_param_class (self ) -> type [Parametrization ]:
580592 """Create a parametrization class containing only the free (unfixed) parameters.
581593
@@ -597,10 +609,10 @@ def _create_free_param_class(self) -> type[Parametrization]:
597609 type[Parametrization]
598610 A lightweight parametrization class with only the unfixed fields.
599611 """
600- fixed_params = self ._fixed_params
601612 original_class = self ._param_class
602613 all_fields = getattr (original_class , "__dataclass_fields__" , {})
603- free_field_names = [name for name in all_fields if name not in fixed_params ]
614+ free_field_names = list (self ._free_parameter_names )
615+ partial_family = self
604616
605617 def __init__ (self : Parametrization , ** kwargs : Any ) -> None :
606618 unexpected = set (kwargs ) - set (free_field_names )
@@ -620,20 +632,12 @@ def __init__(self: Parametrization, **kwargs: Any) -> None:
620632
621633 def transform_to_base (self : Parametrization ) -> Parametrization :
622634 """Substitute fixed values and delegate to the original parametrization."""
623- combined = {
624- ** fixed_params ,
625- ** {f : getattr (self , f ) for f in free_field_names },
626- }
627- original_instance = original_class (** combined )
628- return original_instance .transform_to_base_parametrization ()
635+ full_params = partial_family ._to_full_parametrization (self )
636+ return full_params .transform_to_base_parametrization ()
629637
630638 def validate (self : Parametrization ) -> None :
631639 """Validate by combining fixed and free parameters, then delegating."""
632- combined = {
633- ** fixed_params ,
634- ** {f : getattr (self , f ) for f in free_field_names },
635- }
636- original_class (** combined ).validate ()
640+ partial_family ._to_full_parametrization (self ).validate ()
637641
638642 def gradient_transform (self : Parametrization , base_grad : NumericArray ) -> NumericArray :
639643 """Map a gradient from the base parametrization to free-parameter space.
@@ -642,16 +646,28 @@ def gradient_transform(self: Parametrization, base_grad: NumericArray) -> Numeri
642646 parametrization. Components that correspond to fixed parameters are
643647 then discarded, keeping only the directions of the free parameters.
644648 """
645- combined = {
646- ** fixed_params ,
647- ** {f : getattr (self , f ) for f in free_field_names },
648- }
649- full_instance = original_class (** combined )
649+ full_instance = partial_family ._to_full_parametrization (self )
650650 full_grad = full_instance .gradient_transform (base_grad )
651651 all_field_names = list (all_fields .keys ())
652652 free_indices = [i for i , name in enumerate (all_field_names ) if name in free_field_names ]
653653 return full_grad [..., free_indices ]
654654
655+ def adapt_constraint (
656+ original_constraint : ParametrizationConstraint ,
657+ ) -> ParametrizationConstraint :
658+ def check (params : Parametrization ) -> bool :
659+ return original_constraint .check (partial_family ._to_full_parametrization (params ))
660+
661+ return ParametrizationConstraint (
662+ description = original_constraint .description ,
663+ check = check ,
664+ )
665+
666+ adapted_constraints = [
667+ adapt_constraint (constraint )
668+ for constraint in getattr (original_class , "_constraints" , [])
669+ ]
670+
655671 new_class = type (
656672 f"{ original_class .__name__ } Free" ,
657673 (Parametrization ,),
@@ -662,6 +678,7 @@ def gradient_transform(self: Parametrization, base_grad: NumericArray) -> Numeri
662678 "gradient_transform" : gradient_transform ,
663679 "__dataclass_fields__" : {name : all_fields [name ] for name in free_field_names },
664680 "__annotations__" : {name : all_fields [name ].type for name in free_field_names },
681+ "_constraints" : adapted_constraints ,
665682 },
666683 )
667684 return new_class
@@ -728,7 +745,7 @@ def to_base(self, parameters: Parametrization) -> Parametrization:
728745 The view's own base is the lightweight class, but the true base is the
729746 original family's base. We always transform through the full parametrization.
730747 """
731- return parameters . transform_to_base_parametrization ( )
748+ return self . _base_family . to_base ( self . _to_full_parametrization ( parameters ) )
732749
733750 def _build_view_characteristics (self , base_family : ParametricFamily ) -> dict [str , Any ]:
734751 view_chars = {}
@@ -748,11 +765,7 @@ def wrap_fixed_provider(
748765 provider : ParametricFamilyCharacteristic [Any , Any ],
749766 ) -> ParametricFamilyCharacteristic [Any , Any ]:
750767 def wrapped (params : Parametrization , * args : Any , ** kwargs : Any ) -> Any :
751- combined = {
752- ** self ._fixed_params ,
753- ** {f : getattr (params , f ) for f in self .free_parameter_names },
754- }
755- full_params = self ._param_class (** combined )
768+ full_params = self ._to_full_parametrization (params )
756769 bound = ParametricFamily ._bind_parametrization (provider , full_params )
757770 return bound (* args , ** kwargs )
758771
0 commit comments