Skip to content

Commit 3eaef31

Browse files
committed
reconstruct PR 1439
1 parent dd51746 commit 3eaef31

3 files changed

Lines changed: 210 additions & 1 deletion

File tree

sqlmodel/_compat.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Annotated,
1010
Any,
1111
ForwardRef,
12+
Literal,
1213
Optional,
1314
TypeVar,
1415
Union,
@@ -64,6 +65,36 @@ def _is_union_type(t: Any) -> bool:
6465
return t is UnionType or t is Union
6566

6667

68+
def get_literal_annotation_info(
69+
annotation: Any,
70+
) -> Optional[tuple[type[Any], tuple[Any, ...]]]:
71+
if annotation is None or get_origin(annotation) is None:
72+
return None
73+
origin = get_origin(annotation)
74+
if origin is Annotated:
75+
return get_literal_annotation_info(get_args(annotation)[0])
76+
if _is_union_type(origin):
77+
bases = get_args(annotation)
78+
if len(bases) > 2:
79+
raise ValueError("Cannot have a Union with more than 2 members")
80+
if bases[0] is not NoneType and bases[1] is not NoneType:
81+
raise ValueError("Cannot have a Union without None")
82+
use_type = bases[0] if bases[0] is not NoneType else bases[1]
83+
return get_literal_annotation_info(use_type)
84+
if origin is Literal:
85+
literal_args = get_args(annotation)
86+
if not literal_args:
87+
return None
88+
if all(isinstance(arg, bool) for arg in literal_args): # all bools
89+
base_type: type[Any] = bool
90+
elif all(isinstance(arg, int) for arg in literal_args): # all ints
91+
base_type = int
92+
else:
93+
base_type = str
94+
return base_type, tuple(literal_args)
95+
return None
96+
97+
6798
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
6899

69100

@@ -189,6 +220,12 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any:
189220
# Optional unions are allowed
190221
use_type = bases[0] if bases[0] is not NoneType else bases[1]
191222
return get_sa_type_from_type_annotation(use_type)
223+
if origin is Literal:
224+
literal_info = get_literal_annotation_info(annotation)
225+
if literal_info is None:
226+
raise ValueError("Literal without values is not supported")
227+
base_type, _ = literal_info
228+
return base_type
192229
return origin
193230

194231

