Skip to content

Commit c284ba4

Browse files
[cross-repo from server#265] Conformance: signals/queries still time out on server 0.2.125 (#79)
1 parent 6a3ebff commit c284ba4

3 files changed

Lines changed: 175 additions & 15 deletions

File tree

src/durable_workflow/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,7 @@ def __init__(
11711171
self.control_token = control_token
11721172
self.worker_token = worker_token
11731173
self.namespace = namespace
1174+
self.timeout = timeout
11741175
self.retry_policy = retry_policy or TransportRetryPolicy()
11751176
self.metrics = metrics or NOOP_METRICS
11761177
self.payload_size_warning_config = (

src/durable_workflow/worker.py

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import inspect
2222
import logging
2323
import sys
24+
import threading
2425
import time
2526
import traceback
2627
import uuid
@@ -347,6 +348,8 @@ def __init__(
347348
self._shutdown_timeout = shutdown_timeout
348349
self._in_flight: set[asyncio.Task[Any]] = set()
349350
self._query_tasks_supported = False
351+
self._query_thread_stop = threading.Event()
352+
self._query_thread: threading.Thread | None = None
350353
# In-flight slot accounting feeds the periodic heartbeat so operators
351354
# see free-slot counts without the worker having to re-derive them at
352355
# shutdown. Counters are bumped/decremented around dispatch.
@@ -897,15 +900,15 @@ async def call_interceptor(
897900

898901
return await handler(context)
899902

900-
async def _run_query_task(self, task: dict[str, Any]) -> str:
903+
async def _run_query_task(self, task: dict[str, Any], *, client: Client | None = None) -> str:
901904
context = QueryTaskInterceptorContext(
902905
worker_id=self.worker_id,
903906
task_queue=self.task_queue,
904907
task=task,
905908
)
906909

907910
async def call_core(ctx: QueryTaskInterceptorContext) -> str:
908-
return await self._run_query_task_core(ctx.task)
911+
return await self._run_query_task_core(ctx.task, client=client)
909912

910913
handler = call_core
911914
for interceptor in reversed(self.interceptors):
@@ -923,7 +926,8 @@ async def call_interceptor(
923926

924927
return await handler(context)
925928

926-
async def _run_query_task_core(self, task: dict[str, Any]) -> str:
929+
async def _run_query_task_core(self, task: dict[str, Any], *, client: Client | None = None) -> str:
930+
client = client or self.client
927931
query_task_id: str = task["query_task_id"]
928932
attempt: int = task.get("query_task_attempt", 1)
929933
wf_type: str = task.get("workflow_type", "")
@@ -939,6 +943,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
939943
reason="query_payload_decode_failed",
940944
failure_type=type(e).__name__,
941945
stack_trace=traceback.format_exc(),
946+
client=client,
942947
)
943948
return "failed"
944949

@@ -953,6 +958,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
953958
f"no workflow registered for type {wf_type!r}",
954959
reason="query_workflow_type_not_registered",
955960
failure_type="WorkflowTypeNotRegistered",
961+
client=client,
956962
)
957963
return "failed"
958964

@@ -992,6 +998,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
992998
reason="query_payload_decode_failed",
993999
failure_type=type(e).__name__,
9941000
stack_trace=traceback.format_exc(),
1001+
client=client,
9951002
)
9961003
return "failed"
9971004
except QueryFailed as e:
@@ -1003,6 +1010,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10031010
reason=reason,
10041011
failure_type=type(e).__name__,
10051012
stack_trace=traceback.format_exc(),
1013+
client=client,
10061014
)
10071015
return "failed"
10081016
except Exception as e:
@@ -1013,11 +1021,12 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10131021
reason="query_rejected",
10141022
failure_type=type(e).__name__,
10151023
stack_trace=traceback.format_exc(),
1024+
client=client,
10161025
)
10171026
return "failed"
10181027

10191028
try:
1020-
await self.client.complete_query_task(
1029+
await client.complete_query_task(
10211030
query_task_id=query_task_id,
10221031
lease_owner=self.worker_id,
10231032
query_task_attempt=attempt,
@@ -1044,6 +1053,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10441053
reason=server_reason if server_reason else "query_result_completion_failed",
10451054
failure_type=type(e).__name__,
10461055
stack_trace=traceback.format_exc(),
1056+
client=client,
10471057
)
10481058
return "failed"
10491059
except AvroNotInstalledError as e:
@@ -1057,6 +1067,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10571067
reason="query_result_encode_failed",
10581068
failure_type=type(e).__name__,
10591069
stack_trace=traceback.format_exc(),
1070+
client=client,
10601071
)
10611072
return "failed"
10621073
except (TypeError, ValueError) as e:
@@ -1067,6 +1078,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10671078
reason="query_result_encode_failed",
10681079
failure_type=type(e).__name__,
10691080
stack_trace=traceback.format_exc(),
1081+
client=client,
10701082
)
10711083
return "failed"
10721084
except Exception as e:
@@ -1077,6 +1089,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10771089
reason="query_result_completion_failed",
10781090
failure_type=type(e).__name__,
10791091
stack_trace=traceback.format_exc(),
1092+
client=client,
10801093
)
10811094
return "failed"
10821095

