Skip to content

Commit 70e6fe5

Browse files
fix: fix failing e2e tests (#17)
* fix some e2e tests * fix psycopg2 * remove comments * fix lint
1 parent 47c93a6 commit 70e6fe5

9 files changed

Lines changed: 156 additions & 21 deletions

File tree

drift/instrumentation/django/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ services:
1111
- PYTHONUNBUFFERED=1
1212
- DJANGO_SETTINGS_MODULE=settings
1313
working_dir: /app
14-
14+
volumes:
15+
# Mount app source for development
16+
- ./src:/app/src
17+
# Mount .tusk folder to persist traces
18+
- ./.tusk:/app/.tusk

drift/instrumentation/fastapi/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@ services:
1010
- TUSK_ANALYTICS_DISABLED=1
1111
- PYTHONUNBUFFERED=1
1212
working_dir: /app
13+
volumes:
14+
# Mount app source for development
15+
- ./src:/app/src
16+
# Mount .tusk folder to persist traces
17+
- ./.tusk:/app/.tusk
1318

drift/instrumentation/flask/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ services:
1010
- TUSK_ANALYTICS_DISABLED=1
1111
- PYTHONUNBUFFERED=1
1212
working_dir: /app
13+
volumes:
14+
# Mount app source for development
15+
- ./src:/app/src
16+
# Mount .tusk folder to persist traces
17+
- ./.tusk:/app/.tusk

drift/instrumentation/psycopg/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ services:
3030
- TUSK_ANALYTICS_DISABLED=1
3131
- PYTHONUNBUFFERED=1
3232
working_dir: /app
33+
volumes:
34+
# Mount app source for development
35+
- ./src:/app/src
36+
# Mount .tusk folder to persist traces
37+
- ./.tusk:/app/.tusk

drift/instrumentation/psycopg/instrumentation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
replay_trace_id_context,
2929
)
3030
from ..base import InstrumentationBase
31+
from ..utils.psycopg_utils import deserialize_db_value
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -453,11 +454,12 @@ def _traced_executemany(
453454

454455
# For all other queries (pre-app-start OR within a request trace), get mock
455456
# Convert params_seq to list for serialization
457+
# Wrap in {"_batch": ...} to match the recording format
456458
params_list = list(params_seq)
457459
mock_result = self._try_get_mock(
458460
sdk,
459461
query_str,
460-
params_list,
462+
{"_batch": params_list},
461463
trace_id,
462464
span_id,
463465
parent_span_id,
@@ -625,6 +627,8 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
625627
pass
626628

627629
mock_rows = actual_data.get("rows", [])
630+
# Deserialize datetime strings back to datetime objects for consistent Flask serialization
631+
mock_rows = [deserialize_db_value(row) for row in mock_rows]
628632
cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
629633
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]
630634

drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ services:
3030
- TUSK_ANALYTICS_DISABLED=1
3131
- PYTHONUNBUFFERED=1
3232
working_dir: /app
33+
volumes:
34+
# Mount app source for development
35+
- ./src:/app/src
36+
# Mount .tusk folder to persist traces
37+
- ./.tusk:/app/.tusk

drift/instrumentation/psycopg2/instrumentation.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
replay_trace_id_context,
3737
)
3838
from ..base import InstrumentationBase
39+
from ..utils.psycopg_utils import deserialize_db_value
3940

4041
logger = logging.getLogger(__name__)
4142

@@ -198,6 +199,44 @@ def __exit__(self, exc_type, exc_val, exc_tb):
198199
return False
199200

200201

