Skip to content

Commit da7329d

Browse files
fix cursor iteration
1 parent 8723755 commit da7329d

3 files changed

Lines changed: 155 additions & 15 deletions

File tree

drift/instrumentation/psycopg/e2e-tests/src/app.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,6 @@ def test_pipeline_mode():
250250
except Exception as e:
251251
return jsonify({"error": str(e)}), 500
252252

253-
254-
# ==========================================
255-
# Bug Hunt Test Endpoints
256-
# ==========================================
257-
258253
@app.route("/test/dict-row-factory")
259254
def test_dict_row_factory():
260255
"""Test dict_row row factory.
@@ -300,7 +295,6 @@ def test_namedtuple_row_factory():
300295
except Exception as e:
301296
return jsonify({"error": str(e)}), 500
302297

303-
304298
@app.route("/test/cursor-iteration")
305299
def test_cursor_iteration():
306300
"""Test direct cursor iteration (for row in cursor).

drift/instrumentation/psycopg/e2e-tests/src/test_requests.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,10 @@ def make_request(method, endpoint, **kwargs):
7474

7575
make_request("GET", "/test/pipeline-mode")
7676

77-
# BUG 2: Dict row factory - rows returned as column names
7877
make_request("GET", "/test/dict-row-factory")
7978

80-
# BUG 3: Namedtuple row factory - rows returned as plain tuples
8179
make_request("GET", "/test/namedtuple-row-factory")
8280

83-
# BUG 4: Cursor iteration - "no result available" in replay mode
8481
make_request("GET", "/test/cursor-iteration")
8582

8683
print("\nAll requests completed successfully")

drift/instrumentation/psycopg/instrumentation.py

