Skip to content

Commit 6ca886c

Browse files
kesmit13kkampli-singlestore
authored andcommitted
Add async support
1 parent a01f9ae commit 6ca886c

File tree

2 files changed

+137
-25
lines changed

2 files changed

+137
-25
lines changed

singlestoredb/functions/decorator.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import inspect
34
from typing import Any
@@ -19,6 +20,7 @@
1920
]
2021

2122
ReturnType = ParameterType
23+
UDFType = Callable[..., Any]
2224

2325

2426
def is_valid_type(obj: Any) -> bool:
@@ -100,7 +102,8 @@ def _func(
100102
name: Optional[str] = None,
101103
args: Optional[ParameterType] = None,
102104
returns: Optional[ReturnType] = None,
103-
) -> Callable[..., Any]:
105+
timeout: Optional[int] = None,
106+
) -> UDFType:
104107
"""Generic wrapper for UDF and TVF decorators."""
105108

106109
_singlestoredb_attrs = { # type: ignore
@@ -115,23 +118,33 @@ def _func(
115118
# called later, so the wrapper much be created with the func passed
116119
# in at that time.
117120
if func is None:
118-
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
121+
def decorate(func: UDFType) -> UDFType:
119122

120-
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
121-
return func(*args, **kwargs) # type: ignore
123+
if asyncio.iscoroutinefunction(func):
124+
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
125+
return await func(*args, **kwargs) # type: ignore
126+
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
127+
return functools.wraps(func)(async_wrapper)
122128

123-
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
124-
125-
return functools.wraps(func)(wrapper)
129+
else:
130+
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
131+
return func(*args, **kwargs) # type: ignore
132+
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
133+
return functools.wraps(func)(wrapper)
126134

127135
return decorate
128136

129-
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
130-
return func(*args, **kwargs) # type: ignore
131-
132-
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
137+
if asyncio.iscoroutinefunction(func):
138+
async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType:
139+
return await func(*args, **kwargs) # type: ignore
140+
async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
141+
return functools.wraps(func)(async_wrapper)
133142

134-
return functools.wraps(func)(wrapper)
143+
else:
144+
def wrapper(*args: Any, **kwargs: Any) -> UDFType:
145+
return func(*args, **kwargs) # type: ignore
146+
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
147+
return functools.wraps(func)(wrapper)
135148

136149

137150
def udf(
@@ -140,7 +153,8 @@ def udf(
140153
name: Optional[str] = None,
141154
args: Optional[ParameterType] = None,
142155
returns: Optional[ReturnType] = None,
143-
) -> Callable[..., Any]:
156+
timeout: Optional[int] = None,
157+
) -> UDFType:
144158
"""
145159
Define a user-defined function (UDF).
146160

singlestoredb/functions/ext/asgi.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,24 @@ def build_udf_endpoint(
273273
"""
274274
if returns_data_format in ['scalar', 'list']:
275275

276+
is_async = asyncio.iscoroutinefunction(func)
277+
276278
async def do_func(
277279
row_ids: Sequence[int],
278280
rows: Sequence[Sequence[Any]],
279281
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
280282
'''Call function on given rows of data.'''
281-
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
283+
out = []
284+
for row in rows:
285+
if cancel_event.is_set():
286+
raise asyncio.CancelledError(
287+
'Function call was cancelled',
288+
)
289+
if is_async:
290+
out.append(await func(*row))
291+
else:
292+
out.append(func(*row))
293+
return row_ids, list(zip(out))
282294

283295
return do_func
284296

@@ -307,6 +319,7 @@ def build_vector_udf_endpoint(
307319
"""
308320
masks = get_masked_params(func)
309321
array_cls = get_array_class(returns_data_format)
322+
is_async = asyncio.iscoroutinefunction(func)
310323

311324
async def do_func(
312325
row_ids: Sequence[int],
@@ -320,9 +333,15 @@ async def do_func(
320333

321334
# Call the function with `cols` as the function parameters
322335
if cols and cols[0]:
323-
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
336+
if is_async:
337+
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
338+
else:
339+
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
324340
else:
325-
out = func()
341+
if is_async:
342+
out = await func()
343+
else:
344+
out = func()
326345

327346
# Single masked value
328347
if isinstance(out, Masked):
@@ -360,6 +379,8 @@ def build_tvf_endpoint(
360379
"""
361380
if returns_data_format in ['scalar', 'list']:
362381

382+
is_async = asyncio.iscoroutinefunction(func)
383+
363384
async def do_func(
364385
row_ids: Sequence[int],
365386
rows: Sequence[Sequence[Any]],
@@ -368,7 +389,15 @@ async def do_func(
368389
out_ids: List[int] = []
369390
out = []
370391
# Call function on each row of data
371-
for i, res in zip(row_ids, func_map(func, rows)):
392+
for i, row in zip(row_ids, rows):
393+
if cancel_event.is_set():
394+
raise asyncio.CancelledError(
395+
'Function call was cancelled',
396+
)
397+
if is_async:
398+
res = await func(*row)
399+
else:
400+
res = func(*row)
372401
out.extend(as_list_of_tuples(res))
373402
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
374403
return out_ids, out
@@ -413,13 +442,23 @@ async def do_func(
413442
# each result row, so we just have to use the same
414443
# row ID for all rows in the result.
415444

445+
is_async = asyncio.iscoroutinefunction(func)
446+
416447
# Call function on each column of data
417448
if cols and cols[0]:
418-
res = get_dataframe_columns(
419-
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
420-
)
449+
if is_async:
450+
res = get_dataframe_columns(
451+
await func(*[x if m else x[0] for x, m in zip(cols, masks)]),
452+
)
453+
else:
454+
res = get_dataframe_columns(
455+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
456+
)
421457
else:
422-
res = get_dataframe_columns(func())
458+
if is_async:
459+
res = get_dataframe_columns(await func())
460+
else:
461+
res = get_dataframe_columns(func())
423462

424463
# Generate row IDs
425464
if isinstance(res[0], Masked):
@@ -477,6 +516,12 @@ def make_func(
477516
# Set function type
478517
info['function_type'] = function_type
479518

519+
# Set timeout
520+
info['timeout'] = max(timeout, 1)
521+
522+
# Set async flag
523+
info['is_async'] = asyncio.iscoroutinefunction(func)
524+
480525
# Setup argument types for rowdat_1 parser
481526
colspec = []
482527
for x in sig['args']:
@@ -859,11 +904,64 @@ async def __call__(
859904
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
860905

861906
try:
862-
out = await func(
863-
*input_handler['load']( # type: ignore
864-
func_info['colspec'], b''.join(data),
865-
),
907+
result = []
908+
909+
cancel_event = threading.Event()
910+
911+
if func_info['is_async']:
912+
func_task = asyncio.create_task(
913+
func(
914+
cancel_event,
915+
*input_handler['load']( # type: ignore
916+
func_info['colspec'], b''.join(data),
917+
),
918+
),
919+
)
920+
else:
921+
func_task = asyncio.create_task(
922+
to_thread(
923+
lambda: asyncio.run(
924+
func(
925+
cancel_event,
926+
*input_handler['load']( # type: ignore
927+
func_info['colspec'], b''.join(data),
928+
),
929+
),
930+
),
931+
),
932+
)
933+
disconnect_task = asyncio.create_task(
934+
cancel_on_disconnect(receive),
935+
)
936+
timeout_task = asyncio.create_task(
937+
cancel_on_timeout(func_info['timeout']),
866938
)
939+
940+
all_tasks = [func_task, disconnect_task, timeout_task]
941+
942+
done, pending = await asyncio.wait(
943+
all_tasks, return_when=asyncio.FIRST_COMPLETED,
944+
)
945+
946+
cancel_all_tasks(pending)
947+
948+
for task in done:
949+
if task is disconnect_task:
950+
cancel_event.set()
951+
raise asyncio.CancelledError(
952+
'Function call was cancelled by client disconnect',
953+
)
954+
955+
elif task is timeout_task:
956+
cancel_event.set()
957+
raise asyncio.TimeoutError(
958+
'Function call was cancelled due to timeout',
959+
)
960+
961+
elif task is func_task:
962+
result.extend(task.result())
963+
964+
print(result)
867965
body = output_handler['dump'](
868966
[x[1] for x in func_info['returns']], *out, # type: ignore
869967
)

0 commit comments

Comments
 (0)