Skip to content

Commit 2148a1f

Browse files
committed
Add default values from pydantic / namedtuple / etc
1 parent 114c73a commit 2148a1f

File tree

1 file changed

+141
-29
lines changed

1 file changed

+141
-29
lines changed

singlestoredb/functions/signature.py

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Optional
1717
from typing import Sequence
1818
from typing import Tuple
19+
from typing import Type
1920
from typing import TypeVar
2021
from typing import Union
2122

@@ -27,6 +28,7 @@
2728

2829
try:
2930
import pydantic
31+
import pydantic_core
3032
has_pydantic = True
3133
except ImportError:
3234
has_pydantic = False
@@ -46,6 +48,9 @@ def is_union(x: Any) -> bool:
4648
return typing.get_origin(x) in _UNION_TYPES
4749

4850

51+
NO_DEFAULT = object()
52+
53+
4954
array_types: Tuple[Any, ...]
5055

5156
if has_numpy:
@@ -608,7 +613,10 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str:
608613
return dtypes[0] + ('?' if is_nullable else '')
609614

610615

611-
def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]:
616+
def get_dataclass_schema(
617+
obj: Any,
618+
include_default: bool = False,
619+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
612620
"""
613621
Get the schema of a dataclass.
614622
@@ -619,70 +627,170 @@ def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]:
619627
620628
Returns
621629
-------
622-
List[Tuple[str, Any]]
630+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
623631
A list of tuples containing the field names and field types
624632
625633
"""
626-
return list(get_annotations(obj).items())
634+
if include_default:
635+
return [
636+
(
637+
f.name, f.type,
638+
NO_DEFAULT if f.default is dataclasses.MISSING else f.default,
639+
)
640+
for f in dataclasses.fields(obj)
641+
]
642+
return [(f.name, f.type) for f in dataclasses.fields(obj)]
627643

628644

629-
def get_typeddict_schema(obj: Any) -> List[Tuple[str, Any]]:
645+
def get_typeddict_schema(
646+
obj: Any,
647+
include_default: bool = False,
648+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
630649
"""
631650
Get the schema of a TypedDict.
632651
633652
Parameters
634653
----------
635654
obj : TypedDict
636655
The TypedDict to get the schema of
656+
include_default : bool, optional
657+
Whether to include the default value in the column specification
637658
638659
Returns
639660
-------
640-
List[Tuple[str, Any]]
661+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
641662
A list of tuples containing the field names and field types
642663
643664
"""
665+
if include_default:
666+
return [
667+
(k, v, getattr(obj, 'k', NO_DEFAULT))
668+
for k, v in get_annotations(obj).items()
669+
]
644670
return list(get_annotations(obj).items())
645671

646672

647-
def get_pydantic_schema(obj: pydantic.BaseModel) -> List[Tuple[str, Any]]:
673+
def get_pydantic_schema(
674+
obj: pydantic.BaseModel,
675+
include_default: bool = False,
676+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
648677
"""
649678
Get the schema of a pydantic model.
650679
651680
Parameters
652681
----------
653682
obj : pydantic.BaseModel
654683
The pydantic model to get the schema of
684+
include_default : bool, optional
685+
Whether to include the default value in the column specification
655686
656687
Returns
657688
-------
658-
List[Tuple[str, Any]]
689+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
659690
A list of tuples containing the field names and field types
660691
661692
"""
693+
if include_default:
694+
return [
695+
(
696+
k, v.annotation,
697+
NO_DEFAULT if v.default is pydantic_core.PydanticUndefined else v.default,
698+
)
699+
for k, v in obj.model_fields.items()
700+
]
662701
return [(k, v.annotation) for k, v in obj.model_fields.items()]
663702

664703

665-
def get_namedtuple_schema(obj: Any) -> List[Tuple[Any, str]]:
704+
def get_namedtuple_schema(
705+
obj: Any,
706+
include_default: bool = False,
707+
) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]:
666708
"""
667709
Get the schema of a named tuple.
668710
669711
Parameters
670712
----------
671713
obj : NamedTuple
672714
The named tuple to get the schema of
715+
include_default : bool, optional
716+
Whether to include the default value in the column specification
673717
674718
Returns
675719
-------
676-
List[Tuple[Any, str]]
720+
List[Tuple[Any, str]] | List[Tuple[Any, str, Any]]
677721
A list of tuples containing the field names and field types
678722
679723
"""
724+
if include_default:
725+
return [
726+
(
727+
k,
728+
v,
729+
obj._field_defaults.get(k, NO_DEFAULT),
730+
)
731+
for k, v in get_annotations(obj).items()
732+
]
680733
return list(get_annotations(obj).items())
681734

682735

736+
def get_colspec(
737+
overrides: Any,
738+
include_default: bool = False,
739+
) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]:
740+
"""
741+
Get the column specification from the overrides.
742+
743+
Parameters
744+
----------
745+
overrides : Any
746+
The overrides to get the column specification from
747+
include_default : bool, optional
748+
Whether to include the default value in the column specification
749+
750+
Returns
751+
-------
752+
List[Tuple[str, Any]] | List[Tuple[str, Any, Any]]
753+
A list of tuples containing the field names and field types
754+
755+
"""
756+
overrides_colspec = []
757+
if overrides:
758+
if dataclasses.is_dataclass(overrides):
759+
overrides_colspec = get_dataclass_schema(
760+
overrides, include_default=include_default,
761+
)
762+
elif is_typeddict(overrides):
763+
overrides_colspec = get_typeddict_schema(
764+
overrides, include_default=include_default,
765+
)
766+
elif is_namedtuple(overrides):
767+
overrides_colspec = get_namedtuple_schema(
768+
overrides, include_default=include_default,
769+
)
770+
elif is_pydantic(overrides):
771+
overrides_colspec = get_pydantic_schema(
772+
overrides, include_default=include_default,
773+
)
774+
elif isinstance(overrides, list):
775+
if include_default:
776+
overrides_colspec = [
777+
(getattr(x, 'name', ''), x, NO_DEFAULT) for x in overrides
778+
]
779+
else:
780+
overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides]
781+
else:
782+
if include_default:
783+
overrides_colspec = [
784+
(getattr(overrides, 'name', ''), overrides, NO_DEFAULT),
785+
]
786+
else:
787+
overrides_colspec = [(getattr(overrides, 'name', ''), overrides)]
788+
return overrides_colspec
789+
790+
683791
def get_schema(
684792
spec: Any,
685-
overrides: Optional[List[str]] = None,
793+
overrides: Optional[Union[List[str], Type[Any]]] = None,
686794
function_type: str = 'udf',
687795
mode: str = 'parameter',
688796
) -> Tuple[List[Tuple[str, Any, Optional[str]]], str]:
@@ -748,18 +856,7 @@ def get_schema(
748856
#
749857

750858
# Compute overrides colspec from various formats
751-
overrides_colspec = []
752-
if overrides:
753-
if dataclasses.is_dataclass(overrides):
754-
overrides_colspec = get_dataclass_schema(overrides)
755-
elif is_typeddict(overrides):
756-
overrides_colspec = get_typeddict_schema(overrides)
757-
elif is_namedtuple(overrides):
758-
overrides_colspec = get_namedtuple_schema(overrides)
759-
elif is_pydantic(overrides):
760-
overrides_colspec = get_pydantic_schema(overrides)
761-
else:
762-
overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides]
859+
overrides_colspec = get_colspec(overrides)
763860

