Skip to content

Commit d0e4081

Browse files
✨ Add model_fields_optional to convert inherited fields to Optional
Add support for `model_fields_optional="all"` configuration that makes all inherited fields Optional with a default of None. This enables the common pattern of creating "update" models where all fields are optional, reducing boilerplate when building CRUD APIs. Fixes #64
1 parent 66c2d82 commit d0e4081

3 files changed

Lines changed: 291 additions & 0 deletions

File tree

sqlmodel/_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def partial_init() -> Generator[None, None, None]:
7878
class SQLModelConfig(BaseConfig, total=False):
7979
table: bool | None
8080
registry: Any | None
81+
model_fields_optional: str | None
8182

8283

8384
def get_model_fields(model: InstanceOrType[BaseModel]) -> dict[str, "FieldInfo"]:

sqlmodel/main.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TypeVar,
2020
Union,
2121
cast,
22+
get_args,
2223
get_origin,
2324
overload,
2425
)
@@ -56,10 +57,12 @@
5657
PYDANTIC_MINOR_VERSION,
5758
BaseConfig,
5859
ModelMetaclass,
60+
NoneType,
5961
Representation,
6062
SQLModelConfig,
6163
Undefined,
6264
UndefinedType,
65+
_is_union_type,
6366
finish_init,
6467
get_annotations,
6568
get_field_metadata,
@@ -91,6 +94,17 @@
9194
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]
9295

9396