@@ -1107,9 +1120,11 @@ async def _fail_query_task(
11071120
reason: str,
11081121
failure_type: str | None = None,
11091122
stack_trace: str | None = None,
1123+
client: Client | None = None,
11101124
) -> None:
1125+
client = client or self.client
11111126
try:
1112-
await self.client.fail_query_task(
1127+
await client.fail_query_task(
11131128
query_task_id=query_task_id,
11141129
lease_owner=self.worker_id,
11151130
query_task_attempt=attempt,
@@ -1217,11 +1232,12 @@ async def _dispatch_activity_task(self, task: dict[str, Any]) -> None:
12171232
self._record_task_metrics("activity", outcome, time.perf_counter() - task_start)
12181233
self._act_semaphore.release()
12191234

1220-
async def _poll_query_tasks(self) -> None:
1221-
while not self._stop.is_set():
1235+
async def _poll_query_tasks(self, *, client: Client | None = None, track_tasks: bool = True) -> None:
1236+
client = client or self.client
1237+
while not self._stop.is_set() and not self._query_thread_stop.is_set():
12221238
try:
12231239
poll_start = time.perf_counter()
1224-
task = await self.client.poll_query_task(
1240+
task = await client.poll_query_task(
12251241
worker_id=self.worker_id,
12261242
task_queue=self.task_queue,
12271243
timeout=self._poll_timeout,
@@ -1236,18 +1252,86 @@ async def _poll_query_tasks(self) -> None:
12361252
await asyncio.sleep(0)
12371253
continue
12381254
self._record_poll_metrics("query", "task", time.perf_counter() - poll_start)
1239-
self._track(self._dispatch_query_task(task))
1255+
if track_tasks:
1256+
self._track(self._dispatch_query_task(task, client=client))
1257+
else:
1258+
await self._dispatch_query_task(task, client=client)
12401259

1241-
async def _dispatch_query_task(self, task: dict[str, Any]) -> None:
1260+
async def _dispatch_query_task(self, task: dict[str, Any], *, client: Client | None = None) -> None:
12421261
task_start = time.perf_counter()
12431262
outcome = "error"
12441263
try:
1245-
outcome = await self._run_query_task(task)
1264+
outcome = await self._run_query_task(task, client=client)
12461265
except Exception:
12471266
log.exception("unhandled error in query task execution")
12481267
finally:
12491268
self._record_task_metrics("query", outcome, time.perf_counter() - task_start)
12501269

1270+
def _clone_client_for_query_tasks(self) -> Client:
1271+
warning_config = self._payload_size_warning_config()
1272+
1273+
return Client(
1274+
self.client.base_url,
1275+
token=self.client.token,
1276+
control_token=self.client.control_token,
1277+
worker_token=self.client.worker_token,
1278+
namespace=self.client.namespace,
1279+
timeout=getattr(self.client, "timeout", 60.0),
1280+
retry_policy=self.client.retry_policy,
1281+
metrics=self.metrics,
1282+
payload_size_limit_bytes=(
1283+
warning_config.limit_bytes
1284+
if warning_config is not None
1285+
else serializer.DEFAULT_PAYLOAD_SIZE_BYTES
1286+
),
1287+
payload_size_warning_threshold_percent=(
1288+
warning_config.threshold_percent
1289+
if warning_config is not None
1290+
else serializer.DEFAULT_WARNING_THRESHOLD_PERCENT
1291+
),
1292+
payload_size_warnings=warning_config is not None,
1293+
external_storage=self.external_storage,
1294+
external_storage_threshold_bytes=self.external_storage_threshold_bytes,
1295+
external_storage_cache=self.external_storage_cache,
1296+
)
1297+
1298+
def _can_clone_client_for_query_tasks(self) -> bool:
1299+
return type(self.client) is Client
1300+
1301+
def _start_query_task_thread(self) -> None:
1302+
if self._query_thread is not None and self._query_thread.is_alive():
1303+
return
1304+
1305+
self._query_thread_stop.clear()
1306+
self._query_thread = threading.Thread(
1307+
target=self._run_query_task_thread,
1308+
name=f"durable-workflow-query-poller-{self.worker_id}",
1309+
daemon=True,
1310+
)
1311+
self._query_thread.start()
1312+
1313+
def _run_query_task_thread(self) -> None:
1314+
try:
1315+
asyncio.run(self._query_task_thread_main())
1316+
except Exception:
1317+
log.exception("query task poller thread stopped unexpectedly")
1318+
1319+
async def _query_task_thread_main(self) -> None:
1320+
async with self._clone_client_for_query_tasks() as client:
1321+
await self._poll_query_tasks(client=client, track_tasks=False)
1322+
1323+
async def _stop_query_task_thread(self) -> None:
1324+
self._query_thread_stop.set()
1325+
thread = self._query_thread
1326+
1327+
if thread is None or not thread.is_alive():
1328+
return
1329+
1330+
await asyncio.to_thread(
1331+
thread.join,
1332+
min(self._shutdown_timeout, self._poll_timeout + 1),
1333+
)
1334+
12511335
async def run(self) -> None:
12521336
"""Register the worker and poll until `stop()` is called or the task is cancelled."""
12531337
await self._register()
@@ -1256,9 +1340,15 @@ async def run(self) -> None:
12561340
hb_loop = asyncio.create_task(self._heartbeat_loop())
12571341
loops = [wf_loop, act_loop, hb_loop]
12581342
if self._query_tasks_supported:
1259-
loops.append(asyncio.create_task(self._poll_query_tasks()))
1260-
with contextlib.suppress(asyncio.CancelledError):
1261-
await asyncio.gather(*loops)
1343+
if self._can_clone_client_for_query_tasks():
1344+
self._start_query_task_thread()
1345+
else:
1346+
loops.append(asyncio.create_task(self._poll_query_tasks()))
1347+
try:
1348+
with contextlib.suppress(asyncio.CancelledError):
1349+
await asyncio.gather(*loops)
1350+
finally:
1351+
await self._stop_query_task_thread()
12621352

12631353
async def _heartbeat_loop(self) -> None:
12641354
"""Periodically refresh the server-side worker registration.
@@ -1386,7 +1476,10 @@ async def run_until(
13861476
await self._register()
13871477
background_tasks.append(asyncio.create_task(self._heartbeat_loop()))
13881478
if self._query_tasks_supported:
1389-
background_tasks.append(asyncio.create_task(self._poll_query_tasks()))
1479+
if self._can_clone_client_for_query_tasks():
1480+
self._start_query_task_thread()
1481+
else:
1482+
background_tasks.append(asyncio.create_task(self._poll_query_tasks()))
13901483

13911484
while True:
13921485
desc = await self.client.describe_workflow(workflow_id)
@@ -1465,6 +1558,7 @@ async def run_until(
14651558
async def stop(self) -> None:
14661559
"""Stop polling and drain in-flight tasks up to the configured shutdown timeout."""
14671560
self._stop.set()
1561+
await self._stop_query_task_thread()
14681562
if self._in_flight:
14691563
log.info("draining %d in-flight task(s)…", len(self._in_flight))
14701564
done, pending = await asyncio.wait(

tests/test_worker.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import contextlib
55
import logging
66
import sys
7+
import threading
78
from unittest.mock import AsyncMock
89

910
import pytest
@@ -1476,6 +1477,70 @@ async def test_run_skips_query_loop_without_query_task_capability(self, mock_cli
14761477
assert mock_client.poll_activity_task.call_count >= 1
14771478
mock_client.poll_query_task.assert_not_called()
14781479

1480+
@pytest.mark.asyncio
1481+
async def test_query_thread_processes_tasks_while_event_loop_is_blocked(
1482+
self, mock_client: AsyncMock
1483+
) -> None:
1484+
completed = threading.Event()
1485+
query_task = {
1486+
"query_task_id": "qt-thread",
1487+
"query_task_attempt": 1,
1488+
"workflow_type": "query-wf",
1489+
"workflow_id": "wf-1",
1490+
"run_id": "run-1",
1491+
"query_name": "status",
1492+
"payload_codec": "json",
1493+
"workflow_arguments": serializer.envelope([], codec="json"),
1494+
"query_arguments": serializer.envelope([], codec="json"),
1495+
"history_events": [],
1496+
}
1497+
1498+
class QueryThreadClient:
1499+
def __init__(self) -> None:
1500+
self.polled = False
1501+
self.completed_kwargs: dict[str, object] | None = None
1502+
1503+
async def __aenter__(self) -> "QueryThreadClient":
1504+
return self
1505+
1506+
async def __aexit__(self, *_: object) -> None:
1507+
return None
1508+
1509+
async def poll_query_task(self, **_: object) -> dict[str, object] | None:
1510+
if not self.polled:
1511+
self.polled = True
1512+
return query_task
1513+
await asyncio.sleep(0.01)
1514+
return None
1515+
1516+
async def complete_query_task(self, **kwargs: object) -> dict[str, str]:
1517+
self.completed_kwargs = kwargs
1518+
completed.set()
1519+
return {"outcome": "completed"}
1520+
1521+
async def fail_query_task(self, **_: object) -> dict[str, str]:
1522+
raise AssertionError("query task should complete")
1523+
1524+
query_client = QueryThreadClient()
1525+
worker = Worker(
1526+
mock_client,
1527+
task_queue="q1",
1528+
workflows=[QueryWorkflow],
1529+
activities=[],
1530+
poll_timeout=0.01,
1531+
shutdown_timeout=0.2,
1532+
)
1533+
worker._clone_client_for_query_tasks = lambda: query_client # type: ignore[method-assign]
1534+
1535+
worker._start_query_task_thread()
1536+
1537+
assert completed.wait(timeout=1.0)
1538+
await worker.stop()
1539+
1540+
assert query_client.completed_kwargs is not None
1541+
assert query_client.completed_kwargs["query_task_id"] == "qt-thread"
1542+
assert query_client.completed_kwargs["result"] == {"status": "ready"}
1543+
14791544

14801545
class TestWorkerHeartbeats:
14811546
@pytest.mark.asyncio

0 commit comments

Comments
 (0)