Skip to content

Commit d4e221a

Browse files
interactive
1 parent cccd5f8 commit d4e221a

File tree

1 file changed

+25
-103
lines changed

1 file changed

+25
-103
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 25 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
from ..signature import signature_to_sql
6767
from ..typing import Masked
6868
from ..typing import Table
69+
from ...config import get_option
70+
6971

7072
try:
7173
import cloudpickle
@@ -273,24 +275,12 @@ def build_udf_endpoint(
273275
"""
274276
if returns_data_format in ['scalar', 'list']:
275277

276-
is_async = asyncio.iscoroutinefunction(func)
277-
278278
async def do_func(
279279
row_ids: Sequence[int],
280280
rows: Sequence[Sequence[Any]],
281281
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
282282
'''Call function on given rows of data.'''
283-
out = []
284-
for row in rows:
285-
if cancel_event.is_set():
286-
raise asyncio.CancelledError(
287-
'Function call was cancelled',
288-
)
289-
if is_async:
290-
out.append(await func(*row))
291-
else:
292-
out.append(func(*row))
293-
return row_ids, list(zip(out))
283+
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
294284

295285
return do_func
296286

@@ -319,7 +309,6 @@ def build_vector_udf_endpoint(
319309
"""
320310
masks = get_masked_params(func)
321311
array_cls = get_array_class(returns_data_format)
322-
is_async = asyncio.iscoroutinefunction(func)
323312

324313
async def do_func(
325314
row_ids: Sequence[int],
@@ -332,17 +321,18 @@ async def do_func(
332321
row_ids = array_cls(row_ids)
333322

334323
# Call the function with `cols` as the function parameters
324+
is_async = inspect.iscoroutinefunction(func) or inspect.iscoroutinefunction(getattr(func, "__wrapped__", None))
335325
if cols and cols[0]:
336326
if is_async:
337327
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
338328
else:
339-
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
329+
out = await asyncio.to_thread(func, *[x if m else x[0] for x, m in zip(cols, masks)])
340330
else:
341331
if is_async:
342332
out = await func()
343333
else:
344-
out = func()
345-
334+
out = await asyncio.to_thread(func())
335+
346336
# Single masked value
347337
if isinstance(out, Masked):
348338
return row_ids, [tuple(out)]
@@ -379,8 +369,6 @@ def build_tvf_endpoint(
379369
"""
380370
if returns_data_format in ['scalar', 'list']:
381371

382-
is_async = asyncio.iscoroutinefunction(func)
383-
384372
async def do_func(
385373
row_ids: Sequence[int],
386374
rows: Sequence[Sequence[Any]],
@@ -389,15 +377,7 @@ async def do_func(
389377
out_ids: List[int] = []
390378
out = []
391379
# Call function on each row of data
392-
for i, row in zip(row_ids, rows):
393-
if cancel_event.is_set():
394-
raise asyncio.CancelledError(
395-
'Function call was cancelled',
396-
)
397-
if is_async:
398-
res = await func(*row)
399-
else:
400-
res = func(*row)
380+
for i, res in zip(row_ids, func_map(func, rows)):
401381
out.extend(as_list_of_tuples(res))
402382
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
403383
return out_ids, out
@@ -442,23 +422,13 @@ async def do_func(
442422
# each result row, so we just have to use the same
443423
# row ID for all rows in the result.
444424

445-
is_async = asyncio.iscoroutinefunction(func)
446-
447425
# Call function on each column of data
448426
if cols and cols[0]:
449-
if is_async:
450-
res = get_dataframe_columns(
451-
await func(*[x if m else x[0] for x, m in zip(cols, masks)]),
452-
)
453-
else:
454-
res = get_dataframe_columns(
455-
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
456-
)
427+
res = get_dataframe_columns(
428+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
429+
)
457430
else:
458-
if is_async:
459-
res = get_dataframe_columns(await func())
460-
else:
461-
res = get_dataframe_columns(func())
431+
res = get_dataframe_columns(func())
462432

463433
# Generate row IDs
464434
if isinstance(res[0], Masked):
@@ -516,9 +486,6 @@ def make_func(
516486
# Set function type
517487
info['function_type'] = function_type
518488

519-
# Set async flag
520-
info['is_async'] = asyncio.iscoroutinefunction(func)
521-
522489
# Setup argument types for rowdat_1 parser
523490
colspec = []
524491
for x in sig['args']:
@@ -901,64 +868,11 @@ async def __call__(
901868
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
902869

903870
try:
904-
result = []
905-
906-
cancel_event = threading.Event()
907-
908-
if func_info['is_async']:
909-
func_task = asyncio.create_task(
910-
func(
911-
cancel_event,
912-
*input_handler['load']( # type: ignore
913-
func_info['colspec'], b''.join(data),
914-
),
915-
),
916-
)
917-
else:
918-
func_task = asyncio.create_task(
919-
to_thread(
920-
lambda: asyncio.run(
921-
func(
922-
cancel_event,
923-
*input_handler['load']( # type: ignore
924-
func_info['colspec'], b''.join(data),
925-
),
926-
),
927-
),
928-
),
929-
)
930-
disconnect_task = asyncio.create_task(
931-
cancel_on_disconnect(receive),
932-
)
933-
timeout_task = asyncio.create_task(
934-
cancel_on_timeout(func_info['timeout']),
935-
)
936-
937-
all_tasks = [func_task, disconnect_task, timeout_task]
938-
939-
done, pending = await asyncio.wait(
940-
all_tasks, return_when=asyncio.FIRST_COMPLETED,
871+
out = await func(
872+
*input_handler['load']( # type: ignore
873+
func_info['colspec'], b''.join(data),
874+
),
941875
)
942-
943-
cancel_all_tasks(pending)
944-
945-
for task in done:
946-
if task is disconnect_task:
947-
cancel_event.set()
948-
raise asyncio.CancelledError(
949-
'Function call was cancelled by client disconnect',
950-
)
951-
952-
elif task is timeout_task:
953-
cancel_event.set()
954-
raise asyncio.TimeoutError(
955-
'Function call was cancelled due to timeout',
956-
)
957-
958-
elif task is func_task:
959-
result.extend(task.result())
960-
961-
print(result)
962876
body = output_handler['dump'](
963877
[x[1] for x in func_info['returns']], *out, # type: ignore
964878
)
@@ -1065,6 +979,11 @@ def get_function_info(
1065979
for (_, info), sql in zip(self.endpoints.values(), create_sqls):
1066980
sig = info['signature']
1067981
sql_map[sig['name']] = sql
982+
983+
connection_info = {}
984+
workspace_group_id = os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP')
985+
connection_info['workspace_group_id'] = workspace_group_id
986+
1068987

1069988
for key, (_, info) in self.endpoints.items():
1070989
if not func_name or key == func_name:
@@ -1111,7 +1030,10 @@ def get_function_info(
11111030
sql_statement=sql,
11121031
)
11131032

1114-
return functions
1033+
return {
1034+
'functions': functions,
1035+
'connection_info': connection_info
1036+
}
11151037

11161038
def get_create_functions(
11171039
self,

0 commit comments

Comments
 (0)