6666from ..signature import signature_to_sql
6767from ..typing import Masked
6868from ..typing import Table
69+ from ...config import get_option
70+ from singlestoredb .connection import build_params
71+
6972
7073try :
7174 import cloudpickle
@@ -273,24 +276,12 @@ def build_udf_endpoint(
273276 """
274277 if returns_data_format in ['scalar' , 'list' ]:
275278
276- is_async = asyncio .iscoroutinefunction (func )
277-
278279 async def do_func (
279280 row_ids : Sequence [int ],
280281 rows : Sequence [Sequence [Any ]],
281282 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
282283 '''Call function on given rows of data.'''
283- out = []
284- for row in rows :
285- if cancel_event .is_set ():
286- raise asyncio .CancelledError (
287- 'Function call was cancelled' ,
288- )
289- if is_async :
290- out .append (await func (* row ))
291- else :
292- out .append (func (* row ))
293- return row_ids , list (zip (out ))
284+ return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
294285
295286 return do_func
296287
@@ -319,7 +310,6 @@ def build_vector_udf_endpoint(
319310 """
320311 masks = get_masked_params (func )
321312 array_cls = get_array_class (returns_data_format )
322- is_async = asyncio .iscoroutinefunction (func )
323313
324314 async def do_func (
325315 row_ids : Sequence [int ],
@@ -332,17 +322,18 @@ async def do_func(
332322 row_ids = array_cls (row_ids )
333323
334324 # Call the function with `cols` as the function parameters
325+ is_async = inspect .iscoroutinefunction (func ) or inspect .iscoroutinefunction (getattr (func , "__wrapped__" , None ))
335326 if cols and cols [0 ]:
336327 if is_async :
337328 out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
338329 else :
339- out = func ( * [x if m else x [0 ] for x , m in zip (cols , masks )])
330+ out = await asyncio . to_thread ( func , * [x if m else x [0 ] for x , m in zip (cols , masks )])
340331 else :
341332 if is_async :
342333 out = await func ()
343334 else :
344- out = func ()
345-
335+ out = await asyncio . to_thread ( func () )
336+
346337 # Single masked value
347338 if isinstance (out , Masked ):
348339 return row_ids , [tuple (out )]
@@ -379,8 +370,6 @@ def build_tvf_endpoint(
379370 """
380371 if returns_data_format in ['scalar' , 'list' ]:
381372
382- is_async = asyncio .iscoroutinefunction (func )
383-
384373 async def do_func (
385374 row_ids : Sequence [int ],
386375 rows : Sequence [Sequence [Any ]],
@@ -389,15 +378,7 @@ async def do_func(
389378 out_ids : List [int ] = []
390379 out = []
391380 # Call function on each row of data
392- for i , row in zip (row_ids , rows ):
393- if cancel_event .is_set ():
394- raise asyncio .CancelledError (
395- 'Function call was cancelled' ,
396- )
397- if is_async :
398- res = await func (* row )
399- else :
400- res = func (* row )
381+ for i , res in zip (row_ids , func_map (func , rows )):
401382 out .extend (as_list_of_tuples (res ))
402383 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
403384 return out_ids , out
@@ -442,23 +423,13 @@ async def do_func(
442423 # each result row, so we just have to use the same
443424 # row ID for all rows in the result.
444425
445- is_async = asyncio .iscoroutinefunction (func )
446-
447426 # Call function on each column of data
448427 if cols and cols [0 ]:
449- if is_async :
450- res = get_dataframe_columns (
451- await func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
452- )
453- else :
454- res = get_dataframe_columns (
455- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
456- )
428+ res = get_dataframe_columns (
429+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
430+ )
457431 else :
458- if is_async :
459- res = get_dataframe_columns (await func ())
460- else :
461- res = get_dataframe_columns (func ())
432+ res = get_dataframe_columns (func ())
462433
463434 # Generate row IDs
464435 if isinstance (res [0 ], Masked ):
@@ -516,9 +487,6 @@ def make_func(
516487 # Set function type
517488 info ['function_type' ] = function_type
518489
519- # Set async flag
520- info ['is_async' ] = asyncio .iscoroutinefunction (func )
521-
522490 # Setup argument types for rowdat_1 parser
523491 colspec = []
524492 for x in sig ['args' ]:
@@ -901,64 +869,11 @@ async def __call__(
901869 output_handler = self .handlers [(accepts , data_version , returns_data_format )]
902870
903871 try :
904- result = []
905-
906- cancel_event = threading .Event ()
907-
908- if func_info ['is_async' ]:
909- func_task = asyncio .create_task (
910- func (
911- cancel_event ,
912- * input_handler ['load' ]( # type: ignore
913- func_info ['colspec' ], b'' .join (data ),
914- ),
915- ),
916- )
917- else :
918- func_task = asyncio .create_task (
919- to_thread (
920- lambda : asyncio .run (
921- func (
922- cancel_event ,
923- * input_handler ['load' ]( # type: ignore
924- func_info ['colspec' ], b'' .join (data ),
925- ),
926- ),
927- ),
928- ),
929- )
930- disconnect_task = asyncio .create_task (
931- cancel_on_disconnect (receive ),
932- )
933- timeout_task = asyncio .create_task (
934- cancel_on_timeout (func_info ['timeout' ]),
935- )
936-
937- all_tasks = [func_task , disconnect_task , timeout_task ]
938-
939- done , pending = await asyncio .wait (
940- all_tasks , return_when = asyncio .FIRST_COMPLETED ,
872+ out = await func (
873+ * input_handler ['load' ]( # type: ignore
874+ func_info ['colspec' ], b'' .join (data ),
875+ ),
941876 )
942-
943- cancel_all_tasks (pending )
944-
945- for task in done :
946- if task is disconnect_task :
947- cancel_event .set ()
948- raise asyncio .CancelledError (
949- 'Function call was cancelled by client disconnect' ,
950- )
951-
952- elif task is timeout_task :
953- cancel_event .set ()
954- raise asyncio .TimeoutError (
955- 'Function call was cancelled due to timeout' ,
956- )
957-
958- elif task is func_task :
959- result .extend (task .result ())
960-
961- print (result )
962877 body = output_handler ['dump' ](
963878 [x [1 ] for x in func_info ['returns' ]], * out , # type: ignore
964879 )
@@ -1066,6 +981,19 @@ def get_function_info(
1066981 sig = info ['signature' ]
1067982 sql_map [sig ['name' ]] = sql
1068983
984+ if 'SINGLESTOREDB_URL' in os .environ :
985+ dbname = build_params (host = os .environ ['SINGLESTOREDB_URL' ]).get ('database' )
986+ elif 'SINGLESTOREDB_HOST' in os .environ :
987+ dbname = build_params (host = os .environ ['SINGLESTOREDB_HOST' ]).get ('database' )
988+ elif 'SINGLESTOREDB_DATABASE' in os .environ :
989+ dbname = os .environ ['SINGLESTOREDB_DATBASE' ]
990+
991+ connection_info = {}
992+ workspace_group_id = os .environ .get ('SINGLESTOREDB_WORKSPACE_GROUP' )
993+ connection_info ['database_name' ] = dbname
994+ connection_info ['workspace_group_id' ] = workspace_group_id
995+
996+
1069997 for key , (_ , info ) in self .endpoints .items ():
1070998 if not func_name or key == func_name :
1071999 sig = info ['signature' ]
@@ -1111,7 +1039,10 @@ def get_function_info(
11111039 sql_statement = sql ,
11121040 )
11131041
1114- return functions
1042+ return {
1043+ 'functions' : functions ,
1044+ 'connection_info' : connection_info
1045+ }
11151046
11161047 def get_create_functions (
11171048 self ,
0 commit comments