@@ -76,10 +76,11 @@ class MockConnection:
7676 All queries are mocked at the cursor.execute() level.
7777 """
7878
79- def __init__ (self , sdk : TuskDrift , instrumentation : PsycopgInstrumentation , cursor_factory ):
79+ def __init__ (self , sdk : TuskDrift , instrumentation : PsycopgInstrumentation , cursor_factory , row_factory = None ):
8080 self .sdk = sdk
8181 self .instrumentation = instrumentation
8282 self .cursor_factory = cursor_factory
83+ self .row_factory = row_factory # Store row_factory for cursor creation
8384 self .closed = False
8485 self .autocommit = False
8586
@@ -233,6 +234,18 @@ def stream(self, query, params=None, **kwargs):
233234 """Will be replaced by instrumentation."""
234235 return iter ([])
235236
237+ def __iter__ (self ):
238+ """Support direct cursor iteration (for row in cursor)."""
239+ return self
240+
241+ def __next__ (self ):
242+ """Return next row for iteration."""
243+ if self ._mock_index < len (self ._mock_rows ):
244+ row = self ._mock_rows [self ._mock_index ]
245+ self ._mock_index += 1
246+ return tuple (row ) if isinstance (row , list ) else row
247+ raise StopIteration
248+
236249 def close (self ):
237250 pass
238251
@@ -439,6 +452,7 @@ def patched_connect(*args, **kwargs):
439452 return original_connect (* args , ** kwargs )
440453
441454 user_cursor_factory = kwargs .pop ("cursor_factory" , None )
455+ user_row_factory = kwargs .pop ("row_factory" , None )
442456 cursor_factory = instrumentation ._create_cursor_factory (sdk , user_cursor_factory )
443457
444458 # Create server cursor factory for named cursors (conn.cursor(name="..."))
@@ -448,6 +462,8 @@ def patched_connect(*args, **kwargs):
448462 if sdk .mode == TuskDriftMode .REPLAY :
449463 try :
450464 kwargs ["cursor_factory" ] = cursor_factory
465+ if user_row_factory is not None :
466+ kwargs ["row_factory" ] = user_row_factory
451467 connection = original_connect (* args , ** kwargs )
452468 # Set server cursor factory on the connection for named cursors
453469 if server_cursor_factory :
@@ -459,10 +475,12 @@ def patched_connect(*args, **kwargs):
459475 f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({ e } ), using mock connection (psycopg3)"
460476 )
461477 # Return mock connection that doesn't require a real database
462- return MockConnection (sdk , instrumentation , cursor_factory )
478+ return MockConnection (sdk , instrumentation , cursor_factory , row_factory = user_row_factory )
463479
464480 # In RECORD mode, always require real connection
465481 kwargs ["cursor_factory" ] = cursor_factory
482+ if user_row_factory is not None :
483+ kwargs ["row_factory" ] = user_row_factory
466484 connection = original_connect (* args , ** kwargs )
467485 # Set server cursor factory on the connection for named cursors
468486 if server_cursor_factory :
@@ -558,6 +576,38 @@ def stream(self, query, params=None, **kwargs):
558576 def copy (self , query , params = None , ** kwargs ):
559577 return instrumentation ._traced_copy (self , super ().copy , sdk , query , params , ** kwargs )
560578
579+ def __iter__ (self ):
580+ # Support direct cursor iteration (for row in cursor)
581+ # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows)
582+ if hasattr (self , '_mock_rows' ) and self ._mock_rows is not None :
583+ return self
584+ if hasattr (self , '_tusk_rows' ) and self ._tusk_rows is not None :
585+ return self
586+ return super ().__iter__ ()
587+
588+ def __next__ (self ):
589+ # In replay mode with mock data, iterate over mock rows
590+ if hasattr (self , '_mock_rows' ) and self ._mock_rows is not None :
591+ if self ._mock_index < len (self ._mock_rows ):
592+ row = self ._mock_rows [self ._mock_index ]
593+ self ._mock_index += 1
594+ # Apply row transformation if fetchone is patched
595+ if hasattr (self , 'fetchone' ) and callable (self .fetchone ):
596+ # Reset index, get transformed row, restore index
597+ self ._mock_index -= 1
598+ result = self .fetchone ()
599+ return result
600+ return tuple (row ) if isinstance (row , list ) else row
601+ raise StopIteration
602+ # In record mode with captured data, iterate over stored rows
603+ if hasattr (self , '_tusk_rows' ) and self ._tusk_rows is not None :
604+ if self ._tusk_index < len (self ._tusk_rows ):
605+ row = self ._tusk_rows [self ._tusk_index ]
606+ self ._tusk_index += 1
607+ return row
608+ raise StopIteration
609+ return super ().__next__ ()
610+
561611 return InstrumentedCursor
562612
563613 def _create_server_cursor_factory (self , sdk : TuskDrift , base_factory = None ):
@@ -594,6 +644,38 @@ def execute(self, query, params=None, **kwargs):
594644 # Note: ServerCursor doesn't support executemany()
595645 # Note: ServerCursor has stream-like iteration via fetchmany/itersize
596646
647+ def __iter__ (self ):
648+ # Support direct cursor iteration (for row in cursor)
649+ # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows)
650+ if hasattr (self , '_mock_rows' ) and self ._mock_rows is not None :
651+ return self
652+ if hasattr (self , '_tusk_rows' ) and self ._tusk_rows is not None :
653+ return self
654+ return super ().__iter__ ()
655+
656+ def __next__ (self ):
657+ # In replay mode with mock data, iterate over mock rows
658+ if hasattr (self , '_mock_rows' ) and self ._mock_rows is not None :
659+ if self ._mock_index < len (self ._mock_rows ):
660+ row = self ._mock_rows [self ._mock_index ]
661+ self ._mock_index += 1
662+ # Apply row transformation if fetchone is patched
663+ if hasattr (self , 'fetchone' ) and callable (self .fetchone ):
664+ # Reset index, get transformed row, restore index
665+ self ._mock_index -= 1
666+ result = self .fetchone ()
667+ return result
668+ return tuple (row ) if isinstance (row , list ) else row
669+ raise StopIteration
670+ # In record mode with captured data, iterate over stored rows
671+ if hasattr (self , '_tusk_rows' ) and self ._tusk_rows is not None :
672+ if self ._tusk_index < len (self ._tusk_rows ):
673+ row = self ._tusk_rows [self ._tusk_index ]
674+ self ._tusk_index += 1
675+ return row
676+ raise StopIteration
677+ return super ().__next__ ()
678+
597679 return InstrumentedServerCursor
598680
599681 def _traced_execute (
@@ -1304,6 +1386,28 @@ def _query_to_string(self, query: Any, cursor: Any) -> str:
13041386
13051387 return str (query ) if not isinstance (query , str ) else query
13061388
1389+ def _detect_row_factory_type (self , row_factory : Any ) -> str :
1390+ """Detect the type of row factory for mock transformations.
1391+
1392+ Returns:
1393+ "dict" for dict_row, "namedtuple" for namedtuple_row, "tuple" otherwise
1394+ """
1395+ if row_factory is None :
1396+ return "tuple"
1397+
1398+ # Check by function/class name
1399+ factory_name = getattr (row_factory , '__name__' , '' )
1400+ if not factory_name :
1401+ factory_name = str (type (row_factory ).__name__ )
1402+
1403+ factory_name_lower = factory_name .lower ()
1404+ if 'dict' in factory_name_lower :
1405+ return "dict"
1406+ elif 'namedtuple' in factory_name_lower :
1407+ return "namedtuple"
1408+
1409+ return "tuple"
1410+
13071411 def _is_in_pipeline_mode (self , cursor : Any ) -> bool :
13081412 """Check if the cursor's connection is currently in pipeline mode.
13091413
@@ -1443,6 +1547,36 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
14431547 except AttributeError :
14441548 pass
14451549
1550+ # Get row_factory from cursor or connection for row transformation
1551+ row_factory = getattr (cursor , 'row_factory' , None )
1552+ if row_factory is None :
1553+ conn = getattr (cursor , 'connection' , None )
1554+ if conn :
1555+ row_factory = getattr (conn , 'row_factory' , None )
1556+
1557+ # Extract column names from description for row factory transformations
1558+ column_names = None
1559+ if description_data :
1560+ column_names = [col ["name" ] for col in description_data ]
1561+
1562+ # Detect row factory type for transformation
1563+ row_factory_type = self ._detect_row_factory_type (row_factory )
1564+
1565+ # Create namedtuple class once if needed (avoid recreating for each row)
1566+ RowClass = None
1567+ if row_factory_type == "namedtuple" and column_names :
1568+ from collections import namedtuple
1569+ RowClass = namedtuple ('Row' , column_names )
1570+
1571+ def transform_row (row ):
1572+ """Transform raw row data according to row factory type."""
1573+ values = tuple (row ) if isinstance (row , list ) else row
1574+ if row_factory_type == "dict" and column_names :
1575+ return dict (zip (column_names , values ))
1576+ elif row_factory_type == "namedtuple" and RowClass is not None :
1577+ return RowClass (* values )
1578+ return values
1579+
14461580 mock_rows = actual_data .get ("rows" , [])
14471581 # Deserialize datetime strings back to datetime objects for consistent Flask serialization
14481582 mock_rows = [deserialize_db_value (row ) for row in mock_rows ]
@@ -1453,7 +1587,7 @@ def mock_fetchone():
14531587 if cursor ._mock_index < len (cursor ._mock_rows ): # pyright: ignore[reportAttributeAccessIssue]
14541588 row = cursor ._mock_rows [cursor ._mock_index ] # pyright: ignore[reportAttributeAccessIssue]
14551589 cursor ._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
1456- return tuple (row ) if isinstance ( row , list ) else row
1590+ return transform_row (row )
14571591 return None
14581592
14591593 def mock_fetchmany (size = cursor .arraysize ):
@@ -1468,12 +1602,15 @@ def mock_fetchmany(size=cursor.arraysize):
14681602 def mock_fetchall ():
14691603 rows = cursor ._mock_rows [cursor ._mock_index :] # pyright: ignore[reportAttributeAccessIssue]
14701604 cursor ._mock_index = len (cursor ._mock_rows ) # pyright: ignore[reportAttributeAccessIssue]
1471- return [tuple (row ) if isinstance ( row , list ) else row for row in rows ]
1605+ return [transform_row (row ) for row in rows ]
14721606
14731607 cursor .fetchone = mock_fetchone # pyright: ignore[reportAttributeAccessIssue]
14741608 cursor .fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue]
14751609 cursor .fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue]
14761610
1611+ # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor
1612+ # and MockCursor classes, as Python looks up special methods on the type, not instance
1613+
14771614 def _finalize_query_span (
14781615 self ,
14791616 span : trace .Span ,
@@ -1538,8 +1675,20 @@ def serialize_value(val):
15381675 # We need to capture these for replay mode
15391676 try :
15401677 all_rows = cursor .fetchall ()
1541- # Convert tuples to lists for JSON serialization
1542- rows = [list (row ) for row in all_rows ]
1678+ # Convert rows to lists for JSON serialization
1679+ # Handle dict_row (returns dicts) and namedtuple_row (returns namedtuples)
1680+ column_names = [d ["name" ] for d in description ]
1681+ rows = []
1682+ for row in all_rows :
1683+ if isinstance (row , dict ):
1684+ # dict_row: extract values in column order
1685+ rows .append ([row .get (col ) for col in column_names ])
1686+ elif hasattr (row , '_fields' ):
1687+ # namedtuple: extract values in column order
1688+ rows .append ([getattr (row , col , None ) for col in column_names ])
1689+ else :
1690+ # tuple or list: convert directly
1691+ rows .append (list (row ))
15431692
15441693 # CRITICAL: Re-populate cursor so user code can still fetch
15451694 # We'll store the rows and patch fetch methods
0 commit comments