Skip to content

Commit a80c7da

Browse files
authored
fix(celery): type getter and harden utility helpers (#4505)
* chore(celery): type getter and harden utility helpers Assisted-by: GPT-5.4 * docs: add celery request nitpick exception Assisted-by: GPT-5.4 * chore(celery): refine utility helper typing Assisted-by: GPT-5.4 * Add Celery housekeeping changelog fragment Assisted-by: GitHub Copilot
1 parent 65a1134 commit a80c7da

5 files changed

Lines changed: 74 additions & 37 deletions

File tree

.changelog/4505.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-celery`: add null guards and type-safe helper handling around Celery context propagation internals

docs/nitpick-exceptions.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ py-class=
3232
httpx.URL
3333
httpx.Headers
3434
aiohttp.web_request.Request
35+
celery.worker.request.Request
3536
yarl.URL
3637
cimpl.Producer
3738
cimpl.Consumer

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ def add(x, y):
4848
---
4949
"""
5050

51+
from __future__ import annotations
52+
5153
import logging
54+
from collections.abc import Collection, Iterable
5255
from timeit import default_timer
53-
from typing import Collection, Iterable
5456

5557
from billiard import VERSION
5658
from billiard.einfo import ExceptionInfo
5759
from celery import signals # pylint: disable=no-name-in-module
60+
from celery.worker.request import Request # pylint: disable=no-name-in-module
5861

5962
from opentelemetry import context as context_api
6063
from opentelemetry import trace
@@ -88,8 +91,8 @@ def add(x, y):
8891
_TASK_NAME_KEY = "celery.task_name"
8992

9093

91-
class CeleryGetter(Getter):
92-
def get(self, carrier, key):
94+
class CeleryGetter(Getter[Request]):
95+
def get(self, carrier: Request, key: str) -> list[str] | None:
9396
value = getattr(carrier, key, None)
9497
if value is None:
9598
return None
@@ -98,16 +101,12 @@ def get(self, carrier, key):
98101
# of ints). The TextMapPropagator contract requires string
99102
# values, so coerce anything that isn't already a string.
100103
if isinstance(value, str):
101-
value = (value,)
102-
elif isinstance(value, Iterable):
103-
value = tuple(
104-
str(v) if not isinstance(v, str) else v for v in value
105-
)
106-
else:
107-
value = (str(value),)
108-
return value
104+
return [value]
105+
if isinstance(value, Iterable):
106+
return [str(v) if not isinstance(v, str) else v for v in value]
107+
return [str(value)]
109108

110-
def keys(self, carrier):
109+
def keys(self, carrier: Request) -> list[str]:
111110
return []
112111

113112

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from __future__ import annotations
55

66
import logging
7-
from typing import TYPE_CHECKING, Optional, Tuple
7+
from collections.abc import Mapping
8+
from typing import TYPE_CHECKING, Any, Optional, Protocol, cast
89

910
from celery import registry # pylint: disable=no-name-in-module
1011
from celery.app.task import Task
@@ -18,6 +19,15 @@
1819
if TYPE_CHECKING:
1920
from contextlib import AbstractContextManager
2021

22+
ContextKey = tuple[str, bool]
23+
ContextTuple = tuple[Span, "AbstractContextManager[Span]", object | None]
24+
ContextDict = dict[ContextKey, ContextTuple]
25+
26+
27+
class ContextCarrier(Protocol):
28+
def get(self, key: str, default: Any = None) -> Any: ...
29+
30+
2131
logger = logging.getLogger(__name__)
2232

2333
# Celery Context key
@@ -48,7 +58,10 @@
4858

4959

5060
# pylint:disable=too-many-branches
51-
def set_attributes_from_context(span, context):
61+
def set_attributes_from_context(
62+
span: Span,
63+
context: ContextCarrier,
64+
) -> None:
5265
"""Helper to extract meta values from a Celery Context"""
5366
if not span.is_recording():
5467
return
@@ -144,7 +157,7 @@ def attach_context(
144157
if task is None:
145158
return
146159

147-
ctx_dict = getattr(task, CTX_KEY, None)
160+
ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
148161

149162
if ctx_dict is None:
150163
ctx_dict = {}
@@ -153,12 +166,17 @@ def attach_context(
153166
ctx_dict[(task_id, is_publish)] = (span, activation, token)
154167

155168

156-
def detach_context(task, task_id, is_publish=False) -> None:
169+
def detach_context(
170+
task: Optional[Task], task_id: str, is_publish: bool = False
171+
) -> None:
157172
"""Helper to remove `Span`, `ContextManager` and context token in a
158173
Celery task when it's propagated.
159174
This function handles tasks where no values are attached to the `Task`.
160175
"""
161-
span_dict = getattr(task, CTX_KEY, None)
176+
if task is None:
177+
return
178+
179+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
162180
if span_dict is None:
163181
return
164182

@@ -167,27 +185,30 @@ def detach_context(task, task_id, is_publish=False) -> None:
167185

168186

169187
def retrieve_context(
170-
task, task_id, is_publish=False
171-
) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]:
188+
task: Optional[Task], task_id: str, is_publish: bool = False
189+
) -> Optional[ContextTuple]:
172190
"""Helper to retrieve an active `Span`, `ContextManager` and context token
173191
stored in a `Task` instance
174192
"""
175-
span_dict = getattr(task, CTX_KEY, None)
193+
if task is None:
194+
return None
195+
196+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
176197
if span_dict is None:
177198
return None
178199

179200
# See note in `attach_context` for key info
180201
return span_dict.get((task_id, is_publish), None)
181202

182203

183-
def retrieve_task(kwargs):
204+
def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]:
184205
task = kwargs.get("task")
185206
if task is None:
186207
logger.debug("Unable to retrieve task from signal arguments")
187-
return task
208+
return cast(Optional[Task], task)
188209

189210

190-
def retrieve_task_from_sender(kwargs):
211+
def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]:
191212
sender = kwargs.get("sender")
192213
if sender is None:
193214
logger.debug("Unable to retrieve the sender from signal arguments")
@@ -199,30 +220,31 @@ def retrieve_task_from_sender(kwargs):
199220
if sender is None:
200221
logger.debug("Unable to retrieve the task from sender=%s", sender)
201222

202-
return sender
223+
return cast(Optional[Task], sender)
203224

204225

205-
def retrieve_task_id(kwargs):
226+
def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]:
206227
task_id = kwargs.get("task_id")
207228
if task_id is None:
208229
logger.debug("Unable to retrieve task_id from signal arguments")
209-
return task_id
230+
return cast(Optional[str], task_id)
210231

211232

212-
def retrieve_task_id_from_request(kwargs):
233+
def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]:
213234
# retry signal does not include task_id as argument so use request argument
214235
request = kwargs.get("request")
215236
if request is None:
216237
logger.debug("Unable to retrieve the request from signal arguments")
238+
return None
217239

218-
task_id = getattr(request, "id")
240+
task_id = cast(Optional[str], getattr(request, "id", None))
219241
if task_id is None:
220242
logger.debug("Unable to retrieve the task_id from the request")
221243

222244
return task_id
223245

224246

225-
def retrieve_task_id_from_message(kwargs):
247+
def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]:
226248
"""Helper to retrieve the `Task` identifier from the message `body`.
227249
This helper supports Protocol Version 1 and 2. The Protocol is well
228250
detailed in the official documentation:
@@ -232,12 +254,14 @@ def retrieve_task_id_from_message(kwargs):
232254
body = kwargs.get("body")
233255
if headers is not None and len(headers) > 0:
234256
# Protocol Version 2 (default from Celery 4.0)
235-
return headers.get("id")
257+
return cast(Optional[str], headers.get("id"))
236258
# Protocol Version 1
237-
return body.get("id")
259+
if body is None:
260+
return None
261+
return cast(Optional[str], body.get("id"))
238262

239263

240-
def retrieve_reason(kwargs):
264+
def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]:
241265
reason = kwargs.get("reason")
242266
if not reason:
243267
logger.debug("Unable to retrieve the retry reason")

instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,27 @@
88

99
class TestCeleryGetter(TestCase):
1010
def test_get_none(self):
11+
"""Missing attribute on carrier should return None."""
1112
getter = CeleryGetter()
1213
carrier = {}
1314
val = getter.get(carrier, "test")
1415
self.assertIsNone(val)
1516

1617
def test_get_str(self):
18+
"""String attribute should be wrapped in a single-element list."""
1719
mock_obj = mock.Mock()
1820
getter = CeleryGetter()
1921
mock_obj.test = "val"
2022
val = getter.get(mock_obj, "test")
21-
self.assertEqual(val, ("val",))
23+
self.assertEqual(val, ["val"])
2224

2325
def test_get_iter(self):
26+
"""Iterable attribute should be returned as a list."""
2427
mock_obj = mock.Mock()
2528
getter = CeleryGetter()
2629
mock_obj.test = ["val"]
2730
val = getter.get(mock_obj, "test")
28-
self.assertEqual(val, ("val",))
31+
self.assertEqual(val, ["val"])
2932

3033
def test_get_int(self):
3134
"""Non-string scalar values should be coerced to strings.
@@ -39,7 +42,7 @@ def test_get_int(self):
3942
getter = CeleryGetter()
4043
mock_obj.test = 42
4144
val = getter.get(mock_obj, "test")
42-
self.assertEqual(val, ("42",))
45+
self.assertEqual(val, ["42"])
4346

4447
def test_get_iter_with_non_string_elements(self):
4548
"""Iterable values containing non-strings should be coerced.
@@ -50,17 +53,26 @@ def test_get_iter_with_non_string_elements(self):
5053
getter = CeleryGetter()
5154
mock_obj.test = (300, 60)
5255
val = getter.get(mock_obj, "test")
53-
self.assertEqual(val, ("300", "60"))
56+
self.assertEqual(val, ["300", "60"])
5457

5558
def test_get_iter_with_mixed_types(self):
5659
"""Iterables with a mix of strings and non-strings."""
5760
mock_obj = mock.Mock()
5861
getter = CeleryGetter()
5962
mock_obj.test = ["val", 123]
6063
val = getter.get(mock_obj, "test")
61-
self.assertEqual(val, ("val", "123"))
64+
self.assertEqual(val, ["val", "123"])
65+
66+
def test_get_non_str_non_iterable(self):
67+
"""Non-string, non-iterable value should be coerced to [str(value)]."""
68+
getter = CeleryGetter()
69+
mock_obj = mock.Mock()
70+
mock_obj.key = 42
71+
val = getter.get(mock_obj, "key")
72+
self.assertEqual(val, ["42"])
6273

6374
def test_keys(self):
75+
"""keys() should return an empty list for any carrier."""
6476
getter = CeleryGetter()
6577
keys = getter.keys({})
6678
self.assertEqual(keys, [])

0 commit comments

Comments
 (0)