Skip to content

Commit 1acf8e1

Browse files
[cross-repo from server#137] Worker Heartbeats / Status surface: every SDK emits periodic heartbeat; CLI+UI list workers per task queue (Temporal-parity)
1 parent 3b2ac01 commit 1acf8e1

3 files changed

Lines changed: 251 additions & 2 deletions

File tree

src/durable_workflow/client.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,12 +609,17 @@ class WorkerDescription:
609609
last_heartbeat_at: str | None = None
610610
registered_at: str | None = None
611611
updated_at: str | None = None
612+
task_slots: dict[str, int | None] | None = None
613+
process_metrics: dict[str, Any] | None = None
614+
heartbeat_interval_seconds: int | None = None
612615
raw: dict[str, Any] | None = None
613616

614617
@classmethod
615618
def from_dict(cls, data: dict[str, Any], *, worker_id: str | None = None) -> WorkerDescription:
616619
workflow_types = data.get("supported_workflow_types")
617620
activity_types = data.get("supported_activity_types")
621+
task_slots = data.get("task_slots")
622+
process_metrics = data.get("process_metrics")
618623

619624
return cls(
620625
worker_id=data.get("worker_id", worker_id or ""),
@@ -631,6 +636,13 @@ def from_dict(cls, data: dict[str, Any], *, worker_id: str | None = None) -> Wor
631636
last_heartbeat_at=data.get("last_heartbeat_at"),
632637
registered_at=data.get("registered_at"),
633638
updated_at=data.get("updated_at"),
639+
task_slots=task_slots if isinstance(task_slots, dict) else None,
640+
process_metrics=process_metrics if isinstance(process_metrics, dict) else None,
641+
heartbeat_interval_seconds=(
642+
int(data["heartbeat_interval_seconds"])
643+
if isinstance(data.get("heartbeat_interval_seconds"), int)
644+
else None
645+
),
634646
raw=data,
635647
)
636648

@@ -641,6 +653,7 @@ class WorkerList:
641653

642654
namespace: str | None
643655
workers: list[WorkerDescription]
656+
stale_after_seconds: int | None = None
644657

645658

646659
@dataclass
@@ -1818,6 +1831,11 @@ async def list_workers(
18181831
for item in items
18191832
if isinstance(item, dict)
18201833
],
1834+
stale_after_seconds=(
1835+
int(data["stale_after_seconds"])
1836+
if isinstance(data.get("stale_after_seconds"), int)
1837+
else None
1838+
),
18211839
)
18221840

