1616from typing import Optional
1717from typing import Sequence
1818from typing import Tuple
19+ from typing import Type
1920from typing import TypeVar
2021from typing import Union
2122
2728
2829try :
2930 import pydantic
31+ import pydantic_core
3032 has_pydantic = True
3133except ImportError :
3234 has_pydantic = False
@@ -46,6 +48,9 @@ def is_union(x: Any) -> bool:
4648 return typing .get_origin (x ) in _UNION_TYPES
4749
4850
51+ NO_DEFAULT = object ()
52+
53+
4954array_types : Tuple [Any , ...]
5055
5156if has_numpy :
@@ -608,7 +613,10 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
608613 return dtypes [0 ] + ('?' if is_nullable else '' )
609614
610615
611- def get_dataclass_schema (obj : Any ) -> List [Tuple [str , Any ]]:
616+ def get_dataclass_schema (
617+ obj : Any ,
618+ include_default : bool = False ,
619+ ) -> List [Union [Tuple [str , Any ], Tuple [str , Any , Any ]]]:
612620 """
613621 Get the schema of a dataclass.
614622
@@ -619,70 +627,170 @@ def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]:
619627
620628 Returns
621629 -------
622- List[Tuple[str, Any]]
630+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
623631 A list of tuples containing the field names and field types
624632
625633 """
626- return list (get_annotations (obj ).items ())
634+ if include_default :
635+ return [
636+ (
637+ f .name , f .type ,
638+ NO_DEFAULT if f .default is dataclasses .MISSING else f .default ,
639+ )
640+ for f in dataclasses .fields (obj )
641+ ]
642+ return [(f .name , f .type ) for f in dataclasses .fields (obj )]
627643
628644
629- def get_typeddict_schema (obj : Any ) -> List [Tuple [str , Any ]]:
645+ def get_typeddict_schema (
646+ obj : Any ,
647+ include_default : bool = False ,
648+ ) -> List [Union [Tuple [str , Any ], Tuple [str , Any , Any ]]]:
630649 """
631650 Get the schema of a TypedDict.
632651
633652 Parameters
634653 ----------
635654 obj : TypedDict
636655 The TypedDict to get the schema of
656+ include_default : bool, optional
657+ Whether to include the default value in the column specification
637658
638659 Returns
639660 -------
640- List[Tuple[str, Any]]
661+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
641662 A list of tuples containing the field names and field types
642663
643664 """
665+ if include_default :
666+ return [
667+ (k , v , getattr (obj , 'k' , NO_DEFAULT ))
668+ for k , v in get_annotations (obj ).items ()
669+ ]
644670 return list (get_annotations (obj ).items ())
645671
646672
647- def get_pydantic_schema (obj : pydantic .BaseModel ) -> List [Tuple [str , Any ]]:
673+ def get_pydantic_schema (
674+ obj : pydantic .BaseModel ,
675+ include_default : bool = False ,
676+ ) -> List [Union [Tuple [str , Any ], Tuple [str , Any , Any ]]]:
648677 """
649678 Get the schema of a pydantic model.
650679
651680 Parameters
652681 ----------
653682 obj : pydantic.BaseModel
654683 The pydantic model to get the schema of
684+ include_default : bool, optional
685+ Whether to include the default value in the column specification
655686
656687 Returns
657688 -------
658- List[Tuple[str, Any]]
689+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
659690 A list of tuples containing the field names and field types
660691
661692 """
693+ if include_default :
694+ return [
695+ (
696+ k , v .annotation ,
697+ NO_DEFAULT if v .default is pydantic_core .PydanticUndefined else v .default ,
698+ )
699+ for k , v in obj .model_fields .items ()
700+ ]
662701 return [(k , v .annotation ) for k , v in obj .model_fields .items ()]
663702
664703
665- def get_namedtuple_schema (obj : Any ) -> List [Tuple [Any , str ]]:
704+ def get_namedtuple_schema (
705+ obj : Any ,
706+ include_default : bool = False ,
707+ ) -> List [Union [Tuple [Any , str ], Tuple [Any , str , Any ]]]:
666708 """
667709 Get the schema of a named tuple.
668710
669711 Parameters
670712 ----------
671713 obj : NamedTuple
672714 The named tuple to get the schema of
715+ include_default : bool, optional
716+ Whether to include the default value in the column specification
673717
674718 Returns
675719 -------
676- List[Tuple[Any, str]]
720+ List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
677721 A list of tuples containing the field names and field types
678722
679723 """
724+ if include_default :
725+ return [
726+ (
727+ k ,
728+ v ,
729+ obj ._field_defaults .get (k , NO_DEFAULT ),
730+ )
731+ for k , v in get_annotations (obj ).items ()
732+ ]
680733 return list (get_annotations (obj ).items ())
681734
682735
736+ def get_colspec (
737+ overrides : Any ,
738+ include_default : bool = False ,
739+ ) -> List [Union [Tuple [str , Any ], Tuple [str , Any , Any ]]]:
740+ """
741+ Get the column specification from the overrides.
742+
743+ Parameters
744+ ----------
745+ overrides : Any
746+ The overrides to get the column specification from
747+ include_default : bool, optional
748+ Whether to include the default value in the column specification
749+
750+ Returns
751+ -------
752+ List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
753+ A list of tuples containing the field names and field types
754+
755+ """
756+ overrides_colspec = []
757+ if overrides :
758+ if dataclasses .is_dataclass (overrides ):
759+ overrides_colspec = get_dataclass_schema (
760+ overrides , include_default = include_default ,
761+ )
762+ elif is_typeddict (overrides ):
763+ overrides_colspec = get_typeddict_schema (
764+ overrides , include_default = include_default ,
765+ )
766+ elif is_namedtuple (overrides ):
767+ overrides_colspec = get_namedtuple_schema (
768+ overrides , include_default = include_default ,
769+ )
770+ elif is_pydantic (overrides ):
771+ overrides_colspec = get_pydantic_schema (
772+ overrides , include_default = include_default ,
773+ )
774+ elif isinstance (overrides , list ):
775+ if include_default :
776+ overrides_colspec = [
777+ (getattr (x , 'name' , '' ), x , NO_DEFAULT ) for x in overrides
778+ ]
779+ else :
780+ overrides_colspec = [(getattr (x , 'name' , '' ), x ) for x in overrides ]
781+ else :
782+ if include_default :
783+ overrides_colspec = [
784+ (getattr (overrides , 'name' , '' ), overrides , NO_DEFAULT ),
785+ ]
786+ else :
787+ overrides_colspec = [(getattr (overrides , 'name' , '' ), overrides )]
788+ return overrides_colspec
789+
790+
683791def get_schema (
684792 spec : Any ,
685- overrides : Optional [List [str ]] = None ,
793+ overrides : Optional [Union [ List [str ], Type [ Any ] ]] = None ,
686794 function_type : str = 'udf' ,
687795 mode : str = 'parameter' ,
688796) -> Tuple [List [Tuple [str , Any , Optional [str ]]], str ]:
@@ -748,18 +856,7 @@ def get_schema(
748856 #
749857
750858 # Compute overrides colspec from various formats
751- overrides_colspec = []
752- if overrides :
753- if dataclasses .is_dataclass (overrides ):
754- overrides_colspec = get_dataclass_schema (overrides )
755- elif is_typeddict (overrides ):
756- overrides_colspec = get_typeddict_schema (overrides )
757- elif is_namedtuple (overrides ):
758- overrides_colspec = get_namedtuple_schema (overrides )
759- elif is_pydantic (overrides ):
760- overrides_colspec = get_pydantic_schema (overrides )
761- else :
762- overrides_colspec = [(getattr (x , 'name' , '' ), x ) for x in overrides ]
859+ overrides_colspec = get_colspec (overrides )
763860
764861 # Numpy array types
765862 if is_numpy (spec ):
@@ -878,7 +975,7 @@ def get_schema(
878975 # Normalize colspec data types
879976 out = []
880977
881- for k , v in colspec :
978+ for k , v , * _ in colspec :
882979 out .append ((
883980 k ,
884981 collapse_dtypes ([normalize_dtype (x ) for x in simplify_dtype (v )]),
@@ -953,10 +1050,13 @@ def get_signature(
9531050 # Generate the parameter type and the corresponding SQL code for that parameter
9541051 args_schema = []
9551052 args_data_formats = []
956- for param in signature .parameters .values ():
1053+ args_colspec = [x for x in get_colspec (attrs .get ('args' , []), include_default = True )]
1054+ args_overrides = [x [1 ] for x in args_colspec ]
1055+ args_defaults = [x [2 ] for x in args_colspec ] # type: ignore
1056+ for i , param in enumerate (signature .parameters .values ()):
9571057 arg_schema , args_data_format = get_schema (
9581058 param .annotation ,
959- overrides = attrs . get ( 'args' , None ) ,
1059+ overrides = args_overrides [ i ] if args_overrides else [] ,
9601060 function_type = function_type ,
9611061 mode = 'parameter' ,
9621062 )
@@ -967,12 +1067,24 @@ def get_signature(
9671067 args_schema .append ((param .name , * arg_schema [0 ][1 :]))
9681068
9691069 for i , (name , atype , sql ) in enumerate (args_schema ):
1070+ # Get default value
1071+ default_option = {}
1072+ if args_defaults :
1073+ if args_defaults [i ] is not NO_DEFAULT :
1074+ default_option ['default' ] = args_defaults [i ]
1075+ else :
1076+ if param .default is not param .empty :
1077+ default_option ['default' ] = param .default
1078+
1079+ # Generate SQL code for the parameter
9701080 sql = sql or dtype_to_sql (
9711081 atype ,
9721082 function_type = function_type ,
973- default = param . default if param . default is not param . empty else None ,
1083+ ** default_option ,
9741084 )
975- args .append (dict (name = name , dtype = atype , sql = sql ))
1085+
1086+ # Add parameter to args definitions
1087+ args .append (dict (name = name , dtype = atype , sql = sql , ** default_option ))
9761088
9771089 # Check that all the data formats are all the same
9781090 if len (set (args_data_formats )) > 1 :
@@ -1059,7 +1171,7 @@ def sql_to_dtype(sql: str) -> str:
10591171
10601172def dtype_to_sql (
10611173 dtype : str ,
1062- default : Any = None ,
1174+ default : Any = NO_DEFAULT ,
10631175 field_names : Optional [List [str ]] = None ,
10641176 function_type : str = 'udf' ,
10651177) -> str :
@@ -1092,7 +1204,7 @@ def dtype_to_sql(
10921204 nullable = ''
10931205
10941206 default_clause = ''
1095- if default is not None :
1207+ if default is not NO_DEFAULT :
10961208 if default is dt .NULL :
10971209 default = None
10981210 default_clause = f' DEFAULT { escape_item (default , "utf8" )} '
0 commit comments