2424"""
2525import argparse
2626import asyncio
27+ import contextvars
2728import dataclasses
29+ import functools
2830import importlib .util
2931import inspect
3032import io
3739import sys
3840import tempfile
3941import textwrap
42+ import threading
4043import typing
4144import urllib
4245import zipfile
9598 func_map = itertools .starmap
9699
97100
101+ async def to_thread (
102+ func : Any , / , * args : Any , ** kwargs : Dict [str , Any ],
103+ ) -> Any :
104+ loop = asyncio .get_running_loop ()
105+ ctx = contextvars .copy_context ()
106+ func_call = functools .partial (ctx .run , func , * args , ** kwargs )
107+ return await loop .run_in_executor (None , func_call )
108+
109+
98110# Use negative values to indicate unsigned ints / binary data / usec time precision
99111rowdat_1_type_map = {
100112 'bool' : ft .LONGLONG ,
@@ -274,11 +286,19 @@ def build_udf_endpoint(
274286 if returns_data_format in ['scalar' , 'list' ]:
275287
276288 async def do_func (
289+ cancel_event : threading .Event ,
277290 row_ids : Sequence [int ],
278291 rows : Sequence [Sequence [Any ]],
279292 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
280293 '''Call function on given rows of data.'''
281- return row_ids , [as_tuple (x ) for x in zip (func_map (func , rows ))]
294+ out = []
295+ for row in rows :
296+ if cancel_event .is_set ():
297+ raise asyncio .CancelledError (
298+ 'Function call was cancelled' ,
299+ )
300+ out .append (func (* row ))
301+ return row_ids , list (zip (out ))
282302
283303 return do_func
284304
@@ -309,6 +329,7 @@ def build_vector_udf_endpoint(
309329 array_cls = get_array_class (returns_data_format )
310330
311331 async def do_func (
332+ cancel_event : threading .Event ,
312333 row_ids : Sequence [int ],
313334 cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
314335 ) -> Tuple [
@@ -361,6 +382,7 @@ def build_tvf_endpoint(
361382 if returns_data_format in ['scalar' , 'list' ]:
362383
363384 async def do_func (
385+ cancel_event : threading .Event ,
364386 row_ids : Sequence [int ],
365387 rows : Sequence [Sequence [Any ]],
366388 ) -> Tuple [Sequence [int ], List [Tuple [Any , ...]]]:
@@ -369,6 +391,10 @@ async def do_func(
369391 out = []
370392 # Call function on each row of data
371393 for i , res in zip (row_ids , func_map (func , rows )):
394+ if cancel_event .is_set ():
395+ raise asyncio .CancelledError (
396+ 'Function call was cancelled' ,
397+ )
372398 out .extend (as_list_of_tuples (res ))
373399 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
374400 return out_ids , out
@@ -402,6 +428,7 @@ def build_vector_tvf_endpoint(
402428 array_cls = get_array_class (returns_data_format )
403429
404430 async def do_func (
431+ cancel_event : threading .Event ,
405432 row_ids : Sequence [int ],
406433 cols : Sequence [Tuple [Sequence [Any ], Optional [Sequence [bool ]]]],
407434 ) -> Tuple [
@@ -458,6 +485,7 @@ def make_func(
458485 function_type = sig .get ('function_type' , 'udf' )
459486 args_data_format = sig .get ('args_data_format' , 'scalar' )
460487 returns_data_format = sig .get ('returns_data_format' , 'scalar' )
488+ timeout = sig .get ('timeout' , get_option ('external_function.timeout' ))
461489
462490 if function_type == 'tvf' :
463491 do_func = build_tvf_endpoint (func , returns_data_format )
@@ -477,6 +505,9 @@ def make_func(
477505 # Set function type
478506 info ['function_type' ] = function_type
479507
508+ # Set timeout
509+ info ['timeout' ] = max (timeout , 1 )
510+
480511 # Setup argument types for rowdat_1 parser
481512 colspec = []
482513 for x in sig ['args' ]:
@@ -498,6 +529,37 @@ def make_func(
498529 return do_func , info
499530
500531
532+ async def cancel_on_timeout (timeout : int ) -> None :
533+ """Cancel request if it takes too long."""
534+ await asyncio .sleep (timeout )
535+ raise asyncio .CancelledError (
536+ 'Function call was cancelled due to timeout' ,
537+ )
538+
539+
540+ async def cancel_on_disconnect (
541+ receive : Callable [..., Awaitable [Any ]],
542+ ) -> None :
543+ """Cancel request if client disconnects."""
544+ while True :
545+ message = await receive ()
546+ if message ['type' ] == 'http.disconnect' :
547+ raise asyncio .CancelledError (
548+ 'Function call was cancelled by client' ,
549+ )
550+
551+
552+ def cancel_all_tasks (tasks : Iterable [asyncio .Task [Any ]]) -> None :
553+ """Cancel all tasks."""
554+ for task in tasks :
555+ if task .done ():
556+ continue
557+ try :
558+ task .cancel ()
559+ except Exception :
560+ pass
561+
562+
501563class Application (object ):
502564 """
503565 Create an external function application.
@@ -851,6 +913,8 @@ async def __call__(
851913 more_body = True
852914 while more_body :
853915 request = await receive ()
916+ if request ['type' ] == 'http.disconnect' :
917+ raise RuntimeError ('client disconnected' )
854918 data .append (request ['body' ])
855919 more_body = request .get ('more_body' , False )
856920
@@ -859,21 +923,87 @@ async def __call__(
859923 output_handler = self .handlers [(accepts , data_version , returns_data_format )]
860924
861925 try :
862- out = await func (
863- * input_handler ['load' ]( # type: ignore
864- func_info ['colspec' ], b'' .join (data ),
926+ result = []
927+
928+ cancel_event = threading .Event ()
929+
930+ func_task = asyncio .create_task (
931+ to_thread (
932+ lambda : asyncio .run (
933+ func (
934+ cancel_event ,
935+ * input_handler ['load' ]( # type: ignore
936+ func_info ['colspec' ], b'' .join (data ),
937+ ),
938+ ),
939+ ),
865940 ),
866941 )
942+ disconnect_task = asyncio .create_task (
943+ cancel_on_disconnect (receive ),
944+ )
945+ timeout_task = asyncio .create_task (
946+ cancel_on_timeout (func_info ['timeout' ]),
947+ )
948+
949+ all_tasks = [func_task , disconnect_task , timeout_task ]
950+
951+ done , pending = await asyncio .wait (
952+ all_tasks , return_when = asyncio .FIRST_COMPLETED ,
953+ )
954+
955+ cancel_all_tasks (pending )
956+
957+ for task in done :
958+ if task is disconnect_task :
959+ cancel_event .set ()
960+ raise asyncio .CancelledError (
961+ 'Function call was cancelled by client disconnect' ,
962+ )
963+
964+ elif task is timeout_task :
965+ cancel_event .set ()
966+ raise asyncio .TimeoutError (
967+ 'Function call was cancelled due to timeout' ,
968+ )
969+
970+ elif task is func_task :
971+ result .extend (task .result ())
972+
867973 body = output_handler ['dump' ](
868- [x [1 ] for x in func_info ['returns' ]], * out , # type: ignore
974+ [x [1 ] for x in func_info ['returns' ]], * result , # type: ignore
869975 )
976+
870977 await send (output_handler ['response' ])
871978
979+ except asyncio .TimeoutError :
980+ logging .exception (
981+ 'Timeout in function call: ' + func_name .decode ('utf-8' ),
982+ )
983+ body = (
984+ '[TimeoutError] Function call timed out after ' +
985+ str (func_info ['timeout' ]) +
986+ ' seconds'
987+ ).encode ('utf-8' )
988+ await send (self .error_response_dict )
989+
990+ except asyncio .CancelledError :
991+ logging .exception (
992+ 'Function call cancelled: ' + func_name .decode ('utf-8' ),
993+ )
994+ body = b'[CancelledError] Function call was cancelled'
995+ await send (self .error_response_dict )
996+
872997 except Exception as e :
873- logging .exception ('Error in function call' )
998+ logging .exception (
999+ 'Error in function call: ' + func_name .decode ('utf-8' ),
1000+ )
8741001 body = f'[{ type (e ).__name__ } ] { str (e ).strip ()} ' .encode ('utf-8' )
8751002 await send (self .error_response_dict )
8761003
1004+ finally :
1005+ cancel_all_tasks (all_tasks )
1006+
8771007 # Handle api reflection
8781008 elif method == 'GET' and path == self .show_create_function_path :
8791009 host = headers .get (b'host' , b'localhost:80' )
0 commit comments