764861
# Numpy array types
765862
if is_numpy(spec):
@@ -878,7 +975,7 @@ def get_schema(
878975
# Normalize colspec data types
879976
out = []
880977

881-
for k, v in colspec:
978+
for k, v, *_ in colspec:
882979
out.append((
883980
k,
884981
collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]),
@@ -953,10 +1050,13 @@ def get_signature(
9531050
# Generate the parameter type and the corresponding SQL code for that parameter
9541051
args_schema = []
9551052
args_data_formats = []
956-
for param in signature.parameters.values():
1053+
args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)]
1054+
args_overrides = [x[1] for x in args_colspec]
1055+
args_defaults = [x[2] for x in args_colspec] # type: ignore
1056+
for i, param in enumerate(signature.parameters.values()):
9571057
arg_schema, args_data_format = get_schema(
9581058
param.annotation,
959-
overrides=attrs.get('args', None),
1059+
overrides=args_overrides[i] if args_overrides else [],
9601060
function_type=function_type,
9611061
mode='parameter',
9621062
)
@@ -967,12 +1067,24 @@ def get_signature(
9671067
args_schema.append((param.name, *arg_schema[0][1:]))
9681068

9691069
for i, (name, atype, sql) in enumerate(args_schema):
1070+
# Get default value
1071+
default_option = {}
1072+
if args_defaults:
1073+
if args_defaults[i] is not NO_DEFAULT:
1074+
default_option['default'] = args_defaults[i]
1075+
else:
1076+
if param.default is not param.empty:
1077+
default_option['default'] = param.default
1078+
1079+
# Generate SQL code for the parameter
9701080
sql = sql or dtype_to_sql(
9711081
atype,
9721082
function_type=function_type,
973-
default=param.default if param.default is not param.empty else None,
1083+
**default_option,
9741084
)
975-
args.append(dict(name=name, dtype=atype, sql=sql))
1085+
1086+
# Add parameter to args definitions
1087+
args.append(dict(name=name, dtype=atype, sql=sql, **default_option))
9761088

9771089
# Check that all the data formats are all the same
9781090
if len(set(args_data_formats)) > 1:
@@ -1059,7 +1171,7 @@ def sql_to_dtype(sql: str) -> str:
10591171

10601172
def dtype_to_sql(
10611173
dtype: str,
1062-
default: Any = None,
1174+
default: Any = NO_DEFAULT,
10631175
field_names: Optional[List[str]] = None,
10641176
function_type: str = 'udf',
10651177
) -> str:
@@ -1092,7 +1204,7 @@ def dtype_to_sql(
10921204
nullable = ''
10931205

10941206
default_clause = ''
1095-
if default is not None:
1207+
if default is not NO_DEFAULT:
10961208
if default is dt.NULL:
10971209
default = None
10981210
default_clause = f' DEFAULT {escape_item(default, "utf8")}'

0 commit comments

Comments
 (0)