97+
def _is_optional_annotation(annotation: Any) -> bool:
98+
"""Check if a type annotation is already Optional (i.e., Union[X, None])."""
99+
origin = get_origin(annotation)
100+
if origin is not None and _is_union_type(origin):
101+
args = get_args(annotation)
102+
return NoneType in args
103+
if annotation is NoneType:
104+
return True
105+
return False
106+
107+
94108
def __dataclass_transform__(
95109
*,
96110
eq_default: bool = True,
@@ -560,6 +574,39 @@ def __new__(
560574
relationship_annotations[k] = v
561575
else:
562576
pydantic_annotations[k] = v
577+
578+
# Handle model_fields_optional: make all inherited fields Optional
579+
# with a default of None
580+
model_fields_optional = kwargs.pop("model_fields_optional", None)
581+
if model_fields_optional is None:
582+
# Also check model_config in class_dict
583+
config_dict = class_dict.get("model_config", {})
584+
if isinstance(config_dict, dict):
585+
model_fields_optional = config_dict.get(
586+
"model_fields_optional", None
587+
)
588+
if model_fields_optional == "all":
589+
for base in bases:
590+
base_fields = get_model_fields(base) if hasattr(base, "model_fields") else {}
591+
for field_name, field_info in base_fields.items():
592+
# Only modify fields not explicitly redefined in this class
593+
if field_name not in pydantic_annotations:
594+
ann = field_info.annotation
595+
# Only wrap in Optional if not already Optional
596+
if ann is not None and not _is_optional_annotation(ann):
597+
pydantic_annotations[field_name] = Union[ann, None]
598+
else:
599+
pydantic_annotations[field_name] = ann
600+
# Set default to None if the field was required and
601+
# not already defined in the current class
602+
if field_name not in dict_for_pydantic:
603+
# Copy the FieldInfo to preserve metadata like
604+
# min_length, ge, etc.
605+
new_field_info = field_info._copy()
606+
if new_field_info.is_required():
607+
new_field_info.default = None
608+
dict_for_pydantic[field_name] = new_field_info
609+
563610
dict_used = {
564611
**dict_for_pydantic,
565612
"__weakref__": None,
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from typing import Optional
2+
3+
import pytest
4+
from pydantic import ValidationError
5+
from sqlmodel import Field, SQLModel
6+
from sqlmodel._compat import SQLModelConfig
7+
8+
9+
def test_model_fields_optional_basic(clear_sqlmodel):
10+
"""Test that model_fields_optional='all' makes all inherited fields Optional
11+
with a default of None."""
12+
13+
class HeroBase(SQLModel):
14+
name: str
15+
secret_name: str
16+
age: Optional[int] = None
17+
18+
class HeroUpdate(HeroBase, model_fields_optional="all"):
19+
pass
20+
21+
# All fields should be optional (not required)
22+
for field_info in HeroUpdate.model_fields.values():
23+
assert not field_info.is_required()
24+
25+
# Should be able to create with no arguments
26+
hero = HeroUpdate()
27+
assert hero.name is None
28+
assert hero.secret_name is None
29+
assert hero.age is None
30+
31+
32+
def test_model_fields_optional_partial_data(clear_sqlmodel):
33+
"""Test creating an instance with only some fields set."""
34+
35+
class HeroBase(SQLModel):
36+
name: str
37+
secret_name: str
38+
age: Optional[int] = None
39+
40+
class HeroUpdate(HeroBase, model_fields_optional="all"):
41+
pass
42+
43+
hero = HeroUpdate(name="Spider-Man")
44+
assert hero.name == "Spider-Man"
45+
assert hero.secret_name is None
46+
assert hero.age is None
47+
48+
49+
def test_model_fields_optional_exclude_unset(clear_sqlmodel):
50+
"""Test that model_dump(exclude_unset=True) only includes explicitly set
51+
fields."""
52+
53+
class HeroBase(SQLModel):
54+
name: str
55+
secret_name: str
56+
age: Optional[int] = None
57+
58+
class HeroUpdate(HeroBase, model_fields_optional="all"):
59+
pass
60+
61+
hero = HeroUpdate(name="Spider-Man")
62+
dumped = hero.model_dump(exclude_unset=True)
63+
assert dumped == {"name": "Spider-Man"}
64+
65+
66+
def test_model_fields_optional_override_field(clear_sqlmodel):
67+
"""Test that explicitly redefined fields in the child class are not
68+
overridden by model_fields_optional."""
69+
70+
class HeroBase(SQLModel):
71+
name: str
72+
secret_name: str
73+
age: Optional[int] = None
74+
75+
class HeroUpdate(HeroBase, model_fields_optional="all"):
76+
name: str # Keep name required
77+
78+
# name should still be required
79+
assert HeroUpdate.model_fields["name"].is_required()
80+
# Other fields should be optional
81+
assert not HeroUpdate.model_fields["secret_name"].is_required()
82+
assert not HeroUpdate.model_fields["age"].is_required()
83+
84+
with pytest.raises(ValidationError):
85+
HeroUpdate() # name is required
86+
87+
hero = HeroUpdate(name="Batman")
88+
assert hero.name == "Batman"
89+
assert hero.secret_name is None
90+
91+
92+
def test_model_fields_optional_preserves_constraints(clear_sqlmodel):
93+
"""Test that field constraints (min_length, ge, etc.) are preserved when
94+
making fields optional."""
95+
96+
class HeroBase(SQLModel):
97+
name: str = Field(min_length=1)
98+
age: Optional[int] = Field(default=None, ge=0)
99+
100+
class HeroUpdate(HeroBase, model_fields_optional="all"):
101+
pass
102+
103+
# None should be valid for all fields
104+
hero = HeroUpdate(name=None, age=None)
105+
assert hero.name is None
106+
assert hero.age is None
107+
108+
# Non-None values should still be validated
109+
with pytest.raises(ValidationError):
110+
HeroUpdate(name="") # min_length=1 violated
111+
112+
with pytest.raises(ValidationError):
113+
HeroUpdate(age=-1) # ge=0 violated
114+
115+
# Valid non-None values should work
116+
hero = HeroUpdate(name="X", age=5)
117+
assert hero.name == "X"
118+
assert hero.age == 5
119+
120+
121+
def test_model_fields_optional_multiple_inheritance(clear_sqlmodel):
122+
"""Test model_fields_optional with multiple levels of inheritance."""
123+
124+
class PersonBase(SQLModel):
125+
first_name: str
126+
last_name: str
127+
128+
class EmployeeBase(PersonBase):
129+
employee_id: int
130+
department: str
131+
132+
class EmployeeUpdate(EmployeeBase, model_fields_optional="all"):
133+
pass
134+
135+
# All fields from all base classes should be optional
136+
for field_info in EmployeeUpdate.model_fields.values():
137+
assert not field_info.is_required()
138+
139+
employee = EmployeeUpdate(department="Engineering")
140+
assert employee.department == "Engineering"
141+
assert employee.first_name is None
142+
assert employee.last_name is None
143+
assert employee.employee_id is None
144+
145+
146+
def test_model_fields_optional_via_model_config(clear_sqlmodel):
147+
"""Test model_fields_optional via model_config dict."""
148+
149+
class HeroBase(SQLModel):
150+
name: str
151+
secret_name: str
152+
age: Optional[int] = None
153+
154+
class HeroUpdate(HeroBase):
155+
model_config = SQLModelConfig(model_fields_optional="all")
156+
157+
# All fields should be optional
158+
for field_info in HeroUpdate.model_fields.values():
159+
assert not field_info.is_required()
160+
161+
hero = HeroUpdate()
162+
assert hero.name is None
163+
assert hero.secret_name is None
164+
assert hero.age is None
165+
166+
167+
def test_model_fields_optional_with_table_base(clear_sqlmodel):
168+
"""Test that model_fields_optional works alongside table models."""
169+
170+
class HeroBase(SQLModel):
171+
name: str
172+
secret_name: str
173+
age: Optional[int] = None
174+
175+
class Hero(HeroBase, table=True):
176+
id: Optional[int] = Field(default=None, primary_key=True)
177+
178+
class HeroUpdate(HeroBase, model_fields_optional="all"):
179+
pass
180+
181+
# Table model should still work normally
182+
hero = Hero(name="Batman", secret_name="Bruce Wayne")
183+
assert hero.name == "Batman"
184+
185+
# Update model should have all optional fields
186+
update = HeroUpdate(name="Dark Knight")
187+
assert update.name == "Dark Knight"
188+
assert update.secret_name is None
189+
190+
191+
def test_model_fields_optional_already_optional_fields(clear_sqlmodel):
192+
"""Test that already-optional fields remain optional and keep their
193+
defaults."""
194+
195+
class HeroBase(SQLModel):
196+
name: str
197+
nickname: Optional[str] = "Unknown"
198+
age: Optional[int] = None
199+
200+
class HeroUpdate(HeroBase, model_fields_optional="all"):
201+
pass
202+
203+
hero = HeroUpdate()
204+
# name was required, should now be None
205+
assert hero.name is None
206+
# nickname had a default of "Unknown", should keep it
207+
assert hero.nickname == "Unknown"
208+
# age had a default of None, should stay None
209+
assert hero.age is None
210+
211+
212+
def test_model_fields_optional_model_validate(clear_sqlmodel):
213+
"""Test that model_validate works correctly with model_fields_optional."""
214+
215+
class HeroBase(SQLModel):
216+
name: str
217+
secret_name: str
218+
age: Optional[int] = None
219+
220+
class HeroUpdate(HeroBase, model_fields_optional="all"):
221+
pass
222+
223+
hero = HeroUpdate.model_validate({"name": "Spider-Man"})
224+
assert hero.name == "Spider-Man"
225+
assert hero.secret_name is None
226+
227+
hero2 = HeroUpdate.model_validate({})
228+
assert hero2.name is None
229+
230+
231+
def test_model_fields_optional_json_schema(clear_sqlmodel):
232+
"""Test that JSON schema reflects optional fields."""
233+
234+
class HeroBase(SQLModel):
235+
name: str
236+
secret_name: str
237+
238+
class HeroUpdate(HeroBase, model_fields_optional="all"):
239+
pass
240+
241+
schema = HeroUpdate.model_json_schema()
242+
# No fields should be required in the schema
243+
assert "required" not in schema or len(schema.get("required", [])) == 0

0 commit comments

Comments
 (0)