Skip to content

Commit fd56c65

Browse files
awalker4claude
andcommitted
Convert SplitPdfHook to use async hooks
- Implement AsyncBeforeRequestHook, AsyncAfterSuccessHook, and AsyncAfterErrorHook interfaces - Convert _await_elements to async method that directly awaits coroutines - Remove blocking executor.submit() and .result() calls that were blocking the event loop - Remove ThreadPoolExecutor and _run_coroutines_in_separate_thread workaround - before_request_async contains full implementation (CPU-bound work noted in docstring) - Register async hook implementations in init_hooks - Sync hooks still work for backward compatibility with partition() - Async hooks enable true concurrent requests with partition_async() Fixes ENG-792: partition_async now supports concurrent SDK requests without blocking. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 6746539 commit fd56c65

File tree

2 files changed

+249
-27
lines changed

2 files changed

+249
-27
lines changed

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 238 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tempfile
1010
import uuid
1111
from collections.abc import Awaitable
12-
from concurrent import futures
1312
from functools import partial
1413
from pathlib import Path
1514
from typing import Any, Coroutine, Optional, Tuple, Union, cast, Generator, BinaryIO
@@ -38,6 +37,9 @@
3837
AfterErrorHook,
3938
AfterSuccessContext,
4039
AfterSuccessHook,
40+
AsyncAfterErrorHook,
41+
AsyncAfterSuccessHook,
42+
AsyncBeforeRequestHook,
4143
BeforeRequestContext,
4244
BeforeRequestHook,
4345
SDKInitHook,
@@ -57,12 +59,6 @@
5759
HI_RES_STRATEGY = 'hi_res'
5860
MAX_PAGE_LENGTH = 4000
5961

60-
def _run_coroutines_in_separate_thread(
61-
coroutines_task: Coroutine[Any, Any, list[tuple[int, httpx.Response]]],
62-
) -> list[tuple[int, httpx.Response]]:
63-
return asyncio.run(coroutines_task)
64-
65-
6662
async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, httpx.Response]:
6763
response = await coro
6864
return index, response
@@ -143,7 +139,15 @@ def load_elements_from_response(response: httpx.Response) -> list[dict]:
143139
return json.load(file)
144140

145141

146-
class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
142+
class SplitPdfHook(
143+
SDKInitHook,
144+
BeforeRequestHook,
145+
AfterSuccessHook,
146+
AfterErrorHook,
147+
AsyncBeforeRequestHook,
148+
AsyncAfterSuccessHook,
149+
AsyncAfterErrorHook,
150+
):
147151
"""
148152
A hook class that splits a PDF file into multiple pages and sends each page as
149153
a separate request. This hook is designed to be used with an Speakeasy SDK.
@@ -164,7 +168,6 @@ def __init__(self) -> None:
164168
self.concurrency_level: dict[str, int] = {}
165169
self.api_successful_responses: dict[str, list[httpx.Response]] = {}
166170
self.api_failed_responses: dict[str, list[httpx.Response]] = {}
167-
self.executors: dict[str, futures.ThreadPoolExecutor] = {}
168171
self.tempdirs: dict[str, tempfile.TemporaryDirectory] = {}
169172
self.allow_failed: bool = DEFAULT_ALLOW_FAILED
170173
self.cache_tmp_data_feature: bool = DEFAULT_CACHE_TMP_DATA
@@ -316,9 +319,6 @@ def before_request(
316319
max_allowed=MAX_CONCURRENCY_LEVEL,
317320
)
318321

319-
executor = futures.ThreadPoolExecutor(max_workers=1)
320-
self.executors[operation_id] = executor
321-
322322
self.cache_tmp_data_feature = form_utils.get_split_pdf_cache_tmp_data(
323323
form_data,
324324
key=PARTITION_FORM_SPLIT_CACHE_TMP_DATA_KEY,
@@ -603,7 +603,7 @@ def _get_pdf_chunk_files(
603603
raise
604604
yield pdf_chunk_file, offset
605605

606-
def _await_elements(self, operation_id: str) -> Optional[list]:
606+
async def _await_elements(self, operation_id: str) -> Optional[list]:
607607
"""
608608
Waits for the partition requests to complete and returns the flattened
609609
elements.
@@ -620,15 +620,7 @@ def _await_elements(self, operation_id: str) -> Optional[list]:
620620
return None
621621

