Skip to content

Commit 4fc51a8

Browse files
committed
feat(docling_serve): support async conversion jobs
1 parent 92810ad commit 4fc51a8

3 files changed

Lines changed: 402 additions & 7 deletions

File tree

integrations/docling_serve/src/haystack_integrations/components/converters/docling_serve/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from haystack_integrations.components.converters.docling_serve.converter import DoclingServeConverter, ExportType
5+
from haystack_integrations.components.converters.docling_serve.converter import (
6+
ConversionMode,
7+
DoclingServeConversionError,
8+
DoclingServeConverter,
9+
ExportType,
10+
)
611

7-
__all__ = ["DoclingServeConverter", "ExportType"]
12+
__all__ = ["ConversionMode", "DoclingServeConversionError", "DoclingServeConverter", "ExportType"]

integrations/docling_serve/src/haystack_integrations/components/converters/docling_serve/converter.py

Lines changed: 264 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
import json
67
import mimetypes
8+
import time
79
from enum import Enum
810
from pathlib import Path
911
from typing import Any
@@ -19,6 +21,10 @@
1921

2022
_FILE_CONVERT_PATH = "/v1/convert/file"
2123
_SOURCE_CONVERT_PATH = "/v1/convert/source"
24+
_STATUS_POLL_PATH = "/v1/status/poll"
25+
_RESULT_PATH = "/v1/result"
26+
_TERMINAL_TASK_STATUSES = {"success", "failure"}
27+
_FAILED_CONVERSION_STATUSES = {"failure", "skipped"}
2228

2329

2430
class ExportType(str, Enum):
@@ -35,6 +41,22 @@ class ExportType(str, Enum):
3541
JSON = "json"
3642

3743

44+
class ConversionMode(str, Enum):
45+
"""
46+
Execution mode for DoclingServe conversions.
47+
48+
- `SYNC`: Uses DoclingServe's synchronous conversion endpoints.
49+
- `ASYNC`: Uses DoclingServe's async job endpoints and polls for completion.
50+
"""
51+
52+
SYNC = "sync"
53+
ASYNC = "async"
54+
55+
56+
class DoclingServeConversionError(Exception):
57+
"""Raised when DoclingServe reports an async task or conversion failure."""
58+
59+
3860
def _is_url(source: str) -> bool:
3961
parsed = urlparse(source)
4062
return parsed.scheme in ("http", "https")
@@ -88,6 +110,9 @@ def __init__(
88110
convert_options: dict[str, Any] | None = None,
89111
timeout: float = 120.0,
90112
api_key: Secret | None = Secret.from_env_var("DOCLING_SERVE_API_KEY", strict=False),
113+
mode: ConversionMode | str = ConversionMode.SYNC,
114+
poll_interval: float = 2.0,
115+
job_timeout: float = 600.0,
91116
) -> None:
92117
"""
93118
Initializes the DoclingServeConverter.
@@ -108,12 +133,29 @@ def __init__(
108133
API key for authenticating with a secured DoclingServe instance. Reads from the
109134
`DOCLING_SERVE_API_KEY` environment variable by default. Set to `None` to disable
110135
authentication.
136+
:param mode:
137+
Conversion mode. `sync` uses DoclingServe's synchronous endpoints. `async` submits
138+
conversion jobs to DoclingServe's async endpoints and polls until completion.
139+
:param poll_interval:
140+
Maximum server-side long-poll wait in seconds when `mode="async"`.
141+
:param job_timeout:
142+
Maximum time in seconds to wait for each async conversion job.
111143
"""
144+
if poll_interval <= 0:
145+
msg = "poll_interval must be greater than 0."
146+
raise ValueError(msg)
147+
if job_timeout <= 0:
148+
msg = "job_timeout must be greater than 0."
149+
raise ValueError(msg)
150+
112151
self.base_url = base_url.rstrip("/")
113152
self.export_type = ExportType(export_type)
114153
self.convert_options = dict(convert_options) if convert_options else {}
115154
self.timeout = timeout
116155
self.api_key = api_key
156+
self.mode = ConversionMode(mode)
157+
self.poll_interval = poll_interval
158+
self.job_timeout = job_timeout
117159

