Skip to content

Commit e416d44

Browse files
authored
chore: implement task heartbeat (#909)
* chore: renew Redis stream message ownership during worker execution Add a worker heartbeat so long-running tasks are not reclaimed by XAUTOCLAIM while they are still running. ReceivedMessage now carries both ack() and renew() callbacks, with renew implemented via XCLAIM to reset the pending idle timer for the current consumer. The worker starts this heartbeat for broker-delivered messages and stops it when execution finishes. Interactive execution is unchanged. * test: run_local.sh has parallel workers
1 parent f0bc240 commit e416d44

9 files changed

Lines changed: 176 additions & 25 deletions

File tree

diracx-tasks/src/diracx/tasks/plumbing/broker/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class ReceivedMessage(BaseModel):
109109

110110
data: bytes
111111
ack: Callable[[], Awaitable[None]]
112+
renew: Callable[[], Awaitable[None]]
112113

113114

114115
def _prepare_arg(arg: Any) -> Any:

diracx-tasks/src/diracx/tasks/plumbing/broker/redis_streams.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,28 @@ async def _ack() -> None:
138138

139139
return _ack
140140

141+
def _renew_generator(
142+
self, msg_id: str | bytes, queue_name: str | bytes
143+
) -> Callable[[], Awaitable[None]]:
144+
"""Return a coroutine that resets the PEL idle timer for a message.
145+
146+
Calls XCLAIM with min-idle-time=0, which always succeeds and resets
147+
the idle clock — preventing the autoclaim loop from reclaiming a
148+
message that is still being actively processed.
149+
"""
150+
151+
async def _renew() -> None:
152+
async with Redis(connection_pool=self.connection_pool) as redis:
153+
await redis.xclaim(
154+
queue_name,
155+
self.consumer_group_name,
156+
self.consumer_name,
157+
min_idle_time=0,
158+
message_ids=[msg_id],
159+
)
160+
161+
return _renew
162+
141163
async def listen(self) -> AsyncGenerator[ReceivedMessage, None]:
142164
"""Yield messages from streams in strict priority order.
143165
@@ -163,6 +185,9 @@ async def listen(self) -> AsyncGenerator[ReceivedMessage, None]:
163185
yield ReceivedMessage(
164186
data=msg[b"data"],
165187
ack=self._ack_generator(msg_id=msg_id, queue_name=stream),
188+
renew=self._renew_generator(
189+
msg_id=msg_id, queue_name=stream
190+
),
166191
)
167192

168193
# Reclaim unacknowledged messages (throttled to idle_timeout interval)
@@ -187,10 +212,11 @@ async def listen(self) -> AsyncGenerator[ReceivedMessage, None]:
187212
)
188213

189214
if pending[1]:
190-
logger.debug(
191-
"Reclaimed %d unacked messages from %s",
215+
logger.info(
216+
"Reclaimed %d unacked messages from %s (message ids: %s)",
192217
len(pending[1]),
193218
sname,
219+
[msg_id for msg_id, msg in pending[1]],
194220
)
195221

196222
for msg_id, msg in pending[1]:
@@ -199,4 +225,7 @@ async def listen(self) -> AsyncGenerator[ReceivedMessage, None]:
199225
ack=self._ack_generator(
200226
msg_id=msg_id, queue_name=sname
201227
),
228+
renew=self._renew_generator(
229+
msg_id=msg_id, queue_name=sname
230+
),
202231
)

diracx-tasks/src/diracx/tasks/plumbing/worker/worker.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from datetime import UTC, datetime, timedelta
88
from time import time
9-
from typing import Any, Callable
9+
from typing import Any, Awaitable, Callable
1010

1111
import msgpack
1212
from opentelemetry import metrics, trace
@@ -47,6 +47,28 @@
4747
_LOCK_RETRY_DELAY_SECONDS = 5
4848

4949

50+
async def _message_heartbeat(
51+
renew: Callable[[], Awaitable[None]],
52+
stop_event: asyncio.Event,
53+
interval: float,
54+
) -> None:
55+
"""Periodically reset the PEL idle timer while a message is being processed.
56+
57+
Calls ``renew()`` at ``interval`` seconds so the autoclaim loop never
58+
considers the message stale as long as execution is in progress.
59+
"""
60+
while not stop_event.is_set():
61+
try:
62+
await asyncio.wait_for(stop_event.wait(), timeout=interval)
63+
return
64+
except asyncio.TimeoutError:
65+
pass
66+
try:
67+
await renew()
68+
except Exception:
69+
logger.warning("Failed to renew message ownership", exc_info=True)
70+
71+
5072
class Worker:
5173
"""Execute tasks consumed from a broker.
5274
@@ -261,22 +283,42 @@ async def process_message(self, message: bytes | ReceivedMessage) -> None:
261283
"Executing task %s (ID: %s)", task_message.task_name, task_message.task_id
262284
)
263285