622622
concurrency_level = self.concurrency_level.get(operation_id, DEFAULT_CONCURRENCY_LEVEL)
623-
coroutines = run_tasks(tasks, allow_failed=self.allow_failed, concurrency_level=concurrency_level)
624-
625-
# sending the coroutines to a separate thread to avoid blocking the current event loop
626-
# this operation should be removed when the SDK is updated to support async hooks
627-
executor = self.executors.get(operation_id)
628-
if executor is None:
629-
raise RuntimeError("Executor not found for operation_id")
630-
task_responses_future = executor.submit(_run_coroutines_in_separate_thread, coroutines)
631-
task_responses = task_responses_future.result()
623+
task_responses = await run_tasks(tasks, allow_failed=self.allow_failed, concurrency_level=concurrency_level)
632624

633625
if task_responses is None:
634626
return None
@@ -723,6 +715,230 @@ def after_error(
723715
"""
724716
return (response, error)
725717

718+
async def before_request_async(
719+
self, hook_ctx: BeforeRequestContext, request: httpx.Request
720+
) -> Union[httpx.Request, Exception]:
721+
"""If `splitPdfPage` is set to `true` in the request, the PDF file is split into
722+
separate pages. Each page is sent as a separate request in parallel. The last
723+
page request is returned by this method. It will return the original request
724+
when: `splitPdfPage` is set to `false`, the file is not a PDF, or the HTTP
725+
has not been initialized.
726+
727+
Note: The preparation work (PDF splitting, building requests) is CPU-bound
728+
and doesn't benefit from async, but this method is async to fit the hook interface.
729+
730+
Args:
731+
hook_ctx (BeforeRequestContext): The hook context containing information about
732+
the operation.
733+
request (httpx.PreparedRequest): The request object.
734+
735+
Returns:
736+
Union[httpx.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
737+
the last page request; otherwise, the original request.
738+
"""
739+
740+
# Actually the general.partition operation overwrites the default client's base url (as
741+
# the platform operations do). Here we need to get the base url from the request object.
742+
if hook_ctx.operation_id == "partition":
743+
self.partition_base_url = get_base_url(request.url)
744+
self.is_partition_request = True
745+
else:
746+
self.is_partition_request = False
747+
return request
748+
749+
if self.client is None:
750+
logger.warning("HTTP client not accessible! Continuing without splitting.")
751+
return request
752+
753+
# This is our key into coroutines_to_execute
754+
# We need to pass it on to after_success so
755+
# we know which results are ours
756+
operation_id = str(uuid.uuid4())
757+
758+
content_type = request.headers.get("Content-Type")
759+
if content_type is None:
760+
return request
761+
762+
form_data = request_utils.get_multipart_stream_fields(request)
763+
if not form_data:
764+
return request
765+
766+
split_pdf_page = form_data.get(PARTITION_FORM_SPLIT_PDF_PAGE_KEY)
767+
if split_pdf_page is None or split_pdf_page == "false":
768+
return request
769+
770+
pdf_file_meta = form_data.get(PARTITION_FORM_FILES_KEY)
771+
if (
772+
pdf_file_meta is None or not all(metadata in pdf_file_meta for metadata in
773+
["filename", "content_type", "file"])
774+
):
775+
return request
776+
pdf_file = pdf_file_meta.get("file")
777+
if pdf_file is None:
778+
return request
779+
780+
pdf = pdf_utils.read_pdf(pdf_file)
781+
if pdf is None:
782+
return request
783+
784+
pdf = pdf_utils.check_pdf(pdf)
785+
786+
starting_page_number = form_utils.get_starting_page_number(
787+
form_data,
788+
key=PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
789+
fallback_value=DEFAULT_STARTING_PAGE_NUMBER,
790+
)
791+
792+
self.allow_failed = form_utils.get_split_pdf_allow_failed_param(
793+
form_data,
794+
key=PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
795+
fallback_value=DEFAULT_ALLOW_FAILED,
796+
)
797+
798+
self.concurrency_level[operation_id] = form_utils.get_split_pdf_concurrency_level_param(
799+
form_data,
800+
key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
801+
fallback_value=DEFAULT_CONCURRENCY_LEVEL,
802+
max_allowed=MAX_CONCURRENCY_LEVEL,
803+
)
804+
805+
self.cache_tmp_data_feature = form_utils.get_split_pdf_cache_tmp_data(
806+
form_data,
807+
key=PARTITION_FORM_SPLIT_CACHE_TMP_DATA_KEY,
808+
fallback_value=DEFAULT_CACHE_TMP_DATA,
809+
)
810+
811+
self.cache_tmp_data_dir = form_utils.get_split_pdf_cache_tmp_data_dir(
812+
form_data,
813+
key=PARTITION_FORM_SPLIT_CACHE_TMP_DATA_DIR_KEY,
814+
fallback_value=DEFAULT_CACHE_TMP_DATA_DIR,
815+
)
816+
817+
page_range_start, page_range_end = form_utils.get_page_range(
818+
form_data,
819+
key=PARTITION_FORM_PAGE_RANGE_KEY.replace("[]", ""),
820+
max_pages=pdf.get_num_pages(),
821+
)
822+
823+
page_count = page_range_end - page_range_start + 1
824+
825+
split_size = get_optimal_split_size(
826+
num_pages=page_count, concurrency_level=self.concurrency_level[operation_id]
827+
)
828+
829+
# If the doc is small enough, and we aren't slicing it with a page range:
830+
# do not split, just continue with the original request
831+
if split_size >= page_count and page_count == len(pdf.pages):
832+
return request
833+
834+
pdf = self._trim_large_pages(pdf, form_data)
835+
836+
pdf.stream.seek(0)
837+
pdf_bytes = pdf.stream.read()
838+
839+
if self.cache_tmp_data_feature:
840+
pdf_chunk_paths = self._get_pdf_chunk_paths(
841+
pdf_bytes,
842+
operation_id=operation_id,
843+
split_size=split_size,
844+
page_start=page_range_start,
845+
page_end=page_range_end
846+
)
847+
# force free PDF object memory
848+
del pdf
849+
pdf_chunks = self._get_pdf_chunk_files(pdf_chunk_paths)
850+
else:
851+
pdf_chunks = self._get_pdf_chunks_in_memory(
852+
pdf_bytes,
853+
split_size=split_size,
854+
page_start=page_range_start,
855+
page_end=page_range_end
856+
)
857+
858+
self.coroutines_to_execute[operation_id] = []
859+
set_index = 1
860+
for pdf_chunk_file, page_index in pdf_chunks:
861+
page_number = page_index + starting_page_number
862+
pdf_chunk_request = request_utils.create_pdf_chunk_request(
863+
form_data=form_data,
864+
pdf_chunk=(pdf_chunk_file, page_number),
865+
filename=pdf_file_meta["filename"],
866+
original_request=request,
867+
)
868+
# using partial as the shared client parameter must be passed in `run_tasks` function
869+
# in `after_success`.
870+
coroutine = partial(
871+
self.call_api_partial,
872+
operation_id=operation_id,
873+
pdf_chunk_request=pdf_chunk_request,
874+
pdf_chunk_file=pdf_chunk_file,
875+
)
876+
self.coroutines_to_execute[operation_id].append(coroutine)
877+
set_index += 1
878+
879+
# Return a dummy request for the SDK to use
880+
# This allows us to skip right to the AfterRequestHook and await all the calls
881+
# Also, pass the operation_id so after_success can await the right results
882+
883+
# Note: We need access to the async_client from the sdk_init hook in order to set
884+
# up a mock request like this.
885+
# For now, just make an extra request against our api, which should return 200.
886+
# dummy_request = httpx.Request("GET", "http://no-op")
887+
return httpx.Request(
888+
"GET",
889+
f"{self.partition_base_url}/general/docs",
890+
headers={"operation_id": operation_id},
891+
)
892+
893+
async def after_success_async(
894+
self, hook_ctx: AfterSuccessContext, response: httpx.Response
895+
) -> Union[httpx.Response, Exception]:
896+
"""Async version of after_success. Awaits all parallel requests and
897+
combines the responses into a single response object.
898+
899+
Args:
900+
hook_ctx (AfterSuccessContext): The context object containing information
901+
about the hook execution.
902+
response (httpx.Response): The response object from the SDK call. This was a dummy
903+
request just to get us to the AfterSuccessHook.
904+
905+
Returns:
906+
Union[httpx.Response, Exception]: If requests were run in parallel, a
907+
combined response object; otherwise, the original response. Can return
908+
exception if it occurred during the execution.
909+
"""
910+
if not self.is_partition_request:
911+
return response
912+
913+
# Grab the correct id out of the dummy request
914+
operation_id = response.request.headers.get("operation_id")
915+
916+
elements = await self._await_elements(operation_id)
917+
918+
# if fails are disallowed, return the first failed response
919+
if not self.allow_failed and self.api_failed_responses.get(operation_id):
920+
failure_response = self.api_failed_responses[operation_id][0]
921+
922+
self._clear_operation(operation_id)
923+
return failure_response
924+
925+
if elements is None:
926+
return response
927+
928+
new_response = request_utils.create_response(elements)
929+
self._clear_operation(operation_id)
930+
931+
return new_response
932+
933+
async def after_error_async(
934+
self,
935+
hook_ctx: AfterErrorContext,
936+
response: Optional[httpx.Response],
937+
error: Optional[Exception],
938+
) -> Union[Tuple[Optional[httpx.Response], Optional[Exception]], Exception]:
939+
"""Async version of after_error. Delegates to sync version since no async work needed."""
940+
return self.after_error(hook_ctx, response, error)
941+
726942
def _clear_operation(self, operation_id: str) -> None:
727943
"""
728944
Clears the operation data associated with the given operation ID.
@@ -733,9 +949,6 @@ def _clear_operation(self, operation_id: str) -> None:
733949
self.coroutines_to_execute.pop(operation_id, None)
734950
self.api_successful_responses.pop(operation_id, None)
735951
self.concurrency_level.pop(operation_id, None)
736-
executor = self.executors.pop(operation_id, None)
737-
if executor is not None:
738-
executor.shutdown(wait=True)
739952
tempdir = self.tempdirs.pop(operation_id, None)
740953
if tempdir:
741954
tempdir.cleanup()

src/unstructured_client/_hooks/registration.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,19 @@ def init_hooks(hooks: Hooks):
3636
# Register Before Request hooks
3737
hooks.register_before_request_hook(split_pdf_hook)
3838

39-
# Register After Error hooks
39+
# Register After Success hooks
4040
hooks.register_after_success_hook(split_pdf_hook)
4141
hooks.register_after_success_hook(logger_hook)
4242

4343
# Register After Error hooks
4444
hooks.register_after_error_hook(split_pdf_hook)
45-
hooks.register_after_error_hook(logger_hook)
45+
hooks.register_after_error_hook(logger_hook)
46+
47+
# Register Async Before Request hooks
48+
hooks.register_async_before_request_hook(split_pdf_hook)
49+
50+
# Register Async After Success hooks
51+
hooks.register_async_after_success_hook(split_pdf_hook)
52+
53+
# Register Async After Error hooks
54+
hooks.register_async_after_error_hook(split_pdf_hook)

0 commit comments

Comments
 (0)