118160
def to_dict(self) -> dict[str, Any]:
119161
"""
@@ -129,6 +171,9 @@ def to_dict(self) -> dict[str, Any]:
129171
convert_options=self.convert_options,
130172
timeout=self.timeout,
131173
api_key=self.api_key.to_dict() if self.api_key else None,
174+
mode=self.mode.value,
175+
poll_interval=self.poll_interval,
176+
job_timeout=self.job_timeout,
132177
)
133178

134179
@classmethod
@@ -166,6 +211,41 @@ def _extract_content(self, data: dict[str, Any]) -> str | None:
166211
return json.dumps(content) if content is not None else None
167212
return None
168213

214+
def _raise_for_failed_conversion(self, data: dict[str, Any]) -> None:
215+
status = data.get("status")
216+
if status not in _FAILED_CONVERSION_STATUSES:
217+
return
218+
219+
errors = data.get("errors") or []
220+
details = "; ".join(
221+
str(error.get("error_message", error)) if isinstance(error, dict) else str(error) for error in errors
222+
)
223+
msg = f"DoclingServe conversion finished with status '{status}'"
224+
if details:
225+
msg = f"{msg}: {details}"
226+
raise DoclingServeConversionError(msg)
227+
228+
def _extract_task_id(self, data: dict[str, Any]) -> str:
229+
task_id = data.get("task_id")
230+
if not isinstance(task_id, str) or not task_id:
231+
msg = "DoclingServe async task response did not include a task_id."
232+
raise DoclingServeConversionError(msg)
233+
return task_id
234+
235+
def _raise_for_failed_task(self, data: dict[str, Any]) -> None:
236+
if data.get("task_status") != "failure":
237+
return
238+
239+
error_message = data.get("error_message") or "DoclingServe async task failed."
240+
raise DoclingServeConversionError(str(error_message))
241+
242+
def _async_source_payload(self, url: str) -> dict[str, Any]:
243+
return {
244+
"options": {**self.convert_options, "to_formats": [self._to_format()]},
245+
"sources": [{"kind": "http", "url": url}],
246+
"target": {"kind": "inbody"},
247+
}
248+
169249
def _post_file(self, client: httpx.Client, source: str | Path | ByteStream) -> dict[str, Any]:
170250
filename = _resolve_filename(source)
171251
file_bytes = source.data if isinstance(source, ByteStream) else Path(source).read_bytes()
@@ -184,6 +264,24 @@ def _post_file(self, client: httpx.Client, source: str | Path | ByteStream) -> d
184264
response.raise_for_status()
185265
return response.json()
186266

267+
def _submit_file_job(self, client: httpx.Client, source: str | Path | ByteStream) -> str:
268+
filename = _resolve_filename(source)
269+
file_bytes = source.data if isinstance(source, ByteStream) else Path(source).read_bytes()
270+
mime_type = (
271+
(source.mime_type or _guess_mime_type(filename))
272+
if isinstance(source, ByteStream)
273+
else _guess_mime_type(filename)
274+
)
275+
options = {**self.convert_options, "to_formats": self._to_format(), "target_type": "inbody"}
276+
response = client.post(
277+
f"{self.base_url}{_FILE_CONVERT_PATH}/async",
278+
files={"files": (filename, file_bytes, mime_type)},
279+
data=options,
280+
headers=self._headers(),
281+
)
282+
response.raise_for_status()
283+
return self._extract_task_id(response.json())
284+
187285
def _post_url(self, client: httpx.Client, url: str) -> dict[str, Any]:
188286
payload: dict[str, Any] = {
189287
"options": {**self.convert_options, "to_formats": [self._to_format()]},
@@ -197,6 +295,64 @@ def _post_url(self, client: httpx.Client, url: str) -> dict[str, Any]:
197295
response.raise_for_status()
198296
return response.json()
199297

298+
def _submit_url_job(self, client: httpx.Client, url: str) -> str:
299+
response = client.post(
300+
f"{self.base_url}{_SOURCE_CONVERT_PATH}/async",
301+
json=self._async_source_payload(url),
302+
headers=self._headers(),
303+
)
304+
response.raise_for_status()
305+
return self._extract_task_id(response.json())
306+
307+
def _poll_job_status(self, client: httpx.Client, task_id: str, wait: float) -> dict[str, Any]:
308+
response = client.get(
309+
f"{self.base_url}{_STATUS_POLL_PATH}/{task_id}",
310+
params={"wait": wait},
311+
headers=self._headers(),
312+
)
313+
response.raise_for_status()
314+
return response.json()
315+
316+
def _wait_for_job(self, client: httpx.Client, task_id: str) -> None:
317+
deadline = time.monotonic() + self.job_timeout
318+
while True:
319+
remaining = deadline - time.monotonic()
320+
if remaining <= 0:
321+
msg = f"Timed out waiting for DoclingServe task {task_id} after {self.job_timeout:.2f}s."
322+
raise DoclingServeConversionError(msg)
323+
324+
wait = min(self.poll_interval, remaining)
325+
poll_started = time.monotonic()
326+
status = self._poll_job_status(client, task_id, wait)
327+
task_status = status.get("task_status")
328+
if task_status in _TERMINAL_TASK_STATUSES:
329+
self._raise_for_failed_task(status)
330+
return
331+
332+
sleep_for = min(self.poll_interval, remaining) - (time.monotonic() - poll_started)
333+
if sleep_for > 0:
334+
time.sleep(sleep_for)
335+
336+
def _fetch_job_result(self, client: httpx.Client, task_id: str) -> dict[str, Any]:
337+
response = client.get(
338+
f"{self.base_url}{_RESULT_PATH}/{task_id}",
339+
headers=self._headers(),
340+
)
341+
response.raise_for_status()
342+
data = response.json()
343+
self._raise_for_failed_conversion(data)
344+
return data
345+
346+
def _post_file_job(self, client: httpx.Client, source: str | Path | ByteStream) -> dict[str, Any]:
347+
task_id = self._submit_file_job(client, source)
348+
self._wait_for_job(client, task_id)
349+
return self._fetch_job_result(client, task_id)
350+
351+
def _post_url_job(self, client: httpx.Client, url: str) -> dict[str, Any]:
352+
task_id = self._submit_url_job(client, url)
353+
self._wait_for_job(client, task_id)
354+
return self._fetch_job_result(client, task_id)
355+
200356
async def _post_file_async(self, client: httpx.AsyncClient, source: str | Path | ByteStream) -> dict[str, Any]:
201357
filename = _resolve_filename(source)
202358
file_bytes = source.data if isinstance(source, ByteStream) else Path(source).read_bytes()
@@ -215,6 +371,24 @@ async def _post_file_async(self, client: httpx.AsyncClient, source: str | Path |
215371
response.raise_for_status()
216372
return response.json()
217373

374+
async def _submit_file_job_async(self, client: httpx.AsyncClient, source: str | Path | ByteStream) -> str:
375+
filename = _resolve_filename(source)
376+
file_bytes = source.data if isinstance(source, ByteStream) else Path(source).read_bytes()
377+
mime_type = (
378+
(source.mime_type or _guess_mime_type(filename))
379+
if isinstance(source, ByteStream)
380+
else _guess_mime_type(filename)
381+
)
382+
options = {**self.convert_options, "to_formats": self._to_format(), "target_type": "inbody"}
383+
response = await client.post(
384+
f"{self.base_url}{_FILE_CONVERT_PATH}/async",
385+
files={"files": (filename, file_bytes, mime_type)},
386+
data=options,
387+
headers=self._headers(),
388+
)
389+
response.raise_for_status()
390+
return self._extract_task_id(response.json())
391+
218392
async def _post_url_async(self, client: httpx.AsyncClient, url: str) -> dict[str, Any]:
219393
payload: dict[str, Any] = {
220394
"options": {**self.convert_options, "to_formats": [self._to_format()]},
@@ -228,6 +402,64 @@ async def _post_url_async(self, client: httpx.AsyncClient, url: str) -> dict[str
228402
response.raise_for_status()
229403
return response.json()
230404

405+
async def _submit_url_job_async(self, client: httpx.AsyncClient, url: str) -> str:
406+
response = await client.post(
407+
f"{self.base_url}{_SOURCE_CONVERT_PATH}/async",
408+
json=self._async_source_payload(url),
409+
headers=self._headers(),
410+
)
411+
response.raise_for_status()
412+
return self._extract_task_id(response.json())
413+
414+
async def _poll_job_status_async(self, client: httpx.AsyncClient, task_id: str, wait: float) -> dict[str, Any]:
415+
response = await client.get(
416+
f"{self.base_url}{_STATUS_POLL_PATH}/{task_id}",
417+
params={"wait": wait},
418+
headers=self._headers(),
419+
)
420+
response.raise_for_status()
421+
return response.json()
422+
423+
async def _wait_for_job_async(self, client: httpx.AsyncClient, task_id: str) -> None:
424+
deadline = time.monotonic() + self.job_timeout
425+
while True:
426+
remaining = deadline - time.monotonic()
427+
if remaining <= 0:
428+
msg = f"Timed out waiting for DoclingServe task {task_id} after {self.job_timeout:.2f}s."
429+
raise DoclingServeConversionError(msg)
430+
431+
wait = min(self.poll_interval, remaining)
432+
poll_started = time.monotonic()
433+
status = await self._poll_job_status_async(client, task_id, wait)
434+
task_status = status.get("task_status")
435+
if task_status in _TERMINAL_TASK_STATUSES:
436+
self._raise_for_failed_task(status)
437+
return
438+
439+
sleep_for = min(self.poll_interval, remaining) - (time.monotonic() - poll_started)
440+
if sleep_for > 0:
441+
await asyncio.sleep(sleep_for)
442+
443+
async def _fetch_job_result_async(self, client: httpx.AsyncClient, task_id: str) -> dict[str, Any]:
444+
response = await client.get(
445+
f"{self.base_url}{_RESULT_PATH}/{task_id}",
446+
headers=self._headers(),
447+
)
448+
response.raise_for_status()
449+
data = response.json()
450+
self._raise_for_failed_conversion(data)
451+
return data
452+
453+
async def _post_file_job_async(self, client: httpx.AsyncClient, source: str | Path | ByteStream) -> dict[str, Any]:
454+
task_id = await self._submit_file_job_async(client, source)
455+
await self._wait_for_job_async(client, task_id)
456+
return await self._fetch_job_result_async(client, task_id)
457+
458+
async def _post_url_job_async(self, client: httpx.AsyncClient, url: str) -> dict[str, Any]:
459+
task_id = await self._submit_url_job_async(client, url)
460+
await self._wait_for_job_async(client, task_id)
461+
return await self._fetch_job_result_async(client, task_id)
462+
231463
@component.output_types(documents=list[Document])
232464
def run(
233465
self,
@@ -256,9 +488,17 @@ def run(
256488
merged_meta = {**bytestream_meta, **source_meta}
257489
try:
258490
if isinstance(source, str) and _is_url(source):
259-
result = self._post_url(client, source)
491+
result = (
492+
self._post_url_job(client, source)
493+
if self.mode == ConversionMode.ASYNC
494+
else self._post_url(client, source)
495+
)
260496
else:
261-
result = self._post_file(client, source)
497+
result = (
498+
self._post_file_job(client, source)
499+
if self.mode == ConversionMode.ASYNC
500+
else self._post_file(client, source)
501+
)
262502
content = self._extract_content(result)
263503
if content is not None:
264504
documents.append(Document(content=content, meta=merged_meta))
@@ -277,6 +517,12 @@ def run(
277517
source=source,
278518
error=e,
279519
)
520+
except DoclingServeConversionError as e:
521+
logger.warning(
522+
"DoclingServe conversion failed for {source}. Skipping it. Error: {error}",
523+
source=source,
524+
error=e,
525+
)
280526

281527
return {"documents": documents}
282528

@@ -310,9 +556,17 @@ async def run_async(
310556
merged_meta = {**bytestream_meta, **source_meta}
311557
try:
312558
if isinstance(source, str) and _is_url(source):
313-
result = await self._post_url_async(client, source)
559+
result = (
560+
await self._post_url_job_async(client, source)
561+
if self.mode == ConversionMode.ASYNC
562+
else await self._post_url_async(client, source)
563+
)
314564
else:
315-
result = await self._post_file_async(client, source)
565+
result = (
566+
await self._post_file_job_async(client, source)
567+
if self.mode == ConversionMode.ASYNC
568+
else await self._post_file_async(client, source)
569+
)
316570
content = self._extract_content(result)
317571
if content is not None:
318572
documents.append(Document(content=content, meta=merged_meta))
@@ -331,5 +585,11 @@ async def run_async(
331585
source=source,
332586
error=e,
333587
)
588+
except DoclingServeConversionError as e:
589+
logger.warning(
590+
"DoclingServe conversion failed for {source}. Skipping it. Error: {error}",
591+
source=source,
592+
error=e,
593+
)
334594

335595
return {"documents": documents}

0 commit comments

Comments
 (0)