Skip to content

Commit 1fc33ff

Browse files
committed
fix(asyncpg): instrument prepared statements
Assisted-by: Claude Sonnet 4.6
1 parent d9d586c commit 1fc33ff

3 files changed

Lines changed: 179 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
([#4049](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4049))
5959
- `opentelemetry-instrumentation-sqlalchemy`: implement new semantic convention opt-in migration
6060
([#4110](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4110))
61+
- `opentelemetry-instrumentation-asyncpg`: instrument prepared statements
62+
([#4529](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4529))
6163

6264
### Fixed
6365

instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async def main():
5353
from typing import Collection
5454

5555
import asyncpg
56+
import asyncpg.prepared_stmt
5657
import wrapt
5758

5859
from opentelemetry import trace
@@ -76,6 +77,14 @@ async def main():
7677
from opentelemetry.trace import SpanKind
7778
from opentelemetry.trace.status import Status, StatusCode
7879

80+
_PREPARED_STMT_METHODS = (
81+
"fetch",
82+
"fetchval",
83+
"fetchrow",
84+
"executemany",
85+
"fetchmany",
86+
)
87+
7988

8089
def _hydrate_span_from_args(connection, query, parameters) -> dict:
8190
"""Get network and database attributes from connection."""
@@ -153,6 +162,14 @@ def _instrument(self, **kwargs):
153162
"asyncpg.cursor", method, self._do_cursor_execute
154163
)
155164

165+
for method in _PREPARED_STMT_METHODS:
166+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method):
167+
wrapt.wrap_function_wrapper(
168+
"asyncpg.prepared_stmt",
169+
f"PreparedStatement.{method}",
170+
self._do_prepared_execute,
171+
)
172+
156173
def _uninstrument(self, **__):
157174
for cls, methods in [
158175
(
@@ -165,6 +182,10 @@ def _uninstrument(self, **__):
165182
for method_name in methods:
166183
unwrap(cls, method_name)
167184

185+
for method_name in _PREPARED_STMT_METHODS:
186+
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method_name):
187+
unwrap(asyncpg.prepared_stmt.PreparedStatement, method_name)
188+
168189
async def _do_execute(self, func, instance, args, kwargs):
169190
exception = None
170191
params = getattr(instance, "_params", None)
@@ -243,3 +264,32 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
243264
if not stop:
244265
return result
245266
raise StopAsyncIteration
267+
268+
async def _do_prepared_execute(self, func, instance, args, kwargs):
269+
exception = None
270+
query = instance._query or ""
271+
272+
try:
273+
name = self._leading_comment_remover.sub("", query).split()[0]
274+
except IndexError:
275+
name = ""
276+
277+
span_attributes = _hydrate_span_from_args(
278+
instance._connection,
279+
query,
280+
args if self.capture_parameters else None,
281+
)
282+
283+
with self._tracer.start_as_current_span(
284+
name, kind=SpanKind.CLIENT, attributes=span_attributes
285+
) as span:
286+
try:
287+
result = await func(*args, **kwargs)
288+
except Exception as exc: # pylint: disable=W0703
289+
exception = exc
290+
raise
291+
finally:
292+
if span.is_recording() and exception is not None:
293+
span.set_status(Status(StatusCode.ERROR))
294+
295+
return result

instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from asyncpg import Connection, Record, cursor
6+
from asyncpg.prepared_stmt import PreparedStatement
67

78
try:
89
# wrapt 2.0.0+
@@ -11,11 +12,18 @@
1112
from wrapt import ObjectProxy as BaseObjectProxy
1213

1314
from opentelemetry import trace as trace_api
14-
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
15+
from opentelemetry.instrumentation.asyncpg import (
16+
_PREPARED_STMT_METHODS,
17+
AsyncPGInstrumentor,
18+
)
1519
from opentelemetry.test.test_base import TestBase
1620

1721

1822
class TestAsyncPGInstrumentation(TestBase):
23+
def tearDown(self):
24+
super().tearDown()
25+
AsyncPGInstrumentor().uninstrument()
26+
1927
def test_duplicated_instrumentation_can_be_uninstrumented(self):
2028
AsyncPGInstrumentor().instrument()
2129
AsyncPGInstrumentor().instrument()
@@ -144,3 +152,121 @@ async def exec_mock(*args, **kwargs):
144152

145153
spans = self.memory_exporter.get_finished_spans()
146154
self.assertEqual(len(spans), 0)
155+
156+
def test_prepared_statement_instrumentation(self):
157+
methods = [
158+
m for m in _PREPARED_STMT_METHODS if hasattr(PreparedStatement, m)
159+
]
160+
161+
for method_name in methods:
162+
with self.subTest(method=method_name, phase="before"):
163+
self.assertFalse(
164+
isinstance(
165+
getattr(PreparedStatement, method_name),
166+
BaseObjectProxy,
167+
)
168+
)
169+
170+
AsyncPGInstrumentor().instrument()
171+
172+
for method_name in methods:
173+
with self.subTest(method=method_name, phase="instrumented"):
174+
self.assertTrue(
175+
isinstance(
176+
getattr(PreparedStatement, method_name),
177+
BaseObjectProxy,
178+
)
179+
)
180+
181+
AsyncPGInstrumentor().uninstrument()
182+
183+
for method_name in methods:
184+
with self.subTest(method=method_name, phase="uninstrumented"):
185+
self.assertFalse(
186+
isinstance(
187+
getattr(PreparedStatement, method_name),
188+
BaseObjectProxy,
189+
)
190+
)
191+
192+
def _make_prepared_stmt_conn(self):
193+
async def bind_execute_mock(*args, **kwargs):
194+
return [], b"SELECT 1", True
195+
196+
async def bind_execute_many_mock(*args, **kwargs):
197+
return None
198+
199+
conn = mock.Mock()
200+
conn._pool_release_ctr = 0
201+
conn.is_closed = lambda: False
202+
conn._protocol = mock.Mock()
203+
conn._protocol.bind_execute = bind_execute_mock
204+
conn._protocol.bind_execute_many = bind_execute_many_mock
205+
206+
state = mock.Mock()
207+
state.closed = False
208+
return conn, state
209+
210+
def test_prepared_statement_span(self):
211+
# Per-method: (query, call_args, expected_span_name)
212+
method_cases = {
213+
"fetch": ("SELECT * FROM users", (), "SELECT"),
214+
"fetchval": ("SELECT id FROM users WHERE id=$1", (1,), "SELECT"),
215+
"fetchrow": ("SELECT * FROM t WHERE v=$1", ("x",), "SELECT"),
216+
"executemany": (
217+
"INSERT INTO t (v) VALUES ($1)",
218+
([("a",), ("b",)],),
219+
"INSERT",
220+
),
221+
"fetchmany": ("SELECT * FROM t", ([],), "SELECT"),
222+
}
223+
224+
for method_name in _PREPARED_STMT_METHODS:
225+
if not hasattr(PreparedStatement, method_name):
226+
continue
227+
query, call_args, expected_name = method_cases[method_name]
228+
with self.subTest(method=method_name):
229+
self.memory_exporter.clear()
230+
conn, state = self._make_prepared_stmt_conn()
231+
apg = AsyncPGInstrumentor()
232+
apg.instrument(tracer_provider=self.tracer_provider)
233+
234+
stmt = PreparedStatement(conn, query, state)
235+
asyncio.run(getattr(stmt, method_name)(*call_args))
236+
237+
spans = self.memory_exporter.get_finished_spans()
238+
self.assertEqual(len(spans), 1)
239+
self.assertEqual(spans[0].name, expected_name)
240+
self.assertTrue(spans[0].status.is_ok)
241+
self.assertEqual(
242+
spans[0].attributes.get("db.statement"), query
243+
)
244+
self.assertEqual(
245+
spans[0].attributes.get("db.system"), "postgresql"
246+
)
247+
248+
apg.uninstrument()
249+
250+
def test_prepared_statement_error_span(self):
251+
async def bind_execute_error(*args, **kwargs):
252+
raise RuntimeError("db error")
253+
254+
conn = mock.Mock()
255+
conn._pool_release_ctr = 0
256+
conn.is_closed = lambda: False
257+
conn._protocol = mock.Mock()
258+
conn._protocol.bind_execute = bind_execute_error
259+
260+
state = mock.Mock()
261+
state.closed = False
262+
263+
apg = AsyncPGInstrumentor()
264+
apg.instrument(tracer_provider=self.tracer_provider)
265+
266+
stmt = PreparedStatement(conn, "SELECT 1", state)
267+
with self.assertRaises(RuntimeError):
268+
asyncio.run(stmt.fetch())
269+
270+
spans = self.memory_exporter.get_finished_spans()
271+
self.assertEqual(len(spans), 1)
272+
self.assertFalse(spans[0].status.is_ok)

0 commit comments

Comments
 (0)