5656MAX_PAGES_PER_SPLIT = 20
5757HI_RES_STRATEGY = 'hi_res'
5858MAX_PAGE_LENGTH = 4000
59+ TIMEOUT_BUFFER_SECONDS = 5
60+
61+
62+ def _get_request_timeout_seconds (request : httpx .Request ) -> Optional [float ]:
63+ timeout = request .extensions .get ("timeout" )
64+ if timeout is None :
65+ return None
66+
67+ if isinstance (timeout , (int , float )):
68+ return float (timeout )
69+
70+ if isinstance (timeout , dict ):
71+ timeout_values = [
72+ float (value )
73+ for value in timeout .values ()
74+ if isinstance (value , (int , float ))
75+ ]
76+ if timeout_values :
77+ return max (timeout_values )
78+
79+ return None
5980
6081def _run_coroutines_in_separate_thread (
6182 coroutines_task : Coroutine [Any , Any , list [tuple [int , httpx .Response ]]],
@@ -72,6 +93,7 @@ async def run_tasks(
7293 coroutines : list [partial [Coroutine [Any , Any , httpx .Response ]]],
7394 allow_failed : bool = False ,
7495 concurrency_level : int = 10 ,
96+ client_timeout : Optional [httpx .Timeout ] = None ,
7597) -> list [tuple [int , httpx .Response ]]:
7698 """Run a list of coroutines in parallel and return the results in order.
7799
@@ -84,14 +106,15 @@ async def run_tasks(
84106 """
85107
86108
87- # Use a variable to adjust the httpx client timeout, or default to 30 minutes
88- # When we're able to reuse the SDK to make these calls, we can remove this var
89- # The SDK timeout will be controlled by parameter
90109 limiter = asyncio .Semaphore (concurrency_level )
91- client_timeout_minutes = 60
92- if timeout_var := os .getenv ("UNSTRUCTURED_CLIENT_TIMEOUT_MINUTES" ):
93- client_timeout_minutes = int (timeout_var )
94- client_timeout = httpx .Timeout (60 * client_timeout_minutes )
110+ if client_timeout is None :
111+ # Use a variable to adjust the httpx client timeout, or default to 60 minutes.
112+ # When we're able to reuse the SDK to make these calls, we can remove this var
113+ # and let the SDK timeout flow through directly.
114+ client_timeout_minutes = 60
115+ if timeout_var := os .getenv ("UNSTRUCTURED_CLIENT_TIMEOUT_MINUTES" ):
116+ client_timeout_minutes = int (timeout_var )
117+ client_timeout = httpx .Timeout (60 * client_timeout_minutes )
95118
96119 async with httpx .AsyncClient (timeout = client_timeout ) as client :
97120 armed_coroutines = [coro (async_client = client , limiter = limiter ) for coro in coroutines ] # type: ignore
@@ -166,6 +189,7 @@ def __init__(self) -> None:
166189 self .api_failed_responses : dict [str , list [httpx .Response ]] = {}
167190 self .executors : dict [str , futures .ThreadPoolExecutor ] = {}
168191 self .tempdirs : dict [str , tempfile .TemporaryDirectory ] = {}
192+ self .operation_timeouts : dict [str , Optional [float ]] = {}
169193 self .allow_failed : bool = DEFAULT_ALLOW_FAILED
170194 self .cache_tmp_data_feature : bool = DEFAULT_CACHE_TMP_DATA
171195 self .cache_tmp_data_dir : str = DEFAULT_CACHE_TMP_DATA_DIR
@@ -268,6 +292,7 @@ def before_request(
268292 # We need to pass it on to after_success so
269293 # we know which results are ours
270294 operation_id = str (uuid .uuid4 ())
295+ self .operation_timeouts [operation_id ] = _get_request_timeout_seconds (request )
271296
272297 content_type = request .headers .get ("Content-Type" )
273298 if content_type is None :
@@ -397,14 +422,11 @@ def before_request(
397422 # This allows us to skip right to the AfterRequestHook and await all the calls
398423 # Also, pass the operation_id so after_success can await the right results
399424
400- # Note: We need access to the async_client from the sdk_init hook in order to set
401- # up a mock request like this.
402- # For now, just make an extra request against our api, which should return 200.
403- # dummy_request = httpx.Request("GET", "http://no-op")
404425 return httpx .Request (
405426 "GET" ,
406- f" { self . partition_base_url } /general/docs " ,
427+ "http://no-op " ,
407428 headers = {"operation_id" : operation_id },
429+ extensions = request .extensions .copy (),
408430 )
409431
410432 async def call_api_partial (
@@ -620,15 +642,25 @@ def _await_elements(self, operation_id: str) -> Optional[list]:
620642 return None
621643
622644 concurrency_level = self .concurrency_level .get (operation_id , DEFAULT_CONCURRENCY_LEVEL )
623- coroutines = run_tasks (tasks , allow_failed = self .allow_failed , concurrency_level = concurrency_level )
645+ timeout_seconds = self .operation_timeouts .get (operation_id )
646+ client_timeout = httpx .Timeout (timeout_seconds ) if timeout_seconds is not None else None
647+ coroutines = run_tasks (
648+ tasks ,
649+ allow_failed = self .allow_failed ,
650+ concurrency_level = concurrency_level ,
651+ client_timeout = client_timeout ,
652+ )
624653
625654 # sending the coroutines to a separate thread to avoid blocking the current event loop
626655 # this operation should be removed when the SDK is updated to support async hooks
627656 executor = self .executors .get (operation_id )
628657 if executor is None :
629658 raise RuntimeError ("Executor not found for operation_id" )
630659 task_responses_future = executor .submit (_run_coroutines_in_separate_thread , coroutines )
631- task_responses = task_responses_future .result ()
660+ if timeout_seconds is None :
661+ task_responses = task_responses_future .result ()
662+ else :
663+ task_responses = task_responses_future .result (timeout = timeout_seconds + TIMEOUT_BUFFER_SECONDS )
632664
633665 if task_responses is None :
634666 return None
@@ -683,23 +715,20 @@ def after_success(
683715
684716 # Grab the correct id out of the dummy request
685717 operation_id = response .request .headers .get ("operation_id" )
718+ try :
719+ elements = self ._await_elements (operation_id )
686720
687- elements = self ._await_elements (operation_id )
688-
689- # if fails are disallowed, return the first failed response
690- if not self .allow_failed and self .api_failed_responses .get (operation_id ):
691- failure_response = self .api_failed_responses [operation_id ][0 ]
692-
693- self ._clear_operation (operation_id )
694- return failure_response
695-
696- if elements is None :
697- return response
721+ # if fails are disallowed, return the first failed response
722+ if not self .allow_failed and self .api_failed_responses .get (operation_id ):
723+ return self .api_failed_responses [operation_id ][0 ]
698724
699- new_response = request_utils . create_response ( elements )
700- self . _clear_operation ( operation_id )
725+ if elements is None :
726+ return response
701727
702- return new_response
728+ return request_utils .create_response (elements )
729+ finally :
730+ if operation_id is not None :
731+ self ._clear_operation (operation_id )
703732
704733 def after_error (
705734 self ,
@@ -732,7 +761,9 @@ def _clear_operation(self, operation_id: str) -> None:
732761 """
733762 self .coroutines_to_execute .pop (operation_id , None )
734763 self .api_successful_responses .pop (operation_id , None )
764+ self .api_failed_responses .pop (operation_id , None )
735765 self .concurrency_level .pop (operation_id , None )
766+ self .operation_timeouts .pop (operation_id , None )
736767 executor = self .executors .pop (operation_id , None )
737768 if executor is not None :
738769 executor .shutdown (wait = True )
0 commit comments