Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#4049](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4049))
- `opentelemetry-instrumentation-sqlalchemy`: implement new semantic convention opt-in migration
([#4110](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4110))
- `opentelemetry-instrumentation-asyncpg`: instrument prepared statements
([#4529](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4529))

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def main():
from typing import Collection

import asyncpg
import asyncpg.prepared_stmt
import wrapt

from opentelemetry import trace
Expand All @@ -76,6 +77,14 @@ async def main():
from opentelemetry.trace import SpanKind
from opentelemetry.trace.status import Status, StatusCode

_PREPARED_STMT_METHODS = (
"fetch",
"fetchval",
"fetchrow",
"executemany",
"fetchmany",
)


def _hydrate_span_from_args(connection, query, parameters) -> dict:
"""Get network and database attributes from connection."""
Expand Down Expand Up @@ -153,6 +162,14 @@ def _instrument(self, **kwargs):
"asyncpg.cursor", method, self._do_cursor_execute
)

for method in _PREPARED_STMT_METHODS:
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method):
wrapt.wrap_function_wrapper(
"asyncpg.prepared_stmt",
f"PreparedStatement.{method}",
self._do_prepared_execute,
)

def _uninstrument(self, **__):
for cls, methods in [
(
Expand All @@ -165,6 +182,10 @@ def _uninstrument(self, **__):
for method_name in methods:
unwrap(cls, method_name)

for method_name in _PREPARED_STMT_METHODS:
if hasattr(asyncpg.prepared_stmt.PreparedStatement, method_name):
unwrap(asyncpg.prepared_stmt.PreparedStatement, method_name)

async def _do_execute(self, func, instance, args, kwargs):
exception = None
params = getattr(instance, "_params", None)
Expand Down Expand Up @@ -243,3 +264,32 @@ async def _do_cursor_execute(self, func, instance, args, kwargs):
if not stop:
return result
raise StopAsyncIteration

async def _do_prepared_execute(self, func, instance, args, kwargs):
exception = None
query = instance._query or ""

try:
name = self._leading_comment_remover.sub("", query).split()[0]
except IndexError:
name = ""

span_attributes = _hydrate_span_from_args(
instance._connection,
query,
args if self.capture_parameters else None,
)

with self._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
try:
result = await func(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
raise
finally:
if span.is_recording() and exception is not None:
span.set_status(Status(StatusCode.ERROR))

return result
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from asyncpg import Connection, Record, cursor
from asyncpg.prepared_stmt import PreparedStatement

try:
# wrapt 2.0.0+
Expand All @@ -11,11 +12,18 @@
from wrapt import ObjectProxy as BaseObjectProxy

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
from opentelemetry.instrumentation.asyncpg import (
_PREPARED_STMT_METHODS,
AsyncPGInstrumentor,
)
from opentelemetry.test.test_base import TestBase


class TestAsyncPGInstrumentation(TestBase):
def tearDown(self):
super().tearDown()
AsyncPGInstrumentor().uninstrument()

def test_duplicated_instrumentation_can_be_uninstrumented(self):
AsyncPGInstrumentor().instrument()
AsyncPGInstrumentor().instrument()
Expand Down Expand Up @@ -144,3 +152,122 @@ async def exec_mock(*args, **kwargs):

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)

def test_prepared_statement_instrumentation(self):
methods = [
m for m in _PREPARED_STMT_METHODS if hasattr(PreparedStatement, m)
]

for method_name in methods:
with self.subTest(method=method_name, phase="before"):
self.assertFalse(
isinstance(
getattr(PreparedStatement, method_name),
BaseObjectProxy,
)
)

AsyncPGInstrumentor().instrument()

for method_name in methods:
with self.subTest(method=method_name, phase="instrumented"):
self.assertTrue(
isinstance(
getattr(PreparedStatement, method_name),
BaseObjectProxy,
)
)

AsyncPGInstrumentor().uninstrument()

for method_name in methods:
with self.subTest(method=method_name, phase="uninstrumented"):
self.assertFalse(
isinstance(
getattr(PreparedStatement, method_name),
BaseObjectProxy,
)
)

@staticmethod
def _make_prepared_stmt_conn():
async def bind_execute_mock(*args, **kwargs):
return [], b"SELECT 1", True

async def bind_execute_many_mock(*args, **kwargs):
return None

conn = mock.Mock()
conn._pool_release_ctr = 0
conn.is_closed = lambda: False
conn._protocol = mock.Mock()
conn._protocol.bind_execute = bind_execute_mock
conn._protocol.bind_execute_many = bind_execute_many_mock

state = mock.Mock()
state.closed = False
return conn, state

def test_prepared_statement_span(self):
# Per-method: (query, call_args, expected_span_name)
method_cases = {
"fetch": ("SELECT * FROM users", (), "SELECT"),
"fetchval": ("SELECT id FROM users WHERE id=$1", (1,), "SELECT"),
"fetchrow": ("SELECT * FROM t WHERE v=$1", ("x",), "SELECT"),
"executemany": (
"INSERT INTO t (v) VALUES ($1)",
([("a",), ("b",)],),
"INSERT",
),
"fetchmany": ("SELECT * FROM t", ([],), "SELECT"),
}

for method_name in _PREPARED_STMT_METHODS:
if not hasattr(PreparedStatement, method_name):
continue
query, call_args, expected_name = method_cases[method_name]
with self.subTest(method=method_name):
self.memory_exporter.clear()
conn, state = self._make_prepared_stmt_conn()
apg = AsyncPGInstrumentor()
apg.instrument(tracer_provider=self.tracer_provider)

stmt = PreparedStatement(conn, query, state)
asyncio.run(getattr(stmt, method_name)(*call_args))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, expected_name)
self.assertTrue(spans[0].status.is_ok)
self.assertEqual(
spans[0].attributes.get("db.statement"), query
)
self.assertEqual(
spans[0].attributes.get("db.system"), "postgresql"
)

apg.uninstrument()

def test_prepared_statement_error_span(self):
async def bind_execute_error(*args, **kwargs):
raise RuntimeError("db error")

conn = mock.Mock()
conn._pool_release_ctr = 0
conn.is_closed = lambda: False
conn._protocol = mock.Mock()
conn._protocol.bind_execute = bind_execute_error

state = mock.Mock()
state.closed = False

apg = AsyncPGInstrumentor()
apg.instrument(tracer_provider=self.tracer_provider)

stmt = PreparedStatement(conn, "SELECT 1", state)
with self.assertRaises(RuntimeError):
asyncio.run(stmt.fetch())

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertFalse(spans[0].status.is_ok)