diff --git a/CHANGELOG.md b/CHANGELOG.md index 64da93791c..8ca1bf44ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#4244](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4244)) - `opentelemetry-instrumentation-sqlite3`: Add uninstrument, error status, suppress, and no-op tests ([#4335](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4335)) +- `opentelemetry-instrumentation-celery`: Add task and worker lifecycle metrics matching Celery Flower + ([#4439](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4439)) ### Fixed @@ -26,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#4427](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4427)) - `opentelemetry-instrumentation-flask`: Clean up environ keys in `_teardown_request` to prevent duplicate execution ([#4341](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4341)) +- `opentelemetry-instrumentation-celery`: Fix memory leak in `task_id_to_start_time` dict never being cleaned up + ([#4439](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4439)) ### Breaking changes diff --git a/docs/nitpick-exceptions.ini b/docs/nitpick-exceptions.ini index 128e406a70..32b68f51ed 100644 --- a/docs/nitpick-exceptions.ini +++ b/docs/nitpick-exceptions.ini @@ -48,6 +48,7 @@ py-class= fastapi.applications.FastAPI starlette.applications.Starlette _contextvars.Token + celery.worker.request.Request any= ; API diff --git a/instrumentation/opentelemetry-instrumentation-celery/pyproject.toml b/instrumentation/opentelemetry-instrumentation-celery/pyproject.toml index 21a58ea35e..20d3b2011b 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/pyproject.toml +++ b/instrumentation/opentelemetry-instrumentation-celery/pyproject.toml @@ -37,6 +37,7 @@ instruments = [ [project.entry-points.opentelemetry_instrumentor] celery = "opentelemetry.instrumentation.celery:CeleryInstrumentor" +celery_worker = "opentelemetry.instrumentation.celery:CeleryWorkerInstrumentor" [project.urls] Homepage = "https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/instrumentation/opentelemetry-instrumentation-celery" diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index ca61e9b455..08bb40b6fe 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -55,20 +55,52 @@ def add(x, y): such as the BatchSpanProcessor. Celery provides a signal called ``worker_process_init`` that can be used to accomplish this as shown in the example above. +Worker-level metrics +-------------------- + +To collect worker lifecycle metrics (online/offline status), use +``CeleryWorkerInstrumentor`` in the **main worker process** via the +``celeryd_after_setup`` signal. This signal fires before +``worker_ready``, ensuring the handler is connected in time. + +.. code:: python + + from opentelemetry.instrumentation.celery import ( + CeleryInstrumentor, + CeleryWorkerInstrumentor, + ) + from celery.signals import celeryd_after_setup, worker_process_init + + @celeryd_after_setup.connect(weak=False) + def init_worker_metrics(sender, instance, conf, **kwargs): + CeleryWorkerInstrumentor().instrument() + + @worker_process_init.connect(weak=False) + def init_celery_tracing(*args, **kwargs): + CeleryInstrumentor().instrument() + API --- """ +from __future__ import annotations + import logging +from collections.abc import Collection, Iterable +from dataclasses import dataclass from timeit import default_timer -from typing import Collection, Iterable +from typing import TYPE_CHECKING, Any, Optional, cast from billiard import VERSION from billiard.einfo import ExceptionInfo -from celery import signals # pylint: disable=no-name-in-module +from celery import signals # pylint: disable=import-self +from celery import ( + states as celery_states, # pylint: disable=import-self, no-name-in-module +) +from celery.worker.request import Request # pylint: disable=no-name-in-module from opentelemetry import context as context_api -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.instrumentation.celery import utils from opentelemetry.instrumentation.celery.package import _instruments from opentelemetry.instrumentation.celery.version import __version__ @@ -81,6 +113,10 @@ def add(x, y): ) from opentelemetry.trace.status import Status, StatusCode +if TYPE_CHECKING: + from celery.app.task import Task + + from opentelemetry.metrics import Counter, Histogram, Meter, UpDownCounter if VERSION >= (4, 0, 1): from billiard.einfo import ExceptionWithTraceback else: @@ -94,13 +130,11 @@ def add(x, y): _TASK_RUN = "run" _TASK_RETRY_REASON_KEY = "celery.retry.reason" -_TASK_REVOKED_REASON_KEY = "celery.revoked.reason" -_TASK_REVOKED_TERMINATED_SIGNAL_KEY = "celery.terminated.signal" _TASK_NAME_KEY = "celery.task_name" -class CeleryGetter(Getter): - def get(self, carrier, key): +class CeleryGetter(Getter[Request]): + def get(self, carrier: "Request", key: str) -> list[str] | None: value = getattr(carrier, key, None) if value is None: return None @@ -109,31 +143,158 @@ def get(self, carrier, key): # of ints). The TextMapPropagator contract requires string # values, so coerce anything that isn't already a string. if isinstance(value, str): - value = (value,) - elif isinstance(value, Iterable): - value = tuple( - str(v) if not isinstance(v, str) else v for v in value - ) - else: - value = (str(value),) - return value + return [value] + if isinstance(value, Iterable): + return [str(v) if not isinstance(v, str) else v for v in value] + return [str(value)] - def keys(self, carrier): + def keys(self, carrier: "Request") -> list[str]: return [] celery_getter = CeleryGetter() +def _log_signal( + signal_name: str, + task_id: Optional[str] = None, + task_name: Optional[str] = None, + worker: Optional[str] = None, +) -> None: + """Log Celery signal execution context.""" + logger.debug( + "%s signal received task_id=%s task_name=%s worker=%s", + signal_name, + task_id, + task_name, + worker, + ) + + +def _retrieve_task_name( + task: "Optional[Task]" = None, + request: "Optional[Request]" = None, +) -> "Optional[str]": + """Retrieve the task name from the task or request objects.""" + if task is not None and getattr(task, "name", None) is not None: + return task.name + if request is not None: + request_task = getattr(request, "task", None) + if request_task is not None: + # request.task may be a Task object or a string name + if hasattr(request_task, "name"): + return request_task.name + return str(request_task) + request_name = getattr(request, "name", None) + if request_name is not None: + return str(request_name) + return None + + +def _retrieve_worker_name( + task: "Optional[Task]" = None, + request: "Optional[Request]" = None, + sender: Optional[object] = None, +) -> Optional[str]: + """Retrieve the worker name from the task, request, or sender objects.""" + task_request = task.request if task is not None else None + if task_request is not None: + task_request_worker = getattr(task_request, "hostname", None) + if task_request_worker is not None: + return task_request_worker + + if request is not None: + request_worker = cast( + Optional[str], getattr(request, "hostname", None) + ) + if request_worker is not None: + return request_worker + + if sender is not None: + return cast(Optional[str], getattr(sender, "hostname", None)) + return None + + +@dataclass(frozen=True) +class _CeleryTaskMetricNames: + """Canonical metric names for Celery task instrumentation.""" + + events_total: str = "flower.events.total" + task_runtime_seconds: str = "flower.task.runtime.seconds" + worker_currently_executing_tasks: str = ( + "flower.worker.number.of.currently.executing.tasks" + ) + + +@dataclass(frozen=True) +class _CeleryWorkerMetricNames: + """Canonical metric names for Celery worker lifecycle instrumentation.""" + + events_total: str = "flower.events.total" + worker_online: str = "flower.worker.online" + + +_TASK_METRIC_NAMES = _CeleryTaskMetricNames() +_WORKER_METRIC_NAMES = _CeleryWorkerMetricNames() + + +@dataclass(frozen=True) +class _CeleryEventTypes: + """Celery event type identifiers used as metric label values.""" + + task_sent: str = "task-sent" + task_received: str = "task-received" + task_started: str = "task-started" + task_succeeded: str = "task-succeeded" + task_failed: str = "task-failed" + task_retried: str = "task-retried" + task_revoked: str = "task-revoked" + + +_EVENT_TYPES = _CeleryEventTypes() + + +@dataclass +class CeleryTaskMetrics: + """Metrics for tracking Celery task events and states.""" + + events_total: "Counter" + task_runtime_seconds: "Histogram" + worker_currently_executing_tasks: "UpDownCounter" + + +@dataclass +class CeleryWorkerMetrics: + """Metrics for tracking Celery worker lifecycle.""" + + events_total: "Counter" + worker_online: "UpDownCounter" + + class CeleryInstrumentor(BaseInstrumentor): - metrics = None - task_id_to_start_time = {} + """An instrumentor for Celery task execution. + + Traces task publish, run, failure, retry, and revocation. + Tracks task-level metrics (event counts, runtime, + currently executing tasks). + + Must be initialized in the worker subprocess via the + ``worker_process_init`` signal.""" + + def __init__(self) -> None: + super().__init__() + self.metrics: Optional[CeleryTaskMetrics] = None + self.task_id_to_start_time: dict = {} + self.executing_task_id_to_worker: dict = {} def instrumentation_dependencies(self) -> Collection[str]: return _instruments - def _instrument(self, **kwargs): - tracer_provider = kwargs.get("tracer_provider") + def _instrument(self, **kwargs: object) -> None: + """Connect Celery signal handlers and create task-level metrics.""" + tracer_provider = cast( + Optional[trace.TracerProvider], kwargs.get("tracer_provider") + ) # pylint: disable=attribute-defined-outside-init self._tracer = trace.get_tracer( @@ -143,7 +304,9 @@ def _instrument(self, **kwargs): schema_url="https://opentelemetry.io/schemas/1.11.0", ) - meter_provider = kwargs.get("meter_provider") + meter_provider = cast( + Optional[metrics.MeterProvider], kwargs.get("meter_provider") + ) meter = get_meter( __name__, __version__, @@ -151,8 +314,12 @@ def _instrument(self, **kwargs): schema_url="https://opentelemetry.io/schemas/1.11.0", ) - self.create_celery_metrics(meter) + self.task_id_to_start_time = {} + self.executing_task_id_to_worker = {} + self.metrics = self.create_task_metrics(meter) + + # Connect signal handlers to trace Celery events and track task states signals.task_prerun.connect(self._trace_prerun, weak=False) signals.task_postrun.connect(self._trace_postrun, weak=False) signals.before_task_publish.connect( @@ -164,17 +331,79 @@ def _instrument(self, **kwargs): signals.task_failure.connect(self._trace_failure, weak=False) signals.task_retry.connect(self._trace_retry, weak=False) - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: object) -> None: + """Uninstrument Celery by disconnecting all signal handlers and clearing metrics and state.""" signals.task_prerun.disconnect(self._trace_prerun) signals.task_postrun.disconnect(self._trace_postrun) signals.before_task_publish.disconnect(self._trace_before_publish) signals.after_task_publish.disconnect(self._trace_after_publish) signals.task_failure.disconnect(self._trace_failure) signals.task_retry.disconnect(self._trace_retry) + self.metrics = None + self.task_id_to_start_time = {} + self.executing_task_id_to_worker = {} + + def _metrics(self) -> CeleryTaskMetrics: + """Retrieve the Celery metrics object, raising an error if not initialized.""" + if self.metrics is not None: + return self.metrics + raise RuntimeError("Celery metrics are not initialized") + + def _record_event_count( + self, + event_type: str, + task_name: Optional[str] = None, + worker: Optional[str] = None, + ) -> None: + """Record a Celery event by incrementing the events counter.""" + if task_name is None: + return + + attributes: dict[str, str] = { + "task": task_name, + "type": event_type, + } + if worker is not None: + attributes["worker"] = worker - def _trace_prerun(self, *args, **kwargs): + self._metrics().events_total.add( + 1, + attributes=attributes, + ) + + def _track_executing_task( + self, + task_id: Optional[str], + worker: Optional[str], + ) -> None: + """Track an executing task by recording its worker and incrementing the executing tasks counter.""" + if task_id is None or worker is None: + return + + self.executing_task_id_to_worker[task_id] = worker + self._metrics().worker_currently_executing_tasks.add( + 1, + attributes={"worker": worker}, + ) + + def _untrack_executing_task(self, task_id: str) -> None: + """Untrack an executing task by removing its worker and decrementing the executing tasks counter.""" + worker = self.executing_task_id_to_worker.pop(task_id, None) + if worker is None: + return + + self._metrics().worker_currently_executing_tasks.add( + -1, + attributes={"worker": worker}, + ) + + def _trace_prerun(self, *args: object, **kwargs: object) -> None: + """Start a span for a task about to be executed and track the executing task by recording its start time and incrementing the executing tasks counter.""" task = utils.retrieve_task(kwargs) task_id = utils.retrieve_task_id(kwargs) + task_name = task.name if task is not None else None + worker = _retrieve_worker_name(task=task) if task is not None else None + _log_signal("task_prerun", task_id, task_name, worker) if task is None or task_id is None: return @@ -184,8 +413,6 @@ def _trace_prerun(self, *args, **kwargs): tracectx = extract(request, getter=celery_getter) or None token = context_api.attach(tracectx) if tracectx is not None else None - logger.debug("prerun signal start task_id=%s", task_id) - operation_name = f"{_TASK_RUN}/{task.name}" span = self._tracer.start_span( operation_name, context=tracectx, kind=trace.SpanKind.CONSUMER @@ -195,15 +422,24 @@ def _trace_prerun(self, *args, **kwargs): activation.__enter__() # pylint: disable=unnecessary-dunder-call utils.attach_context(task, task_id, span, activation, token) - def _trace_postrun(self, *args, **kwargs): + worker = _retrieve_worker_name(task=task) + self._track_executing_task(task_id, worker) + self._record_event_count(_EVENT_TYPES.task_started, task.name, worker) + + def _trace_postrun(self, *args: object, **kwargs: object) -> None: + """Finish a span for a task that has been executed and untrack the executing task by recording its end time and decrementing the executing tasks counter. + + https://docs.celeryq.dev/en/main/userguide/signals.html#task-postrun + """ task = utils.retrieve_task(kwargs) task_id = utils.retrieve_task_id(kwargs) + task_name = task.name if task is not None else None + worker = _retrieve_worker_name(task=task) if task is not None else None + _log_signal("task_postrun", task_id, task_name, worker) if task is None or task_id is None: return - logger.debug("postrun signal task_id=%s", task_id) - # retrieve and finish the Span ctx = utils.retrieve_context(task, task_id) @@ -213,26 +449,56 @@ def _trace_postrun(self, *args, **kwargs): span, activation, token = ctx + # Type safety: task.name could be None, but the span attribute requires a string. In this case we can use "unknown" as a fallback task name for the span attribute since it's better to have an "unknown" task name than to have no span at all. + task_name = task.name or "unknown" + # request context tags if span.is_recording(): span.set_attribute(_TASK_TAG_KEY, _TASK_RUN) utils.set_attributes_from_context(span, kwargs) utils.set_attributes_from_context(span, task.request) - span.set_attribute(_TASK_NAME_KEY, task.name) + span.set_attribute(_TASK_NAME_KEY, task_name) activation.__exit__(None, None, None) utils.detach_context(task, task_id) self.update_task_duration_time(task_id) - labels = {"task": task.name, "worker": task.request.hostname} + task_state = cast( + Optional[str], + kwargs.get("state", getattr(task.request, "state", None)), + ) + labels = {"task": task_name, "worker": task.request.hostname} self._record_histograms(task_id, labels) + self.task_id_to_start_time.pop(task_id, None) + self._untrack_executing_task(task_id) + + # Update event counts based on task state + if task_state == celery_states.SUCCESS: + _log_signal( + "task_succeeded", + task_id, + task_name, + worker, + ) + self._record_event_count( + _EVENT_TYPES.task_succeeded, task_name, task.request.hostname + ) + # if the process sending the task is not instrumented # there's no incoming context and no token to detach if token is not None: context_api.detach(token) - def _trace_before_publish(self, *args, **kwargs): + def _trace_before_publish(self, *args: object, **kwargs: object) -> None: + """Start a span for a task about to be published and track the publishing task by recording its start time and incrementing the publishing tasks counter.""" task = utils.retrieve_task_from_sender(kwargs) task_id = utils.retrieve_task_id_from_message(kwargs) + task_name = task.name if task is not None else None + _log_signal( + "before_task_publish", + task_id, + task_name, + None, + ) if task_id is None: return @@ -241,9 +507,10 @@ def _trace_before_publish(self, *args, **kwargs): # task is an anonymous task send using send_task or using canvas workflow # Signatures() to send to a task not in the current processes dependency # tree - task_name = kwargs.get("sender", "unknown") + sender = kwargs.get("sender") + task_name = str(sender) if sender is not None else "unknown" else: - task_name = task.name + task_name = task.name or "unknown" operation_name = f"{_TASK_APPLY_ASYNC}/{task_name}" span = self._tracer.start_span( operation_name, kind=trace.SpanKind.PRODUCER @@ -267,10 +534,15 @@ def _trace_before_publish(self, *args, **kwargs): if headers: inject(headers) + self._record_event_count(_EVENT_TYPES.task_sent, task_name, None) + @staticmethod - def _trace_after_publish(*args, **kwargs): + def _trace_after_publish(*args: object, **kwargs: object) -> None: + """Finish a span for a task that has been published and untrack the publishing task by recording its end time and decrementing the publishing tasks counter.""" task = utils.retrieve_task_from_sender(kwargs) task_id = utils.retrieve_task_id_from_message(kwargs) + task_name = task.name if task is not None else None + _log_signal("after_task_publish", task_id, task_name, None) if task is None or task_id is None: return @@ -287,10 +559,13 @@ def _trace_after_publish(*args, **kwargs): activation.__exit__(None, None, None) # pylint: disable=unnecessary-dunder-call utils.detach_context(task, task_id, is_publish=True) - @staticmethod - def _trace_failure(*args, **kwargs): + def _trace_failure(self, *args: object, **kwargs: object) -> None: + """Trace a task failure event by recording the exception and incrementing the failure event counter.""" task = utils.retrieve_task_from_sender(kwargs) task_id = utils.retrieve_task_id(kwargs) + task_name = task.name if task is not None else None + worker = _retrieve_worker_name(task=task) if task is not None else None + _log_signal("task_failure", task_id, task_name, worker) if task is None or task_id is None: return @@ -305,7 +580,7 @@ def _trace_failure(*args, **kwargs): if not span.is_recording(): return - status_kwargs = {"status_code": StatusCode.ERROR} + status_description: Optional[str] = None ex = kwargs.get("einfo") @@ -329,15 +604,25 @@ def _trace_failure(*args, **kwargs): ): ex = ex.exc - status_kwargs["description"] = str(ex) + status_description = str(ex) span.record_exception(ex) - span.set_status(Status(**status_kwargs)) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=status_description, + ) + ) + worker = _retrieve_worker_name(task=task) + self._record_event_count(_EVENT_TYPES.task_failed, task.name, worker) - @staticmethod - def _trace_retry(*args, **kwargs): + def _trace_retry(self, *args: object, **kwargs: object) -> None: + """Trace a task retry event by recording its reason and incrementing the retry event counter.""" task = utils.retrieve_task_from_sender(kwargs) task_id = utils.retrieve_task_id_from_request(kwargs) reason = utils.retrieve_reason(kwargs) + task_name = task.name if task is not None else None + worker = _retrieve_worker_name(task=task) if task is not None else None + _log_signal("task_retry", task_id, task_name, worker) if task is None or task_id is None or reason is None: return @@ -356,8 +641,11 @@ def _trace_retry(*args, **kwargs): # Use `str(reason)` instead of `reason.message` in case we get # something that isn't an `Exception` span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason)) + worker = _retrieve_worker_name(task=task) + self._record_event_count(_EVENT_TYPES.task_retried, task.name, worker) - def update_task_duration_time(self, task_id): + def update_task_duration_time(self, task_id: str) -> None: + """Update the duration time for a task by calculating the time since it was last started or updated.""" cur_time = default_timer() task_duration_time_until_now = ( cur_time - self.task_id_to_start_time[task_id] @@ -366,20 +654,220 @@ def update_task_duration_time(self, task_id): ) self.task_id_to_start_time[task_id] = task_duration_time_until_now - def _record_histograms(self, task_id, metric_attributes): + def _record_histograms( + self, + task_id: Optional[str], + metric_attributes: dict[str, str], + ) -> None: + """Record histogram metrics for a task by using its duration time and provided attributes.""" if task_id is None: return - self.metrics["flower.task.runtime.seconds"].record( - self.task_id_to_start_time.get(task_id), - attributes=metric_attributes, - ) + task_duration = self.task_id_to_start_time.get(task_id) + if task_duration is not None: + self._metrics().task_runtime_seconds.record( + task_duration, + attributes=metric_attributes, + ) - def create_celery_metrics(self, meter) -> None: - self.metrics = { - "flower.task.runtime.seconds": meter.create_histogram( - name="flower.task.runtime.seconds", + @staticmethod + def create_task_metrics(meter: "Meter") -> CeleryTaskMetrics: + """Create the metrics for tracking Celery task events and states.""" + return CeleryTaskMetrics( + events_total=meter.create_counter( + name=_TASK_METRIC_NAMES.events_total, + unit="{event}", + description=( + "Number of task and worker events recorded " + "by Celery instrumentation." + ), + ), + task_runtime_seconds=meter.create_histogram( + name=_TASK_METRIC_NAMES.task_runtime_seconds, unit="seconds", description="The time it took to run the task.", - ) + ), + worker_currently_executing_tasks=meter.create_up_down_counter( + name=_TASK_METRIC_NAMES.worker_currently_executing_tasks, + unit="{task}", + description="Number of tasks currently executing at this worker.", + ), + ) + + +class CeleryWorkerInstrumentor(BaseInstrumentor): + """An instrumentor for Celery worker lifecycle metrics. + + Tracks worker online/offline status via the ``worker_ready`` and + ``worker_shutdown`` signals. These signals fire in the **main worker + process**, so this instrumentor must be initialized there — typically + via the ``celeryd_after_setup`` signal (which fires before + ``worker_ready``, giving the handler time to connect). + + Usage:: + + from opentelemetry.instrumentation.celery import CeleryWorkerInstrumentor + from celery.signals import celeryd_after_setup + + @celeryd_after_setup.connect(weak=False) + def init_worker_metrics(sender, instance, conf, **kwargs): + CeleryWorkerInstrumentor().instrument() + """ + + def __init__(self) -> None: + super().__init__() + self.metrics: Optional[CeleryWorkerMetrics] = None + self.online_workers: set = set() + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: object) -> None: + """Connect worker lifecycle signal handlers and create worker metrics.""" + meter_provider = cast( + Optional[metrics.MeterProvider], kwargs.get("meter_provider") + ) + meter = get_meter( + __name__, + __version__, + meter_provider, + schema_url="https://opentelemetry.io/schemas/1.11.0", + ) + + self.online_workers = set() + self.metrics = self._create_worker_metrics(meter) + + signals.worker_ready.connect(self._trace_worker_ready, weak=False) + signals.worker_shutdown.connect( + self._trace_worker_shutdown, weak=False + ) + signals.task_received.connect(self._trace_task_received, weak=False) + signals.task_revoked.connect(self._trace_task_revoked, weak=False) + + def _uninstrument(self, **kwargs: object) -> None: + """Disconnect worker lifecycle signal handlers.""" + signals.worker_ready.disconnect(self._trace_worker_ready) + signals.worker_shutdown.disconnect(self._trace_worker_shutdown) + signals.task_received.disconnect(self._trace_task_received) + signals.task_revoked.disconnect(self._trace_task_revoked) + self.metrics = None + self.online_workers = set() + + def _worker_metrics(self) -> CeleryWorkerMetrics: + """Return the worker metrics, raising if not yet initialized.""" + if self.metrics is not None: + return self.metrics + raise RuntimeError("Worker metrics are not initialized") + + def _trace_worker_ready(self, *args: object, **kwargs: object) -> None: + """Track a worker coming online.""" + worker = _retrieve_worker_name(sender=kwargs.get("sender")) + _log_signal("worker_ready", None, None, worker) + if worker is None or worker in self.online_workers: + return + + self.online_workers.add(worker) + self._worker_metrics().worker_online.add( + 1, + attributes={"worker": worker}, + ) + _log_signal("worker_ready", None, None, worker) + + def _trace_worker_shutdown(self, *args: object, **kwargs: object) -> None: + """Track a worker going offline.""" + worker = _retrieve_worker_name(sender=kwargs.get("sender")) + _log_signal("worker_shutdown", None, None, worker) + if worker is None or worker not in self.online_workers: + return + + self.online_workers.remove(worker) + self._worker_metrics().worker_online.add( + -1, + attributes={"worker": worker}, + ) + + def _record_event_count( + self, + event_type: str, + task_name: Optional[str] = None, + worker: Optional[str] = None, + ) -> None: + """Record a Celery event by incrementing the events counter.""" + if task_name is None: + return + + attributes: dict[str, str] = { + "type": event_type, + "task": task_name, } + if worker is not None: + attributes["worker"] = worker + self._worker_metrics().events_total.add( + 1, + attributes=attributes, + ) + + def _trace_task_received( + self, *args: object, **kwargs: dict[str, Any] + ) -> None: + """Track a received task by recording its received time and incrementing the task received event counter. + + https://docs.celeryq.dev/en/main/userguide/signals.html#task-received + """ + request = kwargs.get("request") + task_id = getattr(request, "id", None) + task_name = _retrieve_task_name(request=request) + worker = _retrieve_worker_name( + request=request, sender=kwargs.get("sender") + ) + _log_signal( + "task_received", + task_id, + task_name, + worker, + ) + if task_id is None or not isinstance(request, Request): + return + + self._record_event_count(_EVENT_TYPES.task_received, task_name, worker) + + def _trace_task_revoked(self, *args: object, **kwargs: object) -> None: + """Trace a task revoked event by untracking the task and incrementing the revoked event counter. + + https://docs.celeryq.dev/en/main/userguide/signals.html#task-revoked + """ + request = kwargs.get("request") + task_id = getattr(request, "id", None) + task_name = _retrieve_task_name(request=request) + worker = _retrieve_worker_name( + request=request, sender=kwargs.get("sender") + ) + _log_signal( + "task_revoked", + task_id, + task_name, + worker, + ) + if task_id is None or request is None: + return + + self._record_event_count(_EVENT_TYPES.task_revoked, task_name, worker) + + @staticmethod + def _create_worker_metrics(meter: "Meter") -> CeleryWorkerMetrics: + """Create the metrics for tracking Celery worker lifecycle.""" + return CeleryWorkerMetrics( + events_total=meter.create_counter( + name=_WORKER_METRIC_NAMES.events_total, + unit="{event}", + description=( + "Number of task and worker events recorded " + "by Celery instrumentation." + ), + ), + worker_online=meter.create_up_down_counter( + name=_WORKER_METRIC_NAMES.worker_online, + unit="{worker}", + description="Shows celery worker online status.", + ), + ) diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py index c6bd7b9d86..39ae6db881 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py @@ -15,7 +15,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Tuple +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Tuple, cast from celery import registry # pylint: disable=no-name-in-module from celery.app.task import Task @@ -29,6 +30,9 @@ if TYPE_CHECKING: from contextlib import AbstractContextManager +ContextTuple = Tuple[Span, "AbstractContextManager[Span]", Optional[object]] +ContextDict = dict[tuple[str, bool], ContextTuple] + logger = logging.getLogger(__name__) # Celery Context key @@ -59,7 +63,10 @@ # pylint:disable=too-many-branches -def set_attributes_from_context(span, context): +def set_attributes_from_context( + span: Span, + context: Mapping[str, Any], +) -> None: """Helper to extract meta values from a Celery Context""" if not span.is_recording(): return @@ -155,7 +162,7 @@ def attach_context( if task is None: return - ctx_dict = getattr(task, CTX_KEY, None) + ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if ctx_dict is None: ctx_dict = {} @@ -164,12 +171,17 @@ def attach_context( ctx_dict[(task_id, is_publish)] = (span, activation, token) -def detach_context(task, task_id, is_publish=False) -> None: +def detach_context( + task: Optional[Task], task_id: str, is_publish: bool = False +) -> None: """Helper to remove `Span`, `ContextManager` and context token in a Celery task when it's propagated. This function handles tasks where no values are attached to the `Task`. """ - span_dict = getattr(task, CTX_KEY, None) + if task is None: + return + + span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if span_dict is None: return @@ -178,12 +190,15 @@ def detach_context(task, task_id, is_publish=False) -> None: def retrieve_context( - task, task_id, is_publish=False -) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]: + task: Optional[Task], task_id: str, is_publish: bool = False +) -> Optional[ContextTuple]: """Helper to retrieve an active `Span`, `ContextManager` and context token stored in a `Task` instance """ - span_dict = getattr(task, CTX_KEY, None) + if task is None: + return None + + span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if span_dict is None: return None @@ -191,14 +206,14 @@ def retrieve_context( return span_dict.get((task_id, is_publish), None) -def retrieve_task(kwargs): +def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]: task = kwargs.get("task") if task is None: logger.debug("Unable to retrieve task from signal arguments") - return task + return cast(Optional[Task], task) -def retrieve_task_from_sender(kwargs): +def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]: sender = kwargs.get("sender") if sender is None: logger.debug("Unable to retrieve the sender from signal arguments") @@ -210,30 +225,31 @@ def retrieve_task_from_sender(kwargs): if sender is None: logger.debug("Unable to retrieve the task from sender=%s", sender) - return sender + return cast(Optional[Task], sender) -def retrieve_task_id(kwargs): +def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]: task_id = kwargs.get("task_id") if task_id is None: logger.debug("Unable to retrieve task_id from signal arguments") - return task_id + return cast(Optional[str], task_id) -def retrieve_task_id_from_request(kwargs): +def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]: # retry signal does not include task_id as argument so use request argument request = kwargs.get("request") if request is None: logger.debug("Unable to retrieve the request from signal arguments") + return None - task_id = getattr(request, "id") + task_id = cast(Optional[str], getattr(request, "id", None)) if task_id is None: logger.debug("Unable to retrieve the task_id from the request") return task_id -def retrieve_task_id_from_message(kwargs): +def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]: """Helper to retrieve the `Task` identifier from the message `body`. This helper supports Protocol Version 1 and 2. The Protocol is well detailed in the official documentation: @@ -243,12 +259,14 @@ def retrieve_task_id_from_message(kwargs): body = kwargs.get("body") if headers is not None and len(headers) > 0: # Protocol Version 2 (default from Celery 4.0) - return headers.get("id") + return cast(Optional[str], headers.get("id")) # Protocol Version 1 - return body.get("id") + if body is None: + return None + return cast(Optional[str], body.get("id")) -def retrieve_reason(kwargs): +def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]: reason = kwargs.get("reason") if not reason: logger.debug("Unable to retrieve the retry reason") diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py index ab1f7804cf..84fb0511fd 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py @@ -18,13 +18,13 @@ class TestUtils(unittest.TestCase): - def test_duplicate_instrumentaion(self): + def test_duplicate_instrumentation(self): first = CeleryInstrumentor() first.instrument() second = CeleryInstrumentor() second.instrument() CeleryInstrumentor().uninstrument() - self.assertIsNotNone(first.metrics) - self.assertIsNotNone(second.metrics) + self.assertIsNone(first.metrics) + self.assertIsNone(second.metrics) self.assertEqual(first.task_id_to_start_time, {}) self.assertEqual(second.task_id_to_start_time, {}) diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py index eb3d632f6e..63e0b44c8b 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py @@ -19,24 +19,27 @@ class TestCeleryGetter(TestCase): def test_get_none(self): + """Missing attribute on carrier should return None.""" getter = CeleryGetter() carrier = {} val = getter.get(carrier, "test") self.assertIsNone(val) def test_get_str(self): + """String attribute should be wrapped in a single-element list.""" mock_obj = mock.Mock() getter = CeleryGetter() mock_obj.test = "val" val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val",)) + self.assertEqual(val, ["val"]) def test_get_iter(self): + """Iterable attribute should be returned as a list.""" mock_obj = mock.Mock() getter = CeleryGetter() mock_obj.test = ["val"] val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val",)) + self.assertEqual(val, ["val"]) def test_get_int(self): """Non-string scalar values should be coerced to strings. @@ -50,7 +53,7 @@ def test_get_int(self): getter = CeleryGetter() mock_obj.test = 42 val = getter.get(mock_obj, "test") - self.assertEqual(val, ("42",)) + self.assertEqual(val, ["42"]) def test_get_iter_with_non_string_elements(self): """Iterable values containing non-strings should be coerced. @@ -61,7 +64,7 @@ def test_get_iter_with_non_string_elements(self): getter = CeleryGetter() mock_obj.test = (300, 60) val = getter.get(mock_obj, "test") - self.assertEqual(val, ("300", "60")) + self.assertEqual(val, ["300", "60"]) def test_get_iter_with_mixed_types(self): """Iterables with a mix of strings and non-strings.""" @@ -69,9 +72,18 @@ def test_get_iter_with_mixed_types(self): getter = CeleryGetter() mock_obj.test = ["val", 123] val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val", "123")) + self.assertEqual(val, ["val", "123"]) + + def test_get_non_str_non_iterable(self): + """Non-string, non-iterable value should be coerced to [str(value)].""" + getter = CeleryGetter() + mock_obj = mock.Mock() + mock_obj.key = 42 + val = getter.get(mock_obj, "key") + self.assertEqual(val, ["42"]) def test_keys(self): + """keys() should return an empty list for any carrier.""" getter = CeleryGetter() keys = getter.keys({}) self.assertEqual(keys, []) diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py index 57ceb51d36..d26722a111 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py @@ -1,18 +1,58 @@ +from __future__ import annotations + import threading import time from platform import python_implementation from timeit import default_timer +from typing import TYPE_CHECKING +from unittest.mock import MagicMock +from celery.worker.request import Request from pytest import mark -from opentelemetry.instrumentation.celery import CeleryInstrumentor +from opentelemetry.instrumentation.celery import ( + CeleryInstrumentor, + CeleryWorkerInstrumentor, + _retrieve_task_name, + _retrieve_worker_name, +) from opentelemetry.test.test_base import TestBase from .celery_test_tasks import app, task_add +if TYPE_CHECKING: + from opentelemetry.sdk.metrics.export import Metric + SCOPE = "opentelemetry.instrumentation.celery" +def _find_metric(metrics: list[Metric], name: str) -> Metric | None: + """Find a metric by name in the list of metrics.""" + for metric in metrics: + if metric.name == name: + return metric + return None + + +def _make_request(task_id: str = "test-id-123", hostname: str = "celery@test"): + """Create a minimal celery Request for testing.""" + msg = MagicMock() + msg.headers = {"id": task_id, "task": task_add.name} + msg.payload = ( + [], + {}, + { + "callbacks": None, + "errbacks": None, + "chain": None, + "chord": None, + }, + ) + msg.delivery_info = {} + msg.properties = {} + return Request(msg, app=app, hostname=hostname, task=task_add) + + class TestMetrics(TestBase): def setUp(self): super().setUp() @@ -39,17 +79,16 @@ def get_metrics(self): return self.get_sorted_metrics(SCOPE) def test_basic_metric(self): + """Executing a task should record a task runtime histogram.""" CeleryInstrumentor().instrument() start_time = default_timer() task_runtime_estimated = (default_timer() - start_time) * 1000 metrics = self.get_metrics() CeleryInstrumentor().uninstrument() - self.assertEqual(len(metrics), 1) - task_runtime = metrics[0] - print(task_runtime) - self.assertEqual(task_runtime.name, "flower.task.runtime.seconds") + task_runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertIsNotNone(task_runtime) self.assert_metric_expected( task_runtime, [ @@ -71,24 +110,414 @@ def test_basic_metric(self): python_implementation() == "PyPy", reason="Fails randomly in pypy" ) def test_metric_uninstrument(self): + """After uninstrument, no new metric data points should be recorded.""" CeleryInstrumentor().instrument() metrics = self.get_metrics() + task_runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertIsNotNone(task_runtime) self.assertEqual( - metrics[0].data.data_points[0].bucket_counts[1], + task_runtime.data.data_points[0].bucket_counts[1], 1, ) metrics = self.get_metrics() + task_runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertIsNotNone(task_runtime) self.assertEqual( - metrics[0].data.data_points[0].bucket_counts[1], + task_runtime.data.data_points[0].bucket_counts[1], 2, ) CeleryInstrumentor().uninstrument() metrics = self.get_metrics() + task_runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertIsNotNone(task_runtime) self.assertEqual( - metrics[0].data.data_points[0].bucket_counts[1], + task_runtime.data.data_points[0].bucket_counts[1], 2, ) + + +class TestMetricsIntegration(TestBase): + """End-to-end integration tests: real worker, real task, full signal chain.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + self.worker_instrumentor = CeleryWorkerInstrumentor() + self.worker_instrumentor.instrument() + self._worker = app.Worker( + app=app, pool="solo", concurrency=1, hostname="celery@e2e" + ) + self._thread = threading.Thread(target=self._worker.start) + self._thread.daemon = True + self._thread.start() + + def tearDown(self): + self._worker.stop() + self._thread.join() + self.instrumentor.uninstrument() + self.worker_instrumentor.uninstrument() + super().tearDown() + + @staticmethod + def _run_task(): + """Execute task_add through a real worker and wait for completion.""" + result = task_add.delay(1, 2) + timeout = time.time() + 60 + while not result.ready(): + if time.time() > timeout: + break + time.sleep(0.05) + return result + + def test_events_total_recorded(self): + """A completed task should record task-sent, task-received, task-started, task-succeeded events.""" + self._run_task() + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + + recorded_types = { + dp.attributes["type"] for dp in events.data.data_points + } + for expected in ( + "task-sent", + "task-received", + "task-started", + "task-succeeded", + ): + self.assertIn( + expected, + recorded_types, + f"Expected event type '{expected}' not found in {recorded_types}", + ) + + def test_task_runtime_histogram_recorded(self): + """A completed task should produce a flower.task.runtime.seconds histogram.""" + self._run_task() + metrics = self.get_sorted_metrics(SCOPE) + runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertIsNotNone(runtime) + self.assertGreater(len(runtime.data.data_points), 0) + + def test_executing_tasks_gauge_returns_to_zero(self): + """After task completes, executing gauge should be back to zero.""" + self._run_task() + metrics = self.get_sorted_metrics(SCOPE) + executing = _find_metric( + metrics, "flower.worker.number.of.currently.executing.tasks" + ) + self.assertIsNotNone(executing) + self.assertEqual(executing.data.data_points[0].value, 0) + + def test_metric_attributes_contain_task_and_worker(self): + """Event metrics should carry both task and worker attributes.""" + self._run_task() + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + + # Check a data point that should have both task and worker (e.g. task-started) + started = [ + dp + for dp in events.data.data_points + if dp.attributes.get("type") == "task-started" + ] + self.assertEqual(len(started), 1) + self.assertEqual( + started[0].attributes["task"], + "tests.celery_test_tasks.task_add", + ) + self.assertEqual(started[0].attributes["worker"], "celery@e2e") + + +class TestWorkerMetricsIntegration(TestBase): + """End-to-end integration tests for CeleryWorkerInstrumentor with a real worker.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryWorkerInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_worker_online_on_start(self): + """Starting a real worker should set flower.worker.online to 1.""" + worker = app.Worker( + app=app, pool="solo", concurrency=1, hostname="celery@e2e-worker" + ) + thread = threading.Thread(target=worker.start) + thread.daemon = True + thread.start() + + # Give the worker time to emit worker_ready signal + time.sleep(0.5) + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + self.assertIsNotNone(worker_online) + dp = worker_online.data.data_points[0] + self.assertEqual(dp.value, 1) + self.assertEqual(dp.attributes["worker"], "celery@e2e-worker") + + worker.stop() + thread.join() + + def test_worker_offline_on_stop(self): + """Stopping a real worker should set flower.worker.online back to 0.""" + worker = app.Worker( + app=app, pool="solo", concurrency=1, hostname="celery@e2e-worker2" + ) + thread = threading.Thread(target=worker.start) + thread.daemon = True + thread.start() + + time.sleep(0.5) + worker.stop() + thread.join() + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + self.assertIsNotNone(worker_online) + self.assertEqual(worker_online.data.data_points[0].value, 0) + + +class TestWorkerMetrics(TestBase): + """Tests for CeleryWorkerInstrumentor worker lifecycle metrics.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryWorkerInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_worker_ready_increments_online(self): + """Worker ready signal should increment flower.worker.online gauge.""" + sender = type("Sender", (), {"hostname": "celery@worker1"})() + self.instrumentor._trace_worker_ready(sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + self.assertIsNotNone(worker_online) + self.assertEqual(len(worker_online.data.data_points), 1) + self.assertEqual(worker_online.data.data_points[0].value, 1) + self.assertEqual( + dict(worker_online.data.data_points[0].attributes), + {"worker": "celery@worker1"}, + ) + + def test_worker_ready_idempotent(self): + """Duplicate worker ready signals should not double-count.""" + sender = type("Sender", (), {"hostname": "celery@worker1"})() + self.instrumentor._trace_worker_ready(sender=sender) + self.instrumentor._trace_worker_ready(sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + self.assertIsNotNone(worker_online) + # Still 1 — second call was a no-op + self.assertEqual(worker_online.data.data_points[0].value, 1) + + def test_worker_shutdown_decrements_online(self): + """Worker shutdown should decrement the online gauge back to zero.""" + sender = type("Sender", (), {"hostname": "celery@worker1"})() + self.instrumentor._trace_worker_ready(sender=sender) + self.instrumentor._trace_worker_shutdown(sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + self.assertIsNotNone(worker_online) + self.assertEqual(worker_online.data.data_points[0].value, 0) + + def test_worker_shutdown_unknown_worker_noop(self): + """Shutdown for an unknown worker should not raise or record anything.""" + sender = type("Sender", (), {"hostname": "celery@unknown"})() + self.instrumentor._trace_worker_shutdown(sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + worker_online = _find_metric(metrics, "flower.worker.online") + # No data points recorded + self.assertTrue( + worker_online is None or len(worker_online.data.data_points) == 0 + ) + + def test_uninstrument_disconnects_signals(self): + """Uninstrumenting should disconnect signals and reset state.""" + sender = type("Sender", (), {"hostname": "celery@test-disconnect"})() + self.instrumentor._trace_worker_ready(sender=sender) + self.assertIn( + "celery@test-disconnect", self.instrumentor.online_workers + ) + + # Uninstrument and verify signal no longer fires our handler + self.instrumentor.uninstrument() + + # Re-instrument to check the signal was truly disconnected + # (online_workers was reset) + self.instrumentor.instrument() + self.assertNotIn( + "celery@test-disconnect", self.instrumentor.online_workers + ) + + +class TestTaskReceivedMetrics(TestBase): + """Tests for _trace_task_received signal handler and its metric side-effects.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryWorkerInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_task_received_records_event(self): + """Firing task_received should increment events_total.""" + request = _make_request(task_id="rcv-1", hostname="celery@w1") + sender = type("Sender", (), {"hostname": "celery@w1"})() + + self.instrumentor._trace_task_received(request=request, sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={ + "task": task_add.name, + "type": "task-received", + "worker": "celery@w1", + }, + ) + ], + ) + + def test_task_received_invalid_request_is_noop(self): + """Non-Request objects should be silently ignored.""" + self.instrumentor._trace_task_received(request=None, sender=None) + self.instrumentor._trace_task_received( + request="not-a-request", sender=None + ) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + +class TestTaskRevokedMetrics(TestBase): + """Tests for _trace_task_revoked signal handler.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryWorkerInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_task_revoked_records_event(self): + """Revoking a task should record a task-revoked event.""" + request = _make_request(task_id="rev-1", hostname="celery@w1") + sender = type("Sender", (), {"hostname": "celery@w1"})() + + self.instrumentor._trace_task_revoked(request=request, sender=sender) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={ + "task": task_add.name, + "type": "task-revoked", + "worker": "celery@w1", + }, + ) + ], + ) + + def test_task_revoked_invalid_request_is_noop(self): + """Non-Request objects should be silently ignored.""" + self.instrumentor._trace_task_revoked(request=None, sender=None) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + +class TestRetrieveTaskName(TestBase): + """Tests for the _retrieve_task_name helper function.""" + + def test_from_task_name(self): + """Task with a name attribute should return that name.""" + task = type("Task", (), {"name": "my.task"})() + self.assertEqual(_retrieve_task_name(task=task), "my.task") + + def test_from_request_task_object_with_name(self): + """Request.task as an object with .name should be used.""" + inner_task = type("Task", (), {"name": "inner.task"})() + request = type("Request", (), {"task": inner_task, "name": None})() + self.assertEqual(_retrieve_task_name(request=request), "inner.task") + + def test_from_request_task_string(self): + """Request.task as a string should be returned directly.""" + request = type("Request", (), {"task": "string.task", "name": None})() + self.assertEqual(_retrieve_task_name(request=request), "string.task") + + def test_from_request_name(self): + """Request.name should be used as fallback when task is None.""" + request = type("Request", (), {"task": None, "name": "req.name"})() + self.assertEqual(_retrieve_task_name(request=request), "req.name") + + def test_returns_none_when_no_info(self): + """No arguments or a task with name=None should return None.""" + self.assertIsNone(_retrieve_task_name()) + self.assertIsNone( + _retrieve_task_name(task=type("T", (), {"name": None})()) + ) + + +class TestRetrieveWorkerName(TestBase): + """Tests for the _retrieve_worker_name helper function.""" + + def test_from_task_request_hostname(self): + """task.request.hostname should be preferred source.""" + inner_req = type("Req", (), {"hostname": "celery@from-task"})() + task = type("Task", (), {"request": inner_req})() + self.assertEqual(_retrieve_worker_name(task=task), "celery@from-task") + + def test_from_request_hostname(self): + """Request object hostname should be used when no task is given.""" + request = _make_request(hostname="celery@from-req") + self.assertEqual( + _retrieve_worker_name(request=request), "celery@from-req" + ) + + def test_from_sender_hostname(self): + """Sender hostname should be used as last resort.""" + sender = type("Sender", (), {"hostname": "celery@from-sender"})() + self.assertEqual( + _retrieve_worker_name(sender=sender), "celery@from-sender" + ) + + def test_returns_none_when_no_info(self): + """No arguments or sender without hostname should return None.""" + self.assertIsNone(_retrieve_worker_name()) + self.assertIsNone(_retrieve_worker_name(sender=type("S", (), {})())) diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics_edge_cases.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics_edge_cases.py new file mode 100644 index 0000000000..9d4fe6b2a8 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics_edge_cases.py @@ -0,0 +1,617 @@ +from __future__ import annotations + +from timeit import default_timer +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from billiard.einfo import ExceptionInfo +from celery.worker.request import Request + +from opentelemetry.instrumentation.celery import ( + CeleryInstrumentor, + CeleryWorkerInstrumentor, + utils, +) +from opentelemetry.test.test_base import TestBase +from opentelemetry.trace.status import StatusCode + +from .celery_test_tasks import app, task_add + +if TYPE_CHECKING: + from opentelemetry.sdk.metrics.export import Metric + +SCOPE = "opentelemetry.instrumentation.celery" + + +def _find_metric(metrics: list[Metric], name: str) -> Metric | None: + """Find a metric by name in the list of metrics.""" + for metric in metrics: + if metric.name == name: + return metric + return None + + +def _make_request(task_id: str = "test-id-123", hostname: str = "celery@test"): + """Create a minimal celery Request for testing.""" + msg = MagicMock() + msg.headers = {"id": task_id, "task": task_add.name} + msg.payload = ( + [], + {}, + { + "callbacks": None, + "errbacks": None, + "chain": None, + "chord": None, + }, + ) + msg.delivery_info = {} + msg.properties = {} + return Request(msg, app=app, hostname=hostname, task=task_add) + + +class TestMetricsNotInitialized(TestBase): + """Tests for error paths when metrics are not initialized.""" + + def test_task_metrics_raises_before_instrument(self): + """Accessing _metrics() before instrument() should raise RuntimeError.""" + instrumentor = CeleryInstrumentor() + instrumentor.metrics = None + with self.assertRaises(RuntimeError): + instrumentor._metrics() + + def test_worker_metrics_raises_before_instrument(self): + """Accessing _worker_metrics() before instrument() should raise RuntimeError.""" + instrumentor = CeleryWorkerInstrumentor() + instrumentor.metrics = None + with self.assertRaises(RuntimeError): + instrumentor._worker_metrics() + + +class TestRecordEventCountGuards(TestBase): + """Tests for _record_event_count early-return when task_name is None.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_record_event_count_none_task_name_is_noop(self): + """None task_name should cause an early return with no metric recorded.""" + self.instrumentor._record_event_count("task-received", task_name=None) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + def test_record_event_count_without_worker(self): + """Event should be recorded without worker attribute when worker is None.""" + self.instrumentor._record_event_count( + "task-sent", task_name="my.task", worker=None + ) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={"task": "my.task", "type": "task-sent"}, + ) + ], + ) + + +class TestTrackingGuards(TestBase): + """Tests for guard paths in track/untrack helpers.""" + + def setUp(self): + super().setUp() + self.task_instrumentor = CeleryInstrumentor() + self.task_instrumentor.instrument() + self.worker_instrumentor = CeleryWorkerInstrumentor() + self.worker_instrumentor.instrument() + + def tearDown(self): + self.task_instrumentor.uninstrument() + self.worker_instrumentor.uninstrument() + super().tearDown() + + def test_track_executing_task_none_args_is_noop(self): + """None task_id or worker should cause an early return.""" + self.task_instrumentor._track_executing_task(None, "worker") + self.task_instrumentor._track_executing_task("id", None) + self.assertEqual( + len(self.task_instrumentor.executing_task_id_to_worker), 0 + ) + + def test_untrack_executing_task_unknown_id_is_noop(self): + """Untracking an unknown task_id should not raise or record.""" + self.task_instrumentor._untrack_executing_task("nonexistent-id") + metrics = self.get_sorted_metrics(SCOPE) + executing = _find_metric( + metrics, "flower.worker.number.of.currently.executing.tasks" + ) + self.assertTrue( + executing is None or len(executing.data.data_points) == 0 + ) + + def test_record_histograms_none_task_id_is_noop(self): + """None task_id should skip histogram recording.""" + self.task_instrumentor._record_histograms( + None, {"task": "t", "worker": "w"} + ) + metrics = self.get_sorted_metrics(SCOPE) + runtime = _find_metric(metrics, "flower.task.runtime.seconds") + self.assertTrue(runtime is None or len(runtime.data.data_points) == 0) + + +class TestTraceRetry(TestBase): + """Tests for _trace_retry signal handler.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def _setup_task_with_span(self): + """Create a task with an attached span context, as _trace_retry expects.""" + task = task_add + task_id = "retry-1" + tracer = self.tracer_provider.get_tracer("test") + span = tracer.start_span("test") + activation = MagicMock() + activation.__enter__ = MagicMock(return_value=span) + activation.__exit__ = MagicMock(return_value=False) + utils.attach_context(task, task_id, span, activation, None) + return task, task_id, span + + def test_trace_retry_records_reason_and_event(self): + """Retry with attached span should set reason attribute and record event.""" + task, task_id, span = self._setup_task_with_span() + request = type("Request", (), {"id": task_id})() + + self.instrumentor._trace_retry( + sender=task, + task_id=task_id, + request=request, + reason=Exception("connection lost"), + ) + + self.assertTrue(span.is_recording()) + self.assertEqual( + span.attributes.get("celery.retry.reason"), "connection lost" + ) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={ + "task": task_add.name, + "type": "task-retried", + }, + ) + ], + ) + # cleanup + utils.detach_context(task, task_id) + + def test_trace_retry_no_task_is_noop(self): + """Retry with no sender/task should be silently ignored.""" + self.instrumentor._trace_retry( + sender=None, task_id=None, request=None, reason=None + ) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + def test_trace_retry_no_context_is_noop(self): + """If no span context is attached, retry should bail out silently.""" + request = type("Request", (), {"id": "no-ctx"})() + self.instrumentor._trace_retry( + sender=task_add, + request=request, + reason=Exception("retry"), + ) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + +class TestTracePrerunPostrunGuards(TestBase): + """Tests for early-return guards in _trace_prerun and _trace_postrun.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_prerun_no_task_is_noop(self): + """Prerun with no task should produce no spans.""" + self.instrumentor._trace_prerun(task=None, task_id=None) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + def test_postrun_no_task_is_noop(self): + """Postrun with no task should produce no spans.""" + self.instrumentor._trace_postrun(task=None, task_id=None) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + def test_postrun_no_context_is_noop(self): + """When no span was attached, postrun should log and return.""" + self.instrumentor._trace_postrun(task=task_add, task_id="no-ctx") + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + +class TestTracePostrunStateMetrics(TestBase): + """Tests for postrun event counting based on task state.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def _attach_context_to_task(self, task, task_id: str): + """Attach a live span context to a task for postrun tests.""" + tracer = self.tracer_provider.get_tracer("test") + span = tracer.start_span("test") + activation = MagicMock() + activation.__enter__ = MagicMock(return_value=span) + activation.__exit__ = MagicMock(return_value=False) + utils.attach_context(task, task_id, span, activation, None) + + def test_postrun_records_task_succeeded_only_for_success_state(self): + """Postrun should record task-succeeded when the task state is SUCCESS.""" + task_id = "postrun-success" + request = type( + "RequestContext", + (dict,), + {"hostname": "celery@w1", "state": "SUCCESS"}, + )({}) + task = type("Task", (), {"name": task_add.name, "request": request})() + self.instrumentor.task_id_to_start_time[task_id] = default_timer() + self.instrumentor._track_executing_task(task_id, "celery@w1") + self._attach_context_to_task(task, task_id) + + self.instrumentor._trace_postrun(task=task, task_id=task_id) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={ + "task": task_add.name, + "type": "task-succeeded", + "worker": "celery@w1", + }, + ) + ], + ) + + def test_postrun_skips_task_succeeded_for_non_success_state(self): + """Postrun should not record task-succeeded when the task state is not SUCCESS.""" + task_id = "postrun-failure" + request = type( + "RequestContext", + (dict,), + {"hostname": "celery@w1", "state": "FAILURE"}, + )({}) + task = type("Task", (), {"name": task_add.name, "request": request})() + self.instrumentor.task_id_to_start_time[task_id] = default_timer() + self.instrumentor._track_executing_task(task_id, "celery@w1") + self._attach_context_to_task(task, task_id) + + self.instrumentor._trace_postrun(task=task, task_id=task_id) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + +class TestTracePublishGuards(TestBase): + """Tests for guard paths in _trace_before_publish and _trace_after_publish.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_before_publish_no_task_id_is_noop(self): + """Before publish with no task ID in headers should produce no spans.""" + self.instrumentor._trace_before_publish( + sender=None, headers={}, body=None + ) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + def test_before_publish_anonymous_task(self): + """When task is not found in registry, sender string is used as name.""" + self.instrumentor._trace_before_publish( + sender="some.unknown.task", + headers={"id": "pub-1"}, + body=None, + ) + # Span started but not finished (no after_publish called) + # Verify via event count instead + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + self.assert_metric_expected( + events, + [ + self.create_number_data_point( + value=1, + attributes={ + "task": "some.unknown.task", + "type": "task-sent", + }, + ) + ], + ) + + def test_after_publish_no_task_is_noop(self): + """After publish with no sender should produce no spans.""" + CeleryInstrumentor._trace_after_publish( + sender=None, headers={}, body=None + ) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + def test_after_publish_no_context_is_noop(self): + """When no span was attached, after_publish should log and return.""" + CeleryInstrumentor._trace_after_publish( + sender=task_add, headers={"id": "missing-ctx"}, body=None + ) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + +class TestTraceFailureGuards(TestBase): + """Tests for guard paths in _trace_failure.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_failure_no_task_is_noop(self): + """Failure with no sender should not record any events.""" + self.instrumentor._trace_failure(sender=None, task_id=None, einfo=None) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + def test_failure_no_context_is_noop(self): + """Failure with no attached span context should not record events.""" + self.instrumentor._trace_failure( + sender=task_add, task_id="no-ctx", einfo=None + ) + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + def test_failure_not_recording_is_noop(self): + """When span is not recording, failure should bail.""" + span = MagicMock() + span.is_recording.return_value = False + activation = MagicMock() + utils.attach_context(task_add, "nr-1", span, activation, None) + + self.instrumentor._trace_failure( + sender=task_add, task_id="nr-1", einfo=None + ) + span.set_status.assert_not_called() + utils.detach_context(task_add, "nr-1") + + def test_failure_task_throws_skipped(self): + """Exceptions listed in task.throws should not be recorded.""" + tracer = self.tracer_provider.get_tracer("test") + span = tracer.start_span("test") + activation = MagicMock() + activation.__enter__ = MagicMock(return_value=span) + activation.__exit__ = MagicMock(return_value=False) + + task = task_add + task_id = "throws-1" + # Temporarily set throws + original_throws = getattr(task, "throws", ()) + task.throws = (ValueError,) + utils.attach_context(task, task_id, span, activation, None) + + einfo = MagicMock() + einfo.exception = ValueError("expected") + + self.instrumentor._trace_failure( + sender=task, task_id=task_id, einfo=einfo + ) + # Status should NOT be set to ERROR for throws exceptions + self.assertNotEqual( + span.status.status_code, + 2, # StatusCode.ERROR + ) + utils.detach_context(task, task_id) + task.throws = original_throws + + +class TestMemoryLeakPrevention(TestBase): + """Tests that verify internal state dicts are cleaned up to prevent memory leaks.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def _attach_context_to_task(self, task, task_id: str): + """Attach a live span context to a task.""" + tracer = self.tracer_provider.get_tracer("test") + span = tracer.start_span("test") + activation = MagicMock() + activation.__enter__ = MagicMock(return_value=span) + activation.__exit__ = MagicMock(return_value=False) + utils.attach_context(task, task_id, span, activation, None) + + def test_postrun_cleans_up_start_time(self): + """After postrun, task_id_to_start_time should not retain the task_id.""" + task_id = "leak-1" + request = type( + "RequestContext", + (dict,), + {"hostname": "celery@w1", "state": "SUCCESS"}, + )({}) + task = type("Task", (), {"name": task_add.name, "request": request})() + self.instrumentor.task_id_to_start_time[task_id] = default_timer() + self._attach_context_to_task(task, task_id) + + self.instrumentor._trace_postrun(task=task, task_id=task_id) + + self.assertNotIn(task_id, self.instrumentor.task_id_to_start_time) + + +class TestTraceRetryNotRecording(TestBase): + """Tests for _trace_retry when span is not recording.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_trace_retry_not_recording_is_noop(self): + """Retry with a non-recording span should not set attributes or record events.""" + span = MagicMock() + span.is_recording.return_value = False + activation = MagicMock() + task_id = "retry-nr" + utils.attach_context(task_add, task_id, span, activation, None) + request = type("Request", (), {"id": task_id})() + + self.instrumentor._trace_retry( + sender=task_add, + request=request, + reason=Exception("retry reason"), + ) + + span.set_attribute.assert_not_called() + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertTrue(events is None or len(events.data.data_points) == 0) + + utils.detach_context(task_add, task_id) + + +class TestTraceFailureRecordsEvent(TestBase): + """Tests for _trace_failure happy path — exception recorded and event counted.""" + + def setUp(self): + super().setUp() + self.instrumentor = CeleryInstrumentor() + self.instrumentor.instrument() + + def tearDown(self): + self.instrumentor.uninstrument() + super().tearDown() + + def test_failure_records_exception_and_event(self): + """A genuine failure should set ERROR status and record a task-failed event.""" + tracer = self.tracer_provider.get_tracer("test") + span = tracer.start_span("test") + activation = MagicMock() + activation.__enter__ = MagicMock(return_value=span) + activation.__exit__ = MagicMock(return_value=False) + task_id = "fail-1" + utils.attach_context(task_add, task_id, span, activation, None) + + einfo = None + try: + raise RuntimeError("something broke") + except RuntimeError: + einfo = ExceptionInfo() + + self.instrumentor._trace_failure( + sender=task_add, task_id=task_id, einfo=einfo + ) + + self.assertEqual(span.status.status_code, StatusCode.ERROR) + + metrics = self.get_sorted_metrics(SCOPE) + events = _find_metric(metrics, "flower.events.total") + self.assertIsNotNone(events) + recorded_types = { + dp.attributes["type"] for dp in events.data.data_points + } + self.assertIn("task-failed", recorded_types) + + utils.detach_context(task_add, task_id) + + +class TestUninstrumentClearsState(TestBase): + """Tests that uninstrument resets all internal state.""" + + def test_uninstrument_clears_task_instrumentor_state(self): + """After uninstrument, all tracking dicts and metrics should be reset.""" + instrumentor = CeleryInstrumentor() + instrumentor.instrument() + + # Simulate some accumulated state + instrumentor.task_id_to_start_time["t1"] = 1.0 + instrumentor.executing_task_id_to_worker["t1"] = "celery@w" + + instrumentor.uninstrument() + + self.assertIsNone(instrumentor.metrics) + self.assertEqual(instrumentor.task_id_to_start_time, {}) + self.assertEqual(instrumentor.executing_task_id_to_worker, {}) + + def test_uninstrument_clears_worker_instrumentor_state(self): + """After uninstrument, online_workers, tracking dicts, and metrics should be reset.""" + instrumentor = CeleryWorkerInstrumentor() + instrumentor.instrument() + + instrumentor.online_workers.add("celery@w1") + + instrumentor.uninstrument() + + self.assertIsNone(instrumentor.metrics) + self.assertEqual(instrumentor.online_workers, set())