Skip to content

Commit b68434d

Browse files
committed
fix(asyncpg): instrument prepared statements
Assisted-by: Claude Sonnet 4.6
1 parent d9d586c commit b68434d

3 files changed

Lines changed: 175 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
([#4049](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4049))
5959
- `opentelemetry-instrumentation-sqlalchemy`: implement new semantic convention opt-in migration
6060
([#4110](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4110))
61+
- `opentelemetry-instrumentation-asyncpg`: instrument prepared statements
62+
([#4529](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4529))
6163

6264
### Fixed
6365

instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async def main():
5353
from typing import Collection
5454

5555
import asyncpg
56+
import asyncpg.prepared_stmt
5657
import wrapt
5758

5859
from opentelemetry import trace
@@ -76,6 +77,14 @@ async def main():
7677
from opentelemetry.trace import SpanKind
7778
from opentelemetry.trace.status import Status, StatusCode
7879

80+
_PREPARED_STMT_METHODS = (
81+
"fetch",
82+
"fetchval",
83+
"fetchrow",
84+
"executemany",
85+
"fetchmany",
86+
)
87+
7988

8089
def _hydrate_span_from_args(connection, query, parameters) -> dict:
8190
"""Get network and database attributes from connection."""
@@ -153,6 +162,14 @@ def _instrument(self, **kwargs):
153162
"asyncpg.cursor", method, self._do_cursor_execute
154163
)
155164

165+
for method in _PREPARED_STMT_METHODS:
166+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method):
167+
wrapt.wrap_function_wrapper(
168+
"asyncpg.prepared_stmt",
169+
f"PreparedStatement.{method}",
170+
self._do_prepared_execute,
171+
)
172+
156173
def _uninstrument(self, **__):
157174
for cls, methods in [
158175
(
@@ -165,6 +182,10 @@ def _uninstrument(self, **__):
165182
for method_name in methods:
166183
unwrap(cls, method_name)
167184

185+
for method_name in _PREPARED_STMT_METHODS:
186+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method_name):
187+
unwrap(asyncpg.prepared_stmt.PreparedStatement, method_name)
188+
168189
async def _do_execute(self, func, instance, args, kwargs):
169190
exception = None
170191
params = getattr(instance, "_params", None)
@@ -243,3 +264,32 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
243264
if not stop:
244265
return result
245266
raise StopAsyncIteration
267+
268+
async def _do_prepared_execute(self, func, instance, args, kwargs):
269+
exception = None
270+
query = instance._query or ""
271+
272+
try:
273+
name = self._leading_comment_remover.sub("", query).split()[0]
274+
except IndexError:
275+
name = ""
276+
277+
span_attributes = _hydrate_span_from_args(
278+
instance._connection,
279+
query,
280+
args if self.capture_parameters else None,
281+
)
282+
283+
with self._tracer.start_as_current_span(
284+
name, kind=SpanKind.CLIENT, attributes=span_attributes
285+
) as span:
286+
try:
287+
result = await func(*args, **kwargs)
288+
except Exception as exc: # pylint: disable=W0703
289+
exception = exc
290+
raise
291+
finally:
292+
if span.is_recording() and exception is not None:
293+
span.set_status(Status(StatusCode.ERROR))
294+
295+
return result

instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from asyncpg import Connection, Record, cursor
6+
from asyncpg.prepared_stmt import PreparedStatement
67

78
try:
89
# wrapt 2.0.0+
@@ -11,7 +12,10 @@
1112
from wrapt import ObjectProxy as BaseObjectProxy
1213

1314
from opentelemetry import trace as trace_api
14-
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
15+
from opentelemetry.instrumentation.asyncpg import (
16+
_PREPARED_STMT_METHODS,
17+
AsyncPGInstrumentor,
18+
)
1519
from opentelemetry.test.test_base import TestBase
1620

1721

@@ -144,3 +148,121 @@ async def exec_mock(*args, **kwargs):
144148

145149
spans = self.memory_exporter.get_finished_spans()
146150
self.assertEqual(len(spans), 0)
151+
152+
def test_prepared_statement_instrumentation(self):
153+
methods = [
154+
m for m in _PREPARED_STMT_METHODS if hasattr(PreparedStatement, m)
155+
]
156+
157+
for method_name in methods:
158+
with self.subTest(method=method_name, phase="before"):
159+
self.assertFalse(
160+
isinstance(
161+
getattr(PreparedStatement, method_name),
162+
BaseObjectProxy,
163+
)
164+
)
165+
166+
AsyncPGInstrumentor().instrument()
167+
168+
for method_name in methods:
169+
with self.subTest(method=method_name, phase="instrumented"):
170+
self.assertTrue(
171+
isinstance(
172+
getattr(PreparedStatement, method_name),
173+
BaseObjectProxy,
174+
)
175+
)
176+
177+
AsyncPGInstrumentor().uninstrument()
178+
179+
for method_name in methods:
180+
with self.subTest(method=method_name, phase="uninstrumented"):
181+
self.assertFalse(
182+
isinstance(
183+
getattr(PreparedStatement, method_name),
184+
BaseObjectProxy,
185+
)
186+
)
187+
188+
def _make_prepared_stmt_conn(self):
189+
async def bind_execute_mock(*args, **kwargs):
190+
return [], b"SELECT 1", True
191+
192+
async def bind_execute_many_mock(*args, **kwargs):
193+
return None
194+
195+
conn = mock.Mock()
196+
conn._pool_release_ctr = 0
197+
conn.is_closed = lambda: False
198+
conn._protocol = mock.Mock()
199+
conn._protocol.bind_execute = bind_execute_mock
200+
conn._protocol.bind_execute_many = bind_execute_many_mock
201+
202+
state = mock.Mock()
203+
state.closed = False
204+
return conn, state
205+
206+
def test_prepared_statement_span(self):
207+
# Per-method: (query, call_args, expected_span_name)
208+
method_cases = {
209+
"fetch": ("SELECT * FROM users", (), "SELECT"),
210+
"fetchval": ("SELECT id FROM users WHERE id=$1", (1,), "SELECT"),
211+
"fetchrow": ("SELECT * FROM t WHERE v=$1", ("x",), "SELECT"),
212+
"executemany": (
213+
"INSERT INTO t (v) VALUES ($1)",
214+
([("a",), ("b",)],),
215+
"INSERT",
216+
),
217+
"fetchmany": ("SELECT * FROM t", ([],), "SELECT"),
218+
}
219+
220+
for method_name in _PREPARED_STMT_METHODS:
221+
if not hasattr(PreparedStatement, method_name):
222+
continue
223+
query, call_args, expected_name = method_cases[method_name]
224+
with self.subTest(method=method_name):
225+
self.memory_exporter.clear()
226+
conn, state = self._make_prepared_stmt_conn()
227+
apg = AsyncPGInstrumentor()
228+
apg.instrument(tracer_provider=self.tracer_provider)
229+
230+
stmt = PreparedStatement(conn, query, state)
231+
asyncio.run(getattr(stmt, method_name)(*call_args))
232+
233+
spans = self.memory_exporter.get_finished_spans()
234+
self.assertEqual(len(spans), 1)
235+
self.assertEqual(spans[0].name, expected_name)
236+
self.assertTrue(spans[0].status.is_ok)
237+
self.assertEqual(
238+
spans[0].attributes.get("db.statement"), query
239+
)
240+
self.assertEqual(
241+
spans[0].attributes.get("db.system"), "postgresql"
242+
)
243+
244+
apg.uninstrument()
245+
246+
def test_prepared_statement_error_span(self):
247+
async def bind_execute_error(*args, **kwargs):
248+
raise RuntimeError("db error")
249+
250+
conn = mock.Mock()
251+
conn._pool_release_ctr = 0
252+
conn.is_closed = lambda: False
253+
conn._protocol = mock.Mock()
254+
conn._protocol.bind_execute = bind_execute_error
255+
256+
state = mock.Mock()
257+
state.closed = False
258+
259+
apg = AsyncPGInstrumentor()
260+
apg.instrument(tracer_provider=self.tracer_provider)
261+
262+
stmt = PreparedStatement(conn, "SELECT 1", state)
263+
with self.assertRaises(RuntimeError):
264+
asyncio.run(stmt.fetch())
265+
266+
spans = self.memory_exporter.get_finished_spans()
267+
self.assertEqual(len(spans), 1)
268+
self.assertFalse(spans[0].status.is_ok)

0 commit comments

Comments
 (0)