Skip to content

Commit 71a3c49

Browse files
committed
chore(celery): type getter and harden utility helpers
Assisted-by: GPT-5.4
1 parent a651600 commit 71a3c49

3 files changed

Lines changed: 66 additions & 37 deletions

File tree

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ def add(x, y):
4848
---
4949
"""
5050

51+
from __future__ import annotations
52+
5153
import logging
54+
from collections.abc import Collection, Iterable
5255
from timeit import default_timer
53-
from typing import Collection, Iterable
5456

5557
from billiard import VERSION
5658
from billiard.einfo import ExceptionInfo
5759
from celery import signals # pylint: disable=no-name-in-module
60+
from celery.worker.request import Request # pylint: disable=no-name-in-module
5861

5962
from opentelemetry import context as context_api
6063
from opentelemetry import trace
@@ -88,8 +91,8 @@ def add(x, y):
8891
_TASK_NAME_KEY = "celery.task_name"
8992

9093

91-
class CeleryGetter(Getter):
92-
def get(self, carrier, key):
94+
class CeleryGetter(Getter[Request]):
95+
def get(self, carrier: Request, key: str) -> list[str] | None:
9396
value = getattr(carrier, key, None)
9497
if value is None:
9598
return None
@@ -98,16 +101,12 @@ def get(self, carrier, key):
98101
# of ints). The TextMapPropagator contract requires string
99102
# values, so coerce anything that isn't already a string.
100103
if isinstance(value, str):
101-
value = (value,)
102-
elif isinstance(value, Iterable):
103-
value = tuple(
104-
str(v) if not isinstance(v, str) else v for v in value
105-
)
106-
else:
107-
value = (str(value),)
108-
return value
104+
return [value]
105+
if isinstance(value, Iterable):
106+
return [str(v) if not isinstance(v, str) else v for v in value]
107+
return [str(value)]
109108

110-
def keys(self, carrier):
109+
def keys(self, carrier: Request) -> list[str]:
111110
return []
112111

113112

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from __future__ import annotations
55

66
import logging
7-
from typing import TYPE_CHECKING, Optional, Tuple
7+
from collections.abc import Mapping
8+
from typing import TYPE_CHECKING, Any, Optional, Tuple, cast
89

910
from celery import registry # pylint: disable=no-name-in-module
1011
from celery.app.task import Task
@@ -18,6 +19,9 @@
1819
if TYPE_CHECKING:
1920
from contextlib import AbstractContextManager
2021

22+
ContextTuple = Tuple[Span, "AbstractContextManager[Span]", Optional[object]]
23+
ContextDict = dict[tuple[str, bool], ContextTuple]
24+
2125
logger = logging.getLogger(__name__)
2226

2327
# Celery Context key
@@ -48,7 +52,10 @@
4852

4953

5054
# pylint:disable=too-many-branches
51-
def set_attributes_from_context(span, context):
55+
def set_attributes_from_context(
56+
span: Span,
57+
context: Mapping[str, Any],
58+
) -> None:
5259
"""Helper to extract meta values from a Celery Context"""
5360
if not span.is_recording():
5461
return
@@ -144,7 +151,7 @@ def attach_context(
144151
if task is None:
145152
return
146153

147-
ctx_dict = getattr(task, CTX_KEY, None)
154+
ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
148155

149156
if ctx_dict is None:
150157
ctx_dict = {}
@@ -153,12 +160,17 @@ def attach_context(
153160
ctx_dict[(task_id, is_publish)] = (span, activation, token)
154161

155162

156-
def detach_context(task, task_id, is_publish=False) -> None:
163+
def detach_context(
164+
task: Optional[Task], task_id: str, is_publish: bool = False
165+
) -> None:
157166
"""Helper to remove `Span`, `ContextManager` and context token in a
158167
Celery task when it's propagated.
159168
This function handles tasks where no values are attached to the `Task`.
160169
"""
161-
span_dict = getattr(task, CTX_KEY, None)
170+
if task is None:
171+
return
172+
173+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
162174
if span_dict is None:
163175
return
164176

@@ -167,27 +179,30 @@ def detach_context(task, task_id, is_publish=False) -> None:
167179

168180

169181
def retrieve_context(
170-
task, task_id, is_publish=False
171-
) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]:
182+
task: Optional[Task], task_id: str, is_publish: bool = False
183+
) -> Optional[ContextTuple]:
172184
"""Helper to retrieve an active `Span`, `ContextManager` and context token
173185
stored in a `Task` instance
174186
"""
175-
span_dict = getattr(task, CTX_KEY, None)
187+
if task is None:
188+
return None
189+
190+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
176191
if span_dict is None:
177192
return None
178193

179194
# See note in `attach_context` for key info
180195
return span_dict.get((task_id, is_publish), None)
181196

182197

183-
def retrieve_task(kwargs):
198+
def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]:
184199
task = kwargs.get("task")
185200
if task is None:
186201
logger.debug("Unable to retrieve task from signal arguments")
187-
return task
202+
return cast(Optional[Task], task)
188203

189204

190-
def retrieve_task_from_sender(kwargs):
205+
def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]:
191206
sender = kwargs.get("sender")
192207
if sender is None:
193208
logger.debug("Unable to retrieve the sender from signal arguments")
@@ -199,30 +214,31 @@ def retrieve_task_from_sender(kwargs):
199214
if sender is None:
200215
logger.debug("Unable to retrieve the task from sender=%s", sender)
201216

202-
return sender
217+
return cast(Optional[Task], sender)
203218

204219

205-
def retrieve_task_id(kwargs):
220+
def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]:
206221
task_id = kwargs.get("task_id")
207222
if task_id is None:
208223
logger.debug("Unable to retrieve task_id from signal arguments")
209-
return task_id
224+
return cast(Optional[str], task_id)
210225

211226

212-
def retrieve_task_id_from_request(kwargs):
227+
def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]:
213228
# retry signal does not include task_id as argument so use request argument
214229
request = kwargs.get("request")
215230
if request is None:
216231
logger.debug("Unable to retrieve the request from signal arguments")
232+
return None
217233

218-
task_id = getattr(request, "id")
234+
task_id = cast(Optional[str], getattr(request, "id", None))
219235
if task_id is None:
220236
logger.debug("Unable to retrieve the task_id from the request")
221237

222238
return task_id
223239

224240

225-
def retrieve_task_id_from_message(kwargs):
241+
def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]:
226242
"""Helper to retrieve the `Task` identifier from the message `body`.
227243
This helper supports Protocol Version 1 and 2. The Protocol is well
228244
detailed in the official documentation:
@@ -232,12 +248,14 @@ def retrieve_task_id_from_message(kwargs):
232248
body = kwargs.get("body")
233249
if headers is not None and len(headers) > 0:
234250
# Protocol Version 2 (default from Celery 4.0)
235-
return headers.get("id")
251+
return cast(Optional[str], headers.get("id"))
236252
# Protocol Version 1
237-
return body.get("id")
253+
if body is None:
254+
return None
255+
return cast(Optional[str], body.get("id"))
238256

