2424 TuskDriftMode ,
2525)
2626from ..base import InstrumentationBase
27+ from ..sqlalchemy .context import sqlalchemy_execution_active_context , sqlalchemy_replay_mock_context
2728from ..utils .psycopg_utils import deserialize_db_value , restore_row_integer_types
2829from ..utils .serialization import serialize_value
2930from .mocks import MockConnection , MockCopy
@@ -531,6 +532,15 @@ def _traced_execute(
531532 if sdk .mode == TuskDriftMode .DISABLED :
532533 return original_execute (query , params , ** kwargs )
533534
535+ # SQLAlchemy replay source-of-truth path: consume SQLAlchemy-resolved
536+ # payload and skip driver-level mock matching/span creation.
537+ if sdk .mode == TuskDriftMode .REPLAY and sqlalchemy_execution_active_context .get ():
538+ mock_result = sqlalchemy_replay_mock_context .get ()
539+ if mock_result is not None :
540+ self ._raise_replay_error_if_present (mock_result )
541+ self ._mock_execute_with_data (cursor , mock_result )
542+ return cursor
543+
534544 query_str = self ._query_to_string (query , cursor )
535545
536546 if sdk .mode == TuskDriftMode .REPLAY :
@@ -577,6 +587,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A
577587 f"Query: { query_str [:100 ]} ..."
578588 )
579589
590+ self ._raise_replay_error_if_present (mock_result )
580591 self ._mock_execute_with_data (cursor , mock_result , is_async = is_async )
581592 span_info .span .end ()
582593 return cursor
@@ -593,6 +604,21 @@ def _record_execute(
593604 kwargs : dict ,
594605 ) -> Any :
595606 """Handle RECORD mode for execute - create span and execute query."""
607+ # Under SQLAlchemy instrumentation, skip creating/exporting a driver span
608+ # but keep cursor-state capture so SQLAlchemy span can include result data.
609+ if sqlalchemy_execution_active_context .get ():
610+ error = None
611+ try :
612+ return original_execute (query , params , ** kwargs )
613+ except Exception as e :
614+ error = e
615+ raise
616+ finally :
617+ try :
618+ self ._finalize_query_span (trace .INVALID_SPAN , cursor , query_str , params , error )
619+ except Exception as e :
620+ logger .error (f"Error in SQLAlchemy-scoped psycopg record finalization: { e } " )
621+
596622 # Reset cursor state from any previous execute() on this cursor.
597623 # Delete instance attribute overrides to expose original class methods.
598624 # This is safer than saving/restoring bound methods which can become stale.
@@ -663,6 +689,18 @@ def _traced_executemany(
663689 if sdk .mode == TuskDriftMode .DISABLED :
664690 return original_executemany (query , params_seq , ** kwargs )
665691
692+ # SQLAlchemy replay source-of-truth path: consume SQLAlchemy-resolved
693+ # payload and skip driver-level mock matching/span creation.
694+ if sdk .mode == TuskDriftMode .REPLAY and sqlalchemy_execution_active_context .get ():
695+ mock_result = sqlalchemy_replay_mock_context .get ()
696+ if mock_result is not None :
697+ self ._raise_replay_error_if_present (mock_result )
698+ if mock_result .get ("executemany_returning" ):
699+ self ._mock_executemany_returning_with_data (cursor , mock_result )
700+ else :
701+ self ._mock_execute_with_data (cursor , mock_result )
702+ return cursor
703+
666704 query_str = self ._query_to_string (query , cursor )
667705 # Convert to list BEFORE executing to avoid iterator exhaustion
668706 params_list = list (params_seq )
@@ -713,6 +751,7 @@ def _replay_executemany(
713751 f"Query: { query_str [:100 ]} ..."
714752 )
715753
754+ self ._raise_replay_error_if_present (mock_result )
716755 # Check if this is executemany_returning format (multiple result sets)
717756 if mock_result .get ("executemany_returning" ):
718757 self ._mock_executemany_returning_with_data (cursor , mock_result )
@@ -723,6 +762,17 @@ def _replay_executemany(
723762 span_info .span .end ()
724763 return cursor
725764
765+ def _raise_replay_error_if_present (self , mock_result : dict [str , Any ]) -> None :
766+ """Raise recorded DB error in replay instead of emulating success."""
767+ if not isinstance (mock_result , dict ):
768+ return
769+ error_message = mock_result .get ("errorMessage" )
770+ if error_message :
771+ raise RuntimeError (str (error_message ))
772+ error_name = mock_result .get ("errorName" )
773+ if error_name :
774+ raise RuntimeError (str (error_name ))
775+
726776 def _record_executemany (
727777 self ,
728778 cursor : Any ,
@@ -736,6 +786,36 @@ def _record_executemany(
736786 returning : bool = False ,
737787 ) -> Any :
738788 """Handle RECORD mode for executemany - create span and execute query."""
789+ # Under SQLAlchemy instrumentation, skip driver span export while preserving
790+ # result capture needed for SQLAlchemy source-of-truth spans.
791+ if sqlalchemy_execution_active_context .get ():
792+ error = None
793+ try :
794+ return original_executemany (query , params_list , ** kwargs )
795+ except Exception as e :
796+ error = e
797+ raise
798+ finally :
799+ try :
800+ if returning and error is None :
801+ self ._finalize_executemany_returning_span (
802+ trace .INVALID_SPAN ,
803+ cursor ,
804+ query_str ,
805+ {"_batch" : params_list , "_returning" : True },
806+ error ,
807+ )
808+ else :
809+ self ._finalize_query_span (
810+ trace .INVALID_SPAN ,
811+ cursor ,
812+ query_str ,
813+ {"_batch" : params_list },
814+ error ,
815+ )
816+ except Exception as e :
817+ logger .error (f"Error in SQLAlchemy-scoped psycopg executemany finalization: { e } " )
818+
739819 span_info = self ._create_query_span (sdk , "query" , is_pre_app_start )
740820
741821 if not span_info :
@@ -792,6 +872,13 @@ async def _traced_async_execute(
792872 if sdk .mode == TuskDriftMode .DISABLED :
793873 return await original_execute (query , params , ** kwargs )
794874
875+ if sdk .mode == TuskDriftMode .REPLAY and sqlalchemy_execution_active_context .get ():
876+ mock_result = sqlalchemy_replay_mock_context .get ()
877+ if mock_result is not None :
878+ self ._raise_replay_error_if_present (mock_result )
879+ self ._mock_execute_with_data (cursor , mock_result , is_async = True )
880+ return cursor
881+
795882 query_str = self ._query_to_string (query , cursor )
796883
797884 if sdk .mode == TuskDriftMode .REPLAY :
@@ -812,6 +899,19 @@ async def _record_async_execute(
812899 kwargs : dict ,
813900 ) -> Any :
814901 """Handle RECORD mode for async execute - create span and execute query."""
902+ if sqlalchemy_execution_active_context .get ():
903+ error = None
904+ try :
905+ return await original_execute (query , params , ** kwargs )
906+ except Exception as e :
907+ error = e
908+ raise
909+ finally :
910+ try :
911+ self ._finalize_query_span (trace .INVALID_SPAN , cursor , query_str , params , error )
912+ except Exception as e :
913+ logger .error (f"Error in SQLAlchemy-scoped async psycopg finalization: { e } " )
914+
815915 is_pre_app_start = not sdk .app_ready
816916
817917 # Reset cursor state from any previous execute() on this cursor
@@ -867,6 +967,16 @@ async def _traced_async_executemany(
867967 if sdk .mode == TuskDriftMode .DISABLED :
868968 return await original_executemany (query , params_seq , ** kwargs )
869969
970+ if sdk .mode == TuskDriftMode .REPLAY and sqlalchemy_execution_active_context .get ():
971+ mock_result = sqlalchemy_replay_mock_context .get ()
972+ if mock_result is not None :
973+ self ._raise_replay_error_if_present (mock_result )
974+ if mock_result .get ("executemany_returning" ):
975+ self ._mock_executemany_returning_with_data (cursor , mock_result )
976+ else :
977+ self ._mock_execute_with_data (cursor , mock_result , is_async = True )
978+ return cursor
979+
870980 query_str = self ._query_to_string (query , cursor )
871981 params_list = list (params_seq )
872982 returning = kwargs .get ("returning" , False )
@@ -892,6 +1002,34 @@ async def _record_async_executemany(
8921002 returning : bool = False ,
8931003 ) -> Any :
8941004 """Handle RECORD mode for async executemany - create span and execute query."""
1005+ if sqlalchemy_execution_active_context .get ():
1006+ error = None
1007+ try :
1008+ return await original_executemany (query , params_list , ** kwargs )
1009+ except Exception as e :
1010+ error = e
1011+ raise
1012+ finally :
1013+ try :
1014+ if returning and error is None :
1015+ self ._finalize_executemany_returning_span (
1016+ trace .INVALID_SPAN ,
1017+ cursor ,
1018+ query_str ,
1019+ {"_batch" : params_list , "_returning" : True },
1020+ error ,
1021+ )
1022+ else :
1023+ self ._finalize_query_span (
1024+ trace .INVALID_SPAN ,
1025+ cursor ,
1026+ query_str ,
1027+ {"_batch" : params_list },
1028+ error ,
1029+ )
1030+ except Exception as e :
1031+ logger .error (f"Error in SQLAlchemy-scoped async psycopg executemany finalization: { e } " )
1032+
8951033 is_pre_app_start = not sdk .app_ready
8961034 span_info = self ._create_query_span (sdk , "query" , is_pre_app_start )
8971035
@@ -1657,6 +1795,18 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any], is_asy
16571795 except AttributeError :
16581796 object .__setattr__ (cursor , "rowcount" , actual_data .get ("rowcount" , - 1 ))
16591797
1798+ # Preserve insert metadata for ORM write paths.
1799+ lastrowid = actual_data .get ("lastrowid" )
1800+ if lastrowid is not None :
1801+ try :
1802+ cursor ._mock_lastrowid = lastrowid
1803+ except Exception :
1804+ pass
1805+ try :
1806+ object .__setattr__ (cursor , "lastrowid" , lastrowid )
1807+ except Exception :
1808+ pass
1809+
16601810 description_data = actual_data .get ("description" )
16611811 self ._set_cursor_description (cursor , description_data )
16621812
@@ -1822,9 +1972,10 @@ def fetchone():
18221972 return fetchone
18231973
18241974 def make_fetchmany (cn , RC ):
1825- def fetchmany (size = cursor .arraysize ):
1975+ def fetchmany (size = None ):
1976+ effective_size = cursor .arraysize if size is None else size
18261977 rows = []
1827- for _ in range (size ):
1978+ for _ in range (effective_size ):
18281979 if cursor ._mock_index < len (cursor ._mock_rows ): # pyright: ignore[reportAttributeAccessIssue]
18291980 row = cursor ._mock_rows [cursor ._mock_index ] # pyright: ignore[reportAttributeAccessIssue]
18301981 cursor ._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
@@ -1877,9 +2028,10 @@ def fetchone():
18772028 return fetchone
18782029
18792030 def make_fetchmany_replay (cn , RC ):
1880- def fetchmany (size = cursor .arraysize ):
2031+ def fetchmany (size = None ):
2032+ effective_size = cursor .arraysize if size is None else size
18812033 rows = []
1882- for _ in range (size ):
2034+ for _ in range (effective_size ):
18832035 if cursor ._mock_index < len (cursor ._mock_rows ): # pyright: ignore[reportAttributeAccessIssue]
18842036 row = cursor ._mock_rows [cursor ._mock_index ] # pyright: ignore[reportAttributeAccessIssue]
18852037 cursor ._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
@@ -2060,6 +2212,8 @@ def _finalize_query_span(
20602212 output_value = {
20612213 "rowcount" : cursor .rowcount if hasattr (cursor , "rowcount" ) else - 1 ,
20622214 }
2215+ if hasattr (cursor , "lastrowid" ) and cursor .lastrowid is not None :
2216+ output_value ["lastrowid" ] = serialize_value (cursor .lastrowid )
20632217
20642218 # Capture statusmessage for replay
20652219 if hasattr (cursor , "statusmessage" ) and cursor .statusmessage is not None :
@@ -2772,8 +2926,12 @@ def patched_fetchone():
27722926 return row
27732927 return None
27742928
2775- def patched_fetchmany (size = cursor .arraysize ):
2776- result = cursor ._tusk_rows [cursor ._tusk_index : cursor ._tusk_index + size ] # pyright: ignore[reportAttributeAccessIssue]
2929+ def patched_fetchmany (size = None ):
2930+ effective_size = cursor .arraysize if size is None else size
2931+ result = cursor ._tusk_rows [ # pyright: ignore[reportAttributeAccessIssue]
2932+ cursor ._tusk_index : cursor ._tusk_index
2933+ + effective_size # pyright: ignore[reportAttributeAccessIssue]
2934+ ]
27772935 cursor ._tusk_index += len (result ) # pyright: ignore[reportAttributeAccessIssue]
27782936 return result
27792937
@@ -2808,8 +2966,12 @@ def patched_fetchone():
28082966 return patched_fetchone
28092967
28102968 def make_patched_fetchmany_record ():
2811- def patched_fetchmany (size = cursor .arraysize ):
2812- result = cursor ._tusk_rows [cursor ._tusk_index : cursor ._tusk_index + size ] # pyright: ignore[reportAttributeAccessIssue]
2969+ def patched_fetchmany (size = None ):
2970+ effective_size = cursor .arraysize if size is None else size
2971+ result = cursor ._tusk_rows [ # pyright: ignore[reportAttributeAccessIssue]
2972+ cursor ._tusk_index : cursor ._tusk_index
2973+ + effective_size # pyright: ignore[reportAttributeAccessIssue]
2974+ ]
28132975 cursor ._tusk_index += len (result ) # pyright: ignore[reportAttributeAccessIssue]
28142976 return result
28152977
0 commit comments