Skip to content

Commit 3cb5eb1

Browse files
committed
feat(families): added ability for partial fixing parameters in parametric families
1 parent 937137c commit 3cb5eb1

2 files changed

Lines changed: 556 additions & 4 deletions

File tree

src/pysatl_core/families/parametric_family.py

Lines changed: 274 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import inspect
1616
from collections.abc import Mapping
1717
from functools import partial
18-
from typing import TYPE_CHECKING, Any, cast, dataclass_transform
18+
from typing import TYPE_CHECKING, Any, cast, dataclass_transform, overload
1919

2020
import numpy as np
2121

2222
from pysatl_core.distributions.computation import AnalyticalComputation
2323
from pysatl_core.families.distribution import ParametricFamilyDistribution
24+
from pysatl_core.families.parametrizations import Parametrization
2425
from pysatl_core.types import (
2526
DEFAULT_ANALYTICAL_COMPUTATION_LABEL,
2627
ComputationFunc,
@@ -32,9 +33,6 @@
3233

3334
from pysatl_core.distributions.strategies import ComputationStrategy, SamplingStrategy
3435
from pysatl_core.distributions.support import Support
35-
from pysatl_core.families.parametrizations import (
36-
Parametrization,
37-
)
3836
from pysatl_core.types import (
3937
GenericCharacteristicName,
4038
LabelName,
@@ -451,4 +449,276 @@ def score(self, parameters: Parametrization, x: NumericArray) -> NumericArray:
451449
base_grad = self._base_score(base_params, x_arr)
452450
return parameters.gradient_transform(base_grad)
453451

452+
def view(
453+
self,
454+
*,
455+
parametrization_name: str | None = None,
456+
**fixed_params: Any,
457+
) -> PartialParametricFamily:
458+
"""
459+
Create a view of this family with partially fixed parameters.
460+
461+
Parameters
462+
----------
463+
parametrization_name : str, optional
464+
Name of the parametrization in which the fixed parameters are given.
465+
If not provided, the base parametrization of the family is used.
466+
**fixed_params : Any
467+
Parameter names and values to fix.
468+
469+
Returns
470+
-------
471+
PartialParametricFamily
472+
A view that behaves like the original family but with the specified
473+
parameters fixed.
474+
475+
Examples
476+
--------
477+
>>> uniform = ParametricFamilyRegister.get("uniform")
478+
>>> uniform_lower0 = uniform.view(lower_bound=0)
479+
>>> dist = uniform_lower0.distribution(upper_bound=1) # Uniform(0,1)
480+
"""
481+
if parametrization_name is not None and parametrization_name not in self.parametrizations:
482+
raise ValueError(
483+
f"Unknown parametrization '{parametrization_name}' for family '{self.name}'"
484+
)
485+
return PartialParametricFamily(
486+
base_family=self,
487+
fixed_params=fixed_params,
488+
parametrization_name=parametrization_name,
489+
)
490+
491+
__call__ = distribution
492+
493+
494+
class PartialParametricFamily(ParametricFamily):
495+
"""
496+
View on a parametric family with partially fixed parameters.
497+
498+
This class represents a parametric family where some parameters have been
499+
fixed to specific values. It inherits all behaviour from `ParametricFamily`
500+
but restricts the available parametrization to the one in which parameters
501+
are fixed. All analytical characteristics are preserved via delegation to
502+
the base parametrization of the original family.
503+
504+
Parameters
505+
----------
506+
base_family : ParametricFamily
507+
The original parametric family.
508+
fixed_params : dict[str, Any]
509+
Dictionary of fixed parameter names and their values.
510+
parametrization_name : str, optional
511+
Name of the parametrization in which the fixed parameters are specified.
512+
If not provided, the base parametrization of the family is used.
513+
514+
Raises
515+
------
516+
ValueError
517+
If all parameters of the chosen parametrization are fixed (use `.distribution()` directly),
518+
or if any fixed parameter name is unknown for that parametrization,
519+
or if the parametrization name is not registered in the family.
520+
"""
521+
522+
def __init__(
523+
self,
524+
base_family: ParametricFamily,
525+
fixed_params: dict[str, Any],
526+
parametrization_name: str | None = None,
527+
) -> None:
528+
self._fixed_in_param = parametrization_name or base_family.base_parametrization_name
529+
self._base_family = base_family
530+
self._param_class = base_family.get_parametrization(self._fixed_in_param)
531+
self._fixed_params = fixed_params.copy()
532+
533+
required_fields = set(getattr(self._param_class, "__dataclass_fields__", {}).keys())
534+
# Validate that fixed parameters exist
535+
unknown = set(self._fixed_params) - required_fields
536+
if unknown:
537+
raise ValueError(
538+
f"Unknown parameters for parametrization '{self._fixed_in_param}': {unknown}"
539+
)
540+
541+
# Check completeness: if all parameters are fixed, raise an error
542+
if required_fields.issubset(fixed_params):
543+
raise ValueError(
544+
f"All parameters of parametrization '{self._fixed_in_param}' are already fixed. "
545+
"Use `.distribution()` directly."
546+
)
547+
548+
def _view_distr_type(params: Parametrization) -> DistributionType:
549+
canonical = base_family.to_base(params)
550+
return base_family._distr_type(canonical)
551+
552+
view_chars = self._build_view_characteristics(base_family)
553+
554+
super().__init__(
555+
name=base_family._name,
556+
distr_type=_view_distr_type,
557+
distr_parametrizations=[self._fixed_in_param],
558+
distr_characteristics=view_chars,
559+
support_by_parametrization=base_family._support_resolver,
560+
base_score=base_family._base_score,
561+
)
562+
563+
# Register the parametrization (needed for parent methods)
564+
self.register_parametrization(self._fixed_in_param, self._param_class)
565+
566+
def _create_free_param_class(self) -> type[Parametrization]:
567+
"""Create a parametrization class containing only the free (unfixed) parameters.
568+
569+
The generated class exposes only the fields that were *not* fixed via
570+
:meth:`view`. Its :meth:`transform_to_base_parametrization` automatically
571+
injects the fixed values and delegates to the conversion logic of the
572+
original parametrization class.
573+
574+
Returns
575+
-------
576+
type[Parametrization]
577+
A lightweight parametrization class with only the unfixed fields.
578+
"""
579+
fixed_params = self._fixed_params
580+
original_class = self._param_class
581+
all_fields = getattr(original_class, "__dataclass_fields__", {})
582+
free_field_names = [name for name in all_fields if name not in fixed_params]
583+
584+
def __init__(self: Parametrization, **kwargs: Any) -> None:
585+
for name in free_field_names:
586+
object.__setattr__(self, name, kwargs[name])
587+
588+
def transform_to_base(self: Parametrization) -> Parametrization:
589+
"""Substitute fixed values and delegate to the original parametrization."""
590+
combined = {
591+
**fixed_params,
592+
**{f: getattr(self, f) for f in free_field_names},
593+
}
594+
original_instance = original_class(**combined)
595+
return original_instance.transform_to_base_parametrization()
596+
597+
new_class = type(
598+
f"{original_class.__name__}Free",
599+
(Parametrization,),
600+
{
601+
"__init__": __init__,
602+
"transform_to_base_parametrization": transform_to_base,
603+
"__dataclass_fields__": {name: all_fields[name] for name in free_field_names},
604+
"__annotations__": {name: all_fields[name].type for name in free_field_names},
605+
},
606+
)
607+
return cast(type[Parametrization], new_class)
608+
609+
@property
610+
def base(self) -> type[Parametrization]:
611+
"""Return a parametrization class containing only the free (unfixed) parameters.
612+
613+
Unlike the original parametrization class, the returned class accepts
614+
*only* the parameters that were not fixed via :meth:`view`. Its
615+
:meth:`~Parametrization.transform_to_base_parametrization` method
616+
automatically supplies the fixed values before delegating to the
617+
original conversion logic.
618+
619+
Returns
620+
-------
621+
type[Parametrization]
622+
A lightweight parametrization class with only the unfixed fields.
623+
"""
624+
if not hasattr(self, "_free_param_class"):
625+
self._free_param_class = self._create_free_param_class()
626+
return self._free_param_class
627+
628+
@property
629+
def parametrizations(self) -> dict[str, type[Parametrization]]:
630+
"""Return a dictionary containing only the fixed parametrization."""
631+
return {self._fixed_in_param: self._param_class}
632+
633+
@overload
634+
def get_parametrization(self) -> type[Parametrization]: ...
635+
636+
@overload
637+
def get_parametrization(self, name: ParametrizationName) -> type[Parametrization]: ...
638+
639+
def get_parametrization(self, name: ParametrizationName | None = None) -> type[Parametrization]:
640+
if name is None:
641+
return self._param_class
642+
if name != self._fixed_in_param:
643+
raise KeyError(
644+
f"Parametrization '{name}' is not available in this view. "
645+
f"Only '{self._fixed_in_param}' is available."
646+
)
647+
return self._param_class
648+
649+
def _build_view_characteristics(self, base_family: ParametricFamily) -> dict[str, Any]:
650+
view_chars = {}
651+
original_base = base_family.base_parametrization_name
652+
653+
def wrap_provider(provider: Callable[..., Any]) -> Callable[..., Any]:
654+
def wrapped(params: Parametrization, *args: Any, **kwargs: Any) -> Any:
655+
base_params = base_family.to_base(params)
656+
return provider(base_params, *args, **kwargs)
657+
658+
return wrapped
659+
660+
for char_name, char_map in base_family.distr_characteristics.items():
661+
if self._fixed_in_param in char_map:
662+
providers = char_map[self._fixed_in_param]
663+
view_chars[char_name] = {self._fixed_in_param: dict(providers.items())}
664+
elif original_base in char_map:
665+
original_provider = char_map[original_base]
666+
wrapped = {
667+
label: wrap_provider(provider) if callable(provider) else provider
668+
for label, provider in original_provider.items()
669+
}
670+
view_chars[char_name] = {self._fixed_in_param: wrapped}
671+
672+
return view_chars
673+
674+
def is_complete(self) -> bool:
675+
"""Check whether all parameters of the fixed parametrization are already fixed."""
676+
required_fields = set(getattr(self._param_class, "__dataclass_fields__", {}).keys())
677+
return required_fields.issubset(self._fixed_params)
678+
679+
def distribution(
680+
self,
681+
parametrization_name: str | None = None,
682+
sampling_strategy: SamplingStrategy | None = None,
683+
computation_strategy: ComputationStrategy | None = None,
684+
**kwargs: Any,
685+
) -> ParametricFamilyDistribution:
686+
target = parametrization_name or self._fixed_in_param
687+
if target != self._fixed_in_param:
688+
raise ValueError(
689+
f"Only parametrization '{self._fixed_in_param}' is available in this view. "
690+
"Please omit 'parametrization_name' or use the fixed one."
691+
)
692+
for key, fixed_val in self._fixed_params.items():
693+
if key in kwargs and kwargs[key] != fixed_val:
694+
raise ValueError(
695+
f"Parameter '{key}' is fixed to {fixed_val}, but got {kwargs[key]}"
696+
)
697+
all_params = {**self._fixed_params, **kwargs}
698+
return super().distribution(
699+
parametrization_name=target,
700+
sampling_strategy=sampling_strategy,
701+
computation_strategy=computation_strategy,
702+
**all_params,
703+
)
704+
705+
def view(
706+
self,
707+
*,
708+
parametrization_name: str | None = None,
709+
**additional_params: Any,
710+
) -> PartialParametricFamily:
711+
if parametrization_name is not None and parametrization_name != self._fixed_in_param:
712+
raise ValueError(
713+
f"Cannot change parametrization. Current fixed parametrization is "
714+
f"'{self._fixed_in_param}'. Use the same or omit the argument."
715+
)
716+
for key, fixed_val in self._fixed_params.items():
717+
if key in additional_params and additional_params[key] != fixed_val:
718+
raise ValueError(
719+
f"Parameter '{key}' is fixed to {fixed_val}, but got {additional_params[key]}"
720+
)
721+
new_fixed = {**self._fixed_params, **additional_params}
722+
return PartialParametricFamily(self._base_family, new_fixed, self._fixed_in_param)
723+
454724
__call__ = distribution

0 commit comments

Comments
 (0)