sqlmodel/main.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pydantic.fields import FieldInfo as PydanticFieldInfo
2828
from sqlalchemy import (
2929
Boolean,
30+
CheckConstraint,
3031
Column,
3132
Date,
3233
DateTime,
@@ -63,6 +64,7 @@
6364
finish_init,
6465
get_annotations,
6566
get_field_metadata,
67+
get_literal_annotation_info,
6668
get_model_fields,
6769
get_relationship_to,
6870
get_sa_type_from_field,
@@ -678,6 +680,31 @@ def __init__(
678680
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
679681
# Tag: 1.4.36
680682
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
683+
table = getattr(cls, "__table__", None)
684+
if table is not None:
685+
# Attach Literal-based value constraints at the database level
686+
for field_name, field in get_model_fields(cls).items():
687+
annotation = getattr(field, "annotation", None)
688+
literal_info = get_literal_annotation_info(annotation)
689+
if literal_info is None:
690+
continue
691+
base_type, values = literal_info
692+
assert base_type in (str, int, bool)
693+
column = table.c.get(field_name)
694+
if column is None:
695+
continue
696+
if base_type is int:
697+
coerced_values = tuple(int(v) for v in values)
698+
elif base_type is bool:
699+
coerced_values = tuple(bool(v) for v in values)
700+
else:
701+
coerced_values = tuple(str(v) for v in values)
702+
constraint_name = f"ck_{table.name}_{field_name}_literal"
703+
constraint = CheckConstraint(
704+
column.in_(coerced_values),
705+
name=constraint_name,
706+
)
707+
table.append_constraint(constraint)
681708
else:
682709
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
683710

tests/test_main.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Annotated, Optional
1+
from typing import Annotated, Literal, Optional, Union
22

33
import pytest
4+
from sqlalchemy import text
45
from sqlalchemy.exc import IntegrityError
56
from sqlalchemy.orm import RelationshipProperty
67
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
@@ -216,3 +217,147 @@ class Hero(SQLModel, table=True):
216217
assert len(foreign_keys) == 1
217218
assert foreign_keys[0].ondelete == "CASCADE"
218219
assert team_id_column.nullable is False
220+
221+
222+
def test_literal_valid_values(clear_sqlmodel, caplog):
223+
"""Test https://github.com/fastapi/sqlmodel/issues/57"""
224+
225+
class Model(SQLModel, table=True):
226+
id: Optional[int] = Field(default=None, primary_key=True)
227+
all_str: Literal["a", "b", "c"]
228+
mixed: Literal["yes", "no", 1, 0]
229+
all_int: Literal[1, 2, 3]
230+
int_bool: Literal[0, 1, True, False]
231+
all_bool: Literal[True, False]
232+
233+
obj = Model(
234+
all_str="a",
235+
mixed="yes",
236+
all_int=1,
237+
int_bool=True,
238+
all_bool=False,
239+
)
240+
241+
engine = create_engine("sqlite://", echo=True)
242+
243+
SQLModel.metadata.create_all(engine)
244+
245+
# Check DDL
246+
assert "all_str VARCHAR NOT NULL" in caplog.text
247+
assert "mixed VARCHAR NOT NULL" in caplog.text
248+
assert "all_int INTEGER NOT NULL" in caplog.text
249+
assert "int_bool INTEGER NOT NULL" in caplog.text
250+
assert "all_bool BOOLEAN NOT NULL" in caplog.text
251+
252+
# Check query
253+
with Session(engine) as session:
254+
session.add(obj)
255+
session.commit()
256+
session.refresh(obj)
257+
assert isinstance(obj.all_str, str)
258+
assert obj.all_str == "a"
259+
assert isinstance(obj.mixed, str)
260+
assert obj.mixed == "yes"
261+
assert isinstance(obj.all_int, int)
262+
assert obj.all_int == 1
263+
assert isinstance(obj.int_bool, int)
264+
assert obj.int_bool == 1
265+
assert isinstance(obj.all_bool, bool)
266+
assert obj.all_bool is False
267+
268+
269+
def test_literal_constraints_invalid_values(clear_sqlmodel):
270+
"""DB should reject values that are not part of the Literal choices."""
271+
272+
class Model(SQLModel, table=True):
273+
id: Optional[int] = Field(default=None, primary_key=True)
274+
all_str: Literal["a", "b", "c"]
275+
mixed: Literal["yes", "no", 1, 0]
276+
all_int: Literal[1, 2, 3]
277+
int_bool: Literal[0, 1, True, False]
278+
all_bool: Literal[True, False]
279+
280+
engine = create_engine("sqlite://")
281+
SQLModel.metadata.create_all(engine)
282+
283+
# Helper to attempt a raw insert that bypasses Pydantic validation so we
284+
# can verify that the database-level CHECK constraints are enforced.
285+
def insert_raw(values: dict[str, object]) -> None:
286+
stmt = text(
287+
"INSERT INTO model (all_str, mixed, all_int, int_bool, all_bool) "
288+
"VALUES (:all_str, :mixed, :all_int, :int_bool, :all_bool)"
289+
).bindparams(**values)
290+
with pytest.raises(IntegrityError):
291+
with Session(engine) as session:
292+
session.exec(stmt)
293+
session.commit()
294+
295+
# Invalid string literal for all_str
296+
insert_raw(
297+
{
298+
"all_str": "z", # invalid, not in {"a","b","c"}
299+
"mixed": "yes",
300+
"all_int": 1,
301+
"int_bool": 1,
302+
"all_bool": 0,
303+
}
304+
)
305+
306+
# Invalid int literal for all_int
307+
insert_raw(
308+
{
309+
"all_str": "a",
310+
"mixed": "yes",
311+
"all_int": 5, # invalid, not in {1,2,3}
312+
"int_bool": 1,
313+
"all_bool": 0,
314+
}
315+
)
316+
317+
# Invalid bool literal for all_bool
318+
insert_raw(
319+
{
320+
"all_str": "a",
321+
"mixed": "yes",
322+
"all_int": 1,
323+
"int_bool": 1,
324+
"all_bool": 2, # invalid boolean value
325+
}
326+
)
327+
328+
329+
def test_literal_optional_and_union_constraints(clear_sqlmodel):
330+
"""Literals inside Optional/Union should also be enforced at the DB level."""
331+
332+
class Model(SQLModel, table=True):
333+
id: Optional[int] = Field(default=None, primary_key=True)
334+
opt_str: Optional[Literal["x", "y"]] = None
335+
union_int: Union[Literal[10, 20], None] = None
336+
337+
engine = create_engine("sqlite://")
338+
SQLModel.metadata.create_all(engine)
339+
340+
# Valid values should be accepted
341+
obj = Model(opt_str="x", union_int=10)
342+
with Session(engine) as session:
343+
session.add(obj)
344+
session.commit()
345+
session.refresh(obj)
346+
assert obj.opt_str == "x"
347+
assert obj.union_int == 10
348+
349+
# Invalid values should be rejected by the database
350+
def insert_raw(values: dict[str, object]) -> None:
351+
stmt = text(
352+
"INSERT INTO model (opt_str, union_int) VALUES (:opt_str, :union_int)"
353+
).bindparams(**values)
354+
with pytest.raises(IntegrityError):
355+
with Session(engine) as session:
356+
session.exec(stmt)
357+
session.commit()
358+
359+
# opt_str not in {"x", "y"}
360+
insert_raw({"opt_str": "z", "union_int": 10})
361+
362+
# union_int not in {10, 20}
363+
insert_raw({"opt_str": "x", "union_int": 30})

0 commit comments

Comments
 (0)