Skip to content

Commit 2cbd8e9

Browse files
authored
Merge branch 'main' into move-to-mocket
2 parents 0343d57 + a80c7da commit 2cbd8e9

10 files changed

Lines changed: 453 additions & 44 deletions

File tree

.changelog/4216.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-aws-lambda`: fix improper handling of header casing

.changelog/4505.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-celery`: add null guards and type-safe helper handling around Celery context propagation internals

docs/nitpick-exceptions.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ py-class=
3232
httpx.URL
3333
httpx.Headers
3434
aiohttp.web_request.Request
35+
celery.worker.request.Request
3536
yarl.URL
3637
cimpl.Producer
3738
cimpl.Consumer

instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def custom_event_context_extractor(lambda_event):
7373
from opentelemetry.context.context import Context
7474
from opentelemetry.instrumentation.aws_lambda.package import _instruments
7575
from opentelemetry.instrumentation.aws_lambda.version import __version__
76+
from opentelemetry.instrumentation.cidict import CIDict
7677
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
7778
from opentelemetry.instrumentation.utils import unwrap
7879
from opentelemetry.metrics import MeterProvider, get_meter_provider
@@ -165,7 +166,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:
165166
)
166167
if not isinstance(headers, dict):
167168
headers = {}
168-
return get_global_textmap().extract(headers)
169+
return get_global_textmap().extract(
170+
CIDict(headers),
171+
)
169172

170173

171174
def _determine_parent_context(
@@ -205,20 +208,21 @@ def _set_api_gateway_v1_proxy_attributes(
205208
span.set_attribute(HTTP_METHOD, lambda_event.get("httpMethod"))
206209

207210
if lambda_event.get("headers"):
208-
if "User-Agent" in lambda_event["headers"]:
211+
headers = CIDict(lambda_event["headers"])
212+
if "User-Agent" in headers:
209213
span.set_attribute(
210214
HTTP_USER_AGENT,
211-
lambda_event["headers"]["User-Agent"],
215+
headers["User-Agent"],
212216
)
213-
if "X-Forwarded-Proto" in lambda_event["headers"]:
217+
if "X-Forwarded-Proto" in headers:
214218
span.set_attribute(
215219
HTTP_SCHEME,
216-
lambda_event["headers"]["X-Forwarded-Proto"],
220+
headers["X-Forwarded-Proto"],
217221
)
218-
if "Host" in lambda_event["headers"]:
222+
if "Host" in headers:
219223
span.set_attribute(
220224
NET_HOST_NAME,
221-
lambda_event["headers"]["Host"],
225+
headers["Host"],
222226
)
223227
if "resource" in lambda_event:
224228
span.set_attribute(HTTP_ROUTE, lambda_event["resource"])

instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,36 @@ def custom_event_context_extractor(lambda_event):
314314
expected_baggage=MOCK_W3C_BAGGAGE_VALUE,
315315
propagators="tracecontext,baggage",
316316
),
317+
TestCase(
318+
name="case_insensitive_headers_uppercase",
319+
custom_extractor=None,
320+
context={
321+
"headers": {
322+
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME.upper(): MOCK_W3C_TRACE_CONTEXT_SAMPLED,
323+
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME.upper(): f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
324+
}
325+
},
326+
expected_traceid=MOCK_W3C_TRACE_ID,
327+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
328+
expected_trace_state_len=3,
329+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
330+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
331+
),
332+
TestCase(
333+
name="case_insensitive_headers_mixedcase",
334+
custom_extractor=None,
335+
context={
336+
"headers": {
337+
"TraceParent": MOCK_W3C_TRACE_CONTEXT_SAMPLED,
338+
"tRaCeStAtE": f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
339+
}
340+
},
341+
expected_traceid=MOCK_W3C_TRACE_ID,
342+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
343+
expected_trace_state_len=3,
344+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
345+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
346+
),
317347
]
318348
for test in tests:
319349
with self.subTest(test_name=test.name):
@@ -389,6 +419,57 @@ def test_lambda_no_error_with_invalid_flush_timeout(self):
389419

390420
test_env_patch.stop()
391421

422+
def test_api_gateway_v1_attributes_case_insensitivity(self):
423+
AwsLambdaInstrumentor().instrument()
424+
425+
mock_execute_lambda(
426+
{
427+
"httpMethod": "GET",
428+
"headers": {
429+
"user-agent": "lowercase-agent",
430+
"host": "lowercase-host",
431+
"x-forwarded-proto": "http",
432+
},
433+
"resource": "/test",
434+
"requestContext": {
435+
"version": "1.0",
436+
},
437+
}
438+
)
439+
440+
spans = self.memory_exporter.get_finished_spans()
441+
self.assertEqual(len(spans), 1)
442+
span = spans[0]
443+
self.assertEqual(
444+
span.attributes.get(HTTP_USER_AGENT), "lowercase-agent"
445+
)
446+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "lowercase-host")
447+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "http")
448+
449+
self.memory_exporter.clear()
450+
451+
mock_execute_lambda(
452+
{
453+
"httpMethod": "GET",
454+
"headers": {
455+
"uSeR-aGeNt": "mixed-agent",
456+
"hOsT": "mixed-host",
457+
"X-fOrWaRdEd-PrOtO": "https",
458+
},
459+
"resource": "/test",
460+
"requestContext": {
461+
"version": "1.0",
462+
},
463+
}
464+
)
465+
466+
spans = self.memory_exporter.get_finished_spans()
467+
self.assertEqual(len(spans), 1)
468+
span = spans[0]
469+
self.assertEqual(span.attributes.get(HTTP_USER_AGENT), "mixed-agent")
470+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "mixed-host")
471+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "https")
472+
392473
def test_lambda_handles_multiple_consumers(self):
393474
test_env_patch = mock.patch.dict(
394475
"os.environ",

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ def add(x, y):
4848
---
4949
"""
5050

51+
from __future__ import annotations
52+
5153
import logging
54+
from collections.abc import Collection, Iterable
5255
from timeit import default_timer
53-
from typing import Collection, Iterable
5456

5557
from billiard import VERSION
5658
from billiard.einfo import ExceptionInfo
5759
from celery import signals # pylint: disable=no-name-in-module
60+
from celery.worker.request import Request # pylint: disable=no-name-in-module
5861

5962
from opentelemetry import context as context_api
6063
from opentelemetry import trace
@@ -88,8 +91,8 @@ def add(x, y):
8891
_TASK_NAME_KEY = "celery.task_name"
8992

9093

91-
class CeleryGetter(Getter):
92-
def get(self, carrier, key):
94+
class CeleryGetter(Getter[Request]):
95+
def get(self, carrier: Request, key: str) -> list[str] | None:
9396
value = getattr(carrier, key, None)
9497
if value is None:
9598
return None
@@ -98,16 +101,12 @@ def get(self, carrier, key):
98101
# of ints). The TextMapPropagator contract requires string
99102
# values, so coerce anything that isn't already a string.
100103
if isinstance(value, str):
101-
value = (value,)
102-
elif isinstance(value, Iterable):
103-
value = tuple(
104-
str(v) if not isinstance(v, str) else v for v in value
105-
)
106-
else:
107-
value = (str(value),)
108-
return value
104+
return [value]
105+
if isinstance(value, Iterable):
106+
return [str(v) if not isinstance(v, str) else v for v in value]
107+
return [str(value)]
109108

110-
def keys(self, carrier):
109+
def keys(self, carrier: Request) -> list[str]:
111110
return []
112111

113112

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from __future__ import annotations
55

66
import logging
7-
from typing import TYPE_CHECKING, Optional, Tuple
7+
from collections.abc import Mapping
8+
from typing import TYPE_CHECKING, Any, Optional, Protocol, cast
89

910
from celery import registry # pylint: disable=no-name-in-module
1011
from celery.app.task import Task
@@ -18,6 +19,15 @@
1819
if TYPE_CHECKING:
1920
from contextlib import AbstractContextManager
2021

22+
ContextKey = tuple[str, bool]
23+
ContextTuple = tuple[Span, "AbstractContextManager[Span]", object | None]
24+
ContextDict = dict[ContextKey, ContextTuple]
25+
26+
27+
class ContextCarrier(Protocol):
28+
def get(self, key: str, default: Any = None) -> Any: ...
29+
30+
2131
logger = logging.getLogger(__name__)
2232

2333
# Celery Context key
@@ -48,7 +58,10 @@
4858

4959

5060
# pylint:disable=too-many-branches
51-
def set_attributes_from_context(span, context):
61+
def set_attributes_from_context(
62+
span: Span,
63+
context: ContextCarrier,
64+
) -> None:
5265
"""Helper to extract meta values from a Celery Context"""
5366
if not span.is_recording():
5467
return
@@ -144,7 +157,7 @@ def attach_context(
144157
if task is None:
145158
return
146159

147-
ctx_dict = getattr(task, CTX_KEY, None)
160+
ctx_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
148161

149162
if ctx_dict is None:
150163
ctx_dict = {}
@@ -153,12 +166,17 @@ def attach_context(
153166
ctx_dict[(task_id, is_publish)] = (span, activation, token)
154167

155168

156-
def detach_context(task, task_id, is_publish=False) -> None:
169+
def detach_context(
170+
task: Optional[Task], task_id: str, is_publish: bool = False
171+
) -> None:
157172
"""Helper to remove `Span`, `ContextManager` and context token in a
158173
Celery task when it's propagated.
159174
This function handles tasks where no values are attached to the `Task`.
160175
"""
161-
span_dict = getattr(task, CTX_KEY, None)
176+
if task is None:
177+
return
178+
179+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
162180
if span_dict is None:
163181
return
164182

@@ -167,27 +185,30 @@ def detach_context(task, task_id, is_publish=False) -> None:
167185

168186

169187
def retrieve_context(
170-
task, task_id, is_publish=False
171-
) -> Optional[Tuple[Span, AbstractContextManager[Span], Optional[object]]]:
188+
task: Optional[Task], task_id: str, is_publish: bool = False
189+
) -> Optional[ContextTuple]:
172190
"""Helper to retrieve an active `Span`, `ContextManager` and context token
173191
stored in a `Task` instance
174192
"""
175-
span_dict = getattr(task, CTX_KEY, None)
193+
if task is None:
194+
return None
195+
196+
span_dict = cast(Optional[ContextDict], getattr(task, CTX_KEY, None))
176197
if span_dict is None:
177198
return None
178199

179200
# See note in `attach_context` for key info
180201
return span_dict.get((task_id, is_publish), None)
181202

182203

183-
def retrieve_task(kwargs):
204+
def retrieve_task(kwargs: Mapping[str, Any]) -> Optional[Task]:
184205
task = kwargs.get("task")
185206
if task is None:
186207
logger.debug("Unable to retrieve task from signal arguments")
187-
return task
208+
return cast(Optional[Task], task)
188209

189210

190-
def retrieve_task_from_sender(kwargs):
211+
def retrieve_task_from_sender(kwargs: Mapping[str, Any]) -> Optional[Task]:
191212
sender = kwargs.get("sender")
192213
if sender is None:
193214
logger.debug("Unable to retrieve the sender from signal arguments")
@@ -199,30 +220,31 @@ def retrieve_task_from_sender(kwargs):
199220
if sender is None:
200221
logger.debug("Unable to retrieve the task from sender=%s", sender)
201222

202-
return sender
223+
return cast(Optional[Task], sender)
203224

204225

205-
def retrieve_task_id(kwargs):
226+
def retrieve_task_id(kwargs: Mapping[str, Any]) -> Optional[str]:
206227
task_id = kwargs.get("task_id")
207228
if task_id is None:
208229
logger.debug("Unable to retrieve task_id from signal arguments")
209-
return task_id
230+
return cast(Optional[str], task_id)
210231

211232

212-
def retrieve_task_id_from_request(kwargs):
233+
def retrieve_task_id_from_request(kwargs: Mapping[str, Any]) -> Optional[str]:
213234
# retry signal does not include task_id as argument so use request argument
214235
request = kwargs.get("request")
215236
if request is None:
216237
logger.debug("Unable to retrieve the request from signal arguments")
238+
return None
217239

218-
task_id = getattr(request, "id")
240+
task_id = cast(Optional[str], getattr(request, "id", None))
219241
if task_id is None:
220242
logger.debug("Unable to retrieve the task_id from the request")
221243

222244
return task_id
223245

224246

225-
def retrieve_task_id_from_message(kwargs):
247+
def retrieve_task_id_from_message(kwargs: Mapping[str, Any]) -> Optional[str]:
226248
"""Helper to retrieve the `Task` identifier from the message `body`.
227249
This helper supports Protocol Version 1 and 2. The Protocol is well
228250
detailed in the official documentation:
@@ -232,12 +254,14 @@ def retrieve_task_id_from_message(kwargs):
232254
body = kwargs.get("body")
233255
if headers is not None and len(headers) > 0:
234256
# Protocol Version 2 (default from Celery 4.0)
235-
return headers.get("id")
257+
return cast(Optional[str], headers.get("id"))
236258
# Protocol Version 1
237-
return body.get("id")
259+
if body is None:
260+
return None
261+
return cast(Optional[str], body.get("id"))
238262

239263

240-
def retrieve_reason(kwargs):
264+
def retrieve_reason(kwargs: Mapping[str, Any]) -> Optional[object]:
241265
reason = kwargs.get("reason")
242266
if not reason:
243267
logger.debug("Unable to retrieve the retry reason")

0 commit comments

Comments
 (0)