3636 replay_trace_id_context ,
3737)
3838from ..base import InstrumentationBase
39+ from ..utils .psycopg_utils import deserialize_db_value
3940
4041logger = logging .getLogger (__name__ )
4142
@@ -198,6 +199,44 @@ def __exit__(self, exc_type, exc_val, exc_tb):
198199 return False
199200
200201
202+ class InstrumentedConnection :
203+ """Wraps a real psycopg2 connection to intercept cursor() calls.
204+
205+ This ensures that even when users pass cursor_factory to cursor() instead of
206+ connect(), the cursor is still instrumented for tracing.
207+ """
208+
209+ def __init__ (self , connection : Any , instrumentation : Psycopg2Instrumentation , sdk : TuskDrift ) -> None :
210+ # Use object.__setattr__ to avoid triggering __getattr__
211+ object .__setattr__ (self , "_connection" , connection )
212+ object .__setattr__ (self , "_instrumentation" , instrumentation )
213+ object .__setattr__ (self , "_sdk" , sdk )
214+
215+ def cursor (self , name : str | None = None , cursor_factory : Any = None , * args : Any , ** kwargs : Any ) -> Any :
216+ """Intercept cursor creation to wrap user-provided cursor_factory."""
217+ # Create instrumented cursor factory (wrapping user's factory if provided)
218+ wrapped_factory = self ._instrumentation ._create_cursor_factory (
219+ self ._sdk ,
220+ cursor_factory , # This becomes the base class (None uses default)
221+ )
222+ return self ._connection .cursor (* args , name = name , cursor_factory = wrapped_factory , ** kwargs )
223+
224+ def __getattr__ (self , name : str ) -> Any :
225+ """Proxy all other methods/attributes to the real connection."""
226+ return getattr (self ._connection , name )
227+
228+ def __setattr__ (self , name : str , value : Any ) -> None :
229+ """Proxy attribute setting to the real connection."""
230+ setattr (self ._connection , name , value )
231+
232+ def __enter__ (self ) -> InstrumentedConnection :
233+ self ._connection .__enter__ ()
234+ return self
235+
236+ def __exit__ (self , * args : Any ) -> Any :
237+ return self ._connection .__exit__ (* args )
238+
239+
201240def _query_to_str (query : QueryType ) -> str :
202241 """Convert a query (str, bytes, or Composable) to a string."""
203242 if isinstance (query , str ):
@@ -287,35 +326,28 @@ def patched_connect(*args, **kwargs):
287326 logger .debug ("[PATCHED_CONNECT] SDK disabled, passing through" )
288327 return original_connect (* args , ** kwargs )
289328
290- # Use cursor_factory to wrap cursors
291- # Save any user-provided cursor_factory
292- user_cursor_factory = kwargs .pop ("cursor_factory" , None )
293-
294- # Create our instrumented cursor factory
295- cursor_factory = instrumentation ._create_cursor_factory (sdk , user_cursor_factory )
296-
297329 # In REPLAY mode, try to connect but fall back to mock connection if DB is unavailable
298330 if sdk .mode == TuskDriftMode .REPLAY :
299331 try :
300- kwargs ["cursor_factory" ] = cursor_factory
301332 logger .debug ("[PATCHED_CONNECT] REPLAY mode: Attempting real DB connection..." )
302333 connection = original_connect (* args , ** kwargs )
303334 logger .info ("[PATCHED_CONNECT] REPLAY mode: Successfully connected to real database" )
304- return connection
335+ # Wrap connection to intercept cursor() calls
336+ return InstrumentedConnection (connection , instrumentation , sdk )
305337 except Exception as e :
306338 logger .info (
307339 f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({ e } ), using mock connection"
308340 )
309341 # Return mock connection that doesn't require a real database
310- return MockConnection (sdk , instrumentation , cursor_factory )
342+ # MockConnection already handles cursor_factory correctly in its cursor() method
343+ return MockConnection (sdk , instrumentation , None )
311344
312345 # In RECORD mode, always require real connection
313- kwargs ["cursor_factory" ] = cursor_factory
314346 logger .debug ("[PATCHED_CONNECT] RECORD mode: Connecting to database..." )
315347 connection = original_connect (* args , ** kwargs )
316348 logger .info ("[PATCHED_CONNECT] RECORD mode: Connected to database successfully" )
317-
318- return connection
349+ # Wrap connection to intercept cursor() calls
350+ return InstrumentedConnection ( connection , instrumentation , sdk )
319351
320352 # Apply patch
321353 module .connect = patched_connect # type: ignore[attr-defined]
@@ -573,11 +605,12 @@ def _traced_executemany(
573605 return None
574606
575607 # For all other queries (pre-app-start OR within a request trace), get mock
608+ # Wrap in {"_batch": ...} to match the recording format
576609 is_pre_app_start = not sdk .app_ready
577610 mock_result = self ._try_get_mock (
578611 sdk ,
579612 query ,
580- params_list ,
613+ { "_batch" : params_list } ,
581614 trace_id ,
582615 span_id ,
583616 parent_span_id ,
@@ -803,6 +836,24 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
803836
804837 # Store mock rows for fetching
805838 mock_rows = actual_data .get ("rows" , [])
839+ # Deserialize datetime strings back to datetime objects for consistent Flask/Django serialization
840+ mock_rows = [deserialize_db_value (row ) for row in mock_rows ]
841+
842+ # Check if this is a dict-cursor (like RealDictCursor) by checking if cursor class
843+ # inherits from a dict-returning cursor type
844+ is_dict_cursor = False
845+ try :
846+ import psycopg2 .extras
847+
848+ is_dict_cursor = isinstance (cursor , (psycopg2 .extras .RealDictCursor , psycopg2 .extras .DictCursor ))
849+ except (ImportError , AttributeError ):
850+ pass
851+
852+ # If it's a dict cursor and we have description, convert rows to dicts
853+ if is_dict_cursor and description_data :
854+ column_names = [col ["name" ] for col in description_data ]
855+ mock_rows = [dict (zip (column_names , row , strict = True )) for row in mock_rows ]
856+
806857 cursor ._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
807858 cursor ._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]
808859
@@ -815,7 +866,9 @@ def mock_fetchone():
815866 if cursor ._mock_index < len (cursor ._mock_rows ): # pyright: ignore[reportAttributeAccessIssue]
816867 row = cursor ._mock_rows [cursor ._mock_index ] # pyright: ignore[reportAttributeAccessIssue]
817868 cursor ._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
818- # Convert list to tuple to match psycopg2 behavior
869+ # Return as-is for dict cursors, convert to tuple for regular cursors
870+ if isinstance (row , dict ):
871+ return row
819872 return tuple (row ) if isinstance (row , list ) else row
820873 return None
821874
@@ -832,8 +885,15 @@ def mock_fetchall():
832885 logger .debug (f"[MOCK] fetchall called, returning { len (cursor ._mock_rows [cursor ._mock_index :])} rows" ) # pyright: ignore[reportAttributeAccessIssue]
833886 rows = cursor ._mock_rows [cursor ._mock_index :] # pyright: ignore[reportAttributeAccessIssue]
834887 cursor ._mock_index = len (cursor ._mock_rows ) # pyright: ignore[reportAttributeAccessIssue]
835- # Convert lists to tuples to match psycopg2 behavior
836- result = [tuple (row ) if isinstance (row , list ) else row for row in rows ]
888+ # Return as-is for dict rows, convert lists to tuples for regular cursors
889+ result = []
890+ for row in rows :
891+ if isinstance (row , dict ):
892+ result .append (row )
893+ elif isinstance (row , list ):
894+ result .append (tuple (row ))
895+ else :
896+ result .append (row )
837897 logger .debug (f"[MOCK] fetchall returning: { result } " )
838898 return result
839899
@@ -906,8 +966,16 @@ def serialize_value(val):
906966 # We need to capture these for replay mode
907967 try :
908968 all_rows = cursor .fetchall ()
909- # Convert tuples to lists for JSON serialization
910- rows = [list (row ) for row in all_rows ]
969+ # Convert rows to lists for JSON serialization
970+ # Handle both tuple rows (regular cursor) and dict rows (RealDictCursor)
971+ rows = []
972+ for row in all_rows :
973+ if isinstance (row , dict ):
974+ # RealDictCursor returns dict-like rows - extract values in column order
975+ rows .append ([row [desc [0 ]] for desc in cursor .description ])
976+ else :
977+ # Regular cursor returns tuples
978+ rows .append (list (row ))
911979
912980 # CRITICAL: Re-populate cursor so user code can still fetch
913981 # We'll store the rows and patch fetch methods
0 commit comments