239257

240-
def retrieve_reason(kwargs):
258+
def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]:
241259
reason = kwargs.get("reason")
242260
if not reason:
243261
logger.debug("Unable to retrieve the retry reason")

instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,27 @@
88

99
class TestCeleryGetter(TestCase):
1010
def test_get_none(self):
11+
"""Missing attribute on carrier should return None."""
1112
getter = CeleryGetter()
1213
carrier = {}
1314
val = getter.get(carrier, "test")
1415
self.assertIsNone(val)
1516

1617
def test_get_str(self):
18+
"""String attribute should be wrapped in a single-element list."""
1719
mock_obj = mock.Mock()
1820
getter = CeleryGetter()
1921
mock_obj.test = "val"
2022
val = getter.get(mock_obj, "test")
21-
self.assertEqual(val, ("val",))
23+
self.assertEqual(val, ["val"])
2224

2325
def test_get_iter(self):
26+
"""Iterable attribute should be returned as a list."""
2427
mock_obj = mock.Mock()
2528
getter = CeleryGetter()
2629
mock_obj.test = ["val"]
2730
val = getter.get(mock_obj, "test")
28-
self.assertEqual(val, ("val",))
31+
self.assertEqual(val, ["val"])
2932

3033
def test_get_int(self):
3134
"""Non-string scalar values should be coerced to strings.
@@ -39,7 +42,7 @@ def test_get_int(self):
3942
getter = CeleryGetter()
4043
mock_obj.test = 42
4144
val = getter.get(mock_obj, "test")
42-
self.assertEqual(val, ("42",))
45+
self.assertEqual(val, ["42"])
4346

4447
def test_get_iter_with_non_string_elements(self):
4548
"""Iterable values containing non-strings should be coerced.
@@ -50,17 +53,26 @@ def test_get_iter_with_non_string_elements(self):
5053
getter = CeleryGetter()
5154
mock_obj.test = (300, 60)
5255
val = getter.get(mock_obj, "test")
53-
self.assertEqual(val, ("300", "60"))
56+
self.assertEqual(val, ["300", "60"])
5457

5558
def test_get_iter_with_mixed_types(self):
5659
"""Iterables with a mix of strings and non-strings."""
5760
mock_obj = mock.Mock()
5861
getter = CeleryGetter()
5962
mock_obj.test = ["val", 123]
6063
val = getter.get(mock_obj, "test")
61-
self.assertEqual(val, ("val", "123"))
64+
self.assertEqual(val, ["val", "123"])
65+
66+
def test_get_non_str_non_iterable(self):
67+
"""Non-string, non-iterable value should be coerced to [str(value)]."""
68+
getter = CeleryGetter()
69+
mock_obj = mock.Mock()
70+
mock_obj.key = 42
71+
val = getter.get(mock_obj, "key")
72+
self.assertEqual(val, ["42"])
6273

6374
def test_keys(self):
75+
"""keys() should return an empty list for any carrier."""
6476
getter = CeleryGetter()
6577
keys = getter.keys({})
6678
self.assertEqual(keys, [])

0 commit comments

Comments
 (0)