Skip to content

Commit 5a4052b

Browse files
authored
Merge branch 'main' into chore/celery-housekeeping
2 parents 0e7ec3a + 65a1134 commit 5a4052b

10 files changed

Lines changed: 409 additions & 11 deletions

File tree

.changelog/4216.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-aws-lambda`: fix improper handling of header casing

.changelog/4504.fixed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`opentelemetry-instrumentation-celery`: clear completed task ids from `task_id_to_start_time`

.github/workflows/changelog.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030

3131
- name: Ensure no direct changes to CHANGELOG.md
3232
run: |
33-
if [[ $(git diff --name-only FETCH_HEAD -- '**/CHANGELOG.md') ]]
33+
if [[ $(git diff --name-only FETCH_HEAD -- 'CHANGELOG.md' '**/CHANGELOG.md') ]]
3434
then
3535
echo "CHANGELOG.md files should not be directly modified."
3636
echo "Please add a changelog fragment file to the appropriate .changelog/ directory instead."

instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def custom_event_context_extractor(lambda_event):
7373
from opentelemetry.context.context import Context
7474
from opentelemetry.instrumentation.aws_lambda.package import _instruments
7575
from opentelemetry.instrumentation.aws_lambda.version import __version__
76+
from opentelemetry.instrumentation.cidict import CIDict
7677
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
7778
from opentelemetry.instrumentation.utils import unwrap
7879
from opentelemetry.metrics import MeterProvider, get_meter_provider
@@ -165,7 +166,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:
165166
)
166167
if not isinstance(headers, dict):
167168
headers = {}
168-
return get_global_textmap().extract(headers)
169+
return get_global_textmap().extract(
170+
CIDict(headers),
171+
)
169172

170173

171174
def _determine_parent_context(
@@ -205,20 +208,21 @@ def _set_api_gateway_v1_proxy_attributes(
205208
span.set_attribute(HTTP_METHOD, lambda_event.get("httpMethod"))
206209

207210
if lambda_event.get("headers"):
208-
if "User-Agent" in lambda_event["headers"]:
211+
headers = CIDict(lambda_event["headers"])
212+
if "User-Agent" in headers:
209213
span.set_attribute(
210214
HTTP_USER_AGENT,
211-
lambda_event["headers"]["User-Agent"],
215+
headers["User-Agent"],
212216
)
213-
if "X-Forwarded-Proto" in lambda_event["headers"]:
217+
if "X-Forwarded-Proto" in headers:
214218
span.set_attribute(
215219
HTTP_SCHEME,
216-
lambda_event["headers"]["X-Forwarded-Proto"],
220+
headers["X-Forwarded-Proto"],
217221
)
218-
if "Host" in lambda_event["headers"]:
222+
if "Host" in headers:
219223
span.set_attribute(
220224
NET_HOST_NAME,
221-
lambda_event["headers"]["Host"],
225+
headers["Host"],
222226
)
223227
if "resource" in lambda_event:
224228
span.set_attribute(HTTP_ROUTE, lambda_event["resource"])

instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,36 @@ def custom_event_context_extractor(lambda_event):
314314
expected_baggage=MOCK_W3C_BAGGAGE_VALUE,
315315
propagators="tracecontext,baggage",
316316
),
317+
TestCase(
318+
name="case_insensitive_headers_uppercase",
319+
custom_extractor=None,
320+
context={
321+
"headers": {
322+
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME.upper(): MOCK_W3C_TRACE_CONTEXT_SAMPLED,
323+
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME.upper(): f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
324+
}
325+
},
326+
expected_traceid=MOCK_W3C_TRACE_ID,
327+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
328+
expected_trace_state_len=3,
329+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
330+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
331+
),
332+
TestCase(
333+
name="case_insensitive_headers_mixedcase",
334+
custom_extractor=None,
335+
context={
336+
"headers": {
337+
"TraceParent": MOCK_W3C_TRACE_CONTEXT_SAMPLED,
338+
"tRaCeStAtE": f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
339+
}
340+
},
341+
expected_traceid=MOCK_W3C_TRACE_ID,
342+
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
343+
expected_trace_state_len=3,
344+
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
345+
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
346+
),
317347
]
318348
for test in tests:
319349
with self.subTest(test_name=test.name):
@@ -389,6 +419,57 @@ def test_lambda_no_error_with_invalid_flush_timeout(self):
389419

390420
test_env_patch.stop()
391421

422+
def test_api_gateway_v1_attributes_case_insensitivity(self):
423+
AwsLambdaInstrumentor().instrument()
424+
425+
mock_execute_lambda(
426+
{
427+
"httpMethod": "GET",
428+
"headers": {
429+
"user-agent": "lowercase-agent",
430+
"host": "lowercase-host",
431+
"x-forwarded-proto": "http",
432+
},
433+
"resource": "/test",
434+
"requestContext": {
435+
"version": "1.0",
436+
},
437+
}
438+
)
439+
440+
spans = self.memory_exporter.get_finished_spans()
441+
self.assertEqual(len(spans), 1)
442+
span = spans[0]
443+
self.assertEqual(
444+
span.attributes.get(HTTP_USER_AGENT), "lowercase-agent"
445+
)
446+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "lowercase-host")
447+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "http")
448+
449+
self.memory_exporter.clear()
450+
451+
mock_execute_lambda(
452+
{
453+
"httpMethod": "GET",
454+
"headers": {
455+
"uSeR-aGeNt": "mixed-agent",
456+
"hOsT": "mixed-host",
457+
"X-fOrWaRdEd-PrOtO": "https",
458+
},
459+
"resource": "/test",
460+
"requestContext": {
461+
"version": "1.0",
462+
},
463+
}
464+
)
465+
466+
spans = self.memory_exporter.get_finished_spans()
467+
self.assertEqual(len(spans), 1)
468+
span = spans[0]
469+
self.assertEqual(span.attributes.get(HTTP_USER_AGENT), "mixed-agent")
470+
self.assertEqual(span.attributes.get(NET_HOST_NAME), "mixed-host")
471+
self.assertEqual(span.attributes.get(HTTP_SCHEME), "https")
472+
392473
def test_lambda_handles_multiple_consumers(self):
393474
test_env_patch = mock.patch.dict(
394475
"os.environ",

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,12 @@ def keys(self, carrier: Request) -> list[str]:
114114

115115

116116
class CeleryInstrumentor(BaseInstrumentor):
117-
metrics = None
118-
task_id_to_start_time = {}
117+
def __init__(self):
118+
super().__init__()
119+
if not hasattr(self, "metrics"):
120+
self.metrics = None
121+
if not hasattr(self, "task_id_to_start_time"):
122+
self.task_id_to_start_time = {}
119123

120124
def instrumentation_dependencies(self) -> Collection[str]:
121125
return _instruments
@@ -138,6 +142,7 @@ def _instrument(self, **kwargs):
138142
schema_url="https://opentelemetry.io/schemas/1.11.0",
139143
)
140144

145+
self.task_id_to_start_time = {}
141146
self.create_celery_metrics(meter)
142147

143148
signals.task_prerun.connect(self._trace_prerun, weak=False)
@@ -158,6 +163,7 @@ def _uninstrument(self, **kwargs):
158163
signals.after_task_publish.disconnect(self._trace_after_publish)
159164
signals.task_failure.disconnect(self._trace_failure)
160165
signals.task_retry.disconnect(self._trace_retry)
166+
self.task_id_to_start_time = {}
161167

162168
def _trace_prerun(self, *args, **kwargs):
163169
task = utils.retrieve_task(kwargs)
@@ -212,6 +218,7 @@ def _trace_postrun(self, *args, **kwargs):
212218
self.update_task_duration_time(task_id)
213219
labels = {"task": task.name, "worker": task.request.hostname}
214220
self._record_histograms(task_id, labels)
221+
self.task_id_to_start_time.pop(task_id, None)
215222
# if the process sending the task is not instrumented
216223
# there's no incoming context and no token to detach
217224
if token is not None:

instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ def test_task(self):
9292
self.assertEqual(consumer.parent.span_id, producer.context.span_id)
9393
self.assertEqual(consumer.context.trace_id, producer.context.trace_id)
9494

95+
def test_task_clears_start_time_cache(self):
96+
"""Test that the `task_id_to_start_time` cache is cleared after a task finishes,
97+
to prevent memory leaks."""
98+
instrumentor = CeleryInstrumentor()
99+
instrumentor.instrument()
100+
101+
result = task_add.delay(1, 2)
102+
103+
timeout = time.time() + 60 * 1 # 1 minutes from now
104+
while not result.ready():
105+
if time.time() > timeout:
106+
break
107+
time.sleep(0.05)
108+
109+
self.assertTrue(result.ready())
110+
self.assertEqual(result.result, 3)
111+
self.assertEqual(instrumentor.task_id_to_start_time, {})
112+
95113
def test_task_raises(self):
96114
CeleryInstrumentor().instrument()
97115

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright The OpenTelemetry Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from typing import (
7+
Any,
8+
Iterable,
9+
Iterator,
10+
Mapping,
11+
MutableMapping,
12+
Optional,
13+
Tuple,
14+
TypeVar,
15+
Union,
16+
)
17+
18+
KT = TypeVar("KT")
19+
VT = TypeVar("VT")
20+
21+
22+
class CIDict(MutableMapping[KT, VT]):
23+
def __init__(
24+
self,
25+
data: Optional[Union[Mapping[KT, VT], Iterable[Tuple[KT, VT]]]] = None,
26+
) -> None:
27+
self._data: dict[KT, Tuple[KT, VT]] = {}
28+
if data is None:
29+
data = {}
30+
self.update(data)
31+
32+
@staticmethod
33+
def _normalize_key(key: KT) -> KT:
34+
if isinstance(key, str):
35+
return key.lower() # type: ignore
36+
return key
37+
38+
def _get_entry(self, key: KT) -> Tuple[KT, VT]:
39+
normalized_key = self._normalize_key(key)
40+
if normalized_key in self._data:
41+
return self._data[normalized_key]
42+
raise KeyError(repr(key))
43+
44+
def original_key(self, key: KT) -> KT:
45+
return self._get_entry(key)[0]
46+
47+
def normalized_items(self) -> Iterable[Tuple[KT, VT]]:
48+
return ((key, value[1]) for key, value in self._data.items())
49+
50+
def __setitem__(self, key: KT, value: VT, /) -> None:
51+
self._data[self._normalize_key(key)] = (key, value)
52+
53+
def __delitem__(self, key: KT, /) -> None:
54+
try:
55+
del self._data[self._normalize_key(key)]
56+
except KeyError:
57+
raise KeyError(repr(key)) from None
58+
59+
def __getitem__(self, key: KT, /) -> VT:
60+
return self._get_entry(key)[1]
61+
62+
def __len__(self) -> int:
63+
return len(self._data)
64+
65+
def __iter__(self) -> Iterator[KT]:
66+
return (key for key, _ in self._data.values())
67+
68+
def __repr__(self) -> str:
69+
return f"{self.__class__.__name__}({dict(self.items())!r})"
70+
71+
def __eq__(self, other: Any) -> bool:
72+
if isinstance(other, CIDict):
73+
return dict(self.normalized_items()) == dict(
74+
other.normalized_items()
75+
)
76+
if not isinstance(other, Mapping):
77+
return False
78+
ciother: CIDict[Any, Any] = CIDict(other)
79+
return dict(self.normalized_items()) == dict(
80+
ciother.normalized_items()
81+
)

0 commit comments

Comments
 (0)