Skip to content

Commit 0e9d2c2

Browse files
authored
feat: add sqlalchemy instrumentation and replay hardening (#67)
1 parent fb2c3c9 commit 0e9d2c2

File tree

29 files changed

+1856
-38
lines changed

29 files changed

+1856
-38
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Tusk Drift currently supports the following packages and versions:
5555
| grpcio (client-side only) | all versions |
5656
| psycopg | `>=3.1.12` |
5757
| psycopg2 | all versions |
58+
| SQLAlchemy | all versions |
5859
| Redis | `>=4.0.0` |
5960
| Kinde | `>=2.0.1` |
6061
| PyJWT | all versions |

drift/core/drift_sdk.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,16 @@ def _init_auto_instrumentations(self) -> None:
416416
except ImportError:
417417
pass
418418

419+
try:
420+
import sqlalchemy # type: ignore[unresolved-import]
421+
422+
from ..instrumentation.sqlalchemy import SqlAlchemyInstrumentation
423+
424+
_ = SqlAlchemyInstrumentation()
425+
logger.debug("SQLAlchemy instrumentation initialized")
426+
except ImportError:
427+
pass
428+
419429
# Initialize PostgreSQL instrumentation before Django
420430
# Instrument BOTH psycopg2 and psycopg if available
421431
# This allows apps to use either or both

drift/instrumentation/psycopg/e2e-tests/src/app.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,55 @@ def db_batch_insert():
124124
return jsonify({"error": str(e)}), 500
125125

126126

127+
@app.route("/db/fetchmany-arraysize")
128+
def db_fetchmany_arraysize():
129+
"""Test that fetchmany() respects runtime cursor.arraysize updates."""
130+
try:
131+
with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur:
132+
cur.execute("SELECT id, name FROM users ORDER BY id")
133+
cur.arraysize = 2
134+
batch = cur.fetchmany()
135+
136+
return jsonify(
137+
{
138+
"arraysize": 2,
139+
"batch_len": len(batch),
140+
"ids": [row[0] for row in batch],
141+
}
142+
)
143+
except Exception as e:
144+
return jsonify({"error": str(e)}), 500
145+
146+
147+
@app.route("/db/error-then-query")
148+
def db_error_then_query():
149+
"""Test error handling + follow-up query in same transaction context."""
150+
error_message = ""
151+
follow_up_count = 0
152+
try:
153+
with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur:
154+
try:
155+
cur.execute("SELECT * FROM users_missing_table")
156+
cur.fetchall()
157+
except Exception as exc:
158+
error_message = str(exc)
159+
conn.rollback()
160+
161+
cur.execute("SELECT id FROM users ORDER BY id LIMIT 1")
162+
rows = cur.fetchall()
163+
follow_up_count = len(rows)
164+
165+
return jsonify(
166+
{
167+
"had_error": bool(error_message),
168+
"error_contains_missing_table": "users_missing_table" in error_message,
169+
"follow_up_count": follow_up_count,
170+
}
171+
)
172+
except Exception as e:
173+
return jsonify({"error": str(e)}), 500
174+
175+
127176
@app.route("/db/transaction", methods=["POST"])
128177
def db_transaction():
129178
"""Test transaction with rollback."""

drift/instrumentation/psycopg/e2e-tests/src/test_requests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,8 @@
8888
make_request("GET", "/test/inet-cidr-types")
8989
make_request("GET", "/test/range-types")
9090

91+
# Regression coverage for cursor fetch semantics and error replay fidelity
92+
make_request("GET", "/db/fetchmany-arraysize")
93+
make_request("GET", "/db/error-then-query")
94+
9195
print_request_summary()

drift/instrumentation/psycopg/instrumentation.py

Lines changed: 170 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TuskDriftMode,
2525
)
2626
from ..base import InstrumentationBase
27+
from ..sqlalchemy.context import sqlalchemy_execution_active_context, sqlalchemy_replay_mock_context
2728
from ..utils.psycopg_utils import deserialize_db_value, restore_row_integer_types
2829
from ..utils.serialization import serialize_value
2930
from .mocks import MockConnection, MockCopy
@@ -531,6 +532,15 @@ def _traced_execute(
531532
if sdk.mode == TuskDriftMode.DISABLED:
532533
return original_execute(query, params, **kwargs)
533534

535+
# SQLAlchemy replay source-of-truth path: consume SQLAlchemy-resolved
536+
# payload and skip driver-level mock matching/span creation.
537+
if sdk.mode == TuskDriftMode.REPLAY and sqlalchemy_execution_active_context.get():
538+
mock_result = sqlalchemy_replay_mock_context.get()
539+
if mock_result is not None:
540+
self._raise_replay_error_if_present(mock_result)
541+
self._mock_execute_with_data(cursor, mock_result)
542+
return cursor
543+
534544
query_str = self._query_to_string(query, cursor)
535545

536546
if sdk.mode == TuskDriftMode.REPLAY:
@@ -577,6 +587,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A
577587
f"Query: {query_str[:100]}..."
578588
)
579589

590+
self._raise_replay_error_if_present(mock_result)
580591
self._mock_execute_with_data(cursor, mock_result, is_async=is_async)
581592
span_info.span.end()
582593
return cursor
@@ -593,6 +604,21 @@ def _record_execute(
593604
kwargs: dict,
594605
) -> Any:
595606
"""Handle RECORD mode for execute - create span and execute query."""
607+
# Under SQLAlchemy instrumentation, skip creating/exporting a driver span
608+
# but keep cursor-state capture so SQLAlchemy span can include result data.
609+
if sqlalchemy_execution_active_context.get():
610+
error = None
611+
try:
612+
return original_execute(query, params, **kwargs)
613+
except Exception as e:
614+
error = e
615+
raise
616+
finally:
617+
try:
618+
self._finalize_query_span(trace.INVALID_SPAN, cursor, query_str, params, error)
619+
except Exception as e:
620+
logger.error(f"Error in SQLAlchemy-scoped psycopg record finalization: {e}")
621+
596622
# Reset cursor state from any previous execute() on this cursor.
597623
# Delete instance attribute overrides to expose original class methods.
598624
# This is safer than saving/restoring bound methods which can become stale.
@@ -663,6 +689,18 @@ def _traced_executemany(
663689
if sdk.mode == TuskDriftMode.DISABLED:
664690
return original_executemany(query, params_seq, **kwargs)
665691

692+
# SQLAlchemy replay source-of-truth path: consume SQLAlchemy-resolved
693+
# payload and skip driver-level mock matching/span creation.
694+
if sdk.mode == TuskDriftMode.REPLAY and sqlalchemy_execution_active_context.get():
695+
mock_result = sqlalchemy_replay_mock_context.get()
696+
if mock_result is not None:
697+
self._raise_replay_error_if_present(mock_result)
698+
if mock_result.get("executemany_returning"):
699+
self._mock_executemany_returning_with_data(cursor, mock_result)
700+
else:
701+
self._mock_execute_with_data(cursor, mock_result)
702+
return cursor
703+
666704
query_str = self._query_to_string(query, cursor)
667705
# Convert to list BEFORE executing to avoid iterator exhaustion
668706
params_list = list(params_seq)
@@ -713,6 +751,7 @@ def _replay_executemany(
713751
f"Query: {query_str[:100]}..."
714752
)
715753

754+
self._raise_replay_error_if_present(mock_result)
716755
# Check if this is executemany_returning format (multiple result sets)
717756
if mock_result.get("executemany_returning"):
718757
self._mock_executemany_returning_with_data(cursor, mock_result)
@@ -723,6 +762,17 @@ def _replay_executemany(
723762
span_info.span.end()
724763
return cursor
725764

765+
def _raise_replay_error_if_present(self, mock_result: dict[str, Any]) -> None:
766+
"""Raise recorded DB error in replay instead of emulating success."""
767+
if not isinstance(mock_result, dict):
768+
return
769+
error_message = mock_result.get("errorMessage")
770+
if error_message:
771+
raise RuntimeError(str(error_message))
772+
error_name = mock_result.get("errorName")
773+
if error_name:
774+
raise RuntimeError(str(error_name))
775+
726776
def _record_executemany(
727777
self,
728778
cursor: Any,
@@ -736,6 +786,36 @@ def _record_executemany(
736786
returning: bool = False,
737787
) -> Any:
738788
"""Handle RECORD mode for executemany - create span and execute query."""
789+
# Under SQLAlchemy instrumentation, skip driver span export while preserving
790+
# result capture needed for SQLAlchemy source-of-truth spans.
791+
if sqlalchemy_execution_active_context.get():
792+
error = None
793+
try:
794+
return original_executemany(query, params_list, **kwargs)
795+
except Exception as e:
796+
error = e
797+
raise
798+
finally:
799+
try:
800+
if returning and error is None:
801+
self._finalize_executemany_returning_span(
802+
trace.INVALID_SPAN,
803+
cursor,
804+
query_str,
805+
{"_batch": params_list, "_returning": True},
806+
error,
807+
)
808+
else:
809+
self._finalize_query_span(
810+
trace.INVALID_SPAN,
811+
cursor,
812+
query_str,
813+
{"_batch": params_list},
814+
error,
815+
)
816+
except Exception as e:
817+
logger.error(f"Error in SQLAlchemy-scoped psycopg executemany finalization: {e}")
818+
739819
span_info = self._create_query_span(sdk, "query", is_pre_app_start)
740820

741821
if not span_info:
@@ -792,6 +872,13 @@ async def _traced_async_execute(
792872
if sdk.mode == TuskDriftMode.DISABLED:
793873
return await original_execute(query, params, **kwargs)
794874

875+
if sdk.mode == TuskDriftMode.REPLAY and sqlalchemy_execution_active_context.get():
876+
mock_result = sqlalchemy_replay_mock_context.get()
877+
if mock_result is not None:
878+
self._raise_replay_error_if_present(mock_result)
879+
self._mock_execute_with_data(cursor, mock_result, is_async=True)
880+
return cursor
881+
795882
query_str = self._query_to_string(query, cursor)
796883

797884
if sdk.mode == TuskDriftMode.REPLAY:
@@ -812,6 +899,19 @@ async def _record_async_execute(
812899
kwargs: dict,
813900
) -> Any:
814901
"""Handle RECORD mode for async execute - create span and execute query."""
902+
if sqlalchemy_execution_active_context.get():
903+
error = None
904+
try:
905+
return await original_execute(query, params, **kwargs)
906+
except Exception as e:
907+
error = e
908+
raise
909+
finally:
910+
try:
911+
self._finalize_query_span(trace.INVALID_SPAN, cursor, query_str, params, error)
912+
except Exception as e:
913+
logger.error(f"Error in SQLAlchemy-scoped async psycopg finalization: {e}")
914+
815915
is_pre_app_start = not sdk.app_ready
816916

817917
# Reset cursor state from any previous execute() on this cursor
@@ -867,6 +967,16 @@ async def _traced_async_executemany(
867967
if sdk.mode == TuskDriftMode.DISABLED:
868968
return await original_executemany(query, params_seq, **kwargs)
869969

970+
if sdk.mode == TuskDriftMode.REPLAY and sqlalchemy_execution_active_context.get():
971+
mock_result = sqlalchemy_replay_mock_context.get()
972+
if mock_result is not None:
973+
self._raise_replay_error_if_present(mock_result)
974+
if mock_result.get("executemany_returning"):
975+
self._mock_executemany_returning_with_data(cursor, mock_result)
976+
else:
977+
self._mock_execute_with_data(cursor, mock_result, is_async=True)
978+
return cursor
979+
870980
query_str = self._query_to_string(query, cursor)
871981
params_list = list(params_seq)
872982
returning = kwargs.get("returning", False)
@@ -892,6 +1002,34 @@ async def _record_async_executemany(
8921002
returning: bool = False,
8931003
) -> Any:
8941004
"""Handle RECORD mode for async executemany - create span and execute query."""
1005+
if sqlalchemy_execution_active_context.get():
1006+
error = None
1007+
try:
1008+
return await original_executemany(query, params_list, **kwargs)
1009+
except Exception as e:
1010+
error = e
1011+
raise
1012+
finally:
1013+
try:
1014+
if returning and error is None:
1015+
self._finalize_executemany_returning_span(
1016+
trace.INVALID_SPAN,
1017+
cursor,
1018+
query_str,
1019+
{"_batch": params_list, "_returning": True},
1020+
error,
1021+
)
1022+
else:
1023+
self._finalize_query_span(
1024+
trace.INVALID_SPAN,
1025+
cursor,
1026+
query_str,
1027+
{"_batch": params_list},
1028+
error,
1029+
)
1030+
except Exception as e:
1031+
logger.error(f"Error in SQLAlchemy-scoped async psycopg executemany finalization: {e}")
1032+
8951033
is_pre_app_start = not sdk.app_ready
8961034
span_info = self._create_query_span(sdk, "query", is_pre_app_start)
8971035

@@ -1657,6 +1795,18 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any], is_asy
16571795
except AttributeError:
16581796
object.__setattr__(cursor, "rowcount", actual_data.get("rowcount", -1))
16591797

1798+
# Preserve insert metadata for ORM write paths.
1799+
lastrowid = actual_data.get("lastrowid")
1800+
if lastrowid is not None:
1801+
try:
1802+
cursor._mock_lastrowid = lastrowid
1803+
except Exception:
1804+
pass
1805+
try:
1806+
object.__setattr__(cursor, "lastrowid", lastrowid)
1807+
except Exception:
1808+
pass
1809+
16601810
description_data = actual_data.get("description")
16611811
self._set_cursor_description(cursor, description_data)
16621812

@@ -1822,9 +1972,10 @@ def fetchone():
18221972
return fetchone
18231973

18241974
def make_fetchmany(cn, RC):
1825-
def fetchmany(size=cursor.arraysize):
1975+
def fetchmany(size=None):
1976+
effective_size = cursor.arraysize if size is None else size
18261977
rows = []
1827-
for _ in range(size):
1978+
for _ in range(effective_size):
18281979
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
18291980
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
18301981
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
@@ -1877,9 +2028,10 @@ def fetchone():
18772028
return fetchone
18782029

18792030
def make_fetchmany_replay(cn, RC):
1880-
def fetchmany(size=cursor.arraysize):
2031+
def fetchmany(size=None):
2032+
effective_size = cursor.arraysize if size is None else size
18812033
rows = []
1882-
for _ in range(size):
2034+
for _ in range(effective_size):
18832035
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
18842036
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
18852037
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
@@ -2060,6 +2212,8 @@ def _finalize_query_span(
20602212
output_value = {
20612213
"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1,
20622214
}
2215+
if hasattr(cursor, "lastrowid") and cursor.lastrowid is not None:
2216+
output_value["lastrowid"] = serialize_value(cursor.lastrowid)
20632217

20642218
# Capture statusmessage for replay
20652219
if hasattr(cursor, "statusmessage") and cursor.statusmessage is not None:
@@ -2772,8 +2926,12 @@ def patched_fetchone():
27722926
return row
27732927
return None
27742928

2775-
def patched_fetchmany(size=cursor.arraysize):
2776-
result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue]
2929+
def patched_fetchmany(size=None):
2930+
effective_size = cursor.arraysize if size is None else size
2931+
result = cursor._tusk_rows[ # pyright: ignore[reportAttributeAccessIssue]
2932+
cursor._tusk_index : cursor._tusk_index
2933+
+ effective_size # pyright: ignore[reportAttributeAccessIssue]
2934+
]
27772935
cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue]
27782936
return result
27792937

@@ -2808,8 +2966,12 @@ def patched_fetchone():
28082966
return patched_fetchone
28092967

28102968
def make_patched_fetchmany_record():
2811-
def patched_fetchmany(size=cursor.arraysize):
2812-
result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue]
2969+
def patched_fetchmany(size=None):
2970+
effective_size = cursor.arraysize if size is None else size
2971+
result = cursor._tusk_rows[ # pyright: ignore[reportAttributeAccessIssue]
2972+
cursor._tusk_index : cursor._tusk_index
2973+
+ effective_size # pyright: ignore[reportAttributeAccessIssue]
2974+
]
28132975
cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue]
28142976
return result
28152977

0 commit comments

Comments
 (0)