Skip to content

Commit c10e7d7

Browse files
authored
Add flexible metadata system (#580)
## Summary This PR implements a flexible metadata system for annotating BayBE objects with a description, unit, and other information. So far integrated into `Parameter` only. ## Usage Example ```python from baybe.parameters import NumericalDiscreteParameter, ParameterMetadata from baybe.utils.metadata import to_metadata # Using dict # (Gets automatically converted and entries are separated into known and misc fields) dct = {"description": "Reaction temperature", "unit": "°C", "custom_field": "value"} param = NumericalDiscreteParameter( name="temperature", values=[90, 105, 120], metadata=dct ) # Using Metadata class directly meta = ParameterMetadata( description="Reaction temperature", unit="°C", misc={"custom_field": "value"} ) param = NumericalDiscreteParameter( name="temperature", values=[90, 105, 120], metadata=meta ) assert meta == to_metadata(dct, ParameterMetadata) == ParameterMetadata.from_dict(dct) # Accessing metadata print(param.description) # "Reaction temperature" print(param.unit) # "°C" ```
2 parents ff93cfc + f3760d9 commit c10e7d7

20 files changed

Lines changed: 568 additions & 20 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ build
1919
# VSCode
2020
.vscode
2121

22+
# Virtual environments
23+
.venv
24+
.env
25+
2226
# Testing
2327
.tox
2428
.coverage
@@ -34,5 +38,6 @@ htmlcov
3438

3539
# Folders that are temporarily created when building the documentation
3640
docs/_autosummary
41+
docs/_build
3742
docs/examples
3843
docs/sdk

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88
### Added
99
- API diagram in user guide
10+
- `Metadata` and `MeasurableMetadata` classes providing optional information for BayBE
11+
objects
12+
- `Objective` now has a `metadata` attribute as well as a `description` property
13+
- `Target` and `Parameter` now have a `metadata` attribute as well as `description` and
14+
`unit` properties
1015

1116
### Fixed
1217
- `Campaign` no longer allows overlapping names between parameters and targets

CONTRIBUTORS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@
3131
- Fabian Liebig (Merck KGaA, Darmstadt, Germany):\
3232
Benchmarking structure and persistence capabilities for benchmarking results
3333
- Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dübendorf, Switzerland):\
34-
SHAP explainers for insights
34+
SHAP explainers for insights
35+
- Tobias Plötz (Merck KGaA, Darmstadt, Germany):\
36+
Metadata system

baybe/objectives/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import cattrs
88
import pandas as pd
9-
from attrs import define
9+
from attrs import define, field
1010

1111
from baybe.serialization.core import (
1212
converter,
@@ -15,6 +15,7 @@
1515
)
1616
from baybe.serialization.mixin import SerialMixin
1717
from baybe.targets.base import Target
18+
from baybe.utils.metadata import Metadata, to_metadata
1819

1920
# TODO: Reactive slots in all classes once cached_property is supported:
2021
# https://github.com/python-attrs/attrs/issues/164
@@ -27,6 +28,18 @@ class Objective(ABC, SerialMixin):
2728
is_multi_output: ClassVar[bool]
2829
"""Class variable indicating if the objective produces multiple outputs."""
2930

31+
metadata: Metadata = field(
32+
factory=Metadata,
33+
converter=lambda x: to_metadata(x, Metadata),
34+
kw_only=True,
35+
)
36+
"""Optional metadata containing description and other information."""
37+
38+
@property
39+
def description(self) -> str | None:
40+
"""The description of the objective."""
41+
return self.metadata.description
42+
3043
@property
3144
@abstractmethod
3245
def targets(self) -> tuple[Target, ...]:

baybe/parameters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
NumericalDiscreteParameter,
1313
)
1414
from baybe.parameters.substance import SubstanceParameter
15+
from baybe.utils.metadata import MeasurableMetadata
1516