202+
class InstrumentedConnection:
203+
"""Wraps a real psycopg2 connection to intercept cursor() calls.
204+
205+
This ensures that even when users pass cursor_factory to cursor() instead of
206+
connect(), the cursor is still instrumented for tracing.
207+
"""
208+
209+
def __init__(self, connection: Any, instrumentation: Psycopg2Instrumentation, sdk: TuskDrift) -> None:
210+
# Use object.__setattr__ to avoid triggering __getattr__
211+
object.__setattr__(self, "_connection", connection)
212+
object.__setattr__(self, "_instrumentation", instrumentation)
213+
object.__setattr__(self, "_sdk", sdk)
214+
215+
def cursor(self, name: str | None = None, cursor_factory: Any = None, *args: Any, **kwargs: Any) -> Any:
216+
"""Intercept cursor creation to wrap user-provided cursor_factory."""
217+
# Create instrumented cursor factory (wrapping user's factory if provided)
218+
wrapped_factory = self._instrumentation._create_cursor_factory(
219+
self._sdk,
220+
cursor_factory, # This becomes the base class (None uses default)
221+
)
222+
return self._connection.cursor(*args, name=name, cursor_factory=wrapped_factory, **kwargs)
223+
224+
def __getattr__(self, name: str) -> Any:
225+
"""Proxy all other methods/attributes to the real connection."""
226+
return getattr(self._connection, name)
227+
228+
def __setattr__(self, name: str, value: Any) -> None:
229+
"""Proxy attribute setting to the real connection."""
230+
setattr(self._connection, name, value)
231+
232+
def __enter__(self) -> InstrumentedConnection:
233+
self._connection.__enter__()
234+
return self
235+
236+
def __exit__(self, *args: Any) -> Any:
237+
return self._connection.__exit__(*args)
238+
239+
201240
def _query_to_str(query: QueryType) -> str:
202241
"""Convert a query (str, bytes, or Composable) to a string."""
203242
if isinstance(query, str):
@@ -287,35 +326,28 @@ def patched_connect(*args, **kwargs):
287326
logger.debug("[PATCHED_CONNECT] SDK disabled, passing through")
288327
return original_connect(*args, **kwargs)
289328

290-
# Use cursor_factory to wrap cursors
291-
# Save any user-provided cursor_factory
292-
user_cursor_factory = kwargs.pop("cursor_factory", None)
293-
294-
# Create our instrumented cursor factory
295-
cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory)
296-
297329
# In REPLAY mode, try to connect but fall back to mock connection if DB is unavailable
298330
if sdk.mode == TuskDriftMode.REPLAY:
299331
try:
300-
kwargs["cursor_factory"] = cursor_factory
301332
logger.debug("[PATCHED_CONNECT] REPLAY mode: Attempting real DB connection...")
302333
connection = original_connect(*args, **kwargs)
303334
logger.info("[PATCHED_CONNECT] REPLAY mode: Successfully connected to real database")
304-
return connection
335+
# Wrap connection to intercept cursor() calls
336+
return InstrumentedConnection(connection, instrumentation, sdk)
305337
except Exception as e:
306338
logger.info(
307339
f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({e}), using mock connection"
308340
)
309341
# Return mock connection that doesn't require a real database
310-
return MockConnection(sdk, instrumentation, cursor_factory)
342+
# MockConnection already handles cursor_factory correctly in its cursor() method
343+
return MockConnection(sdk, instrumentation, None)
311344

312345
# In RECORD mode, always require real connection
313-
kwargs["cursor_factory"] = cursor_factory
314346
logger.debug("[PATCHED_CONNECT] RECORD mode: Connecting to database...")
315347
connection = original_connect(*args, **kwargs)
316348
logger.info("[PATCHED_CONNECT] RECORD mode: Connected to database successfully")
317-
318-
return connection
349+
# Wrap connection to intercept cursor() calls
350+
return InstrumentedConnection(connection, instrumentation, sdk)
319351

