Skip to content

Commit 00dcd88

Browse files
misc fixes for psycopg instrumentation
1 parent ec4c189 commit 00dcd88

3 files changed

Lines changed: 28 additions & 14 deletions

File tree

drift/instrumentation/psycopg/instrumentation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,11 @@ def _traced_executemany(
483483
# RECORD mode: Execute real query and record span
484484
time.time()
485485
error = None
486+
# Convert to list BEFORE executing to avoid iterator exhaustion
487+
params_list = list(params_seq)
486488

487489
try:
488-
result = original_executemany(query, params_seq, **kwargs)
490+
result = original_executemany(query, params_list, **kwargs)
489491
return result
490492
except Exception as e:
491493
error = e
@@ -494,7 +496,6 @@ def _traced_executemany(
494496
# Always create span in RECORD mode (including pre-app-start queries)
495497
# Pre-app-start queries are marked with is_pre_app_start=true flag
496498
if sdk.mode == TuskDriftMode.RECORD:
497-
params_list = list(params_seq)
498499
self._finalize_query_span(
499500
span,
500501
cursor,

drift/instrumentation/psycopg2/instrumentation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,17 @@ def __init__(self, connection: Any, instrumentation: Psycopg2Instrumentation, sd
211211
object.__setattr__(self, "_connection", connection)
212212
object.__setattr__(self, "_instrumentation", instrumentation)
213213
object.__setattr__(self, "_sdk", sdk)
214+
# Preserve the connection's default cursor_factory (set at connect() time)
215+
object.__setattr__(self, "_default_cursor_factory", getattr(connection, "cursor_factory", None))
214216

215217
def cursor(self, name: str | None = None, cursor_factory: Any = None, *args: Any, **kwargs: Any) -> Any:
216218
"""Intercept cursor creation to wrap user-provided cursor_factory."""
217-
# Create instrumented cursor factory (wrapping user's factory if provided)
219+
# Use cursor_factory from cursor() call, or fall back to connection's default
220+
base_factory = cursor_factory if cursor_factory is not None else self._default_cursor_factory
221+
# Create instrumented cursor factory (wrapping the base factory)
218222
wrapped_factory = self._instrumentation._create_cursor_factory(
219223
self._sdk,
220-
cursor_factory, # This becomes the base class (None uses default)
224+
base_factory,
221225
)
222226
return self._connection.cursor(*args, name=name, cursor_factory=wrapped_factory, **kwargs)
223227

@@ -606,11 +610,12 @@ def _traced_executemany(
606610

607611
# For all other queries (pre-app-start OR within a request trace), get mock
608612
# Wrap in {"_batch": ...} to match the recording format
613+
# Normalize to list to match RECORD mode and avoid iterator/serialization issues
609614
is_pre_app_start = not sdk.app_ready
610615
mock_result = self._try_get_mock(
611616
sdk,
612617
query,
613-
{"_batch": params_list},
618+
{"_batch": list(params_list)},
614619
trace_id,
615620
span_id,
616621
parent_span_id,
@@ -642,9 +647,11 @@ def _traced_executemany(
642647

643648
# RECORD mode: Execute real query and record span
644649
error = None
650+
# Convert to list BEFORE executing to avoid iterator exhaustion
651+
params_as_list = list(params_list)
645652

646653
try:
647-
result = original_executemany(query, params_list)
654+
result = original_executemany(query, params_as_list)
648655
return result
649656
except Exception as e:
650657
error = e
@@ -656,7 +663,7 @@ def _traced_executemany(
656663
span,
657664
cursor,
658665
query,
659-
{"_batch": list(params_list)},
666+
{"_batch": params_as_list},
660667
error,
661668
)
662669
finally:

drift/instrumentation/utils/psycopg_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,26 @@ def deserialize_db_value(val: Any) -> Any:
1313
During replay, we need to convert them back to datetime objects so that Flask/Django
1414
serializes them the same way (e.g., RFC 2822 vs ISO 8601 format).
1515
16+
Only parses strings that contain a time component (T or space separator with :) to avoid
17+
incorrectly converting date-only strings or text that happens to look like dates.
18+
1619
Args:
1720
val: A value from the mocked database rows. Can be a string, list, dict, or any other type.
1821
1922
Returns:
2023
The value with ISO datetime strings converted back to datetime objects.
2124
"""
2225
if isinstance(val, str):
23-
# Try to parse as ISO datetime
24-
try:
25-
# Handle Z suffix for UTC
26-
parsed = dt.datetime.fromisoformat(val.replace("Z", "+00:00"))
27-
return parsed
28-
except ValueError:
29-
pass
26+
# Only parse strings that look like full datetime (must have time component)
27+
# This avoids converting date-only strings like "2024-01-15" or text columns
28+
# that happen to match date patterns
29+
if ("T" in val or (" " in val and ":" in val)) and "-" in val:
30+
try:
31+
# Handle Z suffix for UTC
32+
parsed = dt.datetime.fromisoformat(val.replace("Z", "+00:00"))
33+
return parsed
34+
except ValueError:
35+
pass
3036
elif isinstance(val, list):
3137
return [deserialize_db_value(v) for v in val]
3238
elif isinstance(val, dict):

0 commit comments

Comments
 (0)