6666from ..signature import signature_to_sql
6767from ..typing import Masked
6868from ..typing import Table
69+ from ...config import get_option
70+
6971
7072try :
7173 import cloudpickle
@@ -273,24 +275,12 @@ def build_udf_endpoint(
273275 """
274276 if returns_data_format in ['scalar' , 'list' ]:
275277
276- is_async = asyncio .iscoroutinefunction (func )
277-
278278 async def do_func (
279279 row_ids : Sequence [int ],
280280 rows : Sequence [Sequence [Any ]],
281281 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
282282 '''Call function on given rows of data.'''
283- out = []
284- for row in rows :
285- if cancel_event .is_set ():
286- raise asyncio .CancelledError (
287- 'Function call was cancelled' ,
288- )
289- if is_async :
290- out .append (await func (* row ))
291- else :
292- out .append (func (* row ))
293- return row_ids , list (zip (out ))
283+ return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
294284
295285 return do_func
296286
@@ -319,7 +309,6 @@ def build_vector_udf_endpoint(
319309 """
320310 masks = get_masked_params (func )
321311 array_cls = get_array_class (returns_data_format )
322- is_async = asyncio .iscoroutinefunction (func )
323312
324313 async def do_func (
325314 row_ids : Sequence [int ],
@@ -332,17 +321,18 @@ async def do_func(
332321 row_ids = array_cls (row_ids )
333322
334323 # Call the function with `cols` as the function parameters
324+ is_async = inspect .iscoroutinefunction (func ) or inspect .iscoroutinefunction (getattr (func , "__wrapped__" , None ))
335325 if cols and cols [0 ]:
336326 if is_async :
337327 out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
338328 else :
339- out = func ( * [x if m else x [0 ] for x , m in zip (cols , masks )])
329+ out = await asyncio . to_thread ( func , * [x if m else x [0 ] for x , m in zip (cols , masks )])
340330 else :
341331 if is_async :
342332 out = await func ()
343333 else :
344- out = func ()
345-
334+ out = await asyncio . to_thread ( func () )
335+
346336 # Single masked value
347337 if isinstance (out , Masked ):
348338 return row_ids , [tuple (out )]
@@ -379,8 +369,6 @@ def build_tvf_endpoint(
379369 """
380370 if returns_data_format in ['scalar' , 'list' ]:
381371
382- is_async = asyncio .iscoroutinefunction (func )
383-
384372 async def do_func (
385373 row_ids : Sequence [int ],
386374 rows : Sequence [Sequence [Any ]],
@@ -389,15 +377,7 @@ async def do_func(
389377 out_ids : List [int ] = []
390378 out = []
391379 # Call function on each row of data
392- for i , row in zip (row_ids , rows ):
393- if cancel_event .is_set ():
394- raise asyncio .CancelledError (
395- 'Function call was cancelled' ,
396- )
397- if is_async :
398- res = await func (* row )
399- else :
400- res = func (* row )
380+ for i , res in zip (row_ids , func_map (func , rows )):
401381 out .extend (as_list_of_tuples (res ))
402382 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
403383 return out_ids , out
@@ -442,23 +422,13 @@ async def do_func(
442422 # each result row, so we just have to use the same
443423 # row ID for all rows in the result.
444424
445- is_async = asyncio .iscoroutinefunction (func )
446-
447425 # Call function on each column of data
448426 if cols and cols [0 ]:
449- if is_async :
450- res = get_dataframe_columns (
451- await func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
452- )
453- else :
454- res = get_dataframe_columns (
455- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
456- )
427+ res = get_dataframe_columns (
428+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
429+ )
457430 else :
458- if is_async :
459- res = get_dataframe_columns (await func ())
460- else :
461- res = get_dataframe_columns (func ())
431+ res = get_dataframe_columns (func ())
462432
463433 # Generate row IDs
464434 if isinstance (res [0 ], Masked ):
@@ -516,9 +486,6 @@ def make_func(
516486 # Set function type
517487 info ['function_type' ] = function_type
518488
519- # Set async flag
520- info ['is_async' ] = asyncio .iscoroutinefunction (func )
521-
522489 # Setup argument types for rowdat_1 parser
523490 colspec = []
524491 for x in sig ['args' ]:
@@ -901,64 +868,11 @@ async def __call__(
901868 output_handler = self .handlers [(accepts , data_version , returns_data_format )]
902869
903870 try :
904- result = []
905-
906- cancel_event = threading .Event ()
907-
908- if func_info ['is_async' ]:
909- func_task = asyncio .create_task (
910- func (
911- cancel_event ,
912- * input_handler ['load' ]( # type: ignore
913- func_info ['colspec' ], b'' .join (data ),
914- ),
915- ),
916- )
917- else :
918- func_task = asyncio .create_task (
919- to_thread (
920- lambda : asyncio .run (
921- func (
922- cancel_event ,
923- * input_handler ['load' ]( # type: ignore
924- func_info ['colspec' ], b'' .join (data ),
925- ),
926- ),
927- ),
928- ),
929- )
930- disconnect_task = asyncio .create_task (
931- cancel_on_disconnect (receive ),
932- )
933- timeout_task = asyncio .create_task (
934- cancel_on_timeout (func_info ['timeout' ]),
935- )
936-
937- all_tasks = [func_task , disconnect_task , timeout_task ]
938-
939- done , pending = await asyncio .wait (
940- all_tasks , return_when = asyncio .FIRST_COMPLETED ,
871+ out = await func (
872+ * input_handler ['load' ]( # type: ignore
873+ func_info ['colspec' ], b'' .join (data ),
874+ ),
941875 )
942-
943- cancel_all_tasks (pending )
944-
945- for task in done :
946- if task is disconnect_task :
947- cancel_event .set ()
948- raise asyncio .CancelledError (
949- 'Function call was cancelled by client disconnect' ,
950- )
951-
952- elif task is timeout_task :
953- cancel_event .set ()
954- raise asyncio .TimeoutError (
955- 'Function call was cancelled due to timeout' ,
956- )
957-
958- elif task is func_task :
959- result .extend (task .result ())
960-
961- print (result )
962876 body = output_handler ['dump' ](
963877 [x [1 ] for x in func_info ['returns' ]], * out , # type: ignore
964878 )
@@ -1065,6 +979,12 @@ def get_function_info(
1065979 for (_ , info ), sql in zip (self .endpoints .values (), create_sqls ):
1066980 sig = info ['signature' ]
1067981 sql_map [sig ['name' ]] = sql
982+
983+ connection_info = {}
984+ workspace_group_id = os .environ .get ('SINGLESTOREDB_WORKSPACE_GROUP' )
985+ connection_info ['database_name' ] = os .environ .get ('SINGLESTOREDB_DATABASE' )
986+ connection_info ['workspace_group_id' ] = workspace_group_id
987+
1068988
1069989 for key , (_ , info ) in self .endpoints .items ():
1070990 if not func_name or key == func_name :
@@ -1111,7 +1031,10 @@ def get_function_info(
11111031 sql_statement = sql ,
11121032 )
11131033
1114- return functions
1034+ return {
1035+ 'functions' : functions ,
1036+ 'connection_info' : connection_info
1037+ }
11151038
11161039 def get_create_functions (
11171040 self ,
0 commit comments