Skip to content

Commit 48c2abc

Browse files
committed
Fix metadata serialization
In the previous implementation, a Metadata object had two different dict representations, depending on which conversion route was taken. The problem has been captured in tests and is now now fixed throught the conversion hooks. A minimal reproducing example demonstrating the issue: ```python from baybe.utils.metadata import Metadata, to_metadata dct = {"description": "test", "unit": "m", "key": "value"} metadata = to_metadata(dct) print("Dict:\n", dct, "\n") print("Via to_metadata:\n", metadata, "\n") print("Via from_dict:\n", Metadata.from_dict(dct), "\n") print("Via to_dict:\n", metadata.to_dict(), "\n") ```
1 parent 70e676c commit 48c2abc

2 files changed

Lines changed: 58 additions & 14 deletions

File tree

baybe/utils/metadata.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
from typing import Any
66

7+
import cattrs
78
from attrs import define, field, fields
89
from attrs.validators import deep_mapping, instance_of
910
from attrs.validators import optional as optional_v
1011

11-
from baybe.serialization.mixin import SerialMixin
12+
from baybe.serialization import SerialMixin, converter
13+
from baybe.utils.basic import classproperty
1214

1315

1416
@define(frozen=True)
@@ -34,6 +36,20 @@ class Metadata(SerialMixin):
3436
)
3537
"""Additional user-defined metadata."""
3638

39+
@misc.validator
40+
def _validate_misc(self, _, value: dict[str, Any]) -> None:
41+
if inv := set(value).intersection(self._explicit_fields):
42+
raise ValueError(
43+
f"Miscellaneous metadata cannot contain the following fields: {inv}. "
44+
f"Use the corresponding attributes instead."
45+
)
46+
47+
@classproperty
48+
def _explicit_fields(self) -> set[str]:
49+
"""The explicit metadata fields.""" # noqa: D401
50+
flds = fields(Metadata)
51+
return {fld.name for fld in flds if fld.name != flds.misc.name}
52+
3753

3854
def to_metadata(value: dict[str, Any] | Metadata, /) -> Metadata:
3955
"""Convert a dictionary to :class:`Metadata` (with :class:`Metadata` passthrough).
@@ -57,10 +73,22 @@ def to_metadata(value: dict[str, Any] | Metadata, /) -> Metadata:
5773
)
5874

5975
# Separate known fields from unknown ones
60-
flds = fields(Metadata)
61-
value = value.copy()
62-
known_fields = {
63-
fld: value.pop(fld, None) for fld in (flds.description.name, flds.unit.name)
64-
}
76+
return converter.structure(value, Metadata)
77+
78+
79+
@converter.register_structure_hook
80+
def _separate_metadata_fields(dct: dict[str, Any], _: type[Metadata]) -> Metadata:
81+
"""Separate known fields from miscellaneous metadata."""
82+
dct = dct.copy()
83+
explicit = {fld: dct.pop(fld, None) for fld in Metadata._explicit_fields}
84+
return Metadata(**explicit, misc=dct)
85+
6586

66-
return Metadata(**known_fields, misc=value)
87+
@converter.register_unstructure_hook
88+
def _flatten_misc_metadata(metadata: Metadata) -> dict[str, Any]:
89+
"""Flatten the metadata for serialization."""
90+
fn = cattrs.gen.make_dict_unstructure_fn(Metadata, converter)
91+
dct = fn(metadata)
92+
fld = fields(Metadata).misc.name
93+
dct = dct | dct.pop(fld)
94+
return dct

tests/validation/test_metadata_validation.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,33 @@
77

88

99
@pytest.mark.parametrize(
10-
("description", "unit", "misc", "match"),
10+
("description", "unit", "misc", "error", "match"),
1111
[
12-
param(0, None, None, "must be <class 'str'>", id="desc-non-str"),
13-
param(None, 0, None, "must be <class 'str'>", id="unit-non-str"),
14-
param(None, None, 0, "must be <class 'dict'>", id="misc-non-dict"),
15-
param(None, None, {0: 0}, "must be <class 'str'>", id="misc-non-str-keys"),
12+
param(0, None, None, TypeError, "must be <class 'str'>", id="desc-non-str"),
13+
param(None, 0, None, TypeError, "must be <class 'str'>", id="unit-non-str"),
14+
param(None, None, 0, TypeError, "must be <class 'dict'>", id="misc-non-dict"),
15+
param(
16+
None,
17+
None,
18+
{0: 0},
19+
TypeError,
20+
"must be <class 'str'>",
21+
id="misc-non-str-keys",
22+
),
23+
param(
24+
None,
25+
None,
26+
{"description": 0},
27+
ValueError,
28+
"fields: {'description'}",
29+
id="desc",
30+
),
31+
param(None, None, {"unit": 0}, ValueError, "fields: {'unit'}", id="unit"),
1632
],
1733
)
18-
def test_invalid_arguments(description, unit, misc, match):
34+
def test_invalid_arguments(description, unit, misc, error, match):
1935
"""Providing invalid arguments raises an error."""
20-
with pytest.raises(TypeError, match=match):
36+
with pytest.raises(error, match=match):
2137
Metadata(description, unit, misc)
2238

2339

0 commit comments

Comments
 (0)