Skip to content

Commit 59144c2

Browse files
committed
Add strict param to Field, add tests
1 parent 5611bda commit 59144c2

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

sqlmodel/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import builtins
44
import ipaddress
55
import uuid
6+
import warnings
67
import weakref
78
from collections.abc import Mapping, Sequence, Set
89
from datetime import date, datetime, time, timedelta
@@ -228,6 +229,7 @@ def Field(
228229
max_length: Optional[int] = None,
229230
allow_mutation: bool = True,
230231
regex: Optional[str] = None,
232+
strict: Optional[bool] = None,
231233
discriminator: Optional[str] = None,
232234
repr: bool = True,
233235
primary_key: Union[bool, UndefinedType] = Undefined,
@@ -271,6 +273,7 @@ def Field(
271273
max_length: Optional[int] = None,
272274
allow_mutation: bool = True,
273275
regex: Optional[str] = None,
276+
strict: Optional[bool] = None,
274277
discriminator: Optional[str] = None,
275278
repr: bool = True,
276279
primary_key: Union[bool, UndefinedType] = Undefined,
@@ -323,6 +326,7 @@ def Field(
323326
max_length: Optional[int] = None,
324327
allow_mutation: bool = True,
325328
regex: Optional[str] = None,
329+
strict: Optional[bool] = None,
326330
discriminator: Optional[str] = None,
327331
repr: bool = True,
328332
sa_column: Union[Column[Any], UndefinedType] = Undefined,
@@ -356,6 +360,7 @@ def Field(
356360
max_length: Optional[int] = None,
357361
allow_mutation: bool = True,
358362
regex: Optional[str] = None,
363+
strict: Optional[bool] = None,
359364
discriminator: Optional[str] = None,
360365
repr: bool = True,
361366
primary_key: Union[bool, UndefinedType] = Undefined,
@@ -371,9 +376,16 @@ def Field(
371376
schema_extra: Optional[dict[str, Any]] = None,
372377
) -> Any:
373378
current_schema_extra = schema_extra or {}
379+
380+
for param_name in ("strict",):
381+
if param_name in current_schema_extra:
382+
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
383+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
384+
374385
# Extract possible alias settings from schema_extra so we can control precedence
375386
schema_validation_alias = current_schema_extra.pop("validation_alias", None)
376387
schema_serialization_alias = current_schema_extra.pop("serialization_alias", None)
388+
current_strict = strict or current_schema_extra.pop("strict", None)
377389
field_info_kwargs = {
378390
"alias": alias,
379391
"title": title,
@@ -395,6 +407,7 @@ def Field(
395407
"max_length": max_length,
396408
"allow_mutation": allow_mutation,
397409
"regex": regex,
410+
"strict": current_strict,
398411
"discriminator": discriminator,
399412
"repr": repr,
400413
"primary_key": primary_key,

tests/test_pydantic/test_field.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
from pydantic import ValidationError
6-
from sqlmodel import Field, SQLModel
6+
from sqlmodel import Field, Session, SQLModel, create_engine
77

88

99
def test_decimal():
@@ -54,3 +54,49 @@ class Model(SQLModel):
5454

5555
instance = Model(id=123, foo="bar")
5656
assert "foo=" not in repr(instance)
57+
58+
59+
def test_strict():
60+
class Model(SQLModel):
61+
id: Optional[int] = Field(default=None, primary_key=True)
62+
val: int
63+
val_strict: int = Field(strict=True)
64+
65+
class ModelDB(Model, table=True):
66+
pass
67+
68+
Model(val=123, val_strict=456)
69+
Model(val="123", val_strict=456)
70+
71+
with pytest.raises(ValidationError):
72+
Model(val=123, val_strict="456")
73+
74+
engine = create_engine("sqlite://", echo=True)
75+
76+
SQLModel.metadata.create_all(engine)
77+
78+
model = ModelDB(val=123, val_strict=456)
79+
with Session(engine) as session:
80+
session.add(model)
81+
session.commit()
82+
session.refresh(model)
83+
84+
assert model.val == 123
85+
assert model.val_strict == 456
86+
87+
88+
def test_strict_via_schema_extra(): # Current workaround. Remove after some time
89+
with pytest.warns(
90+
DeprecationWarning,
91+
match="Pass `strict` parameter directly to Field instead of passing it via `schema_extra`",
92+
):
93+
94+
class Model(SQLModel):
95+
val: int
96+
val_strict: int = Field(schema_extra={"strict": True})
97+
98+
Model(val=123, val_strict=456)
99+
Model(val="123", val_strict=456)
100+
101+
with pytest.raises(ValidationError):
102+
Model(val=123, val_strict="456")

0 commit comments

Comments
 (0)