@@ -273,12 +273,24 @@ def build_udf_endpoint(
273273 """
274274 if returns_data_format in ['scalar' , 'list' ]:
275275
276+ is_async = asyncio .iscoroutinefunction (func )
277+
276278 async def do_func (
277279 row_ids : Sequence [int ],
278280 rows : Sequence [Sequence [Any ]],
279281 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
280282 '''Call function on given rows of data.'''
281- return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
283+ out = []
284+ for row in rows :
285+ if cancel_event .is_set ():
286+ raise asyncio .CancelledError (
287+ 'Function call was cancelled' ,
288+ )
289+ if is_async :
290+ out .append (await func (* row ))
291+ else :
292+ out .append (func (* row ))
293+ return row_ids , list (zip (out ))
282294
283295 return do_func
284296
@@ -307,6 +319,7 @@ def build_vector_udf_endpoint(
307319 """
308320 masks = get_masked_params (func )
309321 array_cls = get_array_class (returns_data_format )
322+ is_async = asyncio .iscoroutinefunction (func )
310323
311324 async def do_func (
312325 row_ids : Sequence [int ],
@@ -320,9 +333,15 @@ async def do_func(
320333
321334 # Call the function with `cols` as the function parameters
322335 if cols and cols [0 ]:
323- out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
336+ if is_async :
337+ out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
338+ else :
339+ out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
324340 else :
325- out = func ()
341+ if is_async :
342+ out = await func ()
343+ else :
344+ out = func ()
326345
327346 # Single masked value
328347 if isinstance (out , Masked ):
@@ -360,6 +379,8 @@ def build_tvf_endpoint(
360379 """
361380 if returns_data_format in ['scalar' , 'list' ]:
362381
382+ is_async = asyncio .iscoroutinefunction (func )
383+
363384 async def do_func (
364385 row_ids : Sequence [int ],
365386 rows : Sequence [Sequence [Any ]],
@@ -368,7 +389,15 @@ async def do_func(
368389 out_ids : List [int ] = []
369390 out = []
370391 # Call function on each row of data
371- for i , res in zip (row_ids , func_map (func , rows )):
392+ for i , row in zip (row_ids , rows ):
393+ if cancel_event .is_set ():
394+ raise asyncio .CancelledError (
395+ 'Function call was cancelled' ,
396+ )
397+ if is_async :
398+ res = await func (* row )
399+ else :
400+ res = func (* row )
372401 out .extend (as_list_of_tuples (res ))
373402 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
374403 return out_ids , out
@@ -413,13 +442,23 @@ async def do_func(
413442 # each result row, so we just have to use the same
414443 # row ID for all rows in the result.
415444
445+ is_async = asyncio .iscoroutinefunction (func )
446+
416447 # Call function on each column of data
417448 if cols and cols [0 ]:
418- res = get_dataframe_columns (
419- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
420- )
449+ if is_async :
450+ res = get_dataframe_columns (
451+ await func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
452+ )
453+ else :
454+ res = get_dataframe_columns (
455+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
456+ )
421457 else :
422- res = get_dataframe_columns (func ())
458+ if is_async :
459+ res = get_dataframe_columns (await func ())
460+ else :
461+ res = get_dataframe_columns (func ())
423462
424463 # Generate row IDs
425464 if isinstance (res [0 ], Masked ):
@@ -477,6 +516,12 @@ def make_func(
477516 # Set function type
478517 info ['function_type' ] = function_type
479518
519+ # Set timeout
520+ info ['timeout' ] = max (timeout , 1 )
521+
522+ # Set async flag
523+ info ['is_async' ] = asyncio .iscoroutinefunction (func )
524+
480525 # Setup argument types for rowdat_1 parser
481526 colspec = []
482527 for x in sig ['args' ]:
@@ -859,11 +904,64 @@ async def __call__(
859904 output_handler = self .handlers [(accepts , data_version , returns_data_format )]
860905
861906 try :
862- out = await func (
863- * input_handler ['load' ]( # type: ignore
864- func_info ['colspec' ], b'' .join (data ),
865- ),
907+ result = []
908+
909+ cancel_event = threading .Event ()
910+
911+ if func_info ['is_async' ]:
912+ func_task = asyncio .create_task (
913+ func (
914+ cancel_event ,
915+ * input_handler ['load' ]( # type: ignore
916+ func_info ['colspec' ], b'' .join (data ),
917+ ),
918+ ),
919+ )
920+ else :
921+ func_task = asyncio .create_task (
922+ to_thread (
923+ lambda : asyncio .run (
924+ func (
925+ cancel_event ,
926+ * input_handler ['load' ]( # type: ignore
927+ func_info ['colspec' ], b'' .join (data ),
928+ ),
929+ ),
930+ ),
931+ ),
932+ )
933+ disconnect_task = asyncio .create_task (
934+ cancel_on_disconnect (receive ),
935+ )
936+ timeout_task = asyncio .create_task (
937+ cancel_on_timeout (func_info ['timeout' ]),
866938 )
939+
940+ all_tasks = [func_task , disconnect_task , timeout_task ]
941+
942+ done , pending = await asyncio .wait (
943+ all_tasks , return_when = asyncio .FIRST_COMPLETED ,
944+ )
945+
946+ cancel_all_tasks (pending )
947+
948+ for task in done :
949+ if task is disconnect_task :
950+ cancel_event .set ()
951+ raise asyncio .CancelledError (
952+ 'Function call was cancelled by client disconnect' ,
953+ )
954+
955+ elif task is timeout_task :
956+ cancel_event .set ()
957+ raise asyncio .TimeoutError (
958+ 'Function call was cancelled due to timeout' ,
959+ )
960+
961+ elif task is func_task :
962+ result .extend (task .result ())
963+
964+ print (result )
867965 body = output_handler ['dump' ](
868966 [x [1 ] for x in func_info ['returns' ]], * out , # type: ignore
869967 )
0 commit comments