-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathdecorator.py
More file actions
209 lines (171 loc) · 6.65 KB
/
Copy pathdecorator.py
File metadata and controls
209 lines (171 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import asyncio
import functools
import inspect
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Type
from typing import Union
from . import utils
from .dtypes import SQLString
ParameterType = Union[
str,
Callable[..., SQLString],
List[Union[str, Callable[..., SQLString]]],
Type[Any],
]
ReturnType = ParameterType
UDFType = Callable[..., Any]
def is_valid_type(obj: Any) -> bool:
"""Check if the object is a valid type for a schema definition."""
if not inspect.isclass(obj):
return False
if utils.is_typeddict(obj):
return True
if utils.is_namedtuple(obj):
return True
if utils.is_dataclass(obj):
return True
# We don't want to import pydantic here, so we check if
# the class is a subclass
if utils.is_pydantic(obj):
return True
return False
def is_sqlstr_callable(obj: Any) -> bool:
"""Check if the object is a valid callable for a parameter type."""
if not callable(obj):
return False
returns = utils.get_annotations(obj).get('return', None)
if inspect.isclass(returns) and issubclass(returns, SQLString):
return True
return False
def expand_types(args: Any) -> Optional[List[Any]]:
"""Expand the types for the function arguments / return values."""
if args is None:
return None
# SQL string
if isinstance(args, str):
return [args]
# List of SQL strings or callables
elif isinstance(args, list):
new_args: List[Any] = []
for arg in args:
if isinstance(arg, str):
new_args.append(arg)
elif is_sqlstr_callable(arg):
new_args.append(arg())
elif type(arg) is type:
new_args.append(arg)
elif is_valid_type(arg):
new_args.append(arg)
else:
raise TypeError(f'unrecognized type for parameter: {arg}')
return new_args
# Callable that returns a SQL string
elif is_sqlstr_callable(args):
return [args()]
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
elif is_valid_type(args):
return [args]
elif type(args) is type:
return [args]
raise TypeError(f'unrecognized type for parameter: {args}')
def _func(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[ParameterType] = None,
returns: Optional[ReturnType] = None,
timeout: Optional[int] = None,
concurrency_limit: Optional[int] = None,
) -> UDFType:
"""Generic wrapper for UDF and TVF decorators."""
_singlestoredb_attrs = { # type: ignore
k: v for k, v in dict(
name=name,
args=expand_types(args),
returns=expand_types(returns),
timeout=timeout,
concurrency_limit=concurrency_limit,
).items() if v is not None
}
# No func was specified, this is an uncalled decorator that will get
# called later, so the wrapper much be created with the func passed
# in at that time.
if func is None:
def decorate(func: UDFType) -> UDFType:
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
return await func(*args, **kwargs) # type: ignore
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
return functools.wraps(func)(async_wrapper)
else:
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
return func(*args, **kwargs) # type: ignore
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
return functools.wraps(func)(wrapper)
return decorate
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
return await func(*args, **kwargs) # type: ignore
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
return functools.wraps(func)(async_wrapper)
else:
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
return func(*args, **kwargs) # type: ignore
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
return functools.wraps(func)(wrapper)
def udf(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[ParameterType] = None,
returns: Optional[ReturnType] = None,
timeout: Optional[int] = None,
concurrency_limit: Optional[int] = None,
) -> UDFType:
"""
Define a user-defined function (UDF).
Parameters
----------
func : callable, optional
The UDF to apply parameters to
name : str, optional
The name to use for the UDF in the database
args : str | Type | Callable | List[str | Callable], optional
Specifies the data types of the function arguments. Typically,
the function data types are derived from the function parameter
annotations. These annotations can be overridden. If the function
takes a single type for all parameters, `args` can be set to a
SQL string describing all parameters. If the function takes more
than one parameter and all of the parameters are being manually
defined, a list of SQL strings may be used (one for each parameter).
A dictionary of SQL strings may be used to specify a parameter type
for a subset of parameters; the keys are the names of the
function parameters. Callables may also be used for datatypes. This
is primarily for using the functions in the ``dtypes`` module that
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
returns : str | Type | Callable | List[str | Callable] | Table, optional
Specifies the return data type of the function. This parameter
works the same way as `args`. If the function is a table-valued
function, the return type should be a `Table` object.
timeout : int, optional
The timeout in seconds for the UDF execution. If not specified,
the global default timeout is used.
concurrency_limit : int, optional
The maximum number of concurrent subsets of rows that will be
processed simultaneously by the UDF. If not specified,
the global default concurrency limit is used.
Returns
-------
Callable
"""
return _func(
func=func,
name=name,
args=args,
returns=returns,
timeout=timeout,
concurrency_limit=concurrency_limit,
)