Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion drift/instrumentation/django/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ services:
- PYTHONUNBUFFERED=1
- DJANGO_SETTINGS_MODULE=settings
working_dir: /app

volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk
5 changes: 5 additions & 0 deletions drift/instrumentation/fastapi/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ services:
- TUSK_ANALYTICS_DISABLED=1
- PYTHONUNBUFFERED=1
working_dir: /app
volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk

5 changes: 5 additions & 0 deletions drift/instrumentation/flask/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ services:
- TUSK_ANALYTICS_DISABLED=1
- PYTHONUNBUFFERED=1
working_dir: /app
volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk
5 changes: 5 additions & 0 deletions drift/instrumentation/psycopg/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ services:
- TUSK_ANALYTICS_DISABLED=1
- PYTHONUNBUFFERED=1
working_dir: /app
volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk
6 changes: 5 additions & 1 deletion drift/instrumentation/psycopg/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
replay_trace_id_context,
)
from ..base import InstrumentationBase
from ..utils.psycopg_utils import deserialize_db_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -453,11 +454,12 @@ def _traced_executemany(

# For all other queries (pre-app-start OR within a request trace), get mock
# Convert params_seq to list for serialization
# Wrap in {"_batch": ...} to match the recording format
params_list = list(params_seq)
mock_result = self._try_get_mock(
sdk,
query_str,
params_list,
{"_batch": params_list},
trace_id,
span_id,
parent_span_id,
Expand Down Expand Up @@ -625,6 +627,8 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
pass

mock_rows = actual_data.get("rows", [])
# Deserialize datetime strings back to datetime objects for consistent Flask serialization
mock_rows = [deserialize_db_value(row) for row in mock_rows]
cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]

Expand Down
5 changes: 5 additions & 0 deletions drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ services:
- TUSK_ANALYTICS_DISABLED=1
- PYTHONUNBUFFERED=1
working_dir: /app
volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk
106 changes: 87 additions & 19 deletions drift/instrumentation/psycopg2/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
replay_trace_id_context,
)
from ..base import InstrumentationBase
from ..utils.psycopg_utils import deserialize_db_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -198,6 +199,44 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False


class InstrumentedConnection:
"""Wraps a real psycopg2 connection to intercept cursor() calls.

This ensures that even when users pass cursor_factory to cursor() instead of
connect(), the cursor is still instrumented for tracing.
"""

def __init__(self, connection: Any, instrumentation: Psycopg2Instrumentation, sdk: TuskDrift) -> None:
# Use object.__setattr__ to avoid triggering __getattr__
object.__setattr__(self, "_connection", connection)
object.__setattr__(self, "_instrumentation", instrumentation)
object.__setattr__(self, "_sdk", sdk)

def cursor(self, name: str | None = None, cursor_factory: Any = None, *args: Any, **kwargs: Any) -> Any:
"""Intercept cursor creation to wrap user-provided cursor_factory."""
# Create instrumented cursor factory (wrapping user's factory if provided)
wrapped_factory = self._instrumentation._create_cursor_factory(
self._sdk,
cursor_factory, # This becomes the base class (None uses default)
)
return self._connection.cursor(*args, name=name, cursor_factory=wrapped_factory, **kwargs)

