|
15 | 15 | import inspect |
16 | 16 | from collections.abc import Mapping |
17 | 17 | from functools import partial |
18 | | -from typing import TYPE_CHECKING, Any, cast, dataclass_transform |
| 18 | +from typing import TYPE_CHECKING, Any, cast, dataclass_transform, overload |
19 | 19 |
|
20 | 20 | import numpy as np |
21 | 21 |
|
22 | 22 | from pysatl_core.distributions.computation import AnalyticalComputation |
23 | 23 | from pysatl_core.families.distribution import ParametricFamilyDistribution |
| 24 | +from pysatl_core.families.parametrizations import Parametrization |
24 | 25 | from pysatl_core.types import ( |
25 | 26 | DEFAULT_ANALYTICAL_COMPUTATION_LABEL, |
26 | 27 | ComputationFunc, |
|
32 | 33 |
|
33 | 34 | from pysatl_core.distributions.strategies import ComputationStrategy, SamplingStrategy |
34 | 35 | from pysatl_core.distributions.support import Support |
35 | | - from pysatl_core.families.parametrizations import ( |
36 | | - Parametrization, |
37 | | - ) |
38 | 36 | from pysatl_core.types import ( |
39 | 37 | GenericCharacteristicName, |
40 | 38 | LabelName, |
@@ -451,4 +449,276 @@ def score(self, parameters: Parametrization, x: NumericArray) -> NumericArray: |
451 | 449 | base_grad = self._base_score(base_params, x_arr) |
452 | 450 | return parameters.gradient_transform(base_grad) |
453 | 451 |
|
| 452 | + def view( |
| 453 | + self, |
| 454 | + *, |
| 455 | + parametrization_name: str | None = None, |
| 456 | + **fixed_params: Any, |
| 457 | + ) -> PartialParametricFamily: |
| 458 | + """ |
| 459 | + Create a view of this family with partially fixed parameters. |
| 460 | +
|
| 461 | + Parameters |
| 462 | + ---------- |
| 463 | + parametrization_name : str, optional |
| 464 | + Name of the parametrization in which the fixed parameters are given. |
| 465 | + If not provided, the base parametrization of the family is used. |
| 466 | + **fixed_params : Any |
| 467 | + Parameter names and values to fix. |
| 468 | +
|
| 469 | + Returns |
| 470 | + ------- |
| 471 | + PartialParametricFamily |
| 472 | + A view that behaves like the original family but with the specified |
| 473 | + parameters fixed. |
| 474 | +
|
| 475 | + Examples |
| 476 | + -------- |
| 477 | + >>> uniform = ParametricFamilyRegister.get("uniform") |
| 478 | + >>> uniform_lower0 = uniform.view(lower_bound=0) |
| 479 | + >>> dist = uniform_lower0.distribution(upper_bound=1) # Uniform(0,1) |
| 480 | + """ |
| 481 | + if parametrization_name is not None and parametrization_name not in self.parametrizations: |
| 482 | + raise ValueError( |
| 483 | + f"Unknown parametrization '{parametrization_name}' for family '{self.name}'" |
| 484 | + ) |
| 485 | + return PartialParametricFamily( |
| 486 | + base_family=self, |
| 487 | + fixed_params=fixed_params, |
| 488 | + parametrization_name=parametrization_name, |
| 489 | + ) |
| 490 | + |
| 491 | + __call__ = distribution |
| 492 | + |
| 493 | + |
| 494 | +class PartialParametricFamily(ParametricFamily): |
| 495 | + """ |
| 496 | + View on a parametric family with partially fixed parameters. |
| 497 | +
|
| 498 | + This class represents a parametric family where some parameters have been |
| 499 | + fixed to specific values. It inherits all behaviour from `ParametricFamily` |
| 500 | + but restricts the available parametrization to the one in which parameters |
| 501 | + are fixed. All analytical characteristics are preserved via delegation to |
| 502 | + the base parametrization of the original family. |
| 503 | +
|
| 504 | + Parameters |
| 505 | + ---------- |
| 506 | + base_family : ParametricFamily |
| 507 | + The original parametric family. |
| 508 | + fixed_params : dict[str, Any] |
| 509 | + Dictionary of fixed parameter names and their values. |
| 510 | + parametrization_name : str, optional |
| 511 | + Name of the parametrization in which the fixed parameters are specified. |
| 512 | + If not provided, the base parametrization of the family is used. |
| 513 | +
|
| 514 | + Raises |
| 515 | + ------ |
| 516 | + ValueError |
| 517 | + If all parameters of the chosen parametrization are fixed (use `.distribution()` directly), |
| 518 | + or if any fixed parameter name is unknown for that parametrization, |
| 519 | + or if the parametrization name is not registered in the family. |
| 520 | + """ |
| 521 | + |
| 522 | + def __init__( |
| 523 | + self, |
| 524 | + base_family: ParametricFamily, |
| 525 | + fixed_params: dict[str, Any], |
| 526 | + parametrization_name: str | None = None, |
| 527 | + ) -> None: |
| 528 | + self._fixed_in_param = parametrization_name or base_family.base_parametrization_name |
| 529 | + self._base_family = base_family |
| 530 | + self._param_class = base_family.get_parametrization(self._fixed_in_param) |
| 531 | + self._fixed_params = fixed_params.copy() |
| 532 | + |
| 533 | + required_fields = set(getattr(self._param_class, "__dataclass_fields__", {}).keys()) |
| 534 | + # Validate that fixed parameters exist |
| 535 | + unknown = set(self._fixed_params) - required_fields |
| 536 | + if unknown: |
| 537 | + raise ValueError( |
| 538 | + f"Unknown parameters for parametrization '{self._fixed_in_param}': {unknown}" |
| 539 | + ) |
| 540 | + |
| 541 | + # Check completeness: if all parameters are fixed, raise an error |
| 542 | + if required_fields.issubset(fixed_params): |
| 543 | + raise ValueError( |
| 544 | + f"All parameters of parametrization '{self._fixed_in_param}' are already fixed. " |
| 545 | + "Use `.distribution()` directly." |
| 546 | + ) |
| 547 | + |
| 548 | + def _view_distr_type(params: Parametrization) -> DistributionType: |
| 549 | + canonical = base_family.to_base(params) |
| 550 | + return base_family._distr_type(canonical) |
| 551 | + |
| 552 | + view_chars = self._build_view_characteristics(base_family) |
| 553 | + |
| 554 | + super().__init__( |
| 555 | + name=base_family._name, |
| 556 | + distr_type=_view_distr_type, |
| 557 | + distr_parametrizations=[self._fixed_in_param], |
| 558 | + distr_characteristics=view_chars, |
| 559 | + support_by_parametrization=base_family._support_resolver, |
| 560 | + base_score=base_family._base_score, |
| 561 | + ) |
| 562 | + |
| 563 | + # Register the parametrization (needed for parent methods) |
| 564 | + self.register_parametrization(self._fixed_in_param, self._param_class) |
| 565 | + |
| 566 | + def _create_free_param_class(self) -> type[Parametrization]: |
| 567 | + """Create a parametrization class containing only the free (unfixed) parameters. |
| 568 | +
|
| 569 | + The generated class exposes only the fields that were *not* fixed via |
| 570 | + :meth:`view`. Its :meth:`transform_to_base_parametrization` automatically |
| 571 | + injects the fixed values and delegates to the conversion logic of the |
| 572 | + original parametrization class. |
| 573 | +
|
| 574 | + Returns |
| 575 | + ------- |
| 576 | + type[Parametrization] |
| 577 | + A lightweight parametrization class with only the unfixed fields. |
| 578 | + """ |
| 579 | + fixed_params = self._fixed_params |
| 580 | + original_class = self._param_class |
| 581 | + all_fields = getattr(original_class, "__dataclass_fields__", {}) |
| 582 | + free_field_names = [name for name in all_fields if name not in fixed_params] |
| 583 | + |
| 584 | + def __init__(self: Parametrization, **kwargs: Any) -> None: |
| 585 | + for name in free_field_names: |
| 586 | + object.__setattr__(self, name, kwargs[name]) |
| 587 | + |
| 588 | + def transform_to_base(self: Parametrization) -> Parametrization: |
| 589 | + """Substitute fixed values and delegate to the original parametrization.""" |
| 590 | + combined = { |
| 591 | + **fixed_params, |
| 592 | + **{f: getattr(self, f) for f in free_field_names}, |
| 593 | + } |
| 594 | + original_instance = original_class(**combined) |
| 595 | + return original_instance.transform_to_base_parametrization() |
| 596 | + |
| 597 | + new_class = type( |
| 598 | + f"{original_class.__name__}Free", |
| 599 | + (Parametrization,), |
| 600 | + { |
| 601 | + "__init__": __init__, |
| 602 | + "transform_to_base_parametrization": transform_to_base, |
| 603 | + "__dataclass_fields__": {name: all_fields[name] for name in free_field_names}, |
| 604 | + "__annotations__": {name: all_fields[name].type for name in free_field_names}, |
| 605 | + }, |
| 606 | + ) |
| 607 | + return cast(type[Parametrization], new_class) |
| 608 | + |
| 609 | + @property |
| 610 | + def base(self) -> type[Parametrization]: |
| 611 | + """Return a parametrization class containing only the free (unfixed) parameters. |
| 612 | +
|
| 613 | + Unlike the original parametrization class, the returned class accepts |
| 614 | + *only* the parameters that were not fixed via :meth:`view`. Its |
| 615 | + :meth:`~Parametrization.transform_to_base_parametrization` method |
| 616 | + automatically supplies the fixed values before delegating to the |
| 617 | + original conversion logic. |
| 618 | +
|
| 619 | + Returns |
| 620 | + ------- |
| 621 | + type[Parametrization] |
| 622 | + A lightweight parametrization class with only the unfixed fields. |
| 623 | + """ |
| 624 | + if not hasattr(self, "_free_param_class"): |
| 625 | + self._free_param_class = self._create_free_param_class() |
| 626 | + return self._free_param_class |
| 627 | + |
| 628 | + @property |
| 629 | + def parametrizations(self) -> dict[str, type[Parametrization]]: |
| 630 | + """Return a dictionary containing only the fixed parametrization.""" |
| 631 | + return {self._fixed_in_param: self._param_class} |
| 632 | + |
| 633 | + @overload |
| 634 | + def get_parametrization(self) -> type[Parametrization]: ... |
| 635 | + |
| 636 | + @overload |
| 637 | + def get_parametrization(self, name: ParametrizationName) -> type[Parametrization]: ... |
| 638 | + |
| 639 | + def get_parametrization(self, name: ParametrizationName | None = None) -> type[Parametrization]: |
| 640 | + if name is None: |
| 641 | + return self._param_class |
| 642 | + if name != self._fixed_in_param: |
| 643 | + raise KeyError( |
| 644 | + f"Parametrization '{name}' is not available in this view. " |
| 645 | + f"Only '{self._fixed_in_param}' is available." |
| 646 | + ) |
| 647 | + return self._param_class |
| 648 | + |
| 649 | + def _build_view_characteristics(self, base_family: ParametricFamily) -> dict[str, Any]: |
| 650 | + view_chars = {} |
| 651 | + original_base = base_family.base_parametrization_name |
| 652 | + |
| 653 | + def wrap_provider(provider: Callable[..., Any]) -> Callable[..., Any]: |
| 654 | + def wrapped(params: Parametrization, *args: Any, **kwargs: Any) -> Any: |
| 655 | + base_params = base_family.to_base(params) |
| 656 | + return provider(base_params, *args, **kwargs) |
| 657 | + |
| 658 | + return wrapped |
| 659 | + |
| 660 | + for char_name, char_map in base_family.distr_characteristics.items(): |
| 661 | + if self._fixed_in_param in char_map: |
| 662 | + providers = char_map[self._fixed_in_param] |
| 663 | + view_chars[char_name] = {self._fixed_in_param: dict(providers.items())} |
| 664 | + elif original_base in char_map: |
| 665 | + original_provider = char_map[original_base] |
| 666 | + wrapped = { |
| 667 | + label: wrap_provider(provider) if callable(provider) else provider |
| 668 | + for label, provider in original_provider.items() |
| 669 | + } |
| 670 | + view_chars[char_name] = {self._fixed_in_param: wrapped} |
| 671 | + |
| 672 | + return view_chars |
| 673 | + |
| 674 | + def is_complete(self) -> bool: |
| 675 | + """Check whether all parameters of the fixed parametrization are already fixed.""" |
| 676 | + required_fields = set(getattr(self._param_class, "__dataclass_fields__", {}).keys()) |
| 677 | + return required_fields.issubset(self._fixed_params) |
| 678 | + |
| 679 | + def distribution( |
| 680 | + self, |
| 681 | + parametrization_name: str | None = None, |
| 682 | + sampling_strategy: SamplingStrategy | None = None, |
| 683 | + computation_strategy: ComputationStrategy | None = None, |
| 684 | + **kwargs: Any, |
| 685 | + ) -> ParametricFamilyDistribution: |
| 686 | + target = parametrization_name or self._fixed_in_param |
| 687 | + if target != self._fixed_in_param: |
| 688 | + raise ValueError( |
| 689 | + f"Only parametrization '{self._fixed_in_param}' is available in this view. " |
| 690 | + "Please omit 'parametrization_name' or use the fixed one." |
| 691 | + ) |
| 692 | + for key, fixed_val in self._fixed_params.items(): |
| 693 | + if key in kwargs and kwargs[key] != fixed_val: |
| 694 | + raise ValueError( |
| 695 | + f"Parameter '{key}' is fixed to {fixed_val}, but got {kwargs[key]}" |
| 696 | + ) |
| 697 | + all_params = {**self._fixed_params, **kwargs} |
| 698 | + return super().distribution( |
| 699 | + parametrization_name=target, |
| 700 | + sampling_strategy=sampling_strategy, |
| 701 | + computation_strategy=computation_strategy, |
| 702 | + **all_params, |
| 703 | + ) |
| 704 | + |
| 705 | + def view( |
| 706 | + self, |
| 707 | + *, |
| 708 | + parametrization_name: str | None = None, |
| 709 | + **additional_params: Any, |
| 710 | + ) -> PartialParametricFamily: |
| 711 | + if parametrization_name is not None and parametrization_name != self._fixed_in_param: |
| 712 | + raise ValueError( |
| 713 | + f"Cannot change parametrization. Current fixed parametrization is " |
| 714 | + f"'{self._fixed_in_param}'. Use the same or omit the argument." |
| 715 | + ) |
| 716 | + for key, fixed_val in self._fixed_params.items(): |
| 717 | + if key in additional_params and additional_params[key] != fixed_val: |
| 718 | + raise ValueError( |
| 719 | + f"Parameter '{key}' is fixed to {fixed_val}, but got {additional_params[key]}" |
| 720 | + ) |
| 721 | + new_fixed = {**self._fixed_params, **additional_params} |
| 722 | + return PartialParametricFamily(self._base_family, new_fixed, self._fixed_in_param) |
| 723 | + |
454 | 724 | __call__ = distribution |
0 commit comments