264-
result = await self.run_task(task_func, task_message)
265-
266-
# Handle failure: retry or dead letter queue
267-
if result.is_err:
268-
await self._handle_failure(task_message, result)
269-
else:
270-
await self._handle_success(task_message, result)
286+
# Start heartbeat to keep PEL ownership while the task runs.
287+
# Renew at half the idle_timeout so there's always a safety margin.
288+
heartbeat_stop = asyncio.Event()
289+
heartbeat_interval = self.broker.idle_timeout / 1000 / 2
290+
heartbeat_task: asyncio.Task[None] | None = None
291+
if isinstance(message, ReceivedMessage):
292+
heartbeat_task = asyncio.create_task(
293+
_message_heartbeat(message.renew, heartbeat_stop, heartbeat_interval)
294+
)
271295

272-
# Always persist the result to the backend
273296
try:
274-
if self.broker.result_backend:
275-
await self.broker.result_backend.set_result(
276-
task_message.task_id, result
277-
)
278-
except Exception:
279-
logger.exception("Failed to save result")
297+
result = await self.run_task(task_func, task_message)
298+
299+
# Handle failure: retry or dead letter queue
300+
if result.is_err:
301+
await self._handle_failure(task_message, result)
302+
else:
303+
await self._handle_success(task_message, result)
304+
305+
# Always persist the result to the backend
306+
try:
307+
if self.broker.result_backend:
308+
await self.broker.result_backend.set_result(
309+
task_message.task_id, result
310+
)
311+
except Exception:
312+
logger.exception("Failed to save result")
313+
314+
finally:
315+
if heartbeat_task is not None:
316+
heartbeat_stop.set()
317+
heartbeat_task.cancel()
318+
try:
319+
await heartbeat_task
320+
except asyncio.CancelledError:
321+
pass
280322

281323
if isinstance(message, ReceivedMessage):
282324
await message.ack()

diracx-tasks/tests/test_redis_streams.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
from __future__ import annotations
44

5+
from unittest.mock import AsyncMock, patch
6+
57
from diracx.tasks.plumbing.broker.redis_streams import (
68
ALL_STREAM_NAMES,
9+
RedisStreamBroker,
710
stream_name_for,
811
)
912
from diracx.tasks.plumbing.enums import Priority, Size
@@ -25,3 +28,29 @@ def test_all_stream_names():
2528
assert "diracx:tasks:realtime:small" in ALL_STREAM_NAMES
2629
assert "diracx:tasks:normal:medium" in ALL_STREAM_NAMES
2730
assert "diracx:tasks:background:large" in ALL_STREAM_NAMES
31+
32+
33+
async def test_renew_generator_calls_xclaim() -> None:
34+
broker = RedisStreamBroker("redis://example.invalid")
35+
36+
redis_cm = AsyncMock()
37+
redis = AsyncMock()
38+
redis_cm.__aenter__.return_value = redis
39+
redis_cm.__aexit__.return_value = False
40+
41+
with patch(
42+
"diracx.tasks.plumbing.broker.redis_streams.Redis", return_value=redis_cm
43+
):
44+
renew = broker._renew_generator(
45+
msg_id="1234-0",
46+
queue_name="diracx:tasks:normal:medium",
47+
)
48+
await renew()
49+
50+
redis.xclaim.assert_awaited_once_with(
51+
"diracx:tasks:normal:medium",
52+
broker.consumer_group_name,
53+
broker.consumer_name,
54+
min_idle_time=0,
55+
message_ids=["1234-0"],
56+
)

diracx-tasks/tests/test_worker_integration.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from unittest.mock import AsyncMock, patch
67

