Skip to content

Commit 2c48668

Browse files
feat: add ReceiverObserver protocol and Prometheus receiver metrics
Add production observability for the Receiver via an observer protocol that tracks prefetch queue depth, semaphore availability, active task count, unknown task lookups, and deserialization errors. - Add ReceiverObserver protocol (taskiq/receiver/observer.py) - Instrument Receiver with guarded observer callbacks at 5 sites - Add PrometheusReceiverObserver implementation with Gauges/Counters - Wire observer from middleware to receiver via broker attribute - Remove redundant in_flight_tasks gauge (replaced by active_tasks_count)
1 parent 0343759 commit 2c48668

File tree

5 files changed

+125
-11
lines changed

5 files changed

+125
-11
lines changed

taskiq/cli/worker/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from taskiq.cli.utils import import_object, import_tasks
1414
from taskiq.cli.worker.args import WorkerArgs
1515
from taskiq.cli.worker.process_manager import ProcessManager
16-
from taskiq.receiver import Receiver
16+
from taskiq.receiver import Receiver, ReceiverObserver
1717

1818
try:
1919
import uvloop
@@ -163,6 +163,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
163163
receiver = receiver_type(
164164
broker=broker,
165165
executor=pool,
166+
observer=getattr(broker, "_receiver_observer", None),
166167
validate_params=not args.no_parse,
167168
max_async_tasks=args.max_async_tasks,
168169
max_prefetch=args.max_prefetch,

taskiq/middlewares/prometheus_middleware.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from taskiq.abc.middleware import TaskiqMiddleware
88
from taskiq.message import TaskiqMessage
99
from taskiq.result import TaskiqResult
10+
from taskiq.receiver.observer import ReceiverObserver
1011

1112
logger = getLogger("taskiq.prometheus")
1213

@@ -75,12 +76,6 @@ def __init__(
7576
["task_name"],
7677
)
7778

78-
self.in_flight_tasks = Gauge(
79-
"in_flight_tasks",
80-
"Number of tasks in flight",
81-
["task_name"],
82-
multiprocess_mode="livesum",
83-
)
8479
self.queue_wait_seconds = Histogram(
8580
"queue_wait_seconds",
8681
"time task spent in message queue",
@@ -166,7 +161,6 @@ def pre_execute(
166161
time_delta,
167162
)
168163

169-
self.in_flight_tasks.labels(message.task_name).inc()
170164
self.received_tasks.labels(message.task_name).inc()
171165
return message
172166

@@ -203,9 +197,12 @@ def post_execute(
203197
self.found_errors.labels(message.task_name).inc()
204198
else:
205199
self.success_tasks.labels(message.task_name).inc()
206-
self.in_flight_tasks.labels(message.task_name).dec()
207200
self.execution_time.labels(message.task_name).observe(result.execution_time)
208201

202+
def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821 pyright: ignore[reportUnknownVariableType]
203+
super().set_broker(broker)
204+
broker._receiver_observer = PrometheusReceiverObserver()
205+
209206
def post_save(
210207
self,
211208
message: "TaskiqMessage",
@@ -218,3 +215,55 @@ def post_save(
218215
:param result: result of execution.
219216
"""
220217
self.saved_results.labels(message.task_name).inc()
218+
219+
220+
class PrometheusReceiverObserver(ReceiverObserver):
221+
"""Receiver observer implementation for prometheus."""
222+
223+
def __init__(self) -> None:
224+
try:
225+
from prometheus_client import Counter, Gauge # noqa: PLC0415
226+
except ImportError as exc:
227+
raise ImportError(
228+
"Cannot initialize metrics. Please install 'taskiq[metrics]'.",
229+
) from exc
230+
231+
self.prefetch_queue_size = Gauge(
232+
"prefetch_queue_size",
233+
"The number of task in the prefetch queue.",
234+
multiprocess_mode="livesum",
235+
)
236+
self.semaphore_available = Gauge(
237+
"semaphore_available",
238+
"Number of semaphore slots available in broker",
239+
multiprocess_mode="livesum",
240+
)
241+
self.active_tasks_count = Gauge(
242+
"worker_active_tasks_count",
243+
"Number of active tasks in worker",
244+
multiprocess_mode="livesum",
245+
)
246+
self.task_not_found_total = Counter(
247+
"task_not_found_total",
248+
"Number of times the worker got a task not registered",
249+
["task_name"],
250+
)
251+
self.deserialize_error = Counter(
252+
"deserialize_error_count",
253+
"Number of times broker faced a desrialization error",
254+
)
255+
256+
def on_prefetch_queue_size(self, size: int) -> None:
257+
self.prefetch_queue_size.set(size)
258+
259+
def on_semaphore_status(self, available: int) -> None:
260+
self.semaphore_available.set(available)
261+
262+
def on_active_tasks_count(self, count: int) -> None:
263+
self.active_tasks_count.set(count)
264+
265+
def on_task_not_found(self, task_name: str) -> None:
266+
self.task_not_found_total.labels(task_name).inc()
267+
268+
def on_deserialize_error(self, raw: bytes, error: Exception) -> None:
269+
self.deserialize_error.inc()

taskiq/receiver/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Package for message receiver."""
22

33
from taskiq.receiver.receiver import Receiver
4+
from taskiq.receiver.observer import ReceiverObserver
45

5-
__all__ = ["Receiver"]
6+
__all__ = ["Receiver", "ReceiverObserver"]

taskiq/receiver/observer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Protocol, runtime_checkable
2+
3+
4+
@runtime_checkable
5+
class ReceiverObserver(Protocol):
6+
"""
7+
Observer for reciever stats.
8+
9+
This classs is used to observe/collect metrics for the receiver.
10+
This includes semaphore usage, tasks in queue, etc.
11+
12+
metrics tracked:
13+
- Number of tasks in queue
14+
- Number of taks in execution (Semaphore uusage)
15+
"""
16+
17+
def on_prefetch_queue_size(self, size: int) -> None: ...
18+
def on_semaphore_status(self, available: int) -> None: ...
19+
def on_active_tasks_count(self, count: int) -> None: ...
20+
def on_task_not_found(self, task_name: str) -> None: ...
21+
def on_deserialize_error(self, raw: bytes, error: Exception) -> None: ...

taskiq/receiver/receiver.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from taskiq.context import Context
1919
from taskiq.exceptions import NoResultError
2020
from taskiq.message import TaskiqMessage
21+
from taskiq.receiver.observer import ReceiverObserver
2122
from taskiq.receiver.params_parser import parse_params
2223
from taskiq.result import TaskiqResult
2324
from taskiq.state import TaskiqState
@@ -35,6 +36,7 @@ def __init__(
3536
self,
3637
broker: AsyncBroker,
3738
executor: Executor | None = None,
39+
observer: ReceiverObserver | None = None,
3840
validate_params: bool = True,
3941
max_async_tasks: "int | None" = None,
4042
max_prefetch: int = 0,
@@ -54,6 +56,7 @@ def __init__(
5456
self.dependency_graphs: dict[str, DependencyGraph] = {}
5557
self.propagate_exceptions = propagate_exceptions
5658
self.on_exit = on_exit
59+
self.observer = observer
5760
self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED
5861
self.known_tasks: set[str] = set()
5962
self.max_tasks_to_execute = max_tasks_to_execute
@@ -92,6 +95,11 @@ async def callback( # noqa: C901, PLR0912
9295
taskiq_msg = self.broker.formatter.loads(message=message_data)
9396
taskiq_msg.parse_labels()
9497
except Exception as exc:
98+
if self.observer is not None:
99+
self.observer.on_deserialize_error(
100+
raw=message_data,
101+
error=exc,
102+
)
95103
logger.warning(
96104
"Cannot parse message: %s. Skipping execution.\n %s",
97105
message_data,
@@ -102,6 +110,11 @@ async def callback( # noqa: C901, PLR0912
102110
logger.debug(f"Received message: {taskiq_msg}")
103111
task = self.broker.find_task(taskiq_msg.task_name)
104112
if task is None:
113+
if self.observer is not None:
114+
self.observer.on_task_not_found(
115+
taskiq_msg.task_name,
116+
)
117+
105118
logger.warning(
106119
'task "%s" is not found. Maybe you forgot to import it?',
107120
taskiq_msg.task_name,
@@ -363,6 +376,7 @@ async def prefetcher(
363376
break
364377
try:
365378
await self.sem_prefetch.acquire()
379+
366380
if (
367381
self.max_tasks_to_execute
368382
and fetched_tasks >= self.max_tasks_to_execute
@@ -376,13 +390,20 @@ async def prefetcher(
376390
# and continue the loop. So it will check if finished event was set.
377391
if not done:
378392
self.sem_prefetch.release()
393+
379394
continue
380395
# We're done, so now we need to check
381396
# whether task has returned an error.
382397
message = current_message.result()
383398
current_message = asyncio.create_task(iterator.__anext__()) # type: ignore
384399
fetched_tasks += 1
385400
await queue.put(message)
401+
402+
if self.observer is not None:
403+
self.observer.on_prefetch_queue_size(
404+
queue.qsize(),
405+
)
406+
386407
except (asyncio.CancelledError, StopAsyncIteration):
387408
break
388409
# We don't want to fetch new messages if we are shutting down.
@@ -413,17 +434,35 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
413434
:param task: finished task
414435
"""
415436
tasks.discard(task)
437+
if self.observer is not None:
438+
self.observer.on_active_tasks_count(
439+
len(tasks),
440+
)
441+
416442
if self.sem is not None:
417443
self.sem.release()
418444

445+
if self.observer is not None:
446+
self.observer.on_semaphore_status(
447+
self.sem._value # noqa
448+
)
449+
419450
while True:
420451
try:
421452
# Waits for semaphore to be released.
422453
if self.sem is not None:
423454
await self.sem.acquire()
455+
if self.observer is not None:
456+
self.observer.on_semaphore_status(
457+
self.sem._value # noqa
458+
)
424459

425460
self.sem_prefetch.release()
426461
message = await queue.get()
462+
if self.observer is not None:
463+
self.observer.on_prefetch_queue_size(
464+
queue.qsize() # noqa
465+
)
427466
if message is QUEUE_DONE:
428467
# asyncio.wait will throw an error if there is nothing to wait for
429468
if tasks:
@@ -438,7 +477,10 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
438477
self.callback(message=message, raise_err=False),
439478
)
440479
tasks.add(task)
441-
480+
if self.observer is not None:
481+
self.observer.on_active_tasks_count(
482+
len(tasks),
483+
)
442484
# We want the task to remove itself from the set when it's done.
443485
#
444486
# Because if we won't save it anywhere,

0 commit comments

Comments
 (0)