Skip to content

Commit 16ab001

Browse files
[cross-repo from server#435] Conformance finding: namespaces coverage remains below full parity (#156)
1 parent fbc1f4d commit 16ab001

2 files changed

Lines changed: 68 additions & 6 deletions

File tree

src/durable_workflow/worker.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,30 @@ def _server_supports_query_tasks(info: dict[str, Any]) -> bool:
580580
return isinstance(capabilities, dict) and capabilities.get("query_tasks") is True
581581

582582

583+
def _server_long_poll_timeout(info: dict[str, Any]) -> float | None:
584+
worker_protocol = info.get("worker_protocol")
585+
if not isinstance(worker_protocol, dict):
586+
return None
587+
588+
capabilities = worker_protocol.get("server_capabilities")
589+
if not isinstance(capabilities, dict):
590+
return None
591+
592+
timeout = capabilities.get("long_poll_timeout")
593+
if isinstance(timeout, bool):
594+
return None
595+
if isinstance(timeout, (int, float)):
596+
return float(timeout) if timeout > 0 else None
597+
if isinstance(timeout, str):
598+
try:
599+
parsed = float(timeout)
600+
except ValueError:
601+
return None
602+
return parsed if parsed > 0 else None
603+
604+
return None
605+
606+
583607
def _contract_version_matches(value: Any, expected: int) -> bool:
584608
if isinstance(value, int):
585609
return value == expected
@@ -635,6 +659,7 @@ def __init__(
635659
raise ValueError("heartbeat_interval must be positive")
636660

637661
self._poll_timeout = poll_timeout
662+
self._poll_http_timeout = poll_timeout
638663
self.max_concurrent_workflow_tasks = max_concurrent_workflow_tasks
639664
self.max_concurrent_activity_tasks = max_concurrent_activity_tasks
640665
self._stop = asyncio.Event()
@@ -757,6 +782,9 @@ async def _register(self) -> None:
757782

758783
_validate_server_compatibility(info)
759784
self._query_tasks_supported = _server_supports_query_tasks(info)
785+
server_long_poll_timeout = _server_long_poll_timeout(info)
786+
if server_long_poll_timeout is not None:
787+
self._poll_http_timeout = max(self._poll_http_timeout, server_long_poll_timeout + 5.0)
760788
log.debug(
761789
"server compatibility accepted: app_version=%s control_plane=%s worker_protocol=%s",
762790
info.get("version", "unknown"),
@@ -1536,7 +1564,7 @@ async def _poll_workflow_tasks(self) -> None:
15361564
task = await self.client.poll_workflow_task(
15371565
worker_id=self.worker_id,
15381566
task_queue=self.task_queue,
1539-
timeout=self._poll_timeout,
1567+
timeout=self._poll_http_timeout,
15401568
)
15411569
except Exception as e:
15421570
self._wf_semaphore.release()
@@ -1613,7 +1641,7 @@ async def _poll_activity_tasks(self) -> None:
16131641
task = await self.client.poll_activity_task(
16141642
worker_id=self.worker_id,
16151643
task_queue=self.task_queue,
1616-
timeout=self._poll_timeout,
1644+
timeout=self._poll_http_timeout,
16171645
)
16181646
except Exception as e:
16191647
self._act_semaphore.release()
@@ -1650,7 +1678,7 @@ async def _poll_query_tasks(self, *, client: Client | None = None, track_tasks:
16501678
task = await client.poll_query_task(
16511679
worker_id=self.worker_id,
16521680
task_queue=self.task_queue,
1653-
timeout=self._poll_timeout,
1681+
timeout=self._poll_http_timeout,
16541682
)
16551683
except Exception as e:
16561684
self._record_poll_metrics("query", "error", time.perf_counter() - poll_start)
@@ -1739,7 +1767,7 @@ async def _stop_query_task_thread(self) -> None:
17391767

17401768
await asyncio.to_thread(
17411769
thread.join,
1742-
min(self._shutdown_timeout, self._poll_timeout + 1),
1770+
min(self._shutdown_timeout, self._poll_http_timeout + 1),
17431771
)
17441772

17451773
async def run(self) -> None:
@@ -1908,7 +1936,7 @@ async def run_until(
19081936
task = await self.client.poll_workflow_task(
19091937
worker_id=self.worker_id,
19101938
task_queue=self.task_queue,
1911-
timeout=self._poll_timeout,
1939+
timeout=self._poll_http_timeout,
19121940
)
19131941
self._record_poll_metrics(
19141942
"workflow",
@@ -1938,7 +1966,7 @@ async def run_until(
19381966
task = await self.client.poll_activity_task(
19391967
worker_id=self.worker_id,
19401968
task_queue=self.task_queue,
1941-
timeout=self._poll_timeout,
1969+
timeout=self._poll_http_timeout,
19421970
)
19431971
self._record_poll_metrics(
19441972
"activity",

tests/test_worker.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def compatible_cluster_info(**overrides: object) -> dict[str, object]:
257257
"version": PROTOCOL_VERSION,
258258
"server_capabilities": {
259259
"query_tasks": True,
260+
"long_poll_timeout": 30,
260261
},
261262
},
262263
}
@@ -314,6 +315,39 @@ async def test_register(self, mock_client: AsyncMock) -> None:
314315
assert process_metrics["process_id"] > 0
315316
assert "process_started_at" in process_metrics
316317

318+
@pytest.mark.asyncio
319+
async def test_register_keeps_http_timeout_above_server_long_poll(self, mock_client: AsyncMock) -> None:
320+
mock_client.get_cluster_info = AsyncMock(
321+
return_value=compatible_cluster_info(
322+
worker_protocol={
323+
"version": PROTOCOL_VERSION,
324+
"server_capabilities": {
325+
"query_tasks": True,
326+
"long_poll_timeout": 12,
327+
},
328+
}
329+
)
330+
)
331+
worker = Worker(
332+
mock_client,
333+
task_queue="q1",
334+
workflows=[TestWorkflow],
335+
activities=[echo_activity],
336+
worker_id="w-short-poll",
337+
poll_timeout=0.01,
338+
)
339+
340+
async def poll_once(**_: object) -> None:
341+
worker._stop.set()
342+
return None
343+
344+
mock_client.poll_workflow_task.side_effect = poll_once
345+
346+
await worker._register()
347+
await worker._poll_workflow_tasks()
348+
349+
assert mock_client.poll_workflow_task.call_args.kwargs["timeout"] == 17.0
350+
317351
@pytest.mark.asyncio
318352
async def test_register_omits_query_task_capability_when_server_does_not_support_it(
319353
self, mock_client: AsyncMock

0 commit comments

Comments
 (0)