Skip to content

Commit 6d8dce2

Browse files
authored
fix: add cursor iteration support to psycopg2 MockCursor for django.contrib.postgres (#54)
1 parent 6ce723d commit 6d8dce2

File tree

5 files changed

+81
-1
lines changed

5 files changed

+81
-1
lines changed

drift/instrumentation/psycopg2/instrumentation.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,22 @@ def fetchall(self):
195195
def close(self):
196196
pass
197197

198+
def __iter__(self):
199+
"""Support direct cursor iteration (for row in cursor).
200+
201+
This is required by Django's django.contrib.postgres which iterates
202+
over cursor results to register type handlers (hstore, citext, etc.).
203+
"""
204+
return self
205+
206+
def __next__(self):
207+
"""Return next row for iteration."""
208+
if self._mock_index >= len(self._mock_rows):
209+
raise StopIteration
210+
row = self._mock_rows[self._mock_index]
211+
self._mock_index += 1
212+
return tuple(row) if isinstance(row, list) else row
213+
198214
def __enter__(self):
199215
return self
200216

@@ -486,6 +502,30 @@ def executemany(self, query: QueryType, vars_list: Any) -> Any:
486502
logger.debug("[INSTRUMENTED_CURSOR] executemany() called on instrumented cursor")
487503
return instrumentation._traced_executemany(self, super().executemany, sdk, query, vars_list)
488504

505+
def __iter__(self):
506+
"""Support direct cursor iteration (for row in cursor).
507+
508+
If _tusk_rows is set (from _finalize_query_span recording), use it.
509+
Otherwise fall back to the base cursor's iteration.
510+
"""
511+
if hasattr(self, "_tusk_rows"):
512+
return self
513+
return super().__iter__()
514+
515+
def __next__(self):
516+
"""Return next row for iteration.
517+
518+
If _tusk_rows is set (from _finalize_query_span recording), iterate over stored rows.
519+
Otherwise fall back to the base cursor's __next__.
520+
"""
521+
if hasattr(self, "_tusk_rows"):
522+
if self._tusk_index < len(self._tusk_rows): # pyright: ignore[reportAttributeAccessIssue]
523+
row = self._tusk_rows[self._tusk_index] # pyright: ignore[reportAttributeAccessIssue]
524+
self._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue]
525+
return row
526+
raise StopIteration
527+
return super().__next__()
528+
489529
return InstrumentedCursor
490530

491531
def _traced_execute(
@@ -1014,6 +1054,8 @@ def patched_fetchall():
10141054
cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue]
10151055
cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue]
10161056
cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue]
1057+
# Note: __iter__ and __next__ are handled at class level in InstrumentedCursor
1058+
# (instance-level dunder patching doesn't work for C extension cursors)
10171059

10181060
except Exception as fetch_error:
10191061
logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}")

drift/stack-tests/django-postgres/src/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
ALLOWED_HOSTS = ["*"]
1515

1616
# Application definition
17+
# NOTE: django.contrib.postgres is included to test MockCursor cursor iteration.
18+
# Django's postgres extension iterates over cursor results to register type handlers.
1719
INSTALLED_APPS = [
1820
"django.contrib.contenttypes",
1921
"django.contrib.auth",
2022
"django.contrib.sessions",
23+
"django.contrib.postgres",
2124
]
2225

2326
MIDDLEWARE = [

drift/stack-tests/django-postgres/src/test_requests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
# Execute test sequence
99
make_request("GET", "/health")
1010

11+
# Cursor iteration test - validates MockCursor.__iter__ fix for django.contrib.postgres
12+
make_request("GET", "/db/cursor-iteration")
13+
1114
# Key integration test: register_default_jsonb on InstrumentedConnection
12-
# This is the main test for the bug fix
1315
make_request("GET", "/db/register-jsonb")
1416

1517
# Transaction test (rollback, doesn't return data)

drift/stack-tests/django-postgres/src/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
path("db/register-jsonb", views.db_register_jsonb, name="db_register_jsonb"),
1515
path("db/transaction", views.db_transaction, name="db_transaction"),
1616
path("db/raw-connection", views.db_raw_connection, name="db_raw_connection"),
17+
path("db/cursor-iteration", views.cursor_iteration, name="cursor_iteration"),
1718
]

drift/stack-tests/django-postgres/src/views.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,35 @@ def db_raw_connection(request):
221221
)
222222
except Exception as e:
223223
return JsonResponse({"error": str(e), "error_type": type(e).__name__}, status=500)
224+
225+
226+
@require_GET
227+
def cursor_iteration(request):
228+
"""Test cursor iteration using 'for row in cursor' syntax.
229+
230+
This validates that MockCursor implements
231+
__iter__ and __next__.
232+
"""
233+
try:
234+
with connection.cursor() as cursor:
235+
cursor.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 5")
236+
237+
rows = []
238+
for row in cursor:
239+
rows.append({"id": row[0], "name": row[1], "email": row[2]})
240+
241+
return JsonResponse(
242+
{"status": "success", "message": "Cursor iteration worked correctly", "count": len(rows), "data": rows}
243+
)
244+
except TypeError as e:
245+
# Error when MockCursor doesn't implement __iter__
246+
return JsonResponse(
247+
{
248+
"error": str(e),
249+
"error_type": "TypeError",
250+
"message": "Cursor iteration failed - MockCursor not iterable",
251+
},
252+
status=500,
253+
)
254+
except Exception as e:
255+
return JsonResponse({"error": str(e), "error_type": type(e).__name__}, status=500)

0 commit comments

Comments
 (0)