78
from diracx.tasks.plumbing.base_task import BaseTask
8-
from diracx.tasks.plumbing.broker.models import TaskMessage, TaskResult
9+
from diracx.tasks.plumbing.broker.models import ReceivedMessage, TaskMessage, TaskResult
910
from diracx.tasks.plumbing.depends import CallbackSpawner
1011
from diracx.tasks.plumbing.enums import Priority, Size
1112
from diracx.tasks.plumbing.factory import wrap_task
@@ -322,6 +323,40 @@ async def test_process_message_acks_on_parse_error(
322323
await worker.process_message(b"not valid msgpack at all!!")
323324

324325

326+
async def test_process_message_renews_ownership_while_running(
327+
broker, task_class_registry, wrapped_registry
328+
):
329+
"""Worker should heartbeat message ownership and ack once done."""
330+
worker = Worker(
331+
broker=broker,
332+
task_registry=wrapped_registry,
333+
task_class_registry=task_class_registry,
334+
)
335+
worker.broker.idle_timeout = 20 # ms
336+
337+
task_msg = TaskMessage(
338+
task_id="t-renew",
339+
task_name="test:SuccessTask",
340+
labels={"priority": "normal", "size": "small"},
341+
task_args=[],
342+
task_kwargs={},
343+
)
344+
345+
ack = AsyncMock()
346+
renew = AsyncMock()
347+
message = ReceivedMessage(data=task_msg.dumpb(), ack=ack, renew=renew)
348+
349+
async def _slow_run_task(*_args, **_kwargs):
350+
await asyncio.sleep(0.05)
351+
return TaskResult.from_value("ok", execution_time=0.05)
352+
353+
with patch.object(worker, "run_task", side_effect=_slow_run_task):
354+
await worker.process_message(message)
355+
356+
assert renew.await_count >= 1
357+
ack.assert_awaited_once()
358+
359+
325360
# ---------------------------------------------------------------------------
326361
# TaskBroker dependency injection resolution
327362
# ---------------------------------------------------------------------------

docs/dev/explanations/tasks/class-details.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ classDiagram
148148
class RedisStreamBroker {
149149
9 streams: 3 priorities × 3 sizes
150150
+enqueue(TaskMessage)
151-
+consume(size) ReceivedMessage
151+
+listen() AsyncGenerator~ReceivedMessage~
152152
+startup()
153153
}
154154
@@ -175,6 +175,7 @@ classDiagram
175175
class ReceivedMessage {
176176
+data: bytes
177177
+ack() awaitable
178+
+renew() awaitable
178179
}
179180
180181
class TaskBinding {
@@ -185,8 +186,8 @@ classDiagram
185186
}
186187
187188
class RedisResultBackend {
188-
+store(task_id, TaskResult)
189-
+get(task_id) TaskResult
189+
+set_result(task_id, TaskResult)
190+
+get_result(task_id) TaskResult
190191
}
191192
192193
RedisStreamBroker ..> TaskMessage : enqueues
@@ -195,7 +196,7 @@ classDiagram
195196
TaskBinding --> RedisStreamBroker : references
196197
```
197198

198-
`TaskMessage` is the wire-protocol message serialized to msgpack. `ReceivedMessage` wraps the raw bytes with an `ack()` callback so the worker can acknowledge the message after processing. `TaskBinding` maps a task class to its broker, providing the `submit()` method used by `BaseTask.schedule()`.
199+
`TaskMessage` is the wire-protocol message serialized to msgpack. `ReceivedMessage` wraps the raw bytes with `ack()` and `renew()` callbacks: `ack()` acknowledges completion, while `renew()` refreshes ownership of in-flight messages during long executions. `TaskBinding` maps a task class to its broker, providing the `submit()` method used by `BaseTask.schedule()`.
199200

200201
## Callback subsystem
201202

docs/dev/explanations/tasks/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Occasionally tasks must be scheduled to run at some point in the future, most co
2828

2929
## What is a task?
3030

31-
Tasks are async Python functions which have extremely low overhead, allowing for many tasks to be spawned for even cheap operations.
31+
Tasks are Python classes derived from `BaseTask`. Each task class implements an async `execute(...)` method, and task instances can be scheduled with `task.schedule(...)`.
3232
Tasks can be executed in four different ways:
3333

3434
- **Standalone tasks:** The task performs its work and then returns. For example synchronising the IAM to DiracX configuration queries IAM and then updates the DiracX CS.

docs/dev/reference/tasks.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,20 @@ ConcurrencyLimiter(obj, key, *extra_keys, ttl_ms=30000)
9393

9494
`PeriodicBaseTask` overrides this with a `MutexLock` keyed by the task class name, preventing concurrent execution. `PeriodicVoAwareBaseTask` adds the VO name to the lock key, so each VO gets its own mutex.
9595

96+
### Lock watchdog vs message heartbeat
97+
98+
Two independent watchdog-style mechanisms exist while a worker executes tasks:
99+
100+
- **Lock watchdog** (`_lock_watchdog` in `factory.py`) periodically calls `extend()` on acquired locks, so lock TTLs do not expire mid-execution.
101+
- **Message heartbeat** (`_message_heartbeat` in `worker.py`) periodically renews stream-message ownership while a task is running.
102+
103+
The message heartbeat uses Redis `XCLAIM` with `min_idle_time=0` to reset the pending-entry idle timer for the in-flight message. This prevents `XAUTOCLAIM` from reclaiming a long-running task simply because it exceeded the idle timeout.
104+
105+
These mechanisms protect different things:
106+
107+
- lock watchdog protects lock ownership
108+
- message heartbeat protects consumer-group message ownership
109+
96110
______________________________________________________________________
97111

98112
## Schedules

run_local.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ uvicorn --factory diracx.testing.routers:create_app --reload > "${tmp_dir}/logs/
174174
diracx_pid=$!
175175
diracx-task-run scheduler > "${tmp_dir}/logs/scheduler.log" 2>&1 &
176176
scheduler_pid=$!
177-
diracx-task-run worker --worker-size small --max-concurrent-tasks 1 > "${tmp_dir}/logs/worker-sm.log" 2>&1 &
177+
diracx-task-run worker --worker-size small --max-concurrent-tasks 3 > "${tmp_dir}/logs/worker-sm.log" 2>&1 &
178178
worker_small_pid=$!
179-
diracx-task-run worker --worker-size medium --max-concurrent-tasks 1 > "${tmp_dir}/logs/worker-md.log" 2>&1 &
179+
diracx-task-run worker --worker-size medium --max-concurrent-tasks 2 > "${tmp_dir}/logs/worker-md.log" 2>&1 &
180180
worker_medium_pid=$!
181181
diracx-task-run worker --worker-size large --max-concurrent-tasks 1 > "${tmp_dir}/logs/worker-lg.log" 2>&1 &
182182
worker_large_pid=$!

0 commit comments

Comments
 (0)