Lines changed: 155 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,11 @@ class MockConnection:
7676
All queries are mocked at the cursor.execute() level.
7777
"""
7878

79-
def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory):
79+
def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory, row_factory=None):
8080
self.sdk = sdk
8181
self.instrumentation = instrumentation
8282
self.cursor_factory = cursor_factory
83+
self.row_factory = row_factory # Store row_factory for cursor creation
8384
self.closed = False
8485
self.autocommit = False
8586

@@ -233,6 +234,18 @@ def stream(self, query, params=None, **kwargs):
233234
"""Will be replaced by instrumentation."""
234235
return iter([])
235236

237+
def __iter__(self):
238+
"""Support direct cursor iteration (for row in cursor)."""
239+
return self
240+
241+
def __next__(self):
242+
"""Return next row for iteration."""
243+
if self._mock_index < len(self._mock_rows):
244+
row = self._mock_rows[self._mock_index]
245+
self._mock_index += 1
246+
return tuple(row) if isinstance(row, list) else row
247+
raise StopIteration
248+
236249
def close(self):
237250
pass
238251

@@ -439,6 +452,7 @@ def patched_connect(*args, **kwargs):
439452
return original_connect(*args, **kwargs)
440453

441454
user_cursor_factory = kwargs.pop("cursor_factory", None)
455+
user_row_factory = kwargs.pop("row_factory", None)
442456
cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory)
443457

444458
# Create server cursor factory for named cursors (conn.cursor(name="..."))
@@ -448,6 +462,8 @@ def patched_connect(*args, **kwargs):
448462
if sdk.mode == TuskDriftMode.REPLAY:
449463
try:
450464
kwargs["cursor_factory"] = cursor_factory
465+
if user_row_factory is not None:
466+
kwargs["row_factory"] = user_row_factory
451467
connection = original_connect(*args, **kwargs)
452468
# Set server cursor factory on the connection for named cursors
453469
if server_cursor_factory:
@@ -459,10 +475,12 @@ def patched_connect(*args, **kwargs):
459475
f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({e}), using mock connection (psycopg3)"
460476
)
461477
# Return mock connection that doesn't require a real database
462-
return MockConnection(sdk, instrumentation, cursor_factory)
478+
return MockConnection(sdk, instrumentation, cursor_factory, row_factory=user_row_factory)
463479

464480
# In RECORD mode, always require real connection
465481
kwargs["cursor_factory"] = cursor_factory
482+
if user_row_factory is not None:
483+
kwargs["row_factory"] = user_row_factory
466484
connection = original_connect(*args, **kwargs)
467485
# Set server cursor factory on the connection for named cursors
468486
if server_cursor_factory:
@@ -558,6 +576,38 @@ def stream(self, query, params=None, **kwargs):
558576
def copy(self, query, params=None, **kwargs):
559577
return instrumentation._traced_copy(self, super().copy, sdk, query, params, **kwargs)
560578

579+
def __iter__(self):
580+
# Support direct cursor iteration (for row in cursor)
581+
# In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows)
582+
if hasattr(self, '_mock_rows') and self._mock_rows is not None:
583+
return self
584+
if hasattr(self, '_tusk_rows') and self._tusk_rows is not None:
585+
return self
586+
return super().__iter__()
587+
588+
def __next__(self):
589+
# In replay mode with mock data, iterate over mock rows
590+
if hasattr(self, '_mock_rows') and self._mock_rows is not None:
591+
if self._mock_index < len(self._mock_rows):
592+
row = self._mock_rows[self._mock_index]
593+
self._mock_index += 1
594+
# Apply row transformation if fetchone is patched
595+
if hasattr(self, 'fetchone') and callable(self.fetchone):
596+
# Reset index, get transformed row, restore index
597+
self._mock_index -= 1
598+
result = self.fetchone()
599+
return result
600+
return tuple(row) if isinstance(row, list) else row
601+
raise StopIteration
602+
# In record mode with captured data, iterate over stored rows
603+
if hasattr(self, '_tusk_rows') and self._tusk_rows is not None:
604+
if self._tusk_index < len(self._tusk_rows):
605+
row = self._tusk_rows[self._tusk_index]
606+
self._tusk_index += 1
607+
return row
608+
raise StopIteration
609+
return super().__next__()
610+
561611
return InstrumentedCursor
562612

563613
def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None):
@@ -594,6 +644,38 @@ def execute(self, query, params=None, **kwargs):
594644
# Note: ServerCursor doesn't support executemany()
595645
# Note: ServerCursor has stream-like iteration via fetchmany/itersize
596646

647+
def __iter__(self):
648+
# Support direct cursor iteration (for row in cursor)
649+
# In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows)
650+
if hasattr(self, '_mock_rows') and self._mock_rows is not None:
651+
return self
652+
if hasattr(self, '_tusk_rows') and self._tusk_rows is not None:
653+
return self
654+
return super().__iter__()
655+
656+
def __next__(self):
657+
# In replay mode with mock data, iterate over mock rows
658+
if hasattr(self, '_mock_rows') and self._mock_rows is not None:
659+
if self._mock_index < len(self._mock_rows):
660+
row = self._mock_rows[self._mock_index]
661+
self._mock_index += 1
662+
# Apply row transformation if fetchone is patched
663+
if hasattr(self, 'fetchone') and callable(self.fetchone):
664+
# Reset index, get transformed row, restore index
665+
self._mock_index -= 1
666+
result = self.fetchone()
667+
return result
668+
return tuple(row) if isinstance(row, list) else row
669+
raise StopIteration
670+
# In record mode with captured data, iterate over stored rows
671+
if hasattr(self, '_tusk_rows') and self._tusk_rows is not None:
672+
if self._tusk_index < len(self._tusk_rows):
673+
row = self._tusk_rows[self._tusk_index]
674+
self._tusk_index += 1
675+
return row
676+
raise StopIteration
677+
return super().__next__()
678+
597679
return InstrumentedServerCursor
598680

599681
def _traced_execute(
@@ -1304,6 +1386,28 @@ def _query_to_string(self, query: Any, cursor: Any) -> str:
13041386

13051387
return str(query) if not isinstance(query, str) else query
13061388

1389+
def _detect_row_factory_type(self, row_factory: Any) -> str:
1390+
"""Detect the type of row factory for mock transformations.
1391+
1392+
Returns:
1393+
"dict" for dict_row, "namedtuple" for namedtuple_row, "tuple" otherwise
1394+
"""
1395+
if row_factory is None:
1396+
return "tuple"
1397+
1398+
# Check by function/class name
1399+
factory_name = getattr(row_factory, '__name__', '')
1400+
if not factory_name:
1401+
factory_name = str(type(row_factory).__name__)
1402+
1403+
factory_name_lower = factory_name.lower()
1404+
if 'dict' in factory_name_lower:
1405+
return "dict"
1406+
elif 'namedtuple' in factory_name_lower:
1407+
return "namedtuple"
1408+
1409+
return "tuple"
1410+
13071411
def _is_in_pipeline_mode(self, cursor: Any) -> bool:
13081412
"""Check if the cursor's connection is currently in pipeline mode.
13091413
@@ -1443,6 +1547,36 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
14431547
except AttributeError:
14441548
pass
14451549

