Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changelog/4505.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`opentelemetry-instrumentation-celery`: add null guards and type-safe helper handling around Celery context propagation internals
1 change: 1 addition & 0 deletions docs/nitpick-exceptions.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ py-class=
httpx.URL
httpx.Headers
aiohttp.web_request.Request
celery.worker.request.Request
yarl.URL
cimpl.Producer
cimpl.Consumer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def add(x, y):
---
"""

from __future__ import annotations

import logging
from collections.abc import Collection, Iterable
from timeit import default_timer
from typing import Collection, Iterable

from billiard import VERSION
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module
from celery.worker.request import Request # pylint: disable=no-name-in-module

from opentelemetry import context as context_api
from opentelemetry import trace
Expand Down Expand Up @@ -88,8 +91,8 @@ def add(x, y):
_TASK_NAME_KEY = "celery.task_name"


class CeleryGetter(Getter):
def get(self, carrier, key):
class CeleryGetter(Getter[Request]):
def get(self, carrier: Request, key: str) -> list[str] | None:
value = getattr(carrier, key, None)
if value is None:
return None
Expand All @@ -98,16 +101,12 @@ def get(self, carrier, key):
# of ints). The TextMapPropagator contract requires string
# values, so coerce anything that isn't already a string.
if isinstance(value, str):
value = (value,)
elif isinstance(value, Iterable):
value = tuple(
str(v) if not isinstance(v, str) else v for v in value
)
else:
value = (str(value),)
return value
return [value]
if isinstance(value, Iterable):
return [str(v) if not isinstance(v, str) else v for v in value]
return [str(value)]

def keys(self, carrier):
def keys(self, carrier: Request) -> list[str]:
return []


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional, Tuple
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional, Protocol, cast

from celery import registry # pylint: disable=no-name-in-module
from celery.app.task import Task
Expand All @@ -18,6 +19,15 @@
if TYPE_CHECKING:
from contextlib import AbstractContextManager

ContextKey = tuple[str, bool]
ContextTuple = tuple[Span, "AbstractContextManager[Span]", object | None]
ContextDict = dict[ContextKey, ContextTuple]


class ContextCarrier(Protocol):
def get(self, key: str, default: Any = None) -> Any: ...


logger = logging.getLogger(__name__)

# Celery Context key
Expand Down Expand Up @@ -48,7 +58,10 @@


# pylint:disable=too-many-branches
def set_attributes_from_context(span, context):
def set_attributes_from_context(
span: Span,
context: ContextCarrier,
) -> None:
"""Helper to extract meta values from a Celery Context"""
if not span.is_recording():
return
Expand Down Expand Up @@ -144,7 +157,7 @@ def attach_context(
if task is None:
return

ctx_dict = getattr(task, CTX_KEY, None)
ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))

if ctx_dict is None:
ctx_dict = {}
Expand All @@ -153,12 +166,17 @@ def attach_context(
ctx_dict[(task_id, is_publish)] = (span, activation, token)


def detach_context(task, task_id, is_publish=False) -> None:
def detach_context(
task: Optional[Task], task_id: str, is_publish: bool = False
) -> None:
"""Helper to remove `Span`, `ContextManager` and context token in a
Celery task when it's propagated.
This function handles tasks where no values are attached to the `Task`.
"""
span_dict = getattr(task, CTX_KEY, None)
if task is None:
return

span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
if span_dict is None:
return

Expand All @@ -167,27 +185,30 @@ def detach_context(task, task_id, is_publish=False) -> None:


def retrieve_context(
task, task_id, is_publish=False
) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]:
task: Optional[Task], task_id: str, is_publish: bool = False
) -> Optional[ContextTuple]:
"""Helper to retrieve an active `Span`, `ContextManager` and context token
stored in a `Task` instance
"""
span_dict = getattr(task, CTX_KEY, None)
if task is None:
return None

span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
if span_dict is None:
return None

# See note in `attach_context` for key info
return span_dict.get((task_id, is_publish), None)


def retrieve_task(kwargs):
def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]:
task = kwargs.get("task")
if task is None:
logger.debug("Unable to retrieve task from signal arguments")
return task
return cast(Optional[Task], task)


def retrieve_task_from_sender(kwargs):
def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]:
sender = kwargs.get("sender")
if sender is None:
logger.debug("Unable to retrieve the sender from signal arguments")
Expand All @@ -199,30 +220,31 @@ def retrieve_task_from_sender(kwargs):
if sender is None:
logger.debug("Unable to retrieve the task from sender=%s", sender)

return sender
return cast(Optional[Task], sender)


def retrieve_task_id(kwargs):
def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]:
task_id = kwargs.get("task_id")
if task_id is None:
logger.debug("Unable to retrieve task_id from signal arguments")
return task_id
return cast(Optional[str], task_id)


def retrieve_task_id_from_request(kwargs):
def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]:
# retry signal does not include task_id as argument so use request argument
request = kwargs.get("request")
if request is None:
logger.debug("Unable to retrieve the request from signal arguments")
return None

task_id = getattr(request, "id")
task_id = cast(Optional[str], getattr(request, "id", None))
if task_id is None:
logger.debug("Unable to retrieve the task_id from the request")

return task_id


def retrieve_task_id_from_message(kwargs):
def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]:
"""Helper to retrieve the `Task` identifier from the message `body`.
This helper supports Protocol Version 1 and 2. The Protocol is well
detailed in the official documentation:
Expand All @@ -232,12 +254,14 @@ def retrieve_task_id_from_message(kwargs):
body = kwargs.get("body")
if headers is not None and len(headers) > 0:
# Protocol Version 2 (default from Celery 4.0)
return headers.get("id")
return cast(Optional[str], headers.get("id"))
# Protocol Version 1
return body.get("id")
if body is None:
return None
return cast(Optional[str], body.get("id"))


