55import uuid
66import weakref
77from collections .abc import Mapping , Sequence , Set
8+ from dataclasses import dataclass
89from datetime import date , datetime , time , timedelta
910from decimal import Decimal
1011from enum import Enum
@@ -200,6 +201,38 @@ def __init__(
200201 self .sa_relationship_kwargs = sa_relationship_kwargs
201202
202203
204+ @dataclass
205+ class FieldInfoMetadata :
206+ primary_key : Union [bool , UndefinedType ] = Undefined
207+ nullable : Union [bool , UndefinedType ] = Undefined
208+ foreign_key : Any = Undefined
209+ ondelete : Union [OnDeleteType , UndefinedType ] = Undefined
210+ unique : Union [bool , UndefinedType ] = Undefined
211+ index : Union [bool , UndefinedType ] = Undefined
212+ sa_type : Union [type [Any ], UndefinedType ] = Undefined
213+ sa_column : Union [Column [Any ], UndefinedType ] = Undefined
214+ sa_column_args : Union [Sequence [Any ], UndefinedType ] = Undefined
215+ sa_column_kwargs : Union [Mapping [str , Any ], UndefinedType ] = Undefined
216+
217+
218+ def _get_sqlmodel_field_metadata (field_info : Any ) -> Optional [FieldInfoMetadata ]:
219+ metadata_items = getattr (field_info , "metadata" , None )
220+ if metadata_items :
221+ for meta in metadata_items :
222+ if isinstance (meta , FieldInfoMetadata ):
223+ return meta
224+ return None
225+
226+
227+ def _get_sqlmodel_field_value (
228+ field_info : Any , attribute : str , default : Any = Undefined
229+ ) -> Any :
230+ metadata = _get_sqlmodel_field_metadata (field_info )
231+ if metadata is not None and hasattr (metadata , attribute ):
232+ return getattr (metadata , attribute )
233+ return getattr (field_info , attribute , default )
234+
235+
203236# include sa_type, sa_column_args, sa_column_kwargs
204237@overload
205238def Field (
@@ -423,6 +456,20 @@ def Field(
423456 default_factory = default_factory ,
424457 ** field_info_kwargs ,
425458 )
459+ field_metadata = FieldInfoMetadata (
460+ primary_key = primary_key ,
461+ nullable = nullable ,
462+ foreign_key = foreign_key ,
463+ ondelete = ondelete ,
464+ unique = unique ,
465+ index = index ,
466+ sa_type = sa_type ,
467+ sa_column = sa_column ,
468+ sa_column_args = sa_column_args ,
469+ sa_column_kwargs = sa_column_kwargs ,
470+ )
471+ if hasattr (field_info , "metadata" ):
472+ field_info .metadata .append (field_metadata )
426473 return field_info
427474
428475
@@ -637,7 +684,7 @@ def __init__(
637684
638685def get_sqlalchemy_type (field : Any ) -> Any :
639686 field_info = field
640- sa_type = getattr (field_info , "sa_type" , Undefined ) # noqa: B009
687+ sa_type = _get_sqlmodel_field_value (field_info , "sa_type" , Undefined ) # noqa: B009
641688 if sa_type is not Undefined :
642689 return sa_type
643690
@@ -691,39 +738,39 @@ def get_sqlalchemy_type(field: Any) -> Any:
691738
692739def get_column_from_field (field : Any ) -> Column : # type: ignore
693740 field_info = field
694- sa_column = getattr (field_info , "sa_column" , Undefined )
741+ sa_column = _get_sqlmodel_field_value (field_info , "sa_column" , Undefined )
695742 if isinstance (sa_column , Column ):
696743 return sa_column
697744 sa_type = get_sqlalchemy_type (field )
698- primary_key = getattr (field_info , "primary_key" , Undefined )
745+ primary_key = _get_sqlmodel_field_value (field_info , "primary_key" , Undefined )
699746 if primary_key is Undefined :
700747 primary_key = False
701- index = getattr (field_info , "index" , Undefined )
748+ index = _get_sqlmodel_field_value (field_info , "index" , Undefined )
702749 if index is Undefined :
703750 index = False
704751 nullable = not primary_key and is_field_noneable (field )
705752 # Override derived nullability if the nullable property is set explicitly
706753 # on the field
707- field_nullable = getattr (field_info , "nullable" , Undefined ) # noqa: B009
754+ field_nullable = _get_sqlmodel_field_value (field_info , "nullable" , Undefined )
708755 if field_nullable is not Undefined :
709756 assert not isinstance (field_nullable , UndefinedType )
710757 nullable = field_nullable
711758 args = []
712- foreign_key = getattr (field_info , "foreign_key" , Undefined )
759+ foreign_key = _get_sqlmodel_field_value (field_info , "foreign_key" , Undefined )
713760 if foreign_key is Undefined :
714761 foreign_key = None
715- unique = getattr (field_info , "unique" , Undefined )
762+ unique = _get_sqlmodel_field_value (field_info , "unique" , Undefined )
716763 if unique is Undefined :
717764 unique = False
718765 if foreign_key :
719- if field_info .ondelete == "SET NULL" and not nullable :
766+ ondelete_value = _get_sqlmodel_field_value (field_info , "ondelete" , Undefined )
767+ if ondelete_value is Undefined :
768+ ondelete_value = None
769+ if ondelete_value == "SET NULL" and not nullable :
720770 raise RuntimeError ('ondelete="SET NULL" requires nullable=True' )
721771 assert isinstance (foreign_key , str )
722- ondelete = getattr (field_info , "ondelete" , Undefined )
723- if ondelete is Undefined :
724- ondelete = None
725- assert isinstance (ondelete , (str , type (None ))) # for typing
726- args .append (ForeignKey (foreign_key , ondelete = ondelete ))
772+ assert isinstance (ondelete_value , (str , type (None ))) # for typing
773+ args .append (ForeignKey (foreign_key , ondelete = ondelete_value ))
727774 kwargs = {
728775 "primary_key" : primary_key ,
729776 "nullable" : nullable ,
@@ -737,10 +784,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
737784 sa_default = field_info .default
738785 if sa_default is not Undefined :
739786 kwargs ["default" ] = sa_default
740- sa_column_args = getattr (field_info , "sa_column_args" , Undefined )
787+ sa_column_args = _get_sqlmodel_field_value (field_info , "sa_column_args" , Undefined )
741788 if sa_column_args is not Undefined :
742789 args .extend (list (cast (Sequence [Any ], sa_column_args )))
743- sa_column_kwargs = getattr (field_info , "sa_column_kwargs" , Undefined )
790+ sa_column_kwargs = _get_sqlmodel_field_value (
791+ field_info , "sa_column_kwargs" , Undefined
792+ )
744793 if sa_column_kwargs is not Undefined :
745794 kwargs .update (cast (dict [Any , Any ], sa_column_kwargs ))
746795 return Column (sa_type , * args , ** kwargs ) # type: ignore
0 commit comments