Skip to content

Commit 459d8ab

Browse files
interactive
1 parent cccd5f8 commit 459d8ab

File tree

1 file changed

+34
-103
lines changed

1 file changed

+34
-103
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 34 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
from ..signature import signature_to_sql
6767
from ..typing import Masked
6868
from ..typing import Table
69+
from ...config import get_option
70+
from singlestoredb.connection import build_params
71+
6972

7073
try:
7174
import cloudpickle
@@ -273,24 +276,12 @@ def build_udf_endpoint(
273276
"""
274277
if returns_data_format in ['scalar', 'list']:
275278

276-
is_async = asyncio.iscoroutinefunction(func)
277-
278279
async def do_func(
279280
row_ids: Sequence[int],
280281
rows: Sequence[Sequence[Any]],
281282
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
282283
'''Call function on given rows of data.'''
283-
out = []
284-
for row in rows:
285-
if cancel_event.is_set():
286-
raise asyncio.CancelledError(
287-
'Function call was cancelled',
288-
)
289-
if is_async:
290-
out.append(await func(*row))
291-
else:
292-
out.append(func(*row))
293-
return row_ids, list(zip(out))
284+
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]
294285

295286
return do_func
296287

@@ -319,7 +310,6 @@ def build_vector_udf_endpoint(
319310
"""
320311
masks = get_masked_params(func)
321312
array_cls = get_array_class(returns_data_format)
322-
is_async = asyncio.iscoroutinefunction(func)
323313

324314
async def do_func(
325315
row_ids: Sequence[int],
@@ -332,17 +322,18 @@ async def do_func(
332322
row_ids = array_cls(row_ids)
333323

334324
# Call the function with `cols` as the function parameters
325+
is_async = inspect.iscoroutinefunction(func) or inspect.iscoroutinefunction(getattr(func, "__wrapped__", None))
335326
if cols and cols[0]:
336327
if is_async:
337328
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
338329
else:
339-
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
330+
out = await asyncio.to_thread(func, *[x if m else x[0] for x, m in zip(cols, masks)])
340331
else:
341332
if is_async:
342333
out = await func()
343334
else:
344-
out = func()
345-
335+
out = await asyncio.to_thread(func())
336+
346337
# Single masked value
347338
if isinstance(out, Masked):
348339
return row_ids, [tuple(out)]
@@ -379,8 +370,6 @@ def build_tvf_endpoint(
379370
"""
380371
if returns_data_format in ['scalar', 'list']:
381372

382-
is_async = asyncio.iscoroutinefunction(func)
383-
384373
async def do_func(
385374
row_ids: Sequence[int],
386375
rows: Sequence[Sequence[Any]],
@@ -389,15 +378,7 @@ async def do_func(
389378
out_ids: List[int] = []
390379
out = []
391380
# Call function on each row of data
392-
for i, row in zip(row_ids, rows):
393-
if cancel_event.is_set():
394-
raise asyncio.CancelledError(
395-
'Function call was cancelled',
396-
)
397-
if is_async:
398-
res = await func(*row)
399-
else:
400-
res = func(*row)
381+
for i, res in zip(row_ids, func_map(func, rows)):
401382
out.extend(as_list_of_tuples(res))
402383
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
403384
return out_ids, out
@@ -442,23 +423,13 @@ async def do_func(
442423
# each result row, so we just have to use the same
443424
# row ID for all rows in the result.
444425

445-
is_async = asyncio.iscoroutinefunction(func)
446-
447426
# Call function on each column of data
448427
if cols and cols[0]:
449-
if is_async:
450-
res = get_dataframe_columns(
451-
await func(*[x if m else x[0] for x, m in zip(cols, masks)]),
452-
)
453-
else:
454-
res = get_dataframe_columns(
455-
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
456-
)
428+
res = get_dataframe_columns(
429+
func(*[x if m else x[0] for x, m in zip(cols, masks)]),
430+
)
457431
else:
458-
if is_async:
459-
res = get_dataframe_columns(await func())
460-
else:
461-
res = get_dataframe_columns(func())
432+
res = get_dataframe_columns(func())
462433

463434
# Generate row IDs
464435
if isinstance(res[0], Masked):
@@ -516,9 +487,6 @@ def make_func(
516487
# Set function type
517488
info['function_type'] = function_type
518489

519-
# Set async flag
520-
info['is_async'] = asyncio.iscoroutinefunction(func)
521-
522490
# Setup argument types for rowdat_1 parser
523491
colspec = []
524492
for x in sig['args']:
@@ -901,64 +869,11 @@ async def __call__(
901869
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
902870

903871
try:
904-
result = []
905-
906-
cancel_event = threading.Event()
907-
908-
if func_info['is_async']:
909-
func_task = asyncio.create_task(
910-
func(
911-
cancel_event,
912-
*input_handler['load']( # type: ignore
913-
func_info['colspec'], b''.join(data),
914-
),
915-
),
916-
)
917-
else:
918-
func_task = asyncio.create_task(
919-
to_thread(
920-
lambda: asyncio.run(
921-
func(
922-
cancel_event,
923-
*input_handler['load']( # type: ignore
924-
func_info['colspec'], b''.join(data),
925-
),
926-
),
927-
),
928-
),
929-
)
930-
disconnect_task = asyncio.create_task(
931-
cancel_on_disconnect(receive),
932-
)
933-
timeout_task = asyncio.create_task(
934-
cancel_on_timeout(func_info['timeout']),
935-
)
936-
937-
all_tasks = [func_task, disconnect_task, timeout_task]
938-
939-
done, pending = await asyncio.wait(
940-
all_tasks, return_when=asyncio.FIRST_COMPLETED,
872+
out = await func(
873+
*input_handler['load']( # type: ignore
874+
func_info['colspec'], b''.join(data),
875+
),
941876
)
942-
943-
cancel_all_tasks(pending)
944-
945-
for task in done:
946-
if task is disconnect_task:
947-
cancel_event.set()
948-
raise asyncio.CancelledError(
949-
'Function call was cancelled by client disconnect',
950-
)
951-
952-
elif task is timeout_task:
953-
cancel_event.set()
954-
raise asyncio.TimeoutError(
955-
'Function call was cancelled due to timeout',
956-
)
957-
958-
elif task is func_task:
959-
result.extend(task.result())
960-
961-
print(result)
962877
body = output_handler['dump'](
963878
[x[1] for x in func_info['returns']], *out, # type: ignore
964879
)
@@ -1066,6 +981,19 @@ def get_function_info(
1066981
sig = info['signature']
1067982
sql_map[sig['name']] = sql
1068983

984+
if 'SINGLESTOREDB_URL' in os.environ:
985+
dbname = build_params(host=os.environ['SINGLESTOREDB_URL']).get('database')
986+
elif 'SINGLESTOREDB_HOST' in os.environ:
987+
dbname = build_params(host=os.environ['SINGLESTOREDB_HOST']).get('database')
988+
elif 'SINGLESTOREDB_DATABASE' in os.environ:
989+
dbname = os.environ['SINGLESTOREDB_DATBASE']
990+
991+
connection_info = {}
992+
workspace_group_id = os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP')
993+
connection_info['database_name'] = dbname
994+
connection_info['workspace_group_id'] = workspace_group_id
995+
996+
1069997
for key, (_, info) in self.endpoints.items():
1070998
if not func_name or key == func_name:
1071999
sig = info['signature']
@@ -1111,7 +1039,10 @@ def get_function_info(
11111039
sql_statement=sql,
11121040
)
11131041

1114-
return functions
1042+
return {
1043+
'functions': functions,
1044+
'connection_info': connection_info
1045+
}
11151046

11161047
def get_create_functions(
11171048
self,

0 commit comments

Comments
 (0)