diff --git a/.changelog/4566.added b/.changelog/4566.added new file mode 100644 index 0000000000..475b24745c --- /dev/null +++ b/.changelog/4566.added @@ -0,0 +1 @@ +`opentelemetry-instrumentation-starlette`: Add missing configuration params from ASGI Middleware diff --git a/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py b/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py index afe264bcd2..bc07544e40 100644 --- a/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-starlette/src/opentelemetry/instrumentation/starlette/__init__.py @@ -165,7 +165,8 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from __future__ import annotations -from typing import TYPE_CHECKING, Any, Collection, cast +import logging +from typing import TYPE_CHECKING, Any, Collection, Literal, cast from weakref import WeakSet from starlette import applications @@ -185,7 +186,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A HTTP_ROUTE, ) from opentelemetry.trace import TracerProvider, get_tracer -from opentelemetry.util.http import get_excluded_urls +from opentelemetry.util.http import get_excluded_urls, parse_excluded_urls if TYPE_CHECKING: from typing import TypedDict, Unpack @@ -196,9 +197,15 @@ class InstrumentKwargs(TypedDict, total=False): server_request_hook: ServerRequestHook client_request_hook: ClientRequestHook client_response_hook: ClientResponseHook + excluded_urls: str + http_capture_headers_server_request: list[str] + http_capture_headers_server_response: list[str] + http_capture_headers_sanitize_fields: list[str] + exclude_spans: list[Literal["receive", "send"]] -_excluded_urls = get_excluded_urls("STARLETTE") +_excluded_urls_from_env = get_excluded_urls("STARLETTE") +_logger = logging.getLogger(__name__) class StarletteInstrumentor(BaseInstrumentor): @@ -217,6 +224,11 @@ def instrument_app( client_response_hook: ClientResponseHook = None, meter_provider: MeterProvider | None = None, tracer_provider: TracerProvider | None = None, + excluded_urls: str | None = None, + http_capture_headers_server_request: list[str] | None = None, + http_capture_headers_server_response: list[str] | None = None, + http_capture_headers_sanitize_fields: list[str] | None = None, + exclude_spans: list[Literal["receive", "send"]] | None = None, ): """Instrument an uninstrumented Starlette application. @@ -232,23 +244,36 @@ def instrument_app( the current globally configured one is used. tracer_provider: The optional tracer provider to use. If omitted the current globally configured one is used. + excluded_urls: Optional comma delimited string of regexes to match URLs that should not be traced. + http_capture_headers_server_request: Optional list of HTTP headers to capture from the request. + http_capture_headers_server_response: Optional list of HTTP headers to capture from the response. + http_capture_headers_sanitize_fields: Optional list of HTTP headers to sanitize. + exclude_spans: Optionally exclude HTTP `send` and/or `receive` spans from the trace. """ - tracer = get_tracer( - __name__, - __version__, - tracer_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", - ) - meter = get_meter( - __name__, - __version__, - meter_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", - ) + if not hasattr(app, "_is_instrumented_by_opentelemetry"): + app._is_instrumented_by_opentelemetry = False + if not getattr(app, "_is_instrumented_by_opentelemetry", False): + if excluded_urls is None: + excluded_urls = _excluded_urls_from_env + else: + excluded_urls = parse_excluded_urls(excluded_urls) + tracer = get_tracer( + __name__, + __version__, + tracer_provider, + schema_url="https://opentelemetry.io/schemas/1.11.0", + ) + meter = get_meter( + __name__, + __version__, + meter_provider, + schema_url="https://opentelemetry.io/schemas/1.11.0", + ) + app.add_middleware( OpenTelemetryMiddleware, - excluded_urls=_excluded_urls, + excluded_urls=excluded_urls, default_span_details=_get_default_span_details, server_request_hook=server_request_hook, client_request_hook=client_request_hook, @@ -256,11 +281,19 @@ def instrument_app( # Pass in tracer/meter to get __name__and __version__ of starlette instrumentation tracer=tracer, meter=meter, + http_capture_headers_server_request=http_capture_headers_server_request, + http_capture_headers_server_response=http_capture_headers_server_response, + http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, + exclude_spans=exclude_spans, ) app._is_instrumented_by_opentelemetry = True # adding apps to set for uninstrumenting _InstrumentedStarlette._instrumented_starlette_apps.add(app) + else: + _logger.warning( + "Attempting to instrument Starlette app while already instrumented" + ) @staticmethod def uninstrument_app(app: applications.Starlette): @@ -277,64 +310,31 @@ def instrumentation_dependencies(self) -> Collection[str]: def _instrument(self, **kwargs: Unpack[InstrumentKwargs]): self._original_starlette = applications.Starlette - _InstrumentedStarlette._tracer_provider = kwargs.get("tracer_provider") - _InstrumentedStarlette._server_request_hook = kwargs.get( - "server_request_hook" - ) - _InstrumentedStarlette._client_request_hook = kwargs.get( - "client_request_hook" - ) - _InstrumentedStarlette._client_response_hook = kwargs.get( - "client_response_hook" - ) - _InstrumentedStarlette._meter_provider = kwargs.get("meter_provider") - + _InstrumentedStarlette._instrument_kwargs = kwargs applications.Starlette = _InstrumentedStarlette def _uninstrument(self, **kwargs: Any): """uninstrumenting all created apps by user""" - for instance in _InstrumentedStarlette._instrumented_starlette_apps: + # Create a copy of the set to avoid RuntimeError during iteration + instances_to_uninstrument = list( + _InstrumentedStarlette._instrumented_starlette_apps + ) + for instance in instances_to_uninstrument: self.uninstrument_app(instance) _InstrumentedStarlette._instrumented_starlette_apps.clear() applications.Starlette = self._original_starlette class _InstrumentedStarlette(applications.Starlette): - _tracer_provider: TracerProvider | None = None - _meter_provider: MeterProvider | None = None - _server_request_hook: ServerRequestHook = None - _client_request_hook: ClientRequestHook = None - _client_response_hook: ClientResponseHook = None + _instrument_kwargs: dict[str, Any] = {} + # Track instrumented app instances using weak references to avoid GC leaks _instrumented_starlette_apps: WeakSet[applications.Starlette] = WeakSet() def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - tracer = get_tracer( - __name__, - __version__, - _InstrumentedStarlette._tracer_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", - ) - meter = get_meter( - __name__, - __version__, - _InstrumentedStarlette._meter_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", - ) - self.add_middleware( - OpenTelemetryMiddleware, - excluded_urls=_excluded_urls, - default_span_details=_get_default_span_details, - server_request_hook=_InstrumentedStarlette._server_request_hook, - client_request_hook=_InstrumentedStarlette._client_request_hook, - client_response_hook=_InstrumentedStarlette._client_response_hook, - # Pass in tracer/meter to get __name__and __version__ of starlette instrumentation - tracer=tracer, - meter=meter, + StarletteInstrumentor.instrument_app( + self, **_InstrumentedStarlette._instrument_kwargs ) - self._is_instrumented_by_opentelemetry = True - # adding apps to set for uninstrumenting - _InstrumentedStarlette._instrumented_starlette_apps.add(self) def _get_route_details(scope: dict[str, Any]) -> str | None: diff --git a/instrumentation/opentelemetry-instrumentation-starlette/tests/test_starlette_instrumentation.py b/instrumentation/opentelemetry-instrumentation-starlette/tests/test_starlette_instrumentation.py index d534b1d7a3..0ac318da37 100644 --- a/instrumentation/opentelemetry-instrumentation-starlette/tests/test_starlette_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-starlette/tests/test_starlette_instrumentation.py @@ -1,17 +1,20 @@ # Copyright The OpenTelemetry Authors # SPDX-License-Identifier: Apache-2.0 +# pylint: disable=too-many-lines + import unittest from timeit import default_timer from unittest.mock import patch from starlette import applications -from starlette.responses import PlainTextResponse +from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route from starlette.testclient import TestClient from starlette.websockets import WebSocket import opentelemetry.instrumentation.starlette as otel_starlette +from opentelemetry import trace from opentelemetry.sdk.metrics.export import ( HistogramDataPoint, NumberDataPoint, @@ -56,7 +59,7 @@ SCOPE = "opentelemetry.instrumentation.starlette" -class TestStarletteManualInstrumentation(TestBase): +class TestBaseStarlette(TestBase): def _create_app(self): app = self._create_starlette_app() self._instrumentor.instrument_app( @@ -67,6 +70,18 @@ def _create_app(self): ) return app + def _create_app_explicit_excluded_urls(self): + app = self._create_starlette_app() + to_exclude = "/user/123,/foobar" + self._instrumentor.instrument_app( + app=app, + excluded_urls=to_exclude, + server_request_hook=getattr(self, "server_request_hook", None), + client_request_hook=getattr(self, "client_request_hook", None), + client_response_hook=getattr(self, "client_response_hook", None), + ) + return app + def setUp(self): super().setUp() self.env_patch = patch.dict( @@ -75,7 +90,7 @@ def setUp(self): ) self.env_patch.start() self.exclude_patch = patch( - "opentelemetry.instrumentation.starlette._excluded_urls", + "opentelemetry.instrumentation.starlette._excluded_urls_from_env", get_excluded_urls("STARLETTE"), ) self.exclude_patch.start() @@ -88,6 +103,33 @@ def tearDown(self): self.env_patch.stop() self.exclude_patch.stop() + @staticmethod + def _create_starlette_app(): + def home(_): + return PlainTextResponse("hi") + + def health(_): + return PlainTextResponse("ok") + + def sub_home(_): + return PlainTextResponse("sub hi") + + sub_app = applications.Starlette(routes=[Route("/home", sub_home)]) + + app = applications.Starlette( + routes=[ + Route("/foobar", home), + Route("/user/{username}", home), + Route("/healthzz", health), + Mount("/sub", app=sub_app), + Host("testserver2", sub_app), + ], + ) + + return app + + +class TestStarletteManualInstrumentation(TestBaseStarlette): def test_basic_starlette_call(self): self._client.get("/foobar") spans = self.memory_exporter.get_finished_spans() @@ -167,6 +209,17 @@ def test_starlette_excluded_urls(self): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 0) + def test_starlette_excluded_urls_not_env(self): + """Ensure that given starlette routes are excluded when passed explicitly (not in the environment)""" + app = self._create_app_explicit_excluded_urls() + client = TestClient(app) + client.get("/user/123") + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + client.get("/foobar") + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + def test_starlette_metrics(self): self._client.get("/foobar") self._client.get("/foobar") @@ -268,35 +321,8 @@ def test_metric_uninstrument_inherited_by_base(self): if isinstance(point, NumberDataPoint): self.assertEqual(point.value, 0) - @staticmethod - def _create_starlette_app(): - def home(_): - return PlainTextResponse("hi") - - def health(_): - return PlainTextResponse("ok") - - def sub_home(_): - return PlainTextResponse("sub hi") - - sub_app = applications.Starlette(routes=[Route("/home", sub_home)]) - - app = applications.Starlette( - routes=[ - Route("/foobar", home), - Route("/user/{username}", home), - Route("/healthzz", health), - Mount("/sub", app=sub_app), - Host("testserver2", sub_app), - ], - ) - - return app - -class TestStarletteManualInstrumentationHooks( - TestStarletteManualInstrumentation -): +class TestStarletteBaseHooks(TestBaseStarlette): _server_request_hook = None _client_request_hook = None _client_response_hook = None @@ -313,6 +339,8 @@ def client_response_hook(self, send_span, scope, message): if self._client_response_hook is not None: self._client_response_hook(send_span, scope, message) + +class TestStarletteManualInstrumentationHooks(TestStarletteBaseHooks): def test_hooks(self): def server_request_hook(span, scope): span.update_name("name from server hook") @@ -346,7 +374,7 @@ def client_response_hook(send_span, scope, message): ) -class TestAutoInstrumentation(TestStarletteManualInstrumentation): +class TestAutoInstrumentation(TestBaseStarlette): """Test the auto-instrumented variant Extending the manual instrumentation as most test cases apply @@ -474,7 +502,7 @@ def test_sub_app_starlette_call(self): ) -class TestAutoInstrumentationHooks(TestStarletteManualInstrumentationHooks): +class TestAutoInstrumentationHooks(TestStarletteBaseHooks): """ Test the auto-instrumented variant for request and response hooks """ @@ -573,7 +601,7 @@ def test_instrumentation(self): self.assertIs(original, should_be_original) -class TestConditonalServerSpanCreation(TestStarletteManualInstrumentation): +class TestConditonalServerSpanCreation(TestBaseStarlette): def test_mark_span_internal_in_presence_of_another_span(self): tracer = get_tracer(__name__) with tracer.start_as_current_span( @@ -950,3 +978,140 @@ def test_custom_header_not_present_in_non_recording_span(self): self.assertEqual(200, resp.status_code) span_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(span_list), 0) + + +class TestHTTPAppWithCustomHeadersParameters(TestBase): + """Minimal tests here since the behavior of this logic is tested above and in the ASGI tests.""" + + def setUp(self): + super().setUp() + self._instrumentor = otel_starlette.StarletteInstrumentor() + self.kwargs = { + "http_capture_headers_server_request": ["a.*", "b.*"], + "http_capture_headers_server_response": ["c.*", "d.*"], + "http_capture_headers_sanitize_fields": [".*secret.*"], + } + self.app = None + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + if self.app: + self._instrumentor.uninstrument_app(self.app) + else: + self._instrumentor.uninstrument() + + @staticmethod + def _create_app(): + def home(_): + return JSONResponse( + content={"message": "hi"}, + headers={ + "carrot": "bar", + "date-secret": "yellow", + "egg": "ham", + }, + ) + + app = applications.Starlette(routes=[Route("/foobar", home)]) + + return app + + def test_http_custom_request_headers_in_span_attributes_app(self): + self.app = self._create_app() + self._instrumentor.instrument_app(self.app, **self.kwargs) + + resp = TestClient(self.app).get( + "/foobar", + headers={ + "apple": "red", + "banana-secret": "yellow", + "fig": "green", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + # apple should be included because it starts with a + "http.request.header.apple": ("red",), + # same with banana because it starts with b, + # redacted because it contains "secret" + "http.request.header.banana_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.request.header.fig", server_span.attributes) + + def test_http_custom_request_headers_in_span_attributes_instr(self): + """As above, but use instrument(), not instrument_app().""" + self._instrumentor.instrument(**self.kwargs) + + resp = TestClient(self._create_app()).get( + "/foobar", + headers={ + "apple": "red", + "banana-secret": "yellow", + "fig": "green", + }, + ) + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + # apple should be included because it starts with a + "http.request.header.apple": ("red",), + # same with banana because it starts with b, + # redacted because it contains "secret" + "http.request.header.banana_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.request.header.fig", server_span.attributes) + + def test_http_custom_response_headers_in_span_attributes_app(self): + self.app = self._create_app() + self._instrumentor.instrument_app(self.app, **self.kwargs) + resp = TestClient(self.app).get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + "http.response.header.carrot": ("bar",), + "http.response.header.date_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.response.header.egg", server_span.attributes) + + def test_http_custom_response_headers_in_span_attributes_inst(self): + """As above, but use instrument(), not instrument_app().""" + self._instrumentor.instrument(**self.kwargs) + + resp = TestClient(self._create_app()).get("/foobar") + self.assertEqual(200, resp.status_code) + span_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(span_list), 3) + + server_span = [ + span for span in span_list if span.kind == trace.SpanKind.SERVER + ][0] + + expected = { + "http.response.header.carrot": ("bar",), + "http.response.header.date_secret": ("[REDACTED]",), + } + self.assertSpanHasAttributes(server_span, expected) + self.assertNotIn("http.response.header.egg", server_span.attributes)