def retrieve_reason(kwargs):
def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]:
reason = kwargs.get("reason")
if not reason:
logger.debug("Unable to retrieve the retry reason")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,27 @@

class TestCeleryGetter(TestCase):
def test_get_none(self):
"""Missing attribute on carrier should return None."""
getter = CeleryGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)

def test_get_str(self):
"""String attribute should be wrapped in a single-element list."""
mock_obj = mock.Mock()
getter = CeleryGetter()
mock_obj.test = "val"
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("val",))
self.assertEqual(val, ["val"])

def test_get_iter(self):
"""Iterable attribute should be returned as a list."""
mock_obj = mock.Mock()
getter = CeleryGetter()
mock_obj.test = ["val"]
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("val",))
self.assertEqual(val, ["val"])

def test_get_int(self):
"""Non-string scalar values should be coerced to strings.
Expand All @@ -39,7 +42,7 @@ def test_get_int(self):
getter = CeleryGetter()
mock_obj.test = 42
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("42",))
self.assertEqual(val, ["42"])

def test_get_iter_with_non_string_elements(self):
"""Iterable values containing non-strings should be coerced.
Expand All @@ -50,17 +53,26 @@ def test_get_iter_with_non_string_elements(self):
getter = CeleryGetter()
mock_obj.test = (300, 60)
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("300", "60"))
self.assertEqual(val, ["300", "60"])

def test_get_iter_with_mixed_types(self):
"""Iterables with a mix of strings and non-strings."""
mock_obj = mock.Mock()
getter = CeleryGetter()
mock_obj.test = ["val", 123]
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("val", "123"))
self.assertEqual(val, ["val", "123"])

def test_get_non_str_non_iterable(self):
"""Non-string, non-iterable value should be coerced to [str(value)]."""
getter = CeleryGetter()
mock_obj = mock.Mock()
mock_obj.key = 42
val = getter.get(mock_obj, "key")
self.assertEqual(val, ["42"])

def test_keys(self):
"""keys() should return an empty list for any carrier."""
getter = CeleryGetter()
keys = getter.keys({})
self.assertEqual(keys, [])
Loading