Skip to content

Commit 25232d1

Browse files
committed
Add validate_default param to Field, add tests
1 parent 51aefc7 commit 25232d1

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

sqlmodel/main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def Field(
231231
allow_mutation: bool = True,
232232
regex: Optional[str] = None,
233233
discriminator: Optional[str] = None,
234+
validate_default: Optional[bool] = None,
234235
repr: bool = True,
235236
primary_key: Union[bool, UndefinedType] = Undefined,
236237
foreign_key: Any = Undefined,
@@ -275,6 +276,7 @@ def Field(
275276
allow_mutation: bool = True,
276277
regex: Optional[str] = None,
277278
discriminator: Optional[str] = None,
279+
validate_default: Optional[bool] = None,
278280
repr: bool = True,
279281
primary_key: Union[bool, UndefinedType] = Undefined,
280282
foreign_key: str,
@@ -328,6 +330,7 @@ def Field(
328330
allow_mutation: bool = True,
329331
regex: Optional[str] = None,
330332
discriminator: Optional[str] = None,
333+
validate_default: Optional[bool] = None,
331334
repr: bool = True,
332335
sa_column: Union[Column[Any], UndefinedType] = Undefined,
333336
schema_extra: Optional[dict[str, Any]] = None,
@@ -362,6 +365,7 @@ def Field(
362365
allow_mutation: bool = True,
363366
regex: Optional[str] = None,
364367
discriminator: Optional[str] = None,
368+
validate_default: Optional[bool] = None,
365369
repr: bool = True,
366370
primary_key: Union[bool, UndefinedType] = Undefined,
367371
foreign_key: Any = Undefined,
@@ -377,7 +381,10 @@ def Field(
377381
) -> Any:
378382
current_schema_extra = schema_extra or {}
379383

380-
for param_name in ("coerce_numbers_to_str",):
384+
for param_name in (
385+
"coerce_numbers_to_str",
386+
"validate_default",
387+
):
381388
if param_name in current_schema_extra:
382389
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
383390
warnings.warn(msg, UserWarning, stacklevel=2)
@@ -388,6 +395,9 @@ def Field(
388395
current_coerce_numbers_to_str = coerce_numbers_to_str or current_schema_extra.pop(
389396
"coerce_numbers_to_str", None
390397
)
398+
current_validate_default = validate_default or current_schema_extra.pop(
399+
"validate_default", None
400+
)
391401
field_info_kwargs = {
392402
"alias": alias,
393403
"title": title,
@@ -396,6 +406,7 @@ def Field(
396406
"include": include,
397407
"const": const,
398408
"coerce_numbers_to_str": current_coerce_numbers_to_str,
409+
"validate_default": current_validate_default,
399410
"gt": gt,
400411
"ge": ge,
401412
"lt": lt,

tests/test_pydantic/test_field.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
from pydantic import ValidationError
6-
from sqlmodel import Field, SQLModel
6+
from sqlmodel import Field, Session, SQLModel, create_engine
77

88

99
def test_decimal():
@@ -87,3 +87,60 @@ class Model(SQLModel):
8787

8888
assert Model.model_validate({"val": 123}).val == "123"
8989
assert Model.model_validate({"val": 45.67}).val == "45.67"
90+
91+
92+
def test_validate_default_true():
93+
class Model(SQLModel):
94+
val: int = Field(default="123", validate_default=True)
95+
96+
assert Model.model_validate({}).val == 123
97+
98+
class Model2(SQLModel):
99+
val: int = Field(default=None, validate_default=True)
100+
101+
with pytest.raises(ValidationError):
102+
Model2.model_validate({})
103+
104+
105+
def test_validate_default_table_model():
106+
class Model(SQLModel):
107+
id: Optional[int] = Field(default=None, primary_key=True)
108+
val: int = Field(default="123", validate_default=True)
109+
110+
class ModelDB(Model, table=True):
111+
pass
112+
113+
engine = create_engine("sqlite://", echo=True)
114+
115+
SQLModel.metadata.create_all(engine)
116+
117+
model = ModelDB()
118+
with Session(engine) as session:
119+
session.add(model)
120+
session.commit()
121+
session.refresh(model)
122+
123+
assert model.val == 123
124+
125+
126+
@pytest.mark.parametrize("validate_default", [None, False])
127+
def test_validate_default_false(validate_default: Optional[bool]):
128+
class Model3(SQLModel):
129+
val: int = Field(default="123", validate_default=validate_default)
130+
131+
assert Model3().val == "123"
132+
133+
134+
def test_validate_default_via_schema_extra(): # Current workaround. Remove after some time
135+
with pytest.warns(
136+
UserWarning,
137+
match=(
138+
"Pass `validate_default` parameter directly to Field instead of passing "
139+
"it via `schema_extra`"
140+
),
141+
):
142+
143+
class Model(SQLModel):
144+
val: int = Field(default="123", schema_extra={"validate_default": True})
145+
146+
assert Model.model_validate({}).val == 123

0 commit comments

Comments
 (0)