def __getattr__(self, name: str) -> Any:
"""Proxy all other methods/attributes to the real connection."""
return getattr(self._connection, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Proxy attribute setting to the real connection."""
setattr(self._connection, name, value)

def __enter__(self) -> InstrumentedConnection:
self._connection.__enter__()
return self

def __exit__(self, *args: Any) -> Any:
return self._connection.__exit__(*args)


def _query_to_str(query: QueryType) -> str:
"""Convert a query (str, bytes, or Composable) to a string."""
if isinstance(query, str):
Expand Down Expand Up @@ -287,35 +326,28 @@ def patched_connect(*args, **kwargs):
logger.debug("[PATCHED_CONNECT] SDK disabled, passing through")
return original_connect(*args, **kwargs)

# Use cursor_factory to wrap cursors
# Save any user-provided cursor_factory
user_cursor_factory = kwargs.pop("cursor_factory", None)

# Create our instrumented cursor factory
cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory)

# In REPLAY mode, try to connect but fall back to mock connection if DB is unavailable
if sdk.mode == TuskDriftMode.REPLAY:
try:
kwargs["cursor_factory"] = cursor_factory
logger.debug("[PATCHED_CONNECT] REPLAY mode: Attempting real DB connection...")
connection = original_connect(*args, **kwargs)
logger.info("[PATCHED_CONNECT] REPLAY mode: Successfully connected to real database")
return connection
# Wrap connection to intercept cursor() calls
return InstrumentedConnection(connection, instrumentation, sdk)
except Exception as e:
logger.info(
f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({e}), using mock connection"
)
# Return mock connection that doesn't require a real database
return MockConnection(sdk, instrumentation, cursor_factory)
# MockConnection already handles cursor_factory correctly in its cursor() method
return MockConnection(sdk, instrumentation, None)

# In RECORD mode, always require real connection
kwargs["cursor_factory"] = cursor_factory
logger.debug("[PATCHED_CONNECT] RECORD mode: Connecting to database...")
connection = original_connect(*args, **kwargs)
logger.info("[PATCHED_CONNECT] RECORD mode: Connected to database successfully")

return connection
# Wrap connection to intercept cursor() calls
return InstrumentedConnection(connection, instrumentation, sdk)

# Apply patch
module.connect = patched_connect # type: ignore[attr-defined]
Expand Down Expand Up @@ -573,11 +605,12 @@ def _traced_executemany(
return None

# For all other queries (pre-app-start OR within a request trace), get mock
# Wrap in {"_batch": ...} to match the recording format
is_pre_app_start = not sdk.app_ready
mock_result = self._try_get_mock(
sdk,
query,
params_list,
{"_batch": params_list},
Comment thread
sohankshirsagar marked this conversation as resolved.
trace_id,
span_id,
parent_span_id,
Expand Down Expand Up @@ -803,6 +836,24 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non

# Store mock rows for fetching
mock_rows = actual_data.get("rows", [])
# Deserialize datetime strings back to datetime objects for consistent Flask/Django serialization
mock_rows = [deserialize_db_value(row) for row in mock_rows]
Comment thread
sohankshirsagar marked this conversation as resolved.

# Check if this is a dict-cursor (like RealDictCursor) by checking if cursor class
# inherits from a dict-returning cursor type
is_dict_cursor = False
try:
import psycopg2.extras

is_dict_cursor = isinstance(cursor, (psycopg2.extras.RealDictCursor, psycopg2.extras.DictCursor))
except (ImportError, AttributeError):
pass

# If it's a dict cursor and we have description, convert rows to dicts
if is_dict_cursor and description_data:
column_names = [col["name"] for col in description_data]
mock_rows = [dict(zip(column_names, row, strict=True)) for row in mock_rows]

cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]

Expand All @@ -815,7 +866,9 @@ def mock_fetchone():
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
# Convert list to tuple to match psycopg2 behavior
# Return as-is for dict cursors, convert to tuple for regular cursors
if isinstance(row, dict):
return row
return tuple(row) if isinstance(row, list) else row
return None

Expand All @@ -832,8 +885,15 @@ def mock_fetchall():
logger.debug(f"[MOCK] fetchall called, returning {len(cursor._mock_rows[cursor._mock_index :])} rows") # pyright: ignore[reportAttributeAccessIssue]
rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue]
cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue]
# Convert lists to tuples to match psycopg2 behavior
result = [tuple(row) if isinstance(row, list) else row for row in rows]
# Return as-is for dict rows, convert lists to tuples for regular cursors
result = []
for row in rows:
if isinstance(row, dict):
result.append(row)
elif isinstance(row, list):
result.append(tuple(row))
else:
result.append(row)
logger.debug(f"[MOCK] fetchall returning: {result}")
return result

Expand Down Expand Up @@ -906,8 +966,16 @@ def serialize_value(val):
# We need to capture these for replay mode
try:
all_rows = cursor.fetchall()
# Convert tuples to lists for JSON serialization
rows = [list(row) for row in all_rows]
# Convert rows to lists for JSON serialization
# Handle both tuple rows (regular cursor) and dict rows (RealDictCursor)
rows = []
for row in all_rows:
if isinstance(row, dict):
# RealDictCursor returns dict-like rows - extract values in column order
rows.append([row[desc[0]] for desc in cursor.description])
else:
# Regular cursor returns tuples
rows.append(list(row))

# CRITICAL: Re-populate cursor so user code can still fetch
# We'll store the rows and patch fetch methods
Expand Down
5 changes: 5 additions & 0 deletions drift/instrumentation/redis/e2e-tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ services:
- TUSK_ANALYTICS_DISABLED=1
- PYTHONUNBUFFERED=1
working_dir: /app
volumes:
# Mount app source for development
- ./src:/app/src
# Mount .tusk folder to persist traces
- ./.tusk:/app/.tusk
34 changes: 34 additions & 0 deletions drift/instrumentation/utils/psycopg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Shared utilities for psycopg, psycopg2"""

from __future__ import annotations

import datetime as dt
from typing import Any


def deserialize_db_value(val: Any) -> Any:
"""Convert ISO datetime strings back to datetime objects for consistent serialization.

During recording, datetime objects from the database are serialized to ISO format strings.
During replay, we need to convert them back to datetime objects so that Flask/Django
serializes them the same way (e.g., RFC 2822 vs ISO 8601 format).

Args:
val: A value from the mocked database rows. Can be a string, list, dict, or any other type.

Returns:
The value with ISO datetime strings converted back to datetime objects.
"""
if isinstance(val, str):
# Try to parse as ISO datetime
try:
# Handle Z suffix for UTC
parsed = dt.datetime.fromisoformat(val.replace("Z", "+00:00"))
return parsed
except ValueError:
pass
elif isinstance(val, list):
return [deserialize_db_value(v) for v in val]
elif isinstance(val, dict):
return {k: deserialize_db_value(v) for k, v in val.items()}
return val