1550+
# Get row_factory from cursor or connection for row transformation
1551+
row_factory = getattr(cursor, 'row_factory', None)
1552+
if row_factory is None:
1553+
conn = getattr(cursor, 'connection', None)
1554+
if conn:
1555+
row_factory = getattr(conn, 'row_factory', None)
1556+
1557+
# Extract column names from description for row factory transformations
1558+
column_names = None
1559+
if description_data:
1560+
column_names = [col["name"] for col in description_data]
1561+
1562+
# Detect row factory type for transformation
1563+
row_factory_type = self._detect_row_factory_type(row_factory)
1564+
1565+
# Create namedtuple class once if needed (avoid recreating for each row)
1566+
RowClass = None
1567+
if row_factory_type == "namedtuple" and column_names:
1568+
from collections import namedtuple
1569+
RowClass = namedtuple('Row', column_names)
1570+
1571+
def transform_row(row):
1572+
"""Transform raw row data according to row factory type."""
1573+
values = tuple(row) if isinstance(row, list) else row
1574+
if row_factory_type == "dict" and column_names:
1575+
return dict(zip(column_names, values))
1576+
elif row_factory_type == "namedtuple" and RowClass is not None:
1577+
return RowClass(*values)
1578+
return values
1579+
14461580
mock_rows = actual_data.get("rows", [])
14471581
# Deserialize datetime strings back to datetime objects for consistent Flask serialization
14481582
mock_rows = [deserialize_db_value(row) for row in mock_rows]
@@ -1453,7 +1587,7 @@ def mock_fetchone():
14531587
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
14541588
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
14551589
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
1456-
return tuple(row) if isinstance(row, list) else row
1590+
return transform_row(row)
14571591
return None
14581592

14591593
def mock_fetchmany(size=cursor.arraysize):
@@ -1468,12 +1602,15 @@ def mock_fetchmany(size=cursor.arraysize):
14681602
def mock_fetchall():
14691603
rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue]
14701604
cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue]
1471-
return [tuple(row) if isinstance(row, list) else row for row in rows]
1605+
return [transform_row(row) for row in rows]
14721606

14731607
cursor.fetchone = mock_fetchone # pyright: ignore[reportAttributeAccessIssue]
14741608
cursor.fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue]
14751609
cursor.fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue]
14761610

1611+
# Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor
1612+
# and MockCursor classes, as Python looks up special methods on the type, not instance
1613+
14771614
def _finalize_query_span(
14781615
self,
14791616
span: trace.Span,
@@ -1538,8 +1675,20 @@ def serialize_value(val):
15381675
# We need to capture these for replay mode
15391676
try:
15401677
all_rows = cursor.fetchall()
1541-
# Convert tuples to lists for JSON serialization
1542-
rows = [list(row) for row in all_rows]
1678+
# Convert rows to lists for JSON serialization
1679+
# Handle dict_row (returns dicts) and namedtuple_row (returns namedtuples)
1680+
column_names = [d["name"] for d in description]
1681+
rows = []
1682+
for row in all_rows:
1683+
if isinstance(row, dict):
1684+
# dict_row: extract values in column order
1685+
rows.append([row.get(col) for col in column_names])
1686+
elif hasattr(row, '_fields'):
1687+
# namedtuple: extract values in column order
1688+
rows.append([getattr(row, col, None) for col in column_names])
1689+
else:
1690+
# tuple or list: convert directly
1691+
rows.append(list(row))
15431692

15441693
# CRITICAL: Re-populate cursor so user code can still fetch
15451694
# We'll store the rows and patch fetch methods

0 commit comments

Comments
 (0)