1- import dataclasses
2- import datetime
1+ from __future__ import annotations
2+
33import functools
44import inspect
55from typing import Any
66from typing import Callable
7- from typing import Dict
87from typing import List
98from typing import Optional
10- from typing import Tuple
9+ from typing import Type
1110from typing import Union
1211
13- from . import dtypes
14- from .dtypes import DataType
15- from .signature import simplify_dtype
16-
17- try :
18- import pydantic
19- has_pydantic = True
20- except ImportError :
21- has_pydantic = False
22-
23- python_type_map : Dict [Any , Callable [..., str ]] = {
24- str : dtypes .TEXT ,
25- int : dtypes .BIGINT ,
26- float : dtypes .DOUBLE ,
27- bool : dtypes .BOOL ,
28- bytes : dtypes .BINARY ,
29- bytearray : dtypes .BINARY ,
30- datetime .datetime : dtypes .DATETIME ,
31- datetime .date : dtypes .DATE ,
32- datetime .timedelta : dtypes .TIME ,
33- }
34-
35-
36- def listify (x : Any ) -> List [Any ]:
37- """Make sure sure value is a list."""
38- if x is None :
39- return []
40- if isinstance (x , (list , tuple , set )):
41- return list (x )
42- return [x ]
43-
44-
45- def process_annotation (annotation : Any ) -> Tuple [Any , bool ]:
46- types = simplify_dtype (annotation )
47- if isinstance (types , list ):
48- nullable = False
49- if type (None ) in types :
50- nullable = True
51- types = [x for x in types if x is not type (None )]
52- if len (types ) > 1 :
53- raise ValueError (f'multiple types not supported: { annotation } ' )
54- return types [0 ], nullable
55- return types , True
5612
13+ ParameterType = Union [
14+ str ,
15+ Callable [..., str ],
16+ List [Union [str , Callable [..., str ]]],
17+ Type [Any ],
18+ ]
5719
58- def process_types (params : Any ) -> Any :
59- if params is None :
60- return params , []
61-
62- elif isinstance (params , (list , tuple )):
63- params = list (params )
64- for i , item in enumerate (params ):
65- if params [i ] in python_type_map :
66- params [i ] = python_type_map [params [i ]]()
67- elif callable (item ):
68- params [i ] = item ()
69- for item in params :
70- if not isinstance (item , str ):
71- raise TypeError (f'unrecognized type for parameter: { item } ' )
72- return params , []
73-
74- elif isinstance (params , dict ):
75- names = []
76- params = dict (params )
77- for k , v in list (params .items ()):
78- names .append (k )
79- if params [k ] in python_type_map :
80- params [k ] = python_type_map [params [k ]]()
81- elif callable (v ):
82- params [k ] = v ()
83- for item in params .values ():
84- if not isinstance (item , str ):
85- raise TypeError (f'unrecognized type for parameter: { item } ' )
86- return params , names
87-
88- elif dataclasses .is_dataclass (params ):
89- names = []
90- out = []
91- for item in dataclasses .fields (params ):
92- typ , nullable = process_annotation (item .type )
93- sql_type = process_types (typ )[0 ]
94- if not nullable :
95- sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
96- out .append (sql_type )
97- names .append (item .name )
98- return out , names
99-
100- elif has_pydantic and inspect .isclass (params ) \
101- and issubclass (params , pydantic .BaseModel ):
102- names = []
103- out = []
104- for name , item in params .model_fields .items ():
105- typ , nullable = process_annotation (item .annotation )
106- sql_type = process_types (typ )[0 ]
107- if not nullable :
108- sql_type = sql_type .replace ('NULL' , 'NOT NULL' )
109- out .append (sql_type )
110- names .append (name )
111- return out , names
112-
113- elif params in python_type_map :
114- return python_type_map [params ](), []
20+ ReturnType = ParameterType
11521
116- elif callable (params ):
117- return params (), []
11822
119- elif isinstance (params , str ):
120- return params , []
23+ def expand_types (args : Any ) -> Optional [Union [List [str ], Type [Any ]]]:
24+ """Expand the types for the function arguments / return values."""
25+ if args is None :
26+ return None
12127
122- raise TypeError (f'unrecognized data type for args: { params } ' )
28+ # SQL string
29+ if isinstance (args , str ):
30+ return [args ]
12331
32+ # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
33+ elif inspect .isclass (args ):
34+ return args
12435
125- ParameterType = Union [
126- str ,
127- List [str ],
128- Dict [str , str ],
129- 'pydantic.BaseModel' ,
130- type ,
131- ]
36+ # Callable that returns a SQL string
37+ elif callable (args ):
38+ out = args ()
39+ if not isinstance (out , str ):
40+ raise TypeError (f'unrecognized type for parameter: { args } ' )
41+ return [out ]
13242
133- ReturnType = Union [
134- str ,
135- List [DataType ],
136- List [type ],
137- 'pydantic.BaseModel' ,
138- type ,
139- ]
43+ # List of SQL strings or callables
44+ else :
45+ new_args = []
46+ for arg in args :
47+ if isinstance (arg , str ):
48+ new_args .append (arg )
49+ elif callable (arg ):
50+ new_args .append (arg ())
51+ else :
52+ raise TypeError (f'unrecognized type for parameter: { arg } ' )
53+ return new_args
14054
14155
14256def _func (
@@ -145,40 +59,18 @@ def _func(
14559 name : Optional [str ] = None ,
14660 args : Optional [ParameterType ] = None ,
14761 returns : Optional [ReturnType ] = None ,
148- data_format : Optional [str ] = None ,
14962 include_masks : bool = False ,
15063 function_type : str = 'udf' ,
151- output_fields : Optional [List [str ]] = None ,
15264) -> Callable [..., Any ]:
15365 """Generic wrapper for UDF and TVF decorators."""
154- args , _ = process_types (args )
155- returns , fields = process_types (returns )
156-
157- if not output_fields and fields :
158- output_fields = fields
159-
160- if isinstance (returns , list ) \
161- and isinstance (output_fields , list ) \
162- and len (output_fields ) != len (returns ):
163- raise ValueError (
164- 'The number of output fields must match the number of return types' ,
165- )
166-
167- if include_masks and data_format == 'python' :
168- raise RuntimeError (
169- 'include_masks is only valid when using '
170- 'vectors for input parameters' ,
171- )
17266
17367 _singlestoredb_attrs = { # type: ignore
17468 k : v for k , v in dict (
17569 name = name ,
176- args = args ,
177- returns = returns ,
178- data_format = data_format ,
70+ args = expand_types (args ),
71+ returns = expand_types (returns ),
17972 include_masks = include_masks ,
18073 function_type = function_type ,
181- output_fields = output_fields or None ,
18274 ).items () if v is not None
18375 }
18476
@@ -207,7 +99,6 @@ def udf(
20799 name : Optional [str ] = None ,
208100 args : Optional [ParameterType ] = None ,
209101 returns : Optional [ReturnType ] = None ,
210- data_format : Optional [str ] = None ,
211102 include_masks : bool = False ,
212103) -> Callable [..., Any ]:
213104 """
@@ -219,7 +110,7 @@ def udf(
219110 The UDF to apply parameters to
220111 name : str, optional
221112 The name to use for the UDF in the database
222- args : str | Callable | List[str | Callable] | Dict[str, str | Callable] , optional
113+ args : str | Callable | List[str | Callable], optional
223114 Specifies the data types of the function arguments. Typically,
224115 the function data types are derived from the function parameter
225116 annotations. These annotations can be overridden. If the function
@@ -235,8 +126,6 @@ def udf(
235126 returns : str, optional
236127 Specifies the return data type of the function. If not specified,
237128 the type annotation from the function is used.
238- data_format : str, optional
239- The data format of each parameter: python, pandas, arrow, polars
240129 include_masks : bool, optional
241130 Should boolean masks be included with each input parameter to indicate
242131 which elements are NULL? This is only used when a input parameters are
@@ -252,27 +141,18 @@ def udf(
252141 name = name ,
253142 args = args ,
254143 returns = returns ,
255- data_format = data_format ,
256144 include_masks = include_masks ,
257145 function_type = 'udf' ,
258146 )
259147
260148
261- udf .pandas = functools .partial (udf , data_format = 'pandas' ) # type: ignore
262- udf .polars = functools .partial (udf , data_format = 'polars' ) # type: ignore
263- udf .arrow = functools .partial (udf , data_format = 'arrow' ) # type: ignore
264- udf .numpy = functools .partial (udf , data_format = 'numpy' ) # type: ignore
265-
266-
267149def tvf (
268150 func : Optional [Callable [..., Any ]] = None ,
269151 * ,
270152 name : Optional [str ] = None ,
271153 args : Optional [ParameterType ] = None ,
272154 returns : Optional [ReturnType ] = None ,
273- data_format : Optional [str ] = None ,
274155 include_masks : bool = False ,
275- output_fields : Optional [List [str ]] = None ,
276156) -> Callable [..., Any ]:
277157 """
278158 Apply attributes to a TVF.
@@ -283,7 +163,7 @@ def tvf(
283163 The TVF to apply parameters to
284164 name : str, optional
285165 The name to use for the TVF in the database
286- args : str | Callable | List[str | Callable] | Dict[str, str | Callable] , optional
166+ args : str | Callable | List[str | Callable], optional
287167 Specifies the data types of the function arguments. Typically,
288168 the function data types are derived from the function parameter
289169 annotations. These annotations can be overridden. If the function
@@ -299,15 +179,10 @@ def tvf(
299179 returns : str, optional
300180 Specifies the return data type of the function. If not specified,
301181 the type annotation from the function is used.
302- data_format : str, optional
303- The data format of each parameter: python, pandas, arrow, polars
304182 include_masks : bool, optional
305183 Should boolean masks be included with each input parameter to indicate
306184 which elements are NULL? This is only used when a input parameters are
307185 configured to a vector type (numpy, pandas, polars, arrow).
308- output_fields : List[str], optional
309- The names of the output fields for the TVF. If not specified, the
310- names are generated.
311186
312187 Returns
313188 -------
@@ -319,14 +194,6 @@ def tvf(
319194 name = name ,
320195 args = args ,
321196 returns = returns ,
322- data_format = data_format ,
323197 include_masks = include_masks ,
324198 function_type = 'tvf' ,
325- output_fields = output_fields ,
326199 )
327-
328-
329- tvf .pandas = functools .partial (tvf , data_format = 'pandas' ) # type: ignore
330- tvf .polars = functools .partial (tvf , data_format = 'polars' ) # type: ignore
331- tvf .arrow = functools .partial (tvf , data_format = 'arrow' ) # type: ignore
332- tvf .numpy = functools .partial (tvf , data_format = 'numpy' ) # type: ignore
0 commit comments