Skip to content

Commit 66533c9

Browse files
vimotapre-commit-ci[bot]svlandegpre-commit-ci-lite[bot]tiangolo
authored
🐛 Fix support for Annotated fields with Pydantic 2.12+ (#1607)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <svlandeg@github.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
1 parent 5611bda commit 66533c9

File tree

4 files changed

+231
-17
lines changed

4 files changed

+231
-17
lines changed

sqlmodel/main.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import uuid
66
import weakref
77
from collections.abc import Mapping, Sequence, Set
8+
from dataclasses import dataclass
89
from datetime import date, datetime, time, timedelta
910
from decimal import Decimal
1011
from enum import Enum
@@ -200,6 +201,38 @@ def __init__(
200201
self.sa_relationship_kwargs = sa_relationship_kwargs
201202

202203

204+
@dataclass
205+
class FieldInfoMetadata:
206+
primary_key: Union[bool, UndefinedType] = Undefined
207+
nullable: Union[bool, UndefinedType] = Undefined
208+
foreign_key: Any = Undefined
209+
ondelete: Union[OnDeleteType, UndefinedType] = Undefined
210+
unique: Union[bool, UndefinedType] = Undefined
211+
index: Union[bool, UndefinedType] = Undefined
212+
sa_type: Union[type[Any], UndefinedType] = Undefined
213+
sa_column: Union[Column[Any], UndefinedType] = Undefined
214+
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
215+
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
216+
217+
218+
def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
219+
metadata_items = getattr(field_info, "metadata", None)
220+
if metadata_items:
221+
for meta in metadata_items:
222+
if isinstance(meta, FieldInfoMetadata):
223+
return meta
224+
return None
225+
226+
227+
def _get_sqlmodel_field_value(
228+
field_info: Any, attribute: str, default: Any = Undefined
229+
) -> Any:
230+
metadata = _get_sqlmodel_field_metadata(field_info)
231+
if metadata is not None and hasattr(metadata, attribute):
232+
return getattr(metadata, attribute)
233+
return getattr(field_info, attribute, default)
234+
235+
203236
# include sa_type, sa_column_args, sa_column_kwargs
204237
@overload
205238
def Field(
@@ -423,6 +456,20 @@ def Field(
423456
default_factory=default_factory,
424457
**field_info_kwargs,
425458
)
459+
field_metadata = FieldInfoMetadata(
460+
primary_key=primary_key,
461+
nullable=nullable,
462+
foreign_key=foreign_key,
463+
ondelete=ondelete,
464+
unique=unique,
465+
index=index,
466+
sa_type=sa_type,
467+
sa_column=sa_column,
468+
sa_column_args=sa_column_args,
469+
sa_column_kwargs=sa_column_kwargs,
470+
)
471+
if hasattr(field_info, "metadata"):
472+
field_info.metadata.append(field_metadata)
426473
return field_info
427474

428475

@@ -637,7 +684,7 @@ def __init__(
637684

638685
def get_sqlalchemy_type(field: Any) -> Any:
639686
field_info = field
640-
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
687+
sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) # noqa: B009
641688
if sa_type is not Undefined:
642689
return sa_type
643690

@@ -691,39 +738,39 @@ def get_sqlalchemy_type(field: Any) -> Any:
691738

692739
def get_column_from_field(field: Any) -> Column: # type: ignore
693740
field_info = field
694-
sa_column = getattr(field_info, "sa_column", Undefined)
741+
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
695742
if isinstance(sa_column, Column):
696743
return sa_column
697744
sa_type = get_sqlalchemy_type(field)
698-
primary_key = getattr(field_info, "primary_key", Undefined)
745+
primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
699746
if primary_key is Undefined:
700747
primary_key = False
701-
index = getattr(field_info, "index", Undefined)
748+
index = _get_sqlmodel_field_value(field_info, "index", Undefined)
702749
if index is Undefined:
703750
index = False
704751
nullable = not primary_key and is_field_noneable(field)
705752
# Override derived nullability if the nullable property is set explicitly
706753
# on the field
707-
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
754+
field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
708755
if field_nullable is not Undefined:
709756
assert not isinstance(field_nullable, UndefinedType)
710757
nullable = field_nullable
711758
args = []
712-
foreign_key = getattr(field_info, "foreign_key", Undefined)
759+
foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
713760
if foreign_key is Undefined:
714761
foreign_key = None
715-
unique = getattr(field_info, "unique", Undefined)
762+
unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
716763
if unique is Undefined:
717764
unique = False
718765
if foreign_key:
719-
if field_info.ondelete == "SET NULL" and not nullable:
766+
ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
767+
if ondelete_value is Undefined:
768+
ondelete_value = None
769+
if ondelete_value == "SET NULL" and not nullable:
720770
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
721771
assert isinstance(foreign_key, str)
722-
ondelete = getattr(field_info, "ondelete", Undefined)
723-
if ondelete is Undefined:
724-
ondelete = None
725-
assert isinstance(ondelete, (str, type(None))) # for typing
726-
args.append(ForeignKey(foreign_key, ondelete=ondelete))
772+
assert isinstance(ondelete_value, (str, type(None))) # for typing
773+
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
727774
kwargs = {
728775
"primary_key": primary_key,
729776
"nullable": nullable,
@@ -737,10 +784,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
737784
sa_default = field_info.default
738785
if sa_default is not Undefined:
739786
kwargs["default"] = sa_default
740-
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
787+
sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
741788
if sa_column_args is not Undefined:
742789
args.extend(list(cast(Sequence[Any], sa_column_args)))
743-
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
790+
sa_column_kwargs = _get_sqlmodel_field_value(
791+
field_info, "sa_column_kwargs", Undefined
792+
)
744793
if sa_column_kwargs is not Undefined:
745794
kwargs.update(cast(dict[Any, Any], sa_column_kwargs))
746795
return Column(sa_type, *args, **kwargs) # type: ignore

tests/test_field_sa_column.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Annotated, Optional
22

33
import pytest
44
from sqlalchemy import Column, Integer, String
@@ -17,6 +17,17 @@ class Item(SQLModel, table=True):
1717
assert isinstance(Item.id.type, String) # type: ignore
1818

1919

20+
def test_sa_column_with_annotated_metadata() -> None:
21+
class Item(SQLModel, table=True):
22+
id: Annotated[Optional[int], "meta"] = Field(
23+
default=None,
24+
sa_column=Column(String, primary_key=True, nullable=False),
25+
)
26+
27+
assert Item.id.nullable is False # type: ignore
28+
assert isinstance(Item.id.type, String) # type: ignore
29+
30+
2031
def test_sa_column_no_sa_args() -> None:
2132
with pytest.raises(RuntimeError):
2233

tests/test_future_annotations.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
from typing import Annotated, Optional
4+
5+
from sqlmodel import Field, Session, SQLModel, create_engine, select
6+
7+
8+
def test_model_with_future_annotations(clear_sqlmodel):
9+
class Hero(SQLModel, table=True):
10+
id: Annotated[Optional[int], Field(primary_key=True)] = None
11+
name: str
12+
secret_name: str
13+
age: Optional[int] = None
14+
15+
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
16+
17+
engine = create_engine("sqlite://")
18+
SQLModel.metadata.create_all(engine)
19+
20+
with Session(engine) as session:
21+
session.add(hero)
22+
session.commit()
23+
session.refresh(hero)
24+
25+
assert hero.id is not None
26+
assert hero.name == "Deadpond"
27+
assert hero.secret_name == "Dive Wilson"
28+
assert hero.age == 25
29+
30+
with Session(engine) as session:
31+
heroes = session.exec(select(Hero)).all()
32+
assert len(heroes) == 1
33+
assert heroes[0].name == "Deadpond"
34+
35+
36+
def test_model_with_string_annotations(clear_sqlmodel):
37+
class Team(SQLModel, table=True):
38+
id: Annotated[Optional[int], Field(primary_key=True)] = None
39+
name: str
40+
41+
class Player(SQLModel, table=True):
42+
id: Annotated[Optional[int], Field(primary_key=True)] = None
43+
name: str
44+
team_id: Annotated[Optional[int], Field(foreign_key="team.id")] = None
45+
46+
engine = create_engine("sqlite://")
47+
SQLModel.metadata.create_all(engine)
48+
49+
team = Team(name="Champions")
50+
player = Player(name="Alice", team_id=None)
51+
52+
with Session(engine) as session:
53+
session.add(team)
54+
session.commit()
55+
session.refresh(team)
56+
57+
player.team_id = team.id
58+
session.add(player)
59+
session.commit()
60+
session.refresh(player)
61+
62+
assert team.id is not None
63+
assert player.team_id == team.id

tests/test_main.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Annotated, Optional
22

33
import pytest
44
from sqlalchemy.exc import IntegrityError
@@ -125,3 +125,94 @@ class Hero(SQLModel, table=True):
125125
# The next statement should not raise an AttributeError
126126
assert hero_rusty_man.team
127127
assert hero_rusty_man.team.name == "Preventers"
128+
129+
130+
def test_composite_primary_key(clear_sqlmodel):
131+
class UserPermission(SQLModel, table=True):
132+
user_id: int = Field(primary_key=True)
133+
resource_id: int = Field(primary_key=True)
134+
permission: str
135+
136+
engine = create_engine("sqlite://")
137+
SQLModel.metadata.create_all(engine)
138+
139+
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
140+
assert pk_column_names == {"user_id", "resource_id"}
141+
142+
with Session(engine) as session:
143+
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
144+
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
145+
session.add(perm1)
146+
session.add(perm2)
147+
session.commit()
148+
149+
with pytest.raises(IntegrityError):
150+
with Session(engine) as session:
151+
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
152+
session.add(perm3)
153+
session.commit()
154+
155+
156+
def test_composite_primary_key_and_validator(clear_sqlmodel):
157+
from pydantic import AfterValidator
158+
159+
def validate_resource_id(value: int) -> int:
160+
if value < 1:
161+
raise ValueError("Resource ID must be positive")
162+
return value
163+
164+
class UserPermission(SQLModel, table=True):
165+
user_id: int = Field(primary_key=True)
166+
resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(
167+
primary_key=True
168+
)
169+
permission: str
170+
171+
engine = create_engine("sqlite://")
172+
SQLModel.metadata.create_all(engine)
173+
174+
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
175+
assert pk_column_names == {"user_id", "resource_id"}
176+
177+
with Session(engine) as session:
178+
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
179+
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
180+
session.add(perm1)
181+
session.add(perm2)
182+
session.commit()
183+
184+
with pytest.raises(IntegrityError):
185+
with Session(engine) as session:
186+
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
187+
session.add(perm3)
188+
session.commit()
189+
190+
191+
def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):
192+
from pydantic import AfterValidator
193+
194+
def ensure_positive(value: int) -> int:
195+
if value < 0:
196+
raise ValueError("Team ID must be positive")
197+
return value
198+
199+
class Team(SQLModel, table=True):
200+
id: int = Field(primary_key=True)
201+
name: str
202+
203+
class Hero(SQLModel, table=True):
204+
id: int = Field(primary_key=True)
205+
team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(
206+
foreign_key="team.id",
207+
ondelete="CASCADE",
208+
)
209+
name: str
210+
211+
engine = create_engine("sqlite://")
212+
SQLModel.metadata.create_all(engine)
213+
214+
team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined]
215+
foreign_keys = list(team_id_column.foreign_keys)
216+
assert len(foreign_keys) == 1
217+
assert foreign_keys[0].ondelete == "CASCADE"
218+
assert team_id_column.nullable is False

0 commit comments

Comments
 (0)