18231841
async def describe_worker(self, worker_id: str) -> WorkerDescription:
@@ -2613,6 +2631,50 @@ async def register_worker(
26132631
body["max_concurrent_activity_tasks"] = max_concurrent_activity_tasks
26142632
return await self._request("POST", "/worker/register", worker=True, json=body)
26152633

2634+
async def heartbeat_worker(
2635+
self,
2636+
*,
2637+
worker_id: str,
2638+
task_slots: dict[str, int] | None = None,
2639+
process_metrics: dict[str, Any] | None = None,
2640+
heartbeat_interval_seconds: int | None = None,
2641+
) -> Any:
2642+
"""Send a worker-fleet heartbeat to refresh liveness and report state.
2643+
2644+
Workers should call this on a steady cadence (default 60s, advertised
2645+
by the server in the register/heartbeat acknowledgement) so operators
2646+
can answer "what workers are polling task queue X right now, what's
2647+
their slot capacity, when did each last check in" via the worker
2648+
management API, the CLI worker listing, and the operator Worker
2649+
Status view.
2650+
2651+
``task_slots`` is an optional dict with any subset of
2652+
``workflow_available``, ``activity_available``, ``session_available``
2653+
— the count of currently free slots for each family. The server
2654+
clamps each value into ``[0, max_concurrent_*]``.
2655+
2656+
``process_metrics`` is an optional dict with any subset of
2657+
``cpu_percent``, ``memory_bytes``, ``process_uptime_seconds``,
2658+
``process_id``, and ``host`` — the SDK reports only what it has
2659+
cheap access to, and the server records exactly what was reported.
2660+
2661+
Returns the server acknowledgement, which includes the advertised
2662+
``heartbeat_interval_seconds`` and ``stale_after_seconds`` so the
2663+
worker can adapt its cadence on the fly.
2664+
2665+
Most applications create a :class:`~durable_workflow.Worker`, which
2666+
runs this on a background asyncio task — call this directly only when
2667+
driving the worker protocol by hand (smoke tests, custom runtimes).
2668+
"""
2669+
body: dict[str, Any] = {"worker_id": worker_id}
2670+
if task_slots:
2671+
body["task_slots"] = task_slots
2672+
if process_metrics:
2673+
body["process_metrics"] = process_metrics
2674+
if heartbeat_interval_seconds is not None:
2675+
body["heartbeat_interval_seconds"] = heartbeat_interval_seconds
2676+
return await self._request("POST", "/worker/heartbeat", worker=True, json=body)
2677+
26162678
async def poll_workflow_task(
26172679
self, *, worker_id: str, task_queue: str, timeout: float = 35.0
26182680
) -> Any:

src/durable_workflow/worker.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import hashlib
2121
import inspect
2222
import logging
23+
import sys
2324
import time
2425
import traceback
2526
import uuid
@@ -290,6 +291,7 @@ def __init__(
290291
max_concurrent_workflow_tasks: int = 10,
291292
max_concurrent_activity_tasks: int = 10,
292293
shutdown_timeout: float = 30.0,
294+
heartbeat_interval: float = 60.0,
293295
metrics: MetricsRecorder | None = None,
294296
interceptors: Iterable[WorkerInterceptor] = (),
295297
) -> None:
@@ -313,6 +315,8 @@ def __init__(
313315
raise ValueError("max_concurrent_workflow_tasks must be at least 1")
314316
if max_concurrent_activity_tasks < 1:
315317
raise ValueError("max_concurrent_activity_tasks must be at least 1")
318+
if heartbeat_interval <= 0:
319+
raise ValueError("heartbeat_interval must be positive")
316320

317321
self._poll_timeout = poll_timeout
318322
self.max_concurrent_workflow_tasks = max_concurrent_workflow_tasks
@@ -323,6 +327,13 @@ def __init__(
323327
self._shutdown_timeout = shutdown_timeout
324328
self._in_flight: set[asyncio.Task[Any]] = set()
325329
self._query_tasks_supported = False
330+
# In-flight slot accounting feeds the periodic heartbeat so operators
331+
# see free-slot counts without the worker having to re-derive them at
332+
# shutdown. Counters are bumped/decremented around dispatch.
333+
self._workflow_inflight = 0
334+
self._activity_inflight = 0
335+
self._heartbeat_interval = float(heartbeat_interval)
336+
self._process_started_at = time.time()
326337
configured_metrics = metrics if metrics is not None else getattr(client, "metrics", NOOP_METRICS)
327338
self.metrics: MetricsRecorder = configured_metrics or NOOP_METRICS
328339
self.interceptors = tuple(interceptors)
@@ -388,7 +399,7 @@ async def _register(self) -> None:
388399
_manifest_version(info.get("worker_protocol")),
389400
)
390401

391-
await self.client.register_worker(
402+
ack = await self.client.register_worker(
392403
worker_id=self.worker_id,
393404
task_queue=self.task_queue,
394405
supported_workflow_types=list(self.workflows),
@@ -398,6 +409,14 @@ async def _register(self) -> None:
398409
max_concurrent_activity_tasks=self.max_concurrent_activity_tasks,
399410
build_id=self.build_id,
400411
)
412+
# Adapt to the server-advertised cadence when present so a cluster
413+
# can pin the worker fleet's heartbeat beat without each worker
414+
# passing the cadence explicitly. Falls back to the constructor
415+
# value when the server has not advertised a cadence.
416+
if isinstance(ack, dict):
417+
advertised = ack.get("heartbeat_interval_seconds")
418+
if isinstance(advertised, int) and advertised > 0:
419+
self._heartbeat_interval = float(advertised)
401420
log.info("worker %s registered on %s", self.worker_id, self.task_queue)
402421

403422
async def _run_workflow_task(self, task: dict[str, Any]) -> list[dict[str, Any]] | None:
@@ -993,12 +1012,14 @@ async def _poll_workflow_tasks(self) -> None:
9931012
async def _dispatch_workflow_task(self, task: dict[str, Any]) -> None:
9941013
task_start = time.perf_counter()
9951014
outcome = "error"
1015+
self._workflow_inflight += 1
9961016
try:
9971017
commands = await self._run_workflow_task(task)
9981018
outcome = "completed" if commands is not None else "failed"
9991019
except Exception:
10001020
log.exception("unhandled error in workflow task execution")
10011021
finally:
1022+
self._workflow_inflight = max(0, self._workflow_inflight - 1)
10021023
self._record_task_metrics("workflow", outcome, time.perf_counter() - task_start)
10031024
self._wf_semaphore.release()
10041025

@@ -1032,11 +1053,13 @@ async def _poll_activity_tasks(self) -> None:
10321053
async def _dispatch_activity_task(self, task: dict[str, Any]) -> None:
10331054
task_start = time.perf_counter()
10341055
outcome = "error"
1056+
self._activity_inflight += 1
10351057
try:
10361058
outcome = await self._run_activity_task(task)
10371059
except Exception:
10381060
log.exception("unhandled error in activity task execution")
10391061
finally:
1062+
self._activity_inflight = max(0, self._activity_inflight - 1)
10401063
self._record_task_metrics("activity", outcome, time.perf_counter() - task_start)
10411064
self._act_semaphore.release()
10421065

@@ -1076,12 +1099,96 @@ async def run(self) -> None:
10761099
await self._register()
10771100
wf_loop = asyncio.create_task(self._poll_workflow_tasks())
10781101
act_loop = asyncio.create_task(self._poll_activity_tasks())
1079-
loops = [wf_loop, act_loop]
1102+
hb_loop = asyncio.create_task(self._heartbeat_loop())
1103+
loops = [wf_loop, act_loop, hb_loop]
10801104
if self._query_tasks_supported:
10811105
loops.append(asyncio.create_task(self._poll_query_tasks()))
10821106
with contextlib.suppress(asyncio.CancelledError):
10831107
await asyncio.gather(*loops)
10841108

1109+
async def _heartbeat_loop(self) -> None:
1110+
"""Periodically refresh the server-side worker registration.
1111+
1112+
Reports current task-slot availability and basic process-level
1113+
metrics so the worker management API, CLI worker listing, and
1114+
Waterline Worker Status view can show free-slot counts and
1115+
process health alongside ``last_heartbeat_at``. Cadence is the
1116+
server-advertised ``heartbeat_interval_seconds`` (default 60s,
1117+
bounded to [1s, 1h] cluster-wide) so workers stop being
1118+
considered for task dispatch when they miss enough heartbeats.
1119+
"""
1120+
while not self._stop.is_set():
1121+
try:
1122+
await asyncio.wait_for(self._stop.wait(), timeout=self._heartbeat_interval)
1123+
except asyncio.TimeoutError:
1124+
pass
1125+
if self._stop.is_set():
1126+
return
1127+
try:
1128+
ack = await self.client.heartbeat_worker(
1129+
worker_id=self.worker_id,
1130+
task_slots=self._current_task_slots(),
1131+
process_metrics=self._current_process_metrics(),
1132+
)
1133+
except Exception as e:
1134+
log.warning("worker heartbeat failed: %s", e)
1135+
continue
1136+
if isinstance(ack, dict):
1137+
advertised = ack.get("heartbeat_interval_seconds")
1138+
if isinstance(advertised, int) and advertised > 0:
1139+
self._heartbeat_interval = float(advertised)
1140+
1141+
def _current_task_slots(self) -> dict[str, int]:
1142+
return {
1143+
"workflow_available": max(
1144+
0, self.max_concurrent_workflow_tasks - self._workflow_inflight
1145+
),
1146+
"activity_available": max(
1147+
0, self.max_concurrent_activity_tasks - self._activity_inflight
1148+
),
1149+
}
1150+
1151+
def _current_process_metrics(self) -> dict[str, Any]:
1152+
import os
1153+
import socket
1154+
1155+
metrics: dict[str, Any] = {
1156+
"process_uptime_seconds": int(time.time() - self._process_started_at),
1157+
"process_id": os.getpid(),
1158+
}
1159+
1160+
try:
1161+
import resource
1162+
1163+
usage = resource.getrusage(resource.RUSAGE_SELF)
1164+
# ru_maxrss is kilobytes on Linux and bytes on macOS — normalize
1165+
# to bytes. The server stores whatever is sent so the units stay
1166+
# consistent across SDKs.
1167+
if sys.platform == "darwin":
1168+
metrics["memory_bytes"] = int(usage.ru_maxrss)
1169+
else:
1170+
metrics["memory_bytes"] = int(usage.ru_maxrss) * 1024
1171+
1172+
cpu_seconds = float(usage.ru_utime) + float(usage.ru_stime)
1173+
wall_seconds = max(0.001, time.time() - self._process_started_at)
1174+
metrics["cpu_percent"] = max(
1175+
0.0, min(100.0, round((cpu_seconds / wall_seconds) * 100.0, 2))
1176+
)
1177+
except (ImportError, OSError):
1178+
# `resource` is POSIX-only — Windows skips getrusage but still
1179+
# reports pid + uptime + host so the operator surface remains
1180+
# populated.
1181+
pass
1182+
1183+
try:
1184+
host = socket.gethostname()
1185+
except Exception:
1186+
host = ""
1187+
if isinstance(host, str) and host != "":
1188+
metrics["host"] = host[:255]
1189+
1190+
return metrics
1191+
10851192
async def run_until(
10861193
self,
10871194
*,

tests/test_worker.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ async def echo_async_activity(val: str) -> str:
9090
def mock_client() -> AsyncMock:
9191
client = AsyncMock(spec=Client)
9292
client.register_worker = AsyncMock(return_value={"worker_id": "w1", "registered": True})
93+
client.heartbeat_worker = AsyncMock(
94+
return_value={"worker_id": "w1", "acknowledged": True, "heartbeat_interval_seconds": 60}
95+
)
9396
client.poll_workflow_task = AsyncMock(return_value=None)
9497
client.poll_activity_task = AsyncMock(return_value=None)
9598
client.poll_query_task = AsyncMock(return_value=None)
@@ -1357,6 +1360,83 @@ async def test_run_skips_query_loop_without_query_task_capability(self, mock_cli
13571360
mock_client.poll_query_task.assert_not_called()
13581361

13591362

1363+
class TestWorkerHeartbeats:
1364+
@pytest.mark.asyncio
1365+
async def test_run_drives_periodic_heartbeats_with_slot_state(
1366+
self, mock_client: AsyncMock
1367+
) -> None:
1368+
worker = Worker(
1369+
mock_client,
1370+
task_queue="q1",
1371+
workflows=[TestWorkflow],
1372+
activities=[echo_activity],
1373+
max_concurrent_workflow_tasks=4,
1374+
max_concurrent_activity_tasks=2,
1375+
poll_timeout=0.01,
1376+
heartbeat_interval=0.05,
1377+
)
1378+
run_task = asyncio.create_task(worker.run())
1379+
await asyncio.sleep(0.2)
1380+
await worker.stop()
1381+
run_task.cancel()
1382+
with contextlib.suppress(asyncio.CancelledError):
1383+
await run_task
1384+
1385+
assert mock_client.heartbeat_worker.call_count >= 1
1386+
kwargs = mock_client.heartbeat_worker.call_args.kwargs
1387+
assert kwargs["worker_id"] == worker.worker_id
1388+
assert kwargs["task_slots"]["workflow_available"] == 4
1389+
assert kwargs["task_slots"]["activity_available"] == 2
1390+
process_metrics = kwargs["process_metrics"]
1391+
assert "process_id" in process_metrics
1392+
assert process_metrics["process_id"] > 0
1393+
assert "process_uptime_seconds" in process_metrics
1394+
1395+
@pytest.mark.asyncio
1396+
async def test_register_adopts_server_advertised_heartbeat_cadence(
1397+
self, mock_client: AsyncMock
1398+
) -> None:
1399+
mock_client.register_worker = AsyncMock(
1400+
return_value={
1401+
"worker_id": "w1",
1402+
"registered": True,
1403+
"heartbeat_interval_seconds": 7,
1404+
}
1405+
)
1406+
worker = Worker(
1407+
mock_client,
1408+
task_queue="q1",
1409+
workflows=[TestWorkflow],
1410+
activities=[echo_activity],
1411+
heartbeat_interval=120.0,
1412+
)
1413+
await worker._register()
1414+
assert worker._heartbeat_interval == 7.0
1415+
1416+
@pytest.mark.asyncio
1417+
async def test_heartbeat_loop_survives_transient_errors(
1418+
self, mock_client: AsyncMock
1419+
) -> None:
1420+
mock_client.heartbeat_worker = AsyncMock(
1421+
side_effect=[RuntimeError("temporary"), {"acknowledged": True}]
1422+
)
1423+
worker = Worker(
1424+
mock_client,
1425+
task_queue="q1",
1426+
workflows=[TestWorkflow],
1427+
activities=[echo_activity],
1428+
poll_timeout=0.01,
1429+
heartbeat_interval=0.02,
1430+
)
1431+
run_task = asyncio.create_task(worker.run())
1432+
await asyncio.sleep(0.15)
1433+
await worker.stop()
1434+
run_task.cancel()
1435+
with contextlib.suppress(asyncio.CancelledError):
1436+
await run_task
1437+
assert mock_client.heartbeat_worker.call_count >= 2
1438+
1439+
13601440
class TestRunUntil:
13611441
@pytest.mark.asyncio
13621442
async def test_run_until_returns_terminal_description(self, mock_client: AsyncMock) -> None:

0 commit comments

Comments
 (0)