320352
# Apply patch
321353
module.connect = patched_connect # type: ignore[attr-defined]
@@ -573,11 +605,12 @@ def _traced_executemany(
573605
return None
574606

575607
# For all other queries (pre-app-start OR within a request trace), get mock
608+
# Wrap in {"_batch": ...} to match the recording format
576609
is_pre_app_start = not sdk.app_ready
577610
mock_result = self._try_get_mock(
578611
sdk,
579612
query,
580-
params_list,
613+
{"_batch": params_list},
581614
trace_id,
582615
span_id,
583616
parent_span_id,
@@ -803,6 +836,24 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
803836

804837
# Store mock rows for fetching
805838
mock_rows = actual_data.get("rows", [])
839+
# Deserialize datetime strings back to datetime objects for consistent Flask/Django serialization
840+
mock_rows = [deserialize_db_value(row) for row in mock_rows]
841+
842+
# Check if this is a dict-cursor (like RealDictCursor) by checking if cursor class
843+
# inherits from a dict-returning cursor type
844+
is_dict_cursor = False
845+
try:
846+
import psycopg2.extras
847+
848+
is_dict_cursor = isinstance(cursor, (psycopg2.extras.RealDictCursor, psycopg2.extras.DictCursor))
849+
except (ImportError, AttributeError):
850+
pass
851+
852+
# If it's a dict cursor and we have description, convert rows to dicts
853+
if is_dict_cursor and description_data:
854+
column_names = [col["name"] for col in description_data]
855+
mock_rows = [dict(zip(column_names, row, strict=True)) for row in mock_rows]
856+
806857
cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
807858
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]
808859

@@ -815,7 +866,9 @@ def mock_fetchone():
815866
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
816867
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
817868
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
818-
# Convert list to tuple to match psycopg2 behavior
869+
# Return as-is for dict cursors, convert to tuple for regular cursors
870+
if isinstance(row, dict):
871+
return row
819872
return tuple(row) if isinstance(row, list) else row
820873
return None
821874

@@ -832,8 +885,15 @@ def mock_fetchall():
832885
logger.debug(f"[MOCK] fetchall called, returning {len(cursor._mock_rows[cursor._mock_index :])} rows") # pyright: ignore[reportAttributeAccessIssue]
833886
rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue]
834887
cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue]
835-
# Convert lists to tuples to match psycopg2 behavior
836-
result = [tuple(row) if isinstance(row, list) else row for row in rows]
888+
# Return as-is for dict rows, convert lists to tuples for regular cursors
889+
result = []
890+
for row in rows:
891+
if isinstance(row, dict):
892+
result.append(row)
893+
elif isinstance(row, list):
894+
result.append(tuple(row))
895+
else:
896+
result.append(row)
837897
logger.debug(f"[MOCK] fetchall returning: {result}")
838898
return result
839899

@@ -906,8 +966,16 @@ def serialize_value(val):
906966
# We need to capture these for replay mode
907967
try:
908968
all_rows = cursor.fetchall()
909-
# Convert tuples to lists for JSON serialization
910-
rows = [list(row) for row in all_rows]
969+
# Convert rows to lists for JSON serialization
970+
# Handle both tuple rows (regular cursor) and dict rows (RealDictCursor)
971+
rows = []
972+
for row in all_rows:
973+
if isinstance(row, dict):
974+
# RealDictCursor returns dict-like rows - extract values in column order
975+
rows.append([row[desc[0]] for desc in cursor.description])
976+
else:
977+
# Regular cursor returns tuples
978+
rows.append(list(row))
911979

912980
# CRITICAL: Re-populate cursor so user code can still fetch
913981
# We'll store the rows and patch fetch methods

drift/instrumentation/redis/e2e-tests/docker-compose.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,8 @@ services:
2323
- TUSK_ANALYTICS_DISABLED=1
2424
- PYTHONUNBUFFERED=1
2525
working_dir: /app
26+
volumes:
27+
# Mount app source for development
28+
- ./src:/app/src
29+
# Mount .tusk folder to persist traces
30+
- ./.tusk:/app/.tusk
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Shared utilities for psycopg, psycopg2"""
2+
3+
from __future__ import annotations
4+
5+
import datetime as dt
6+
from typing import Any
7+
8+
9+
def deserialize_db_value(val: Any) -> Any:
10+
"""Convert ISO datetime strings back to datetime objects for consistent serialization.
11+
12+
During recording, datetime objects from the database are serialized to ISO format strings.
13+
During replay, we need to convert them back to datetime objects so that Flask/Django
14+
serializes them the same way (e.g., RFC 2822 vs ISO 8601 format).
15+
16+
Args:
17+
val: A value from the mocked database rows. Can be a string, list, dict, or any other type.
18+
19+
Returns:
20+
The value with ISO datetime strings converted back to datetime objects.
21+
"""
22+
if isinstance(val, str):
23+
# Try to parse as ISO datetime
24+
try:
25+
# Handle Z suffix for UTC
26+
parsed = dt.datetime.fromisoformat(val.replace("Z", "+00:00"))
27+
return parsed
28+
except ValueError:
29+
pass
30+
elif isinstance(val, list):
31+
return [deserialize_db_value(v) for v in val]
32+
elif isinstance(val, dict):
33+
return {k: deserialize_db_value(v) for k, v in val.items()}
34+
return val

0 commit comments

Comments
 (0)