1617
__all__ = [
1718
"CategoricalEncoding",
1819
"CategoricalParameter",
1920
"CustomDiscreteParameter",
2021
"CustomEncoding",
22+
"MeasurableMetadata",
2123
"NumericalContinuousParameter",
2224
"NumericalDiscreteParameter",
2325
"SubstanceEncoding",

baybe/parameters/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
unstructure_base,
2323
)
2424
from baybe.utils.basic import to_tuple
25+
from baybe.utils.metadata import MeasurableMetadata, to_metadata
2526

2627
if TYPE_CHECKING:
2728
from baybe.searchspace.continuous import SubspaceContinuous
@@ -48,6 +49,13 @@ class Parameter(ABC, SerialMixin):
4849
name: str = field(validator=(instance_of(str), min_len(1)))
4950
"""The name of the parameter"""
5051

52+
metadata: MeasurableMetadata = field(
53+
factory=MeasurableMetadata,
54+
converter=lambda x: to_metadata(x, MeasurableMetadata),
55+
kw_only=True,
56+
)
57+
"""Optional metadata containing description, unit, and other information."""
58+
5159
@abstractmethod
5260
def is_in_range(self, item: Any) -> bool:
5361
"""Return whether an item is within the parameter range.
@@ -88,6 +96,16 @@ def to_searchspace(self) -> SearchSpace:
8896
def summary(self) -> dict:
8997
"""Return a custom summarization of the parameter."""
9098

99+
@property
100+
def description(self) -> str | None:
101+
"""The description of the parameter."""
102+
return self.metadata.description
103+
104+
@property
105+
def unit(self) -> str | None:
106+
"""The unit of measurement for the parameter."""
107+
return self.metadata.unit
108+
91109

92110
@define(frozen=True, slots=False)
93111
class DiscreteParameter(Parameter, ABC):

baybe/parameters/enum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class CustomEncoding(ParameterEncoding):
2727
class SubstanceEncoding(ParameterEncoding):
2828
"""Available encodings for substance parameters from `scikit-fingerprints`_ package.
2929
30-
.. _scikit-fingerprints: https://scikit-fingerprints.github.io/scikit-fingerprints/
30+
.. _scikit-fingerprints: https://scikit-fingerprints.readthedocs.io/
3131
"""
3232

3333
ATOMPAIR = "ATOMPAIR"

baybe/targets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from baybe.targets.binary import BinaryTarget
44
from baybe.targets.enum import TargetMode, TargetTransformation
55
from baybe.targets.numerical import NumericalTarget
6+
from baybe.utils.metadata import MeasurableMetadata
67

78
__all__ = [
89
"BinaryTarget",
10+
"MeasurableMetadata",
911
"NumericalTarget",
1012
"TargetMode",
1113
"TargetTransformation",

baybe/targets/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_base_structure_hook,
1717
unstructure_base,
1818
)
19+
from baybe.utils.metadata import MeasurableMetadata, to_metadata
1920

2021
if TYPE_CHECKING:
2122
from baybe.objectives import SingleTargetObjective
@@ -31,6 +32,23 @@ class Target(ABC, SerialMixin):
3132
name: str = field()
3233
"""The name of the target."""
3334

35+
metadata: MeasurableMetadata = field(
36+
factory=MeasurableMetadata,
37+
converter=lambda x: to_metadata(x, MeasurableMetadata),
38+
kw_only=True,
39+
)
40+
"""Optional metadata containing description, unit, and other information."""
41+
42+
@property
43+
def description(self) -> str | None:
44+
"""The description of the target."""
45+
return self.metadata.description
46+
47+
@property
48+
def unit(self) -> str | None:
49+
"""The unit of measurement for the target."""
50+
return self.metadata.unit
51+
3452
def to_objective(self) -> SingleTargetObjective:
3553
"""Create a single-task objective from the target."""
3654
from baybe.objectives.single import SingleTargetObjective

baybe/utils/metadata.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Generic metadata system for BayBE objects."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, TypeVar
6+
7+
import cattrs
8+
from attrs import AttrsInstance, define, field, fields
9+
from attrs.validators import deep_mapping, instance_of
10+
from attrs.validators import optional as optional_v
11+
from typing_extensions import override
12+
13+
from baybe.serialization import SerialMixin, converter
14+
from baybe.utils.basic import classproperty
15+
16+
_TMetaData = TypeVar("_TMetaData", bound="Metadata")
17+
18+
19+
@define(frozen=True)
20+
class Metadata(SerialMixin):
21+
"""Metadata class providing basic information for BayBE objects."""
22+
23+
description: str | None = field(
24+
default=None, validator=optional_v(instance_of(str))
25+
)
26+
"""A description of the object."""
27+
28+
misc: dict[str, Any] = field(
29+
factory=dict,
30+
validator=deep_mapping(
31+
mapping_validator=instance_of(dict),
32+
key_validator=instance_of(str),
33+
# FIXME: https://github.com/python-attrs/attrs/issues/1246
34+
value_validator=lambda *x: None,
35+
),
36+
kw_only=True,
37+
)
38+
"""Additional user-defined metadata."""
39+
40+
@misc.validator
41+
def _validate_misc(self, _, value: dict[str, Any]) -> None:
42+
if inv := set(value).intersection(self._explicit_fields):
43+
raise ValueError(
44+
f"Miscellaneous metadata cannot contain the following fields: {inv}. "
45+
f"Use the corresponding attributes instead."
46+
)
47+
48+
@classproperty
49+
def _explicit_fields(cls: type[AttrsInstance]) -> set[str]:
50+
"""The explicit metadata fields.""" # noqa: D401
51+
flds = fields(cls)
52+
return {fld.name for fld in flds if fld.name != flds.misc.name}
53+
54+
@property
55+
def is_empty(self) -> bool:
56+
"""Check if metadata contains any meaningful information."""
57+
return self.description is None and not self.misc
58+
59+
60+
@define(frozen=True)
61+
class MeasurableMetadata(Metadata):
62+
"""Class providing metadata for BayBE :class:`Parameter` objects."""
63+
64+
unit: str | None = field(default=None, validator=optional_v(instance_of(str)))
65+
"""The unit of measurement for the parameter."""
66+
67+
@override
68+
@property
69+
def is_empty(self) -> bool:
70+
"""Check if metadata contains any meaningful information."""
71+
return super().is_empty and self.unit is None
72+
73+
74+
def to_metadata(
75+
value: dict[str, Any] | _TMetaData, cls: type[_TMetaData], /
76+
) -> _TMetaData:
77+
"""Convert a dictionary to :class:`Metadata` (with :class:`Metadata` passthrough).
78+
79+
Args:
80+
value: The metadata input.
81+
cls: The specific :class:`Metadata` subclass to convert to.
82+
83+
Returns:
84+
The created metadata instance of the requested :class:`Metadata` subclass.
85+
86+
Raises:
87+
TypeError: If the input is not a dictionary or of the specified
88+
:class:`Metadata` type.
89+
"""
90+
if isinstance(value, cls):
91+
return value
92+
93+
if not isinstance(value, dict):
94+
raise TypeError(
95+
f"The input must be a dictionary or a '{cls.__name__}' instance. "
96+
f"Got: {type(value)}"
97+
)
98+
99+
# Separate known fields from unknown ones
100+
return converter.structure(value, cls)
101+
102+
103+
@converter.register_structure_hook
104+
def _separate_metadata_fields(dct: dict[str, Any], cls: type[Metadata]) -> Metadata:
105+
"""Separate known fields from miscellaneous metadata."""
106+
dct = dct.copy()
107+
explicit = {fld: dct.pop(fld, None) for fld in cls._explicit_fields}
108+
return cls(**explicit, misc=dct)
109+
110+
111+
@converter.register_unstructure_hook
112+
def _flatten_misc_metadata(metadata: Metadata) -> dict[str, Any]:
113+
"""Flatten the metadata for serialization."""
114+
cls = type(metadata)
115+
fn = cattrs.gen.make_dict_unstructure_fn(cls, converter)
116+
dct = fn(metadata)
117+
dct = dct | dct.pop(fields(Metadata).misc.name)
118+
return dct

0 commit comments

Comments
 (0)