diff --git a/.changelog/4505.fixed b/.changelog/4505.fixed new file mode 100644 index 0000000000..392e04d041 --- /dev/null +++ b/.changelog/4505.fixed @@ -0,0 +1 @@ +`opentelemetry-instrumentation-celery`: add null guards and type-safe helper handling around Celery context propagation internals \ No newline at end of file diff --git a/docs/nitpick-exceptions.ini b/docs/nitpick-exceptions.ini index 128e406a70..6b858a3726 100644 --- a/docs/nitpick-exceptions.ini +++ b/docs/nitpick-exceptions.ini @@ -32,6 +32,7 @@ py-class= httpx.URL httpx.Headers aiohttp.web_request.Request + celery.worker.request.Request yarl.URL cimpl.Producer cimpl.Consumer diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index 1ebae832b6..e6b1d37fbf 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -48,13 +48,16 @@ def add(x, y): --- """ +from __future__ import annotations + import logging +from collections.abc import Collection, Iterable from timeit import default_timer -from typing import Collection, Iterable from billiard import VERSION from billiard.einfo import ExceptionInfo from celery import signals # pylint: disable=no-name-in-module +from celery.worker.request import Request # pylint: disable=no-name-in-module from opentelemetry import context as context_api from opentelemetry import trace @@ -88,8 +91,8 @@ def add(x, y): _TASK_NAME_KEY = "celery.task_name" -class CeleryGetter(Getter): - def get(self, carrier, key): +class CeleryGetter(Getter[Request]): + def get(self, carrier: Request, key: str) -> list[str] | None: value = getattr(carrier, key, None) if value is None: return None @@ -98,16 +101,12 @@ def get(self, carrier, key): # of ints). The TextMapPropagator contract requires string # values, so coerce anything that isn't already a string. if isinstance(value, str): - value = (value,) - elif isinstance(value, Iterable): - value = tuple( - str(v) if not isinstance(v, str) else v for v in value - ) - else: - value = (str(value),) - return value + return [value] + if isinstance(value, Iterable): + return [str(v) if not isinstance(v, str) else v for v in value] + return [str(value)] - def keys(self, carrier): + def keys(self, carrier: Request) -> list[str]: return [] diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py index a1882b2363..be9d073f6a 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py @@ -4,7 +4,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Tuple +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Protocol, cast from celery import registry # pylint: disable=no-name-in-module from celery.app.task import Task @@ -18,6 +19,15 @@ if TYPE_CHECKING: from contextlib import AbstractContextManager +ContextKey = tuple[str, bool] +ContextTuple = tuple[Span, "AbstractContextManager[Span]", object | None] +ContextDict = dict[ContextKey, ContextTuple] + + +class ContextCarrier(Protocol): + def get(self, key: str, default: Any = None) -> Any: ... + + logger = logging.getLogger(__name__) # Celery Context key @@ -48,7 +58,10 @@ # pylint:disable=too-many-branches -def set_attributes_from_context(span, context): +def set_attributes_from_context( + span: Span, + context: ContextCarrier, +) -> None: """Helper to extract meta values from a Celery Context""" if not span.is_recording(): return @@ -144,7 +157,7 @@ def attach_context( if task is None: return - ctx_dict = getattr(task, CTX_KEY, None) + ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if ctx_dict is None: ctx_dict = {} @@ -153,12 +166,17 @@ def attach_context( ctx_dict[(task_id, is_publish)] = (span, activation, token) -def detach_context(task, task_id, is_publish=False) -> None: +def detach_context( + task: Optional[Task], task_id: str, is_publish: bool = False +) -> None: """Helper to remove `Span`, `ContextManager` and context token in a Celery task when it's propagated. This function handles tasks where no values are attached to the `Task`. """ - span_dict = getattr(task, CTX_KEY, None) + if task is None: + return + + span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if span_dict is None: return @@ -167,12 +185,15 @@ def detach_context(task, task_id, is_publish=False) -> None: def retrieve_context( - task, task_id, is_publish=False -) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]: + task: Optional[Task], task_id: str, is_publish: bool = False +) -> Optional[ContextTuple]: """Helper to retrieve an active `Span`, `ContextManager` and context token stored in a `Task` instance """ - span_dict = getattr(task, CTX_KEY, None) + if task is None: + return None + + span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None)) if span_dict is None: return None @@ -180,14 +201,14 @@ def retrieve_context( return span_dict.get((task_id, is_publish), None) -def retrieve_task(kwargs): +def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]: task = kwargs.get("task") if task is None: logger.debug("Unable to retrieve task from signal arguments") - return task + return cast(Optional[Task], task) -def retrieve_task_from_sender(kwargs): +def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]: sender = kwargs.get("sender") if sender is None: logger.debug("Unable to retrieve the sender from signal arguments") @@ -199,30 +220,31 @@ def retrieve_task_from_sender(kwargs): if sender is None: logger.debug("Unable to retrieve the task from sender=%s", sender) - return sender + return cast(Optional[Task], sender) -def retrieve_task_id(kwargs): +def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]: task_id = kwargs.get("task_id") if task_id is None: logger.debug("Unable to retrieve task_id from signal arguments") - return task_id + return cast(Optional[str], task_id) -def retrieve_task_id_from_request(kwargs): +def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]: # retry signal does not include task_id as argument so use request argument request = kwargs.get("request") if request is None: logger.debug("Unable to retrieve the request from signal arguments") + return None - task_id = getattr(request, "id") + task_id = cast(Optional[str], getattr(request, "id", None)) if task_id is None: logger.debug("Unable to retrieve the task_id from the request") return task_id -def retrieve_task_id_from_message(kwargs): +def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]: """Helper to retrieve the `Task` identifier from the message `body`. This helper supports Protocol Version 1 and 2. The Protocol is well detailed in the official documentation: @@ -232,12 +254,14 @@ def retrieve_task_id_from_message(kwargs): body = kwargs.get("body") if headers is not None and len(headers) > 0: # Protocol Version 2 (default from Celery 4.0) - return headers.get("id") + return cast(Optional[str], headers.get("id")) # Protocol Version 1 - return body.get("id") + if body is None: + return None + return cast(Optional[str], body.get("id")) -def retrieve_reason(kwargs): +def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]: reason = kwargs.get("reason") if not reason: logger.debug("Unable to retrieve the retry reason") diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py index fd8189030b..44c6ece596 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py @@ -8,24 +8,27 @@ class TestCeleryGetter(TestCase): def test_get_none(self): + """Missing attribute on carrier should return None.""" getter = CeleryGetter() carrier = {} val = getter.get(carrier, "test") self.assertIsNone(val) def test_get_str(self): + """String attribute should be wrapped in a single-element list.""" mock_obj = mock.Mock() getter = CeleryGetter() mock_obj.test = "val" val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val",)) + self.assertEqual(val, ["val"]) def test_get_iter(self): + """Iterable attribute should be returned as a list.""" mock_obj = mock.Mock() getter = CeleryGetter() mock_obj.test = ["val"] val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val",)) + self.assertEqual(val, ["val"]) def test_get_int(self): """Non-string scalar values should be coerced to strings. @@ -39,7 +42,7 @@ def test_get_int(self): getter = CeleryGetter() mock_obj.test = 42 val = getter.get(mock_obj, "test") - self.assertEqual(val, ("42",)) + self.assertEqual(val, ["42"]) def test_get_iter_with_non_string_elements(self): """Iterable values containing non-strings should be coerced. @@ -50,7 +53,7 @@ def test_get_iter_with_non_string_elements(self): getter = CeleryGetter() mock_obj.test = (300, 60) val = getter.get(mock_obj, "test") - self.assertEqual(val, ("300", "60")) + self.assertEqual(val, ["300", "60"]) def test_get_iter_with_mixed_types(self): """Iterables with a mix of strings and non-strings.""" @@ -58,9 +61,18 @@ def test_get_iter_with_mixed_types(self): getter = CeleryGetter() mock_obj.test = ["val", 123] val = getter.get(mock_obj, "test") - self.assertEqual(val, ("val", "123")) + self.assertEqual(val, ["val", "123"]) + + def test_get_non_str_non_iterable(self): + """Non-string, non-iterable value should be coerced to [str(value)].""" + getter = CeleryGetter() + mock_obj = mock.Mock() + mock_obj.key = 42 + val = getter.get(mock_obj, "key") + self.assertEqual(val, ["42"]) def test_keys(self): + """keys() should return an empty list for any carrier.""" getter = CeleryGetter() keys = getter.keys({}) self.assertEqual(keys, [])