Skip to content

Commit 5fd2fbd

Browse files
committed
fix(asyncpg): instrument prepared statements
Assisted-by: Claude Sonnet 4.6
1 parent d9d586c commit 5fd2fbd

3 files changed

Lines changed: 180 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
([#4049](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4049))
5959
- `opentelemetry-instrumentation-sqlalchemy`: implement new semantic convention opt-in migration
6060
([#4110](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4110))
61+
- `opentelemetry-instrumentation-asyncpg`: instrument prepared statements
62+
([#4529](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4529))
6163

6264
### Fixed
6365

instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async def main():
5353
from typing import Collection
5454

5555
import asyncpg
56+
import asyncpg.prepared_stmt
5657
import wrapt
5758

5859
from opentelemetry import trace
@@ -76,6 +77,14 @@ async def main():
7677
from opentelemetry.trace import SpanKind
7778
from opentelemetry.trace.status import Status, StatusCode
7879

80+
_PREPARED_STMT_METHODS = (
81+
"fetch",
82+
"fetchval",
83+
"fetchrow",
84+
"executemany",
85+
"fetchmany",
86+
)
87+
7988

8089
def _hydrate_span_from_args(connection, query, parameters) -> dict:
8190
"""Get network and database attributes from connection."""
@@ -153,6 +162,14 @@ def _instrument(self, **kwargs):
153162
"asyncpg.cursor", method, self._do_cursor_execute
154163
)
155164

165+
for method in _PREPARED_STMT_METHODS:
166+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method):
167+
wrapt.wrap_function_wrapper(
168+
"asyncpg.prepared_stmt",
169+
f"PreparedStatement.{method}",
170+
self._do_prepared_execute,
171+
)
172+
156173
def _uninstrument(self, **__):
157174
for cls, methods in [
158175
(
@@ -165,6 +182,10 @@ def _uninstrument(self, **__):
165182
for method_name in methods:
166183
unwrap(cls, method_name)
167184

185+
for method_name in _PREPARED_STMT_METHODS:
186+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method_name):
187+
unwrap(asyncpg.prepared_stmt.PreparedStatement, method_name)
188+
168189
async def _do_execute(self, func, instance, args, kwargs):
169190
exception = None
170191
params = getattr(instance, "_params", None)
@@ -243,3 +264,32 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
243264
if not stop:
244265
return result
245266
raise StopAsyncIteration
267+
268+
async def _do_prepared_execute(self, func, instance, args, kwargs):
269+
exception = None
270+
query = instance._query or ""
271+
272+
try:
273+
name = self._leading_comment_remover.sub("", query).split()[0]
274+
except IndexError:
275+
name = ""
276+
277+
span_attributes = _hydrate_span_from_args(
278+
instance._connection,
279+
query,
280+
args if self.capture_parameters else None,
281+
)
282+
283+
with self._tracer.start_as_current_span(
284+
name, kind=SpanKind.CLIENT, attributes=span_attributes
285+
) as span:
286+
try:
287+
result = await func(*args, **kwargs)
288+
except Exception as exc: # pylint: disable=W0703
289+
exception = exc
290+
raise
291+
finally:
292+
if span.is_recording() and exception is not None:
293+
span.set_status(Status(StatusCode.ERROR))
294+
295+
return result

instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from asyncpg import Connection, Record, cursor
6+
from asyncpg.prepared_stmt import PreparedStatement
67

78
try:
89
# wrapt 2.0.0+
@@ -11,11 +12,18 @@
1112
from wrapt import ObjectProxy as BaseObjectProxy
1213

1314
from opentelemetry import trace as trace_api
14-
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
15+
from opentelemetry.instrumentation.asyncpg import (
16+
_PREPARED_STMT_METHODS,
17+
AsyncPGInstrumentor,
18+
)
1519
from opentelemetry.test.test_base import TestBase
1620

1721

1822
class TestAsyncPGInstrumentation(TestBase):
23+
def tearDown(self):
24+
super().tearDown()
25+
AsyncPGInstrumentor().uninstrument()
26+
1927
def test_duplicated_instrumentation_can_be_uninstrumented(self):
2028
AsyncPGInstrumentor().instrument()
2129
AsyncPGInstrumentor().instrument()
@@ -144,3 +152,122 @@ async def exec_mock(*args, **kwargs):
144152

145153
spans = self.memory_exporter.get_finished_spans()
146154
self.assertEqual(len(spans), 0)
155+
156+
def test_prepared_statement_instrumentation(self):
157+
methods = [
158+
m for m in _PREPARED_STMT_METHODS if hasattr(PreparedStatement, m)
159+
]
160+
161+
for method_name in methods:
162+
with self.subTest(method=method_name, phase="before"):
163+
self.assertFalse(
164+
isinstance(
165+
getattr(PreparedStatement, method_name),
166+
BaseObjectProxy,
167+
)
168+
)
169+
170+
AsyncPGInstrumentor().instrument()
171+
172+
for method_name in methods:
173+
with self.subTest(method=method_name, phase="instrumented"):
174+
self.assertTrue(
175+
isinstance(
176+
getattr(PreparedStatement, method_name),
177+
BaseObjectProxy,
178+
)
179+
)
180+
181+
AsyncPGInstrumentor().uninstrument()
182+
183+
for method_name in methods:
184+
with self.subTest(method=method_name, phase="uninstrumented"):
185+
self.assertFalse(
186+
isinstance(
187+
getattr(PreparedStatement, method_name),
188+
BaseObjectProxy,
189+
)
190+
)
191+
192+
@staticmethod
193+
def _make_prepared_stmt_conn():
194+
async def bind_execute_mock(*args, **kwargs):
195+
return [], b"SELECT 1", True
196+
197+
async def bind_execute_many_mock(*args, **kwargs):
198+
return None
199+
200+
conn = mock.Mock()
201+
conn._pool_release_ctr = 0
202+
conn.is_closed = lambda: False
203+
conn._protocol = mock.Mock()
204+
conn._protocol.bind_execute = bind_execute_mock
205+
conn._protocol.bind_execute_many = bind_execute_many_mock
206+
207+
state = mock.Mock()
208+
state.closed = False
209+
return conn, state
210+
211+
def test_prepared_statement_span(self):
212+
# Per-method: (query, call_args, expected_span_name)
213+
method_cases = {
214+
"fetch": ("SELECT * FROM users", (), "SELECT"),
215+
"fetchval": ("SELECT id FROM users WHERE id=$1", (1,), "SELECT"),
216+
"fetchrow": ("SELECT * FROM t WHERE v=$1", ("x",), "SELECT"),
217+
"executemany": (
218+
"INSERT INTO t (v) VALUES ($1)",
219+
([("a",), ("b",)],),
220+
"INSERT",
221+
),
222+
"fetchmany": ("SELECT * FROM t", ([],), "SELECT"),
223+
}
224+
225+
for method_name in _PREPARED_STMT_METHODS:
226+
if not hasattr(PreparedStatement, method_name):
227+
continue
228+
query, call_args, expected_name = method_cases[method_name]
229+
with self.subTest(method=method_name):
230+
self.memory_exporter.clear()
231+
conn, state = self._make_prepared_stmt_conn()
232+
apg = AsyncPGInstrumentor()
233+
apg.instrument(tracer_provider=self.tracer_provider)
234+
235+
stmt = PreparedStatement(conn, query, state)
236+
asyncio.run(getattr(stmt, method_name)(*call_args))
237+
238+
spans = self.memory_exporter.get_finished_spans()
239+
self.assertEqual(len(spans), 1)
240+
self.assertEqual(spans[0].name, expected_name)
241+
self.assertTrue(spans[0].status.is_ok)
242+
self.assertEqual(
243+
spans[0].attributes.get("db.statement"), query
244+
)
245+
self.assertEqual(
246+
spans[0].attributes.get("db.system"), "postgresql"
247+
)
248+
249+
apg.uninstrument()
250+
251+
def test_prepared_statement_error_span(self):
252+
async def bind_execute_error(*args, **kwargs):
253+
raise RuntimeError("db error")
254+
255+
conn = mock.Mock()
256+
conn._pool_release_ctr = 0
257+
conn.is_closed = lambda: False
258+
conn._protocol = mock.Mock()
259+
conn._protocol.bind_execute = bind_execute_error
260+
261+
state = mock.Mock()
262+
state.closed = False
263+
264+
apg = AsyncPGInstrumentor()
265+
apg.instrument(tracer_provider=self.tracer_provider)
266+
267+
stmt = PreparedStatement(conn, "SELECT 1", state)
268+
with self.assertRaises(RuntimeError):
269+
asyncio.run(stmt.fetch())
270+
271+
spans = self.memory_exporter.get_finished_spans()
272+
self.assertEqual(len(spans), 1)
273+
self.assertFalse(spans[0].status.is_ok)

0 commit comments

Comments
 (0)