diff --git a/CHANGELOG.md b/CHANGELOG.md index 93f0cab4fc..2b45a9ba75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#4049](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4049)) - `opentelemetry-instrumentation-sqlalchemy`: implement new semantic convention opt-in migration ([#4110](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4110)) +- `opentelemetry-instrumentation-asyncpg`: instrument prepared statements + ([#4529](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4529)) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py b/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py index c8aba9bbf3..f0c47c8a40 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py @@ -53,6 +53,7 @@ async def main(): from typing import Collection import asyncpg +import asyncpg.prepared_stmt import wrapt from opentelemetry import trace @@ -76,6 +77,14 @@ async def main(): from opentelemetry.trace import SpanKind from opentelemetry.trace.status import Status, StatusCode +_PREPARED_STMT_METHODS = ( + "fetch", + "fetchval", + "fetchrow", + "executemany", + "fetchmany", +) + def _hydrate_span_from_args(connection, query, parameters) -> dict: """Get network and database attributes from connection.""" @@ -153,6 +162,14 @@ def _instrument(self, **kwargs): "asyncpg.cursor", method, self._do_cursor_execute ) + for method in _PREPARED_STMT_METHODS: + if hasattr(asyncpg.prepared_stmt.PreparedStatement, method): + wrapt.wrap_function_wrapper( + "asyncpg.prepared_stmt", + f"PreparedStatement.{method}", + self._do_prepared_execute, + ) + def _uninstrument(self, **__): for cls, methods in [ ( @@ -165,6 +182,10 @@ def _uninstrument(self, **__): for method_name in methods: unwrap(cls, method_name) + for method_name in _PREPARED_STMT_METHODS: + if hasattr(asyncpg.prepared_stmt.PreparedStatement, method_name): + unwrap(asyncpg.prepared_stmt.PreparedStatement, method_name) + async def _do_execute(self, func, instance, args, kwargs): exception = None params = getattr(instance, "_params", None) @@ -243,3 +264,32 @@ async def _do_cursor_execute(self, func, instance, args, kwargs): if not stop: return result raise StopAsyncIteration + + async def _do_prepared_execute(self, func, instance, args, kwargs): + exception = None + query = instance._query or "" + + try: + name = self._leading_comment_remover.sub("", query).split()[0] + except IndexError: + name = "" + + span_attributes = _hydrate_span_from_args( + instance._connection, + query, + args if self.capture_parameters else None, + ) + + with self._tracer.start_as_current_span( + name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + try: + result = await func(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + raise + finally: + if span.is_recording() and exception is not None: + span.set_status(Status(StatusCode.ERROR)) + + return result diff --git a/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py b/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py index f8c905a2c7..36109d6f70 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py +++ b/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py @@ -3,6 +3,7 @@ import pytest from asyncpg import Connection, Record, cursor +from asyncpg.prepared_stmt import PreparedStatement try: # wrapt 2.0.0+ @@ -11,11 +12,18 @@ from wrapt import ObjectProxy as BaseObjectProxy from opentelemetry import trace as trace_api -from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor +from opentelemetry.instrumentation.asyncpg import ( + _PREPARED_STMT_METHODS, + AsyncPGInstrumentor, +) from opentelemetry.test.test_base import TestBase class TestAsyncPGInstrumentation(TestBase): + def tearDown(self): + super().tearDown() + AsyncPGInstrumentor().uninstrument() + def test_duplicated_instrumentation_can_be_uninstrumented(self): AsyncPGInstrumentor().instrument() AsyncPGInstrumentor().instrument() @@ -144,3 +152,122 @@ async def exec_mock(*args, **kwargs): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 0) + + def test_prepared_statement_instrumentation(self): + methods = [ + m for m in _PREPARED_STMT_METHODS if hasattr(PreparedStatement, m) + ] + + for method_name in methods: + with self.subTest(method=method_name, phase="before"): + self.assertFalse( + isinstance( + getattr(PreparedStatement, method_name), + BaseObjectProxy, + ) + ) + + AsyncPGInstrumentor().instrument() + + for method_name in methods: + with self.subTest(method=method_name, phase="instrumented"): + self.assertTrue( + isinstance( + getattr(PreparedStatement, method_name), + BaseObjectProxy, + ) + ) + + AsyncPGInstrumentor().uninstrument() + + for method_name in methods: + with self.subTest(method=method_name, phase="uninstrumented"): + self.assertFalse( + isinstance( + getattr(PreparedStatement, method_name), + BaseObjectProxy, + ) + ) + + @staticmethod + def _make_prepared_stmt_conn(): + async def bind_execute_mock(*args, **kwargs): + return [], b"SELECT 1", True + + async def bind_execute_many_mock(*args, **kwargs): + return None + + conn = mock.Mock() + conn._pool_release_ctr = 0 + conn.is_closed = lambda: False + conn._protocol = mock.Mock() + conn._protocol.bind_execute = bind_execute_mock + conn._protocol.bind_execute_many = bind_execute_many_mock + + state = mock.Mock() + state.closed = False + return conn, state + + def test_prepared_statement_span(self): + # Per-method: (query, call_args, expected_span_name) + method_cases = { + "fetch": ("SELECT * FROM users", (), "SELECT"), + "fetchval": ("SELECT id FROM users WHERE id=$1", (1,), "SELECT"), + "fetchrow": ("SELECT * FROM t WHERE v=$1", ("x",), "SELECT"), + "executemany": ( + "INSERT INTO t (v) VALUES ($1)", + ([("a",), ("b",)],), + "INSERT", + ), + "fetchmany": ("SELECT * FROM t", ([],), "SELECT"), + } + + for method_name in _PREPARED_STMT_METHODS: + if not hasattr(PreparedStatement, method_name): + continue + query, call_args, expected_name = method_cases[method_name] + with self.subTest(method=method_name): + self.memory_exporter.clear() + conn, state = self._make_prepared_stmt_conn() + apg = AsyncPGInstrumentor() + apg.instrument(tracer_provider=self.tracer_provider) + + stmt = PreparedStatement(conn, query, state) + asyncio.run(getattr(stmt, method_name)(*call_args)) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].name, expected_name) + self.assertTrue(spans[0].status.is_ok) + self.assertEqual( + spans[0].attributes.get("db.statement"), query + ) + self.assertEqual( + spans[0].attributes.get("db.system"), "postgresql" + ) + + apg.uninstrument() + + def test_prepared_statement_error_span(self): + async def bind_execute_error(*args, **kwargs): + raise RuntimeError("db error") + + conn = mock.Mock() + conn._pool_release_ctr = 0 + conn.is_closed = lambda: False + conn._protocol = mock.Mock() + conn._protocol.bind_execute = bind_execute_error + + state = mock.Mock() + state.closed = False + + apg = AsyncPGInstrumentor() + apg.instrument(tracer_provider=self.tracer_provider) + + stmt = PreparedStatement(conn, "SELECT 1", state) + with self.assertRaises(RuntimeError): + asyncio.run(stmt.fetch()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.assertFalse(spans[0].status.is_ok)