Skip to content

Commit 57757d7

Browse files
committed
Support overrides in annotated attributes
1 parent 309e9d1 commit 57757d7

5 files changed

Lines changed: 175 additions & 22 deletions

File tree

src/cattrs/cols.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,15 @@
55
from collections import defaultdict
66
from collections.abc import Callable, Iterable
77
from functools import partial
8-
from typing import (
9-
TYPE_CHECKING,
10-
Any,
11-
DefaultDict,
12-
Literal,
13-
NamedTuple,
14-
TypeVar,
15-
get_type_hints,
16-
)
8+
from typing import TYPE_CHECKING, Any, DefaultDict, Literal, NamedTuple, TypeVar
179

1810
from attrs import NOTHING, Attribute, NothingType
1911

2012
from ._compat import (
2113
ANIES,
2214
AbcSet,
2315
get_args,
16+
get_full_type_hints,
2417
get_origin,
2518
is_bare,
2619
is_frozenset,
@@ -246,7 +239,7 @@ def _namedtuple_to_attrs(cl: type[tuple]) -> list[Attribute]:
246239
type=a,
247240
alias=name,
248241
)
249-
for name, a in get_type_hints(cl).items()
242+
for name, a in get_full_type_hints(cl).items()
250243
]
251244

252245

src/cattrs/gen/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ._consts import AttributeOverride, already_generating, neutral
3434
from ._generics import generate_mapping
3535
from ._lc import generate_unique_filename
36-
from ._shared import find_structure_handler
36+
from ._shared import _annotated_override_or_default, find_structure_handler
3737

