Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import weakref
from collections import deque
from collections.abc import Callable, Generator
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager, suppress
from datetime import datetime, timezone
from http import HTTPStatus
Expand Down Expand Up @@ -541,6 +542,16 @@ class WatchedSubprocess:
start_time: float = attrs.field(factory=time.monotonic)
"""The start time of the child process."""

_request_thread_pool: ThreadPoolExecutor = attrs.field(
factory=lambda: ThreadPoolExecutor(max_workers=1, thread_name_prefix="supervisor-request"),
init=False,
repr=False,
)
"""Thread pool for offloading long-running API requests (e.g. large XCom uploads)."""

_pending_requests: deque[tuple[Future, int]] = attrs.field(factory=deque, init=False, repr=False)
"""Futures from offloaded requests, paired with their request IDs."""

@classmethod
def start(
cls,
Expand Down Expand Up @@ -791,6 +802,54 @@ def _cleanup_open_sockets(self):

self.selector.close()
self.stdin.close()
self._request_thread_pool.shutdown(wait=False)

def _drain_pending_requests(self):
"""Send responses for any offloaded requests that have completed."""
remaining: deque[tuple[Future, int]] = deque()
while self._pending_requests:
future, req_id = self._pending_requests.popleft()
if not future.done():
remaining.append((future, req_id))
continue
exc = future.exception()
if exc is not None:
if isinstance(exc, ServerResponseError) and exc.response is not None:
try:
error_details: dict | None = exc.response.json()
except Exception:
error_details = None
log.error(
"API server error",
status_code=exc.response.status_code,
detail=error_details,
message=str(exc),
)
self.send_msg(
msg=None,
error=ErrorResponse(
error=ErrorType.API_SERVER_ERROR,
detail={
"status_code": exc.response.status_code,
"message": str(exc),
"detail": error_details,
},
),
request_id=req_id,
)
else:
log.error("Offloaded request failed", exc_info=exc)
self.send_msg(
msg=None,
error=ErrorResponse(
error=ErrorType.API_SERVER_ERROR,
detail={"status_code": None, "message": str(exc), "detail": None},
),
request_id=req_id,
)
else:
self.send_msg(msg=None, request_id=req_id)
Comment thread
skymensch marked this conversation as resolved.
self._pending_requests = remaining

def kill(
self,
Expand Down Expand Up @@ -910,6 +969,8 @@ def _service_subprocess(
on_close(sock)
sock.close()

self._drain_pending_requests()

# Check if the subprocess has exited
return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal)

Expand Down Expand Up @@ -1459,7 +1520,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(
# Offload XCom upload to a thread so that large payloads do not block the
# supervisor event loop and prevent heartbeats from being sent.
# See: https://github.com/apache/airflow/issues/64628
future = self._request_thread_pool.submit(
self.client.xcoms.set,
msg.dag_id,
msg.run_id,
msg.task_id,
Expand All @@ -1469,6 +1534,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
dag_result=msg.dag_result,
mapped_length=msg.mapped_length,
)
self._pending_requests.append((future, req_id))
return
elif isinstance(msg, DeleteXCom):
self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
elif isinstance(msg, PutVariable):
Expand Down Expand Up @@ -1679,6 +1746,16 @@ def send(self, msg: BaseModel):
with set_supervisor_comms(None):
self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type]

# Some requests (e.g. SetXCom) are offloaded to the supervisor's thread pool to
# avoid blocking its event loop. In the in-process path there is no event loop
# calling _drain_pending_requests(), so we wait for any in-flight futures here
# and drain them so that the response is available before we try to pop it.
if self.supervisor._pending_requests:
from concurrent.futures import wait as futures_wait

futures_wait([f for f, _ in self.supervisor._pending_requests])
self.supervisor._drain_pending_requests()

return self._get_response()


Expand Down
10 changes: 10 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,6 +2679,16 @@ def test_handle_requests(
req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump())
generator.send(req_frame)

# SetXCom is offloaded to a thread to avoid blocking the event loop.
# We need to wait for the future and drain it before reading the response.
# Use concurrent.futures.wait() rather than shutdown() so the executor
# remains usable for subsequent SetXCom calls in the same test.
if isinstance(message, SetXCom):
from concurrent.futures import wait as futures_wait

futures_wait([f for f, _ in watched_subprocess._pending_requests])
watched_subprocess._drain_pending_requests()
Comment thread
skymensch marked this conversation as resolved.

if mask_secret_args is not None:
mock_mask_secret.assert_called_with(*mask_secret_args)

Expand Down
Loading