Skip to content

Commit e36e2ea

Browse files
committed
Add generics tests
1 parent a93027c commit e36e2ea

1 file changed

Lines changed: 118 additions & 0 deletions

File tree

tests/test_generics.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from enum import Enum
2+
from typing import Generic, Literal, TypeVar
3+
4+
import pytest
5+
from sqlalchemy import create_engine
6+
from sqlmodel import Field, Session, SQLModel, select
7+
from typing_extensions import assert_type
8+
9+
10+
def test_generic_type_with_bound(clear_sqlmodel) -> None:
11+
TagT = TypeVar("TagT", bound=int)
12+
13+
class HeroFields(SQLModel, Generic[TagT]):
14+
tag: TagT
15+
16+
class Hero(HeroFields[int], table=True):
17+
id: int | None = Field(default=None, primary_key=True)
18+
19+
engine = create_engine("sqlite://")
20+
SQLModel.metadata.create_all(engine)
21+
22+
with Session(engine) as session:
23+
tag_number = 67
24+
hero = Hero(tag=tag_number)
25+
session.add(hero)
26+
27+
hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first()
28+
assert hero is not None
29+
assert hero.tag == tag_number
30+
31+
32+
def test_generic_type_with_constraints(clear_sqlmodel) -> None:
33+
TagT = TypeVar("TagT", int, None)
34+
35+
class HeroFields(SQLModel, Generic[TagT]):
36+
tag: TagT
37+
38+
class Hero(HeroFields[int], table=True):
39+
id: int | None = Field(default=None, primary_key=True)
40+
41+
engine = create_engine("sqlite://")
42+
SQLModel.metadata.create_all(engine)
43+
44+
with Session(engine) as session:
45+
tag_number = 67
46+
hero = Hero(tag=tag_number)
47+
session.add(hero)
48+
49+
hero = session.exec(select(Hero).where(Hero.tag == tag_number)).first()
50+
assert hero is not None
51+
assert hero.tag == tag_number
52+
53+
54+
def test_generic_type_with_multiple_type_constraints_raises_error(
55+
clear_sqlmodel,
56+
) -> None:
57+
with pytest.raises(ValueError):
58+
TagT = TypeVar("TagT", int, str)
59+
60+
class HeroFields(SQLModel, Generic[TagT]):
61+
tag: TagT
62+
63+
class Hero(HeroFields[int], table=True):
64+
id: int | None = Field(default=None, primary_key=True)
65+
66+
67+
def test_discriminated_union_with_generics(clear_sqlmodel) -> None:
68+
AmountRefundedT = TypeVar("AmountRefundedT", bound=int | None)
69+
RejectionMessageT = TypeVar("RejectionMessageT", bound=str | None)
70+
71+
class RefundStatus(str, Enum):
72+
ACCEPTED = "ACCEPTED"
73+
REJECTED = "REJECTED"
74+
75+
DiscriminantT = TypeVar("DiscriminantT", bound=RefundStatus)
76+
77+
class RefundRequestFields(SQLModel, Generic[AmountRefundedT, RejectionMessageT, DiscriminantT]):
78+
item_name: str
79+
amount_refunded: AmountRefundedT
80+
rejection_message: RejectionMessageT
81+
status: DiscriminantT
82+
83+
class RefundRequest(RefundRequestFields[int | None, str | None, RefundStatus], table=True):
84+
id: int | None = Field(default=None, primary_key=True)
85+
status: RefundStatus
86+
87+
class AcceptedRequest(RefundRequestFields[int, None, RefundStatus.ACCEPTED]):
88+
amount_refunded: int
89+
rejection_message: None = None
90+
status: Literal[RefundStatus.ACCEPTED] = RefundStatus.ACCEPTED
91+
92+
class RejectedRequest(RefundRequestFields[None, str, RefundStatus.REJECTED]):
93+
rejection_message: str
94+
amount_refunded: None = None
95+
status: Literal[RefundStatus.REJECTED] = RefundStatus.REJECTED
96+
97+
engine = create_engine("sqlite://")
98+
SQLModel.metadata.create_all(engine)
99+
100+
with Session(engine) as session:
101+
c = RejectedRequest(
102+
item_name="EmptyJuice",
103+
rejection_message="This item cannot be refunded because it has been emptied",
104+
)
105+
session.add(RefundRequest.model_validate(c.model_dump()))
106+
107+
requests = session.exec(
108+
select(RefundRequest).where(
109+
RefundRequest.status == RefundStatus.REJECTED,
110+
)
111+
).all()
112+
rejected_requests = [
113+
RejectedRequest.model_validate(request.model_dump())
114+
for request in requests
115+
if request.status == RefundStatus.REJECTED
116+
]
117+
assert_type(rejected_requests, list[RejectedRequest])
118+
assert len(rejected_requests) == 1

0 commit comments

Comments
 (0)