3838
if TYPE_CHECKING:
3939
from ..converters import BaseConverter
@@ -117,7 +117,9 @@ def make_dict_unstructure_fn_from_attrs(
117117

118118
for a in attrs:
119119
attr_name = a.name
120-
override = kwargs.get(attr_name, neutral)
120+
override = kwargs.get(
121+
attr_name, _annotated_override_or_default(a.type, neutral)
122+
)
121123
if override.omit:
122124
continue
123125
if override.omit is None and not a.init and not _cattrs_include_init_false:
@@ -408,7 +410,7 @@ def make_dict_structure_fn_from_attrs(
408410
internal_arg_parts["__c_avn"] = AttributeValidationNote
409411
for a in attrs:
410412
an = a.name
411-
override = kwargs.get(an, neutral)
413+
override = kwargs.get(an, _annotated_override_or_default(a.type, neutral))
412414
if override.omit:
413415
continue
414416
if override.omit is None and not a.init and not _cattrs_include_init_false:
@@ -539,7 +541,7 @@ def make_dict_structure_fn_from_attrs(
539541
# The first loop deals with required args.
540542
for a in attrs:
541543
an = a.name
542-
override = kwargs.get(an, neutral)
544+
override = kwargs.get(an, _annotated_override_or_default(a.type, neutral))
543545
if override.omit:
544546
continue
545547
if override.omit is None and not a.init and not _cattrs_include_init_false:
@@ -614,7 +616,9 @@ def make_dict_structure_fn_from_attrs(
614616

615617
for a in non_required:
616618
an = a.name
617-
override = kwargs.get(an, neutral)
619+
override = kwargs.get(
620+
an, _annotated_override_or_default(a.type, neutral)
621+
)
618622
t = a.type
619623
if isinstance(t, TypeVar):
620624
t = typevar_map.get(t.__name__, t)

src/cattrs/gen/_shared.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,31 @@
44

55
from attrs import NOTHING, Attribute, Factory
66

7-
from .._compat import is_bare_final
7+
from .._compat import get_args, is_annotated, is_bare_final
88
from ..dispatch import StructureHook
99
from ..errors import StructureHandlerNotFoundError
1010
from ..fns import raise_error
11+
from ._consts import AttributeOverride
1112

1213
if TYPE_CHECKING:
1314
from ..converters import BaseConverter
1415

1516

17+
def _annotated_override_or_default(
18+
type: Any, default: AttributeOverride
19+
) -> AttributeOverride:
20+
"""
21+
If the type is Annotated containing an AttributeOverride, return it.
22+
Otherwise, return the default.
23+
"""
24+
if is_annotated(type):
25+
for arg in get_args(type):
26+
if isinstance(arg, AttributeOverride):
27+
return arg
28+
29+
return default
30+
31+
1632
def find_structure_handler(
1733
a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False
1834
) -> StructureHook | None:

src/cattrs/gen/typeddicts.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ._consts import already_generating, neutral
3030
from ._generics import generate_mapping
3131
from ._lc import generate_unique_filename
32-
from ._shared import find_structure_handler
32+
from ._shared import _annotated_override_or_default, find_structure_handler
3333

3434
if TYPE_CHECKING:
3535
from ..converters import BaseConverter
@@ -102,7 +102,9 @@ def make_dict_unstructure_fn(
102102
# * all attributes resolve to `converter._unstructure_identity`
103103
for a in attrs:
104104
attr_name = a.name
105-
override = kwargs.get(attr_name, neutral)
105+
override = kwargs.get(
106+
attr_name, _annotated_override_or_default(a.type, neutral)
107+
)
106108
if override != neutral:
107109
break
108110
handler = None
@@ -135,7 +137,9 @@ def make_dict_unstructure_fn(
135137

136138
for ix, a in enumerate(attrs):
137139
attr_name = a.name
138-
override = kwargs.get(attr_name, neutral)
140+
override = kwargs.get(
141+
attr_name, _annotated_override_or_default(a.type, neutral)
142+
)
139143
if override.omit:
140144
lines.append(f" res.pop('{attr_name}', None)")
141145
continue
@@ -319,7 +323,7 @@ def make_dict_structure_fn(
319323
for ix, a in enumerate(attrs):
320324
an = a.name
321325
attr_required = an in req_keys
322-
override = kwargs.get(an, neutral)
326+
override = kwargs.get(an, _annotated_override_or_default(a.type, neutral))
323327
if override.omit:
324328
continue
325329
t = a.type
@@ -392,7 +396,7 @@ def make_dict_structure_fn(
392396
for ix, a in enumerate(attrs):
393397
an = a.name
394398
attr_required = an in req_keys
395-
override = kwargs.get(an, neutral)
399+
override = kwargs.get(an, _annotated_override_or_default(a.type, neutral))
396400
if override.omit:
397401
continue
398402
if not attr_required:
@@ -441,7 +445,9 @@ def make_dict_structure_fn(
441445
if non_required:
442446
for ix, a in non_required:
443447
an = a.name
444-
override = kwargs.get(an, neutral)
448+
override = kwargs.get(
449+
an, _annotated_override_or_default(a.type, neutral)
450+
)
445451
t = a.type
446452

447453
nrb = get_notrequired_base(t)

tests/test_annotated_overrides.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from dataclasses import dataclass
2+
from typing import Annotated, NamedTuple, TypedDict
3+
4+
from attrs import define
5+
6+
from cattrs import Converter
7+
from cattrs.cols import (
8+
is_namedtuple,
9+
namedtuple_dict_structure_factory,
10+
namedtuple_dict_unstructure_factory,
11+
)
12+
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
13+
from cattrs.gen.typeddicts import make_dict_structure_fn as make_td_structure_fn
14+
from cattrs.gen.typeddicts import make_dict_unstructure_fn as make_td_unstructure_fn
15+
16+
17+
def test_annotated_override_attrs(genconverter: Converter):
18+
"""Annotated overrides work for attrs classes."""
19+
20+
@define
21+
class A:
22+
a: Annotated[int, override(rename="b")]
23+
c: Annotated[int, override(omit=True)] = 1
24+
d: Annotated[int, override(rename="e")] = 2
25+
26+
instance = A(1)
27+
# 'a' is renamed to 'b', 'c' is omitted. 'd' is default so present as 'e'
28+
assert genconverter.unstructure(instance) == {"b": 1, "e": 2}
29+
30+
assert genconverter.structure({"b": 1, "e": 2}, A) == A(1)
31+
32+
33+
def test_annotated_override_dataclasses(genconverter: Converter):
34+
"""Annotated overrides work for dataclasses."""
35+
36+
@dataclass
37+
class A:
38+
a: Annotated[int, override(rename="b")]
39+
c: Annotated[int, override(omit=True)] = 1
40+
41+
instance = A(1)
42+
assert genconverter.unstructure(instance) == {"b": 1}
43+
44+
assert genconverter.structure({"b": 1}, A) == A(1)
45+
46+
47+
def test_annotated_override_typeddict(genconverter: Converter):
48+
"""Annotated overrides work for TypedDicts."""
49+
50+
class TD(TypedDict):
51+
a: Annotated[int, override(rename="b")]
52+
c: Annotated[int, override(omit=True)]
53+
54+
instance: TD = {"a": 1, "c": 2}
55+
56+
assert genconverter.unstructure(instance, TD) == {"b": 1}
57+
58+
# Let's simplify and just test rename for now to avoid required field issues with omit.
59+
class TD2(TypedDict):
60+
a: Annotated[int, override(rename="b")]
61+
62+
inst2: TD2 = {"a": 1}
63+
assert genconverter.unstructure(inst2, TD2) == {"b": 1}
64+
assert genconverter.structure({"b": 1}, TD2) == {"a": 1}
65+
66+
67+
def test_annotated_override_namedtuple(genconverter: Converter):
68+
"""Annotated overrides work for NamedTuples using dict factories."""
69+
70+
# We need to register the dict factories for NamedTuples
71+
genconverter.register_unstructure_hook_factory(
72+
is_namedtuple, namedtuple_dict_unstructure_factory
73+
)
74+
genconverter.register_structure_hook_factory(
75+
is_namedtuple, namedtuple_dict_structure_factory
76+
)
77+
78+
class NT(NamedTuple):
79+
a: Annotated[int, override(rename="b")]
80+
c: Annotated[int, override(omit=True)] = 1
81+
82+
instance = NT(1)
83+
assert genconverter.unstructure(instance) == {"b": 1}
84+
assert genconverter.structure({"b": 1}, NT) == NT(1)
85+
86+
87+
def test_annotated_override_precedence(genconverter: Converter):
88+
"""Test that explicit kwargs override Annotated metadata."""
89+
90+
@define
91+
class A:
92+
a: Annotated[int, override(rename="b")]
93+
94+
# Override the rename back to 'a' explicitly
95+
unstructure_fn = make_dict_unstructure_fn(A, genconverter, a=override(rename="a"))
96+
genconverter.register_unstructure_hook(A, unstructure_fn)
97+
98+
assert genconverter.unstructure(A(1)) == {"a": 1}
99+
100+
# # Structure override
101+
structure_fn = make_dict_structure_fn(A, genconverter, a=override(rename="a"))
102+
genconverter.register_structure_hook(A, structure_fn)
103+
104+
assert genconverter.structure({"a": 1}, A) == A(1)
105+
106+
107+
def test_annotated_override_hooks(genconverter: Converter):
108+
"""struct_hook and unstruct_hook work in Annotated."""
109+
110+
def double_hook(v):
111+
return v * 2
112+
113+
def half_hook(v, _):
114+
return v // 2
115+
116+
@define
117+
class A:
118+
a: Annotated[int, override(unstruct_hook=double_hook, struct_hook=half_hook)]
119+
120+
assert genconverter.unstructure(A(10)) == {"a": 20}
121+
assert genconverter.structure({"a": 20}, A) == A(10)
122+
123+
124+
def test_annotated_override_omit_if_default(genconverter: Converter):
125+
"""omit_if_default works in Annotated."""
126+
127+
@define
128+
class A:
129+
a: Annotated[int, override(omit_if_default=True)] = 0
130+
b: int = 1
131+
132+
# a matches default, should be omitted. b matches default but no override, should stay (default behavior is to keep)
133+
assert genconverter.unstructure(A()) == {"b": 1}
134+
assert genconverter.unstructure(A(a=1)) == {"a": 1, "b": 1}

0 commit comments

Comments
 (0)