Skip to content

Commit 1eb4db7

Browse files
committed
Logic fixes
1 parent fa780b8 commit 1eb4db7

File tree

3 files changed

+131
-35
lines changed

3 files changed

+131
-35
lines changed

_test_unstructured_client/integration/test_decorators.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from unstructured_client._hooks.custom import split_pdf_hook
2525

2626
FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
27+
TEST_TIMEOUT_MS = 300_000
2728

2829
_HI_RES_STRATEGIES = ("hi_res", Strategy.HI_RES)
2930

@@ -143,7 +144,7 @@ def test_integration_split_pdf_has_same_output_as_non_split(
143144
except requests.exceptions.ConnectionError:
144145
assert False, "The unstructured-api is not running on localhost:8000"
145146

146-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
147+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
147148

148149
with open(filename, "rb") as f:
149150
files = shared.Files(
@@ -215,7 +216,7 @@ def test_integration_split_pdf_with_caching(
215216
except requests.exceptions.ConnectionError:
216217
assert False, "The unstructured-api is not running on localhost:8000"
217218

218-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
219+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
219220

220221
with open(filename, "rb") as f:
221222
files = shared.Files(
@@ -282,7 +283,7 @@ def test_long_pages_hi_res(filename):
282283
split_pdf_concurrency_level=15
283284
), )
284285

285-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
286+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
286287

287288
response = client.general.partition(
288289
request=req,
@@ -301,7 +302,7 @@ def test_integration_split_pdf_for_file_with_no_name():
301302
except requests.exceptions.ConnectionError:
302303
assert False, "The unstructured-api is not running on localhost:8000"
303304

304-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
305+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
305306

306307
with open("_sample_docs/layout-parser-paper-fast.pdf", "rb") as f:
307308
files = shared.Files(
@@ -357,7 +358,7 @@ def test_integration_split_pdf_with_page_range(
357358
except requests.exceptions.ConnectionError:
358359
assert False, "The unstructured-api is not running on localhost:8000"
359360

360-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
361+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
361362

362363
filename = "_sample_docs/layout-parser-paper.pdf"
363364
with open(filename, "rb") as f:
@@ -421,7 +422,7 @@ def test_integration_split_pdf_strict_mode(
421422
except requests.exceptions.ConnectionError:
422423
assert False, "The unstructured-api is not running on localhost:8000"
423424

424-
client = UnstructuredClient(api_key_auth=FAKE_KEY)
425+
client = UnstructuredClient(api_key_auth=FAKE_KEY, timeout_ms=TEST_TIMEOUT_MS)
425426

426427
with open(filename, "rb") as f:
427428
files = shared.Files(

_test_unstructured_client/unit/test_split_pdf_hook.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import io
45
from asyncio import Task
56
from collections import Counter
67
from functools import partial
@@ -28,6 +29,7 @@
2829
MAX_PAGES_PER_SPLIT,
2930
MIN_PAGES_PER_SPLIT,
3031
SplitPdfHook,
32+
_get_request_timeout_seconds,
3133
get_optimal_split_size, run_tasks,
3234
)
3335
from unstructured_client._hooks.types import BeforeRequestContext
@@ -47,6 +49,8 @@ async def example():
4749
requests.Response(),
4850
requests.Response(),
4951
]
52+
hook.api_failed_responses[operation_id] = [requests.Response()]
53+
hook.operation_timeouts[operation_id] = 30.0
5054

5155
assert len(hook.coroutines_to_execute[operation_id]) == 2
5256
assert len(hook.api_successful_responses[operation_id]) == 2
@@ -55,6 +59,18 @@ async def example():
5559

5660
assert hook.coroutines_to_execute.get(operation_id) is None
5761
assert hook.api_successful_responses.get(operation_id) is None
62+
assert hook.api_failed_responses.get(operation_id) is None
63+
assert hook.operation_timeouts.get(operation_id) is None
64+
65+
66+
def test_unit_get_request_timeout_seconds_uses_request_timeout_extension():
67+
request = httpx.Request(
68+
"POST",
69+
"http://localhost",
70+
extensions={"timeout": {"connect": 10.0, "read": 30.0, "write": 20.0, "pool": 5.0}},
71+
)
72+
73+
assert _get_request_timeout_seconds(request) == 30.0
5874

5975

6076
def test_unit_prepare_request_headers():
@@ -525,4 +541,52 @@ def test_before_request_raises_pdf_validation_error_when_pdf_check_fails():
525541
# Verify that the mocked functions were called as expected
526542
mock_get_fields.assert_called_once_with(mock_request)
527543
mock_read_pdf.assert_called_once_with(mock_pdf_file)
528-
mock_check_pdf.assert_called_once_with(mock_pdf_reader)
544+
mock_check_pdf.assert_called_once_with(mock_pdf_reader)
545+
546+
547+
def test_before_request_uses_in_memory_noop_request_for_split_pdf():
548+
hook = SplitPdfHook()
549+
mock_client = httpx.Client()
550+
hook.sdk_init(base_url="http://localhost:8888", client=mock_client)
551+
552+
mock_hook_ctx = MagicMock()
553+
mock_hook_ctx.operation_id = "partition"
554+
mock_hook_ctx.config.timeout_ms = 12_000
555+
556+
mock_request = MagicMock(spec=httpx.Request)
557+
mock_request.headers = {"Content-Type": "multipart/form-data"}
558+
mock_request.url = httpx.URL("http://localhost:8888/general/v0/general")
559+
mock_request.extensions = {"timeout": {"connect": 12.0, "read": 12.0, "write": 12.0, "pool": 12.0}}
560+
561+
mock_pdf_file = MagicMock()
562+
mock_form_data = {
563+
"split_pdf_page": "true",
564+
"strategy": "fast",
565+
"files": {
566+
"filename": "test.pdf",
567+
"content_type": "application/pdf",
568+
"file": mock_pdf_file,
569+
},
570+
}
571+
mock_pdf_reader = MagicMock()
572+
mock_pdf_reader.get_num_pages.return_value = 100
573+
mock_pdf_reader.pages = [MagicMock()] * 100
574+
mock_pdf_reader.stream = io.BytesIO(b"fake-pdf-bytes")
575+
576+
with patch("unstructured_client._hooks.custom.request_utils.get_multipart_stream_fields") as mock_get_fields, \
577+
patch("unstructured_client._hooks.custom.pdf_utils.read_pdf") as mock_read_pdf, \
578+
patch("unstructured_client._hooks.custom.pdf_utils.check_pdf") as mock_check_pdf, \
579+
patch("unstructured_client._hooks.custom.request_utils.get_base_url") as mock_get_base_url, \
580+
patch.object(hook, "_trim_large_pages", side_effect=lambda pdf, fd: pdf), \
581+
patch.object(hook, "_get_pdf_chunks_in_memory", return_value=[]):
582+
mock_get_fields.return_value = mock_form_data
583+
mock_read_pdf.return_value = mock_pdf_reader
584+
mock_check_pdf.return_value = mock_pdf_reader
585+
mock_get_base_url.return_value = "http://localhost:8888"
586+
587+
result = hook.before_request(mock_hook_ctx, mock_request)
588+
589+
assert isinstance(result, httpx.Request)
590+
assert str(result.url) == "http://no-op"
591+
assert result.headers["operation_id"]
592+
assert result.extensions["timeout"]["read"] == 12.0

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@
5656
MAX_PAGES_PER_SPLIT = 20
5757
HI_RES_STRATEGY = 'hi_res'
5858
MAX_PAGE_LENGTH = 4000
59+
TIMEOUT_BUFFER_SECONDS = 5
60+
61+
62+
def _get_request_timeout_seconds(request: httpx.Request) -> Optional[float]:
63+
timeout = request.extensions.get("timeout")
64+
if timeout is None:
65+
return None
66+
67+
if isinstance(timeout, (int, float)):
68+
return float(timeout)
69+
70+
if isinstance(timeout, dict):
71+
timeout_values = [
72+
float(value)
73+
for value in timeout.values()
74+
if isinstance(value, (int, float))
75+
]
76+
if timeout_values:
77+
return max(timeout_values)
78+
79+
return None
5980

6081
def _run_coroutines_in_separate_thread(
6182
coroutines_task: Coroutine[Any, Any, list[tuple[int, httpx.Response]]],
@@ -72,6 +93,7 @@ async def run_tasks(
7293
coroutines: list[partial[Coroutine[Any, Any, httpx.Response]]],
7394
allow_failed: bool = False,
7495
concurrency_level: int = 10,
96+
client_timeout: Optional[httpx.Timeout] = None,
7597
) -> list[tuple[int, httpx.Response]]:
7698
"""Run a list of coroutines in parallel and return the results in order.
7799
@@ -84,14 +106,15 @@ async def run_tasks(
84106
"""
85107

86108

87-
# Use a variable to adjust the httpx client timeout, or default to 30 minutes
88-
# When we're able to reuse the SDK to make these calls, we can remove this var
89-
# The SDK timeout will be controlled by parameter
90109
limiter = asyncio.Semaphore(concurrency_level)
91-
client_timeout_minutes = 60
92-
if timeout_var := os.getenv("UNSTRUCTURED_CLIENT_TIMEOUT_MINUTES"):
93-
client_timeout_minutes = int(timeout_var)
94-
client_timeout = httpx.Timeout(60 * client_timeout_minutes)
110+
if client_timeout is None:
111+
# Use a variable to adjust the httpx client timeout, or default to 60 minutes.
112+
# When we're able to reuse the SDK to make these calls, we can remove this var
113+
# and let the SDK timeout flow through directly.
114+
client_timeout_minutes = 60
115+
if timeout_var := os.getenv("UNSTRUCTURED_CLIENT_TIMEOUT_MINUTES"):
116+
client_timeout_minutes = int(timeout_var)
117+
client_timeout = httpx.Timeout(60 * client_timeout_minutes)
95118

96119
async with httpx.AsyncClient(timeout=client_timeout) as client:
97120
armed_coroutines = [coro(async_client=client, limiter=limiter) for coro in coroutines] # type: ignore
@@ -166,6 +189,7 @@ def __init__(self) -> None:
166189
self.api_failed_responses: dict[str, list[httpx.Response]] = {}
167190
self.executors: dict[str, futures.ThreadPoolExecutor] = {}
168191
self.tempdirs: dict[str, tempfile.TemporaryDirectory] = {}
192+
self.operation_timeouts: dict[str, Optional[float]] = {}
169193
self.allow_failed: bool = DEFAULT_ALLOW_FAILED
170194
self.cache_tmp_data_feature: bool = DEFAULT_CACHE_TMP_DATA
171195
self.cache_tmp_data_dir: str = DEFAULT_CACHE_TMP_DATA_DIR
@@ -268,6 +292,7 @@ def before_request(
268292
# We need to pass it on to after_success so
269293
# we know which results are ours
270294
operation_id = str(uuid.uuid4())
295+
self.operation_timeouts[operation_id] = _get_request_timeout_seconds(request)
271296

272297
content_type = request.headers.get("Content-Type")
273298
if content_type is None:
@@ -397,14 +422,11 @@ def before_request(
397422
# This allows us to skip right to the AfterRequestHook and await all the calls
398423
# Also, pass the operation_id so after_success can await the right results
399424

400-
# Note: We need access to the async_client from the sdk_init hook in order to set
401-
# up a mock request like this.
402-
# For now, just make an extra request against our api, which should return 200.
403-
# dummy_request = httpx.Request("GET", "http://no-op")
404425
return httpx.Request(
405426
"GET",
406-
f"{self.partition_base_url}/general/docs",
427+
"http://no-op",
407428
headers={"operation_id": operation_id},
429+
extensions=request.extensions.copy(),
408430
)
409431

410432
async def call_api_partial(
@@ -620,15 +642,25 @@ def _await_elements(self, operation_id: str) -> Optional[list]:
620642
return None
621643

622644
concurrency_level = self.concurrency_level.get(operation_id, DEFAULT_CONCURRENCY_LEVEL)
623-
coroutines = run_tasks(tasks, allow_failed=self.allow_failed, concurrency_level=concurrency_level)
645+
timeout_seconds = self.operation_timeouts.get(operation_id)
646+
client_timeout = httpx.Timeout(timeout_seconds) if timeout_seconds is not None else None
647+
coroutines = run_tasks(
648+
tasks,
649+
allow_failed=self.allow_failed,
650+
concurrency_level=concurrency_level,
651+
client_timeout=client_timeout,
652+
)
624653

625654
# sending the coroutines to a separate thread to avoid blocking the current event loop
626655
# this operation should be removed when the SDK is updated to support async hooks
627656
executor = self.executors.get(operation_id)
628657
if executor is None:
629658
raise RuntimeError("Executor not found for operation_id")
630659
task_responses_future = executor.submit(_run_coroutines_in_separate_thread, coroutines)
631-
task_responses = task_responses_future.result()
660+
if timeout_seconds is None:
661+
task_responses = task_responses_future.result()
662+
else:
663+
task_responses = task_responses_future.result(timeout=timeout_seconds + TIMEOUT_BUFFER_SECONDS)
632664

633665
if task_responses is None:
634666
return None
@@ -683,23 +715,20 @@ def after_success(
683715

684716
# Grab the correct id out of the dummy request
685717
operation_id = response.request.headers.get("operation_id")
718+
try:
719+
elements = self._await_elements(operation_id)
686720

687-
elements = self._await_elements(operation_id)
688-
689-
# if fails are disallowed, return the first failed response
690-
if not self.allow_failed and self.api_failed_responses.get(operation_id):
691-
failure_response = self.api_failed_responses[operation_id][0]
692-
693-
self._clear_operation(operation_id)
694-
return failure_response
695-
696-
if elements is None:
697-
return response
721+
# if fails are disallowed, return the first failed response
722+
if not self.allow_failed and self.api_failed_responses.get(operation_id):
723+
return self.api_failed_responses[operation_id][0]
698724

699-
new_response = request_utils.create_response(elements)
700-
self._clear_operation(operation_id)
725+
if elements is None:
726+
return response
701727

702-
return new_response
728+
return request_utils.create_response(elements)
729+
finally:
730+
if operation_id is not None:
731+
self._clear_operation(operation_id)
703732

704733
def after_error(
705734
self,
@@ -732,7 +761,9 @@ def _clear_operation(self, operation_id: str) -> None:
732761
"""
733762
self.coroutines_to_execute.pop(operation_id, None)
734763
self.api_successful_responses.pop(operation_id, None)
764+
self.api_failed_responses.pop(operation_id, None)
735765
self.concurrency_level.pop(operation_id, None)
766+
self.operation_timeouts.pop(operation_id, None)
736767
executor = self.executors.pop(operation_id, None)
737768
if executor is not None:
738769
executor.shutdown(wait=True)

0 commit comments

Comments
 (0)