Skip to content

Commit 7113a6c

Browse files
authored
Merge branch 'main' into feat/strarlette_exclude_spans
2 parents 4f616fd + a80c7da commit 7113a6c

12 files changed

Lines changed: 481 additions & 46 deletions

File tree

.changelog/4216.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-aws-lambda`: fix improper handling of header casing

.changelog/4504.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-celery`: clear completed task ids from `task_id_to_start_time`

.changelog/4505.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-celery`: add null guards and type-safe helper handling around Celery context propagation internals

docs/nitpick-exceptions.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ py-class=
3232
httpx.URL
3333
httpx.Headers
3434
aiohttp.web_request.Request
35+
celery.worker.request.Request
3536
yarl.URL
3637
cimpl.Producer
3738
cimpl.Consumer

instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def custom_event_context_extractor(lambda_event):
7373
from opentelemetry.context.context import Context
7474
from opentelemetry.instrumentation.aws_lambda.package import _instruments
7575
from opentelemetry.instrumentation.aws_lambda.version import __version__
76+
from opentelemetry.instrumentation.cidict import CIDict
7677
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
7778
from opentelemetry.instrumentation.utils import unwrap
7879
from opentelemetry.metrics import MeterProvider, get_meter_provider
@@ -165,7 +166,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:
165166
)
166167
if not isinstance(headers, dict):
167168
headers = {}
168-
return get_global_textmap().extract(headers)
169+
return get_global_textmap().extract(
170+
CIDict(headers),
171+
)
169172

170173

171174
def _determine_parent_context(
@@ -205,20 +208,21 @@ def _set_api_gateway_v1_proxy_attributes(
205208
span.set_attribute(HTTP_METHOD, lambda_event.get("httpMethod"))
206209

207210
if lambda_event.get("headers"):
208-
if "User-Agent" in lambda_event["headers"]:
211+
headers = CIDict(lambda_event["headers"])
212+
if "User-Agent" in headers:
209213
span.set_attribute(
210214
HTTP_USER_AGENT,
211-
lambda_event["headers"]["User-Agent"],
215+
headers["User-Agent"],
212216
)
213-
if "X-Forwarded-Proto" in lambda_event["headers"]:
217+
if "X-Forwarded-Proto" in headers:
214218
span.set_attribute(
215219
HTTP_SCHEME,
216-
lambda_event["headers"]["X-Forwarded-Proto"],
220+
headers["X-Forwarded-Proto"],
217221
)
218-
if "Host" in lambda_event["headers"]:
222+
if "Host" in headers:
219223
span.set_attribute(
220224
NET_HOST_NAME,
221-
lambda_event["headers"]["Host"],
225+
headers["Host"],
222226
)
223227
if "resource" in lambda_event:
224228
span.set_attribute(HTTP_ROUTE, lambda_event["resource"])

instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,36 @@ def custom_event_context_extractor(lambda_event):
314314
expected_baggage=MOCK_W3C_BAGGAGE_VALUE,
315315
propagators="tracecontext,baggage",
316316
),
317+
TestCase(
318+
name="case_insensitive_headers_uppercase",
319+
custom_extractor=None,
320+
context={
321+
"headers": {
322+
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME.upper(): MOCK_W3C_TRACE_CONTEXT_SAMPLED,
323+
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME.upper(): f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
324+
}
325+
},
326+
expected_traceid=MOCK_W3C_TRACE_ID,
327+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
328+
expected_trace_state_len=3,
329+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
330+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
331+
),
332+
TestCase(
333+
name="case_insensitive_headers_mixedcase",
334+
custom_extractor=None,
335+
context={
336+
"headers": {
337+
"TraceParent": MOCK_W3C_TRACE_CONTEXT_SAMPLED,
338+
"tRaCeStAtE": f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
339+
}
340+
},
341+
expected_traceid=MOCK_W3C_TRACE_ID,
342+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
343+
expected_trace_state_len=3,
344+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
345+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
346+
),
317347
]
318348
for test in tests:
319349
with self.subTest(test_name=test.name):
@@ -389,6 +419,57 @@ def test_lambda_no_error_with_invalid_flush_timeout(self):
389419

390420
test_env_patch.stop()
391421

422+
def test_api_gateway_v1_attributes_case_insensitivity(self):
423+
AwsLambdaInstrumentor().instrument()
424+
425+
mock_execute_lambda(
426+
{
427+
"httpMethod": "GET",
428+
"headers": {
429+
"user-agent": "lowercase-agent",
430+
"host": "lowercase-host",
431+
"x-forwarded-proto": "http",
432+
},
433+
"resource": "/test",
434+
"requestContext": {
435+
"version": "1.0",
436+
},
437+
}
438+
)
439+
440+
spans = self.memory_exporter.get_finished_spans()
441+
self.assertEqual(len(spans), 1)
442+
span = spans[0]
443+
self.assertEqual(
444+
span.attributes.get(HTTP_USER_AGENT), "lowercase-agent"
445+
)
446+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "lowercase-host")
447+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "http")
448+
449+
self.memory_exporter.clear()
450+
451+
mock_execute_lambda(
452+
{
453+
"httpMethod": "GET",
454+
"headers": {
455+
"uSeR-aGeNt": "mixed-agent",
456+
"hOsT": "mixed-host",
457+
"X-fOrWaRdEd-PrOtO": "https",
458+
},
459+
"resource": "/test",
460+
"requestContext": {
461+
"version": "1.0",
462+
},
463+
}
464+
)
465+
466+
spans = self.memory_exporter.get_finished_spans()
467+
self.assertEqual(len(spans), 1)
468+
span = spans[0]
469+
self.assertEqual(span.attributes.get(HTTP_USER_AGENT), "mixed-agent")
470+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "mixed-host")
471+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "https")
472+
392473
def test_lambda_handles_multiple_consumers(self):
393474
test_env_patch = mock.patch.dict(
394475
"os.environ",

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ def add(x, y):
4848
---
4949
"""
5050

51+
from __future__ import annotations
52+
5153
import logging
54+
from collections.abc import Collection, Iterable
5255
from timeit import default_timer
53-
from typing import Collection, Iterable
5456

5557
from billiard import VERSION
5658
from billiard.einfo import ExceptionInfo
5759
from celery import signals # pylint: disable=no-name-in-module
60+
from celery.worker.request import Request # pylint: disable=no-name-in-module
5861

5962
from opentelemetry import context as context_api
6063
from opentelemetry import trace
@@ -88,8 +91,8 @@ def add(x, y):
8891
_TASK_NAME_KEY = "celery.task_name"
8992

9093

91-
class CeleryGetter(Getter):
92-
def get(self, carrier, key):
94+
class CeleryGetter(Getter[Request]):
95+
def get(self, carrier: Request, key: str) -> list[str] | None:
9396
value = getattr(carrier, key, None)
9497
if value is None:
9598
return None
@@ -98,25 +101,25 @@ def get(self, carrier, key):
98101
# of ints). The TextMapPropagator contract requires string
99102
# values, so coerce anything that isn't already a string.
100103
if isinstance(value, str):
101-
value = (value,)
102-
elif isinstance(value, Iterable):
103-
value = tuple(
104-
str(v) if not isinstance(v, str) else v for v in value
105-
)
106-
else:
107-
value = (str(value),)
108-
return value
104+
return [value]
105+
if isinstance(value, Iterable):
106+
return [str(v) if not isinstance(v, str) else v for v in value]
107+
return [str(value)]
109108

110-
def keys(self, carrier):
109+
def keys(self, carrier: Request) -> list[str]:
111110
return []
112111

113112

114113
celery_getter = CeleryGetter()
115114

116115

117116
class CeleryInstrumentor(BaseInstrumentor):
118-
metrics = None
119-
task_id_to_start_time = {}
117+
def __init__(self):
118+
super().__init__()
119+
if not hasattr(self, "metrics"):
120+
self.metrics = None
121+
if not hasattr(self, "task_id_to_start_time"):
122+
self.task_id_to_start_time = {}
120123

121124
def instrumentation_dependencies(self) -> Collection[str]:
122125
return _instruments
@@ -139,6 +142,7 @@ def _instrument(self, **kwargs):
139142
schema_url="https://opentelemetry.io/schemas/1.11.0",
140143
)
141144

145+
self.task_id_to_start_time = {}
142146
self.create_celery_metrics(meter)
143147

144148
signals.task_prerun.connect(self._trace_prerun, weak=False)
@@ -159,6 +163,7 @@ def _uninstrument(self, **kwargs):
159163
signals.after_task_publish.disconnect(self._trace_after_publish)
160164
signals.task_failure.disconnect(self._trace_failure)
161165
signals.task_retry.disconnect(self._trace_retry)
166+
self.task_id_to_start_time = {}
162167

163168
def _trace_prerun(self, *args, **kwargs):
164169
task = utils.retrieve_task(kwargs)
@@ -213,6 +218,7 @@ def _trace_postrun(self, *args, **kwargs):
213218
self.update_task_duration_time(task_id)
214219
labels = {"task": task.name, "worker": task.request.hostname}
215220
self._record_histograms(task_id, labels)
221+
self.task_id_to_start_time.pop(task_id, None)
216222
# if the process sending the task is not instrumented
217223
# there's no incoming context and no token to detach
218224
if token is not None:

0 commit comments

Comments
 (0)