From c8b7baddce3628c31cce63f6ac96afe48e735ca1 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 11:25:56 -0800 Subject: [PATCH 01/37] fix cursor-stream --- .../instrumentation/httpx/instrumentation.py | 1 + .../psycopg/e2e-tests/src/app.py | 91 ++++++++ .../psycopg/e2e-tests/src/test_requests.py | 2 + .../psycopg/instrumentation.py | 203 ++++++++++++++++++ .../requests/instrumentation.py | 1 + 5 files changed, 298 insertions(+) diff --git a/drift/instrumentation/httpx/instrumentation.py b/drift/instrumentation/httpx/instrumentation.py index c758ead..c13911a 100644 --- a/drift/instrumentation/httpx/instrumentation.py +++ b/drift/instrumentation/httpx/instrumentation.py @@ -806,6 +806,7 @@ def _try_get_mock_from_request_sync( input_value=input_value, kind=SpanKind.CLIENT, input_schema_merges=input_schema_merges, + is_pre_app_start=not sdk.app_ready, ) if not mock_response_output or not mock_response_output.found: diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index b12e33e..6ae9f49 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -143,6 +143,97 @@ def db_transaction(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/cursor-stream") +def test_cursor_stream(): + """Test cursor.stream() - generator-based result streaming. + + This tests whether the instrumentation captures streaming queries + that return results as a generator. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Stream results row-by-row instead of fetchall + results = [] + for row in cur.stream("SELECT id, name, email FROM users ORDER BY id LIMIT 5"): + results.append({"id": row[0], "name": row[1], "email": row[2]}) + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +# ============================================================================= +# BUG HUNTING TEST ENDPOINTS +# These endpoints expose confirmed bugs in the psycopg instrumentation. +# See BUG_TRACKING.md for detailed analysis. +# ============================================================================= + + +@app.route("/test/server-cursor") +def test_server_cursor(): + """Test ServerCursor (named cursor) - server-side cursor. + + This tests whether the instrumentation captures server-side cursors + which use DECLARE CURSOR on the database server. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Named cursor creates a server-side cursor + with conn.cursor(name="test_server_cursor") as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 5") + rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] if cur.description else ["id", "name", "email"] + results = [dict(zip(columns, row, strict=False)) for row in rows] + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/copy-to") +def test_copy_to(): + """Test cursor.copy() with COPY TO - bulk data export. + + This tests whether the instrumentation captures COPY operations. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Use COPY to export data + output = [] + with cur.copy("COPY (SELECT id, name, email FROM users ORDER BY id LIMIT 5) TO STDOUT") as copy: + for row in copy: + # Handle both bytes and memoryview + if isinstance(row, memoryview): + row = bytes(row) + output.append(row.decode('utf-8').strip()) + return jsonify({"count": len(output), "data": output}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/multiple-queries") +def test_multiple_queries(): + """Test multiple queries in same connection. + + This tests whether multiple queries in the same connection + are all captured and replayed correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Query 1 + cur.execute("SELECT COUNT(*) FROM users") + count = cur.fetchone()[0] + + # Query 2 + cur.execute("SELECT MAX(id) FROM users") + max_id = cur.fetchone()[0] + + # Query 3 + cur.execute("SELECT MIN(id) FROM users") + min_id = cur.fetchone()[0] + + return jsonify({"count": count, "max_id": max_id, "min_id": min_id}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 629647c..95d4a18 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -64,4 +64,6 @@ def make_request(method, endpoint, **kwargs): if user_id: make_request("DELETE", f"/db/delete/{user_id}") + make_request("GET", "/test/cursor-stream") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 08bb4d5..5a6256e 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -136,9 +136,17 @@ def noop_executemany(q, ps, **kw): return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) + def mock_stream(query, params=None, **kwargs): + # For mock cursor, original_stream is just a no-op generator + def noop_stream(q, p, **kw): + return iter([]) + + return instrumentation._traced_stream(cursor, noop_stream, sdk, query, params, **kwargs) + # Monkey-patch mock functions onto cursor cursor.execute = mock_execute # type: ignore[method-assign] cursor.executemany = mock_executemany # type: ignore[method-assign] + cursor.stream = mock_stream # type: ignore[method-assign] logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") return cursor @@ -208,6 +216,10 @@ def fetchmany(self, size=None): def fetchall(self): return [] + def stream(self, query, params=None, **kwargs): + """Will be replaced by instrumentation.""" + return iter([]) + def close(self): pass @@ -315,6 +327,9 @@ def execute(self, query, params=None, **kwargs): def executemany(self, query, params_seq, **kwargs): return instrumentation._traced_executemany(self, super().executemany, sdk, query, params_seq, **kwargs) + def stream(self, query, params=None, **kwargs): + return instrumentation._traced_stream(self, super().stream, sdk, query, params, **kwargs) + return InstrumentedCursor def _traced_execute( @@ -558,6 +573,194 @@ def _record_executemany( ) span_info.span.end() + def _traced_stream( + self, cursor: Any, original_stream: Any, sdk: TuskDrift, query: str, params=None, **kwargs + ) -> Any: + """Traced cursor.stream method.""" + if sdk.mode == TuskDriftMode.DISABLED: + return original_stream(query, params, **kwargs) + + query_str = self._query_to_string(query, cursor) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_stream(cursor, sdk, query_str, params), + no_op_request_handler=lambda: iter([]), # Empty iterator for background requests + is_server_request=False, + ) + + # RECORD mode + return handle_record_mode( + original_function_call=lambda: original_stream(query, params, **kwargs), + record_mode_handler=lambda is_pre_app_start: self._record_stream( + cursor, original_stream, sdk, query, query_str, params, is_pre_app_start, kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + + def _record_stream( + self, + cursor: Any, + original_stream: Any, + sdk: TuskDrift, + query: str, + query_str: str, + params: Any, + is_pre_app_start: bool, + kwargs: dict, + ): + """Handle RECORD mode for stream - wrap generator with tracing.""" + span_info = SpanUtils.create_span( + CreateSpanOptions( + name="psycopg.query", + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: "psycopg.query", + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "query", + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + if not span_info: + yield from original_stream(query, params, **kwargs) + return + + rows_collected = [] + error = None + + try: + with SpanUtils.with_span(span_info): + for row in original_stream(query, params, **kwargs): + rows_collected.append(row) + yield row + except Exception as e: + error = e + raise + finally: + self._finalize_stream_span(span_info.span, cursor, query_str, params, rows_collected, error) + span_info.span.end() + + def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any): + """Handle REPLAY mode for stream - return mock generator.""" + span_info = SpanUtils.create_span( + CreateSpanOptions( + name="psycopg.query", + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: "psycopg.query", + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "query", + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, + }, + is_pre_app_start=not sdk.app_ready, + ) + ) + + if not span_info: + raise RuntimeError("Error creating span in replay mode") + + with SpanUtils.with_span(span_info): + mock_result = self._try_get_mock( + sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id + ) + + if mock_result is None: + is_pre_app_start = not sdk.app_ready + raise RuntimeError( + f"[Tusk REPLAY] No mock found for psycopg stream query. " + f"This {'pre-app-start ' if is_pre_app_start else ''}query was not recorded. " + f"Query: {query_str[:100]}..." + ) + + # Deserialize and yield rows from mock + rows = mock_result.get("rows", []) + for row in rows: + deserialized = deserialize_db_value(row) + yield tuple(deserialized) if isinstance(deserialized, list) else deserialized + + span_info.span.end() + + def _finalize_stream_span( + self, + span: trace.Span, + cursor: Any, + query: str, + params: Any, + rows: list, + error: Exception | None, + ) -> None: + """Finalize span for stream operation with collected rows.""" + try: + import datetime + + def serialize_value(val): + """Convert non-JSON-serializable values to JSON-compatible types.""" + if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): + return val.isoformat() + elif isinstance(val, bytes): + return val.decode("utf-8", errors="replace") + elif isinstance(val, (list, tuple)): + return [serialize_value(v) for v in val] + elif isinstance(val, dict): + return {k: serialize_value(v) for k, v in val.items()} + return val + + # Build input value + input_value = { + "query": query.strip(), + } + if params is not None: + input_value["parameters"] = serialize_value(params) + + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Use pre-collected rows (unlike _finalize_query_span which calls fetchall) + serialized_rows = [[serialize_value(col) for col in row] for row in rows] + + output_value = { + "rowcount": len(rows), + } + + if serialized_rows: + output_value["rows"] = serialized_rows + + # Generate schemas and hashes + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + # Set span attributes + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Stream span finalized successfully") + + except Exception as e: + logger.error(f"Error finalizing stream span: {e}") + def _query_to_string(self, query: Any, cursor: Any) -> str: """Convert query to string.""" try: diff --git a/drift/instrumentation/requests/instrumentation.py b/drift/instrumentation/requests/instrumentation.py index 88874eb..d5481d7 100644 --- a/drift/instrumentation/requests/instrumentation.py +++ b/drift/instrumentation/requests/instrumentation.py @@ -511,6 +511,7 @@ def _try_get_mock( input_value=input_value, kind=SpanKind.CLIENT, input_schema_merges=input_schema_merges, + is_pre_app_start=not sdk.app_ready, ) if not mock_response_output or not mock_response_output.found: From 1bdc5cae1ae228439ddf8cecfc217d77aa538d56 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 11:41:52 -0800 Subject: [PATCH 02/37] refactor psycopg instrumentations to use mock utils --- .../psycopg/instrumentation.py | 57 ++-------- .../psycopg2/instrumentation.py | 106 ++---------------- 2 files changed, 16 insertions(+), 147 deletions(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 5a6256e..fffd805 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -2,7 +2,6 @@ import json import logging -import time from types import ModuleType from typing import Any @@ -11,22 +10,15 @@ from opentelemetry.trace import Status from opentelemetry.trace import StatusCode as OTelStatusCode -from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper from ...core.mode_utils import handle_record_mode, handle_replay_mode from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( - CleanSpanData, - Duration, PackageType, SpanKind, - SpanStatus, - StatusCode, - Timestamp, TuskDriftMode, - replay_trace_id_context, ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value @@ -386,7 +378,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, params, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -503,7 +495,7 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -668,7 +660,7 @@ def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: An with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, params, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -780,7 +772,6 @@ def _try_get_mock( params: Any, trace_id: str, span_id: str, - parent_span_id: str | None, ) -> dict[str, Any] | None: """Try to get a mocked response from CLI. @@ -795,56 +786,24 @@ def _try_get_mock( if params is not None: input_value["parameters"] = params - # Generate schema and hashes for CLI matching - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - - # Create mock span for matching - timestamp_ms = time.time() * 1000 - timestamp_seconds = int(timestamp_ms // 1000) - timestamp_nanos = int((timestamp_ms % 1000) * 1_000_000) + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync - # Create mock span for matching - # NOTE: Schemas must be None to avoid betterproto map serialization issues - # The CLI only needs the hashes for matching anyway, not the full schemas - mock_span = CleanSpanData( + mock_response_output = find_mock_response_sync( + sdk=sdk, trace_id=trace_id, span_id=span_id, - parent_span_id=parent_span_id or "", name="psycopg.query", package_name="psycopg", package_type=PackageType.PG, instrumentation_name="PsycopgInstrumentation", submodule_name="query", input_value=input_value, - output_value=None, - input_schema=None, # type: ignore - Must be None to avoid betterproto serialization issues - output_schema=None, # type: ignore - Must be None to avoid betterproto serialization issues - input_schema_hash=input_result.decoded_schema_hash, - output_schema_hash="", - input_value_hash=input_result.decoded_value_hash, - output_value_hash="", - stack_trace="", # Empty in REPLAY mode kind=SpanKind.CLIENT, - status=SpanStatus(code=StatusCode.OK, message=""), - timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos), - duration=Duration(seconds=0, nanos=0), - is_root_span=False, is_pre_app_start=not sdk.app_ready, ) - # Request mock from CLI - replay_trace_id = replay_trace_id_context.get() - - mock_request = MockRequestInput( - test_id=replay_trace_id or "", - outbound_span=mock_span, - ) - - logger.debug(f"Requesting mock from CLI for query: {query[:50]}...") - mock_response_output = sdk.request_mock_sync(mock_request) - logger.debug(f"CLI returned: found={mock_response_output.found}") - - if not mock_response_output.found: + if not mock_response_output or not mock_response_output.found: logger.debug(f"No mock found for psycopg query: {query[:100]}") return None diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index 12d2115..4030304 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -4,7 +4,6 @@ import json import logging -import time from types import ModuleType from typing import TYPE_CHECKING, Any @@ -19,22 +18,15 @@ from opentelemetry.trace import Status from opentelemetry.trace import StatusCode as OTelStatusCode -from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper from ...core.mode_utils import handle_record_mode, handle_replay_mode from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( - CleanSpanData, - Duration, PackageType, SpanKind, - SpanStatus, - StatusCode, - Timestamp, TuskDriftMode, - replay_trace_id_context, ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value @@ -455,7 +447,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, params, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -581,7 +573,7 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -661,7 +653,6 @@ def _try_get_mock( params: Any, trace_id: str, span_id: str, - parent_span_id: str | None, ) -> dict[str, Any] | None: """Try to get a mocked response from CLI. @@ -677,108 +668,27 @@ def _try_get_mock( if params is not None: input_value["parameters"] = params - # Generate schema and hashes for CLI matching - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - - # Create mock span for matching - timestamp_ms = time.time() * 1000 - timestamp_seconds = int(timestamp_ms // 1000) - timestamp_nanos = int((timestamp_ms % 1000) * 1_000_000) + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync - # Create mock span for matching - # NOTE: Schemas must be None to avoid betterproto map serialization issues - # The CLI only needs the hashes for matching anyway, not the full schemas - mock_span = CleanSpanData( + mock_response_output = find_mock_response_sync( + sdk=sdk, trace_id=trace_id, span_id=span_id, - parent_span_id=parent_span_id or "", name="psycopg2.query", package_name="psycopg2", package_type=PackageType.PG, instrumentation_name="Psycopg2Instrumentation", submodule_name="query", input_value=input_value, - output_value=None, - input_schema=None, # type: ignore[arg-type] - output_schema=None, # type: ignore[arg-type] - input_schema_hash=input_result.decoded_schema_hash, - output_schema_hash="", - input_value_hash=input_result.decoded_value_hash, - output_value_hash="", kind=SpanKind.CLIENT, - status=SpanStatus(code=StatusCode.OK, message=""), - timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos), - duration=Duration(seconds=0, nanos=0), - is_root_span=False, is_pre_app_start=not sdk.app_ready, ) - # Request mock from CLI - replay_trace_id = replay_trace_id_context.get() - - mock_request = MockRequestInput( - test_id=replay_trace_id or "", - outbound_span=mock_span, - ) - - logger.info("[MOCK_REQUEST] Requesting mock from CLI:") - logger.info(f" replay_trace_id={replay_trace_id}") - logger.info(f" trace_id={trace_id}") - logger.info(f" span_id={span_id}") - logger.info(f" parent_span_id={parent_span_id or 'None'}") - logger.info(f" package_name={mock_span.package_name}") - logger.info(f" package_type={mock_span.package_type}") - logger.info(f" instrumentation_name={mock_span.instrumentation_name}") - logger.info(f" input_value_hash={mock_span.input_value_hash}") - logger.info(f" input_schema_hash={mock_span.input_schema_hash}") - logger.info(f" is_pre_app_start={mock_span.is_pre_app_start}") - logger.info(f" query={query_str[:100]}") - - # Check if communicator is connected before requesting mock - if not sdk.communicator or not sdk.communicator.is_connected: - logger.warning("[MOCK_REQUEST] CLI communicator is not connected yet!") - logger.warning(f"[MOCK_REQUEST] is_pre_app_start={mock_span.is_pre_app_start}") - - if mock_span.is_pre_app_start: - # For pre-app-start queries, return None (will trigger empty result fallback) - logger.warning( - "[MOCK_REQUEST] Pre-app-start query and CLI not ready - returning None to use empty result" - ) - return None - else: - # For in-request queries, this is an error but we'll return None to be safe - logger.error("[MOCK_REQUEST] In-request query but CLI not connected - returning None") - return None - - logger.debug("[MOCK_REQUEST] Calling sdk.request_mock_sync()...") - mock_response_output = sdk.request_mock_sync(mock_request) - logger.info( - f"[MOCK_RESPONSE] CLI returned: found={mock_response_output.found}, response={mock_response_output.response is not None}" - ) - - if mock_response_output.response: - logger.debug( - f"[MOCK_RESPONSE] Response keys: {mock_response_output.response.keys() if isinstance(mock_response_output.response, dict) else 'not a dict'}" - ) - - if not mock_response_output.found: - logger.error( - f"No mock found for psycopg2 query:\n" - f" replay_trace_id={replay_trace_id}\n" - f" span_trace_id={trace_id}\n" - f" span_id={span_id}\n" - f" package_name={mock_span.package_name}\n" - f" package_type={mock_span.package_type}\n" - f" name={mock_span.name}\n" - f" input_value_hash={mock_span.input_value_hash}\n" - f" input_schema_hash={mock_span.input_schema_hash}\n" - f" query={query_str[:100]}" - ) + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for psycopg2 query: {query_str[:100]}") return None - logger.info(f"[MOCK_FOUND] Found mock for psycopg2 query: {query_str[:100]}") - logger.info(f"[MOCK_FOUND] Mock response type: {type(mock_response_output.response)}") - logger.info(f"[MOCK_FOUND] Mock response: {mock_response_output.response}") return mock_response_output.response except Exception as e: From 132b5566209796b2c6256f4e2f118029edbd78c4 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 12:07:20 -0800 Subject: [PATCH 03/37] instrument ServerCursor --- .../psycopg/e2e-tests/src/app.py | 15 +++-- .../psycopg/e2e-tests/src/test_requests.py | 2 + .../psycopg/instrumentation.py | 55 ++++++++++++++++++- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 6ae9f49..e4bd654 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -160,14 +160,6 @@ def test_cursor_stream(): except Exception as e: return jsonify({"error": str(e)}), 500 - -# ============================================================================= -# BUG HUNTING TEST ENDPOINTS -# These endpoints expose confirmed bugs in the psycopg instrumentation. -# See BUG_TRACKING.md for detailed analysis. -# ============================================================================= - - @app.route("/test/server-cursor") def test_server_cursor(): """Test ServerCursor (named cursor) - server-side cursor. @@ -188,6 +180,13 @@ def test_server_cursor(): return jsonify({"error": str(e)}), 500 + +# ============================================================================= +# BUG HUNTING TEST ENDPOINTS +# These endpoints expose confirmed bugs in the psycopg instrumentation. +# See BUG_TRACKING.md for detailed analysis. +# ============================================================================= + @app.route("/test/copy-to") def test_copy_to(): """Test cursor.copy() with COPY TO - bulk data export. diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 95d4a18..e964203 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -66,4 +66,6 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/cursor-stream") + make_request("GET", "/test/server-cursor") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index fffd805..d143279 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -105,9 +105,15 @@ def parameter_status(self, param): logger.debug("[MOCK_CONNECTION] Created mock connection for REPLAY mode (psycopg3)") - def cursor(self, name=None, cursor_factory=None): - """Create a cursor using the instrumented cursor factory.""" + def cursor(self, name=None, *, cursor_factory=None, **kwargs): + """Create a cursor using the instrumented cursor factory. + + Accepts the same parameters as psycopg's Connection.cursor(), including + server cursor parameters like scrollable and withhold. + """ # For mock connections, we create a MockCursor directly + # The name parameter is accepted but not used since mock cursors + # behave the same for both regular and server cursors cursor = MockCursor(self) # Wrap execute/executemany for mock cursor @@ -262,11 +268,17 @@ def patched_connect(*args, **kwargs): user_cursor_factory = kwargs.pop("cursor_factory", None) cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory) + # Create server cursor factory for named cursors (conn.cursor(name="...")) + server_cursor_factory = instrumentation._create_server_cursor_factory(sdk) + # In REPLAY mode, try to connect but fall back to mock connection if DB is unavailable if sdk.mode == TuskDriftMode.REPLAY: try: kwargs["cursor_factory"] = cursor_factory connection = original_connect(*args, **kwargs) + # Set server cursor factory on the connection for named cursors + if server_cursor_factory: + connection.server_cursor_factory = server_cursor_factory logger.info("[PATCHED_CONNECT] REPLAY mode: Successfully connected to database (psycopg3)") return connection except Exception as e: @@ -279,6 +291,9 @@ def patched_connect(*args, **kwargs): # In RECORD mode, always require real connection kwargs["cursor_factory"] = cursor_factory connection = original_connect(*args, **kwargs) + # Set server cursor factory on the connection for named cursors + if server_cursor_factory: + connection.server_cursor_factory = server_cursor_factory logger.debug("[PATCHED_CONNECT] RECORD mode: Connected to database (psycopg3)") return connection @@ -324,6 +339,42 @@ def stream(self, query, params=None, **kwargs): return InstrumentedCursor + def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): + """Create a server cursor factory that wraps ServerCursor with instrumentation. + + Returns a cursor CLASS (psycopg3 expects a class, not a function). + ServerCursor is used when conn.cursor(name="...") is called. + """ + instrumentation = self + logger.debug(f"[CURSOR_FACTORY] Creating server cursor factory, sdk.mode={sdk.mode}") + + try: + from psycopg import ServerCursor as BaseServerCursor + except ImportError: + logger.warning("[CURSOR_FACTORY] Could not import psycopg.ServerCursor") + return None + + base = base_factory or BaseServerCursor + + class InstrumentedServerCursor(base): # type: ignore + _tusk_description = None # Store mock description for replay mode + + @property + def description(self): + # In replay mode, return mock description if set; otherwise use base + if self._tusk_description is not None: + return self._tusk_description + return super().description + + def execute(self, query, params=None, **kwargs): + # Note: ServerCursor.execute() doesn't support 'prepare' parameter + return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) + + # Note: ServerCursor doesn't support executemany() + # Note: ServerCursor has stream-like iteration via fetchmany/itersize + + return InstrumentedServerCursor + def _traced_execute( self, cursor: Any, original_execute: Any, sdk: TuskDrift, query: str, params=None, **kwargs ) -> Any: From 6e568a12227d3144fb304a2b3894a713b6a66c38 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 12:29:29 -0800 Subject: [PATCH 04/37] add psycopg cursor.copy() instrumentation for COPY operations --- .../psycopg/e2e-tests/src/app.py | 14 +- .../psycopg/e2e-tests/src/test_requests.py | 2 + .../psycopg/instrumentation.py | 393 +++++++++++++++++- 3 files changed, 400 insertions(+), 9 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index e4bd654..26a2a4e 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -179,14 +179,6 @@ def test_server_cursor(): except Exception as e: return jsonify({"error": str(e)}), 500 - - -# ============================================================================= -# BUG HUNTING TEST ENDPOINTS -# These endpoints expose confirmed bugs in the psycopg instrumentation. -# See BUG_TRACKING.md for detailed analysis. -# ============================================================================= - @app.route("/test/copy-to") def test_copy_to(): """Test cursor.copy() with COPY TO - bulk data export. @@ -207,6 +199,12 @@ def test_copy_to(): except Exception as e: return jsonify({"error": str(e)}), 500 +# ============================================================================= +# BUG HUNTING TEST ENDPOINTS +# These endpoints expose confirmed bugs in the psycopg instrumentation. +# See BUG_TRACKING.md for detailed analysis. +# ============================================================================= + @app.route("/test/multiple-queries") def test_multiple_queries(): diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index e964203..cafad71 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -68,4 +68,6 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/server-cursor") + make_request("GET", "/test/copy-to") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index d143279..5c37a18 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -2,8 +2,9 @@ import json import logging +from contextlib import contextmanager from types import ModuleType -from typing import Any +from typing import Any, Iterator from opentelemetry import trace from opentelemetry.trace import SpanKind as OTelSpanKind @@ -141,10 +142,19 @@ def noop_stream(q, p, **kw): return instrumentation._traced_stream(cursor, noop_stream, sdk, query, params, **kwargs) + def mock_copy(query, params=None, **kwargs): + # For mock cursor, original_copy is a no-op context manager + @contextmanager + def noop_copy(q, p=None, **kw): + yield MockCopy([]) + + return instrumentation._traced_copy(cursor, noop_copy, sdk, query, params, **kwargs) + # Monkey-patch mock functions onto cursor cursor.execute = mock_execute # type: ignore[method-assign] cursor.executemany = mock_executemany # type: ignore[method-assign] cursor.stream = mock_stream # type: ignore[method-assign] + cursor.copy = mock_copy # type: ignore[method-assign] logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") return cursor @@ -229,6 +239,141 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False +class MockCopy: + """Mock Copy object for REPLAY mode. + + Provides a minimal interface compatible with psycopg's Copy object + for COPY TO operations (iteration) and COPY FROM operations (write). + """ + + def __init__(self, data: list): + """Initialize MockCopy with recorded data. + + Args: + data: For COPY TO - list of data chunks (as strings from JSON, will be encoded to bytes) + """ + self._data = data + self._index = 0 + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data chunks.""" + for item in self._data: + # Data was stored as string in JSON, convert back to bytes + if isinstance(item, str): + yield item.encode("utf-8") + elif isinstance(item, bytes): + yield item + else: + yield str(item).encode("utf-8") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read next data chunk for COPY TO.""" + if self._index < len(self._data): + item = self._data[self._index] + self._index += 1 + if isinstance(item, str): + return item.encode("utf-8") + elif isinstance(item, bytes): + return item + return str(item).encode("utf-8") + return b"" + + def rows(self) -> Iterator[tuple]: + """Iterate over rows for COPY TO (parsed format).""" + for item in self._data: + yield tuple(item) if isinstance(item, list) else item + + def write(self, buffer) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def write_row(self, row) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def set_types(self, types) -> None: + """No-op for replay mode.""" + pass + + +class _TracedCopyWrapper: + """Wrapper around psycopg's Copy object to capture data in RECORD mode. + + Intercepts all data operations to record them for replay. + """ + + def __init__(self, copy: Any, data_collected: list): + """Initialize wrapper. + + Args: + copy: The real psycopg Copy object + data_collected: List to append captured data chunks to + """ + self._copy = copy + self._data_collected = data_collected + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data, capturing each chunk.""" + for data in self._copy: + # Handle both bytes and memoryview + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + yield data + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read raw data from COPY TO, capturing it.""" + data = self._copy.read() + if data: + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + return data + + def read_row(self): + """Read a parsed row from COPY TO.""" + row = self._copy.read_row() + if row is not None: + self._data_collected.append(row) + return row + + def rows(self): + """Iterate over parsed rows from COPY TO.""" + for row in self._copy.rows(): + self._data_collected.append(row) + yield row + + def write(self, buffer): + """Write raw data for COPY FROM.""" + self._data_collected.append(buffer) + return self._copy.write(buffer) + + def write_row(self, row): + """Write a row for COPY FROM.""" + self._data_collected.append(row) + return self._copy.write_row(row) + + def set_types(self, types): + """Proxy set_types to the underlying Copy object.""" + return self._copy.set_types(types) + + def __getattr__(self, name): + """Proxy any other attributes to the underlying copy object.""" + return getattr(self._copy, name) + + class PsycopgInstrumentation(InstrumentationBase): """Instrumentation for psycopg (psycopg3) PostgreSQL client library. @@ -337,6 +482,9 @@ def executemany(self, query, params_seq, **kwargs): def stream(self, query, params=None, **kwargs): return instrumentation._traced_stream(self, super().stream, sdk, query, params, **kwargs) + def copy(self, query, params=None, **kwargs): + return instrumentation._traced_copy(self, super().copy, sdk, query, params, **kwargs) + return InstrumentedCursor def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): @@ -804,6 +952,249 @@ def serialize_value(val): except Exception as e: logger.error(f"Error finalizing stream span: {e}") + def _traced_copy( + self, cursor: Any, original_copy: Any, sdk: TuskDrift, query: str, params=None, **kwargs + ) -> Any: + """Traced cursor.copy method - returns a context manager.""" + if sdk.mode == TuskDriftMode.DISABLED: + return original_copy(query, params, **kwargs) + + query_str = self._query_to_string(query, cursor) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_copy(cursor, sdk, query_str), + no_op_request_handler=lambda: self._noop_copy(), + is_server_request=False, + ) + + # RECORD mode - return a context manager that wraps the copy operation + return self._record_copy(cursor, original_copy, sdk, query, query_str, params, kwargs) + + @contextmanager + def _noop_copy(self) -> Iterator[MockCopy]: + """Handle background requests in REPLAY mode - return empty mock.""" + yield MockCopy(data=[]) + + @contextmanager + def _record_copy( + self, + cursor: Any, + original_copy: Any, + sdk: TuskDrift, + query: str, + query_str: str, + params: Any, + kwargs: dict, + ) -> Iterator[_TracedCopyWrapper]: + """Handle RECORD mode for copy - wrap Copy object with tracing.""" + is_pre_app_start = not sdk.app_ready + + span_info = SpanUtils.create_span( + CreateSpanOptions( + name="psycopg.copy", + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: "psycopg.copy", + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "copy", + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + if not span_info: + # Fallback to original if span creation fails + with original_copy(query, params, **kwargs) as copy: + yield copy + return + + error = None + data_collected: list = [] + + try: + with SpanUtils.with_span(span_info): + with original_copy(query, params, **kwargs) as copy: + # Wrap the Copy object to capture data + wrapped_copy = _TracedCopyWrapper(copy, data_collected) + yield wrapped_copy + except Exception as e: + error = e + raise + finally: + self._finalize_copy_span( + span_info.span, + query_str, + data_collected, + error, + ) + span_info.span.end() + + @contextmanager + def _replay_copy(self, cursor: Any, sdk: TuskDrift, query_str: str) -> Iterator[MockCopy]: + """Handle REPLAY mode for copy - return mock Copy object.""" + span_info = SpanUtils.create_span( + CreateSpanOptions( + name="psycopg.copy", + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: "psycopg.copy", + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "copy", + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, + }, + is_pre_app_start=not sdk.app_ready, + ) + ) + + if not span_info: + raise RuntimeError("Error creating span in replay mode") + + with SpanUtils.with_span(span_info): + mock_result = self._try_get_copy_mock( + sdk, query_str, span_info.trace_id, span_info.span_id + ) + + if mock_result is None: + is_pre_app_start = not sdk.app_ready + raise RuntimeError( + f"[Tusk REPLAY] No mock found for psycopg copy operation. " + f"This {'pre-app-start ' if is_pre_app_start else ''}copy was not recorded. " + f"Query: {query_str[:100]}..." + ) + + # Yield a mock copy object with recorded data + mock_copy = MockCopy(mock_result.get("data", [])) + yield mock_copy + + span_info.span.end() + + def _try_get_copy_mock( + self, + sdk: TuskDrift, + query: str, + trace_id: str, + span_id: str, + ) -> dict[str, Any] | None: + """Try to get a mocked response for copy operation from CLI.""" + try: + # Determine operation type from query + query_upper = query.upper() + is_copy_to = "TO" in query_upper and "STDOUT" in query_upper + is_copy_from = "FROM" in query_upper and "STDIN" in query_upper + + input_value = { + "query": query.strip(), + "operation": "COPY_TO" if is_copy_to else "COPY_FROM" if is_copy_from else "COPY", + } + + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync + + mock_response_output = find_mock_response_sync( + sdk=sdk, + trace_id=trace_id, + span_id=span_id, + name="psycopg.copy", + package_name="psycopg", + package_type=PackageType.PG, + instrumentation_name="PsycopgInstrumentation", + submodule_name="copy", + input_value=input_value, + kind=SpanKind.CLIENT, + is_pre_app_start=not sdk.app_ready, + ) + + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for psycopg copy: {query[:100]}") + return None + + return mock_response_output.response + + except Exception as e: + logger.error(f"Error getting mock for psycopg copy: {e}") + return None + + def _finalize_copy_span( + self, + span: trace.Span, + query: str, + data_collected: list, + error: Exception | None, + ) -> None: + """Finalize span for copy operation.""" + try: + import datetime + + def serialize_value(val): + """Convert non-JSON-serializable values to JSON-compatible types.""" + if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): + return val.isoformat() + elif isinstance(val, bytes): + return val.decode("utf-8", errors="replace") + elif isinstance(val, memoryview): + return bytes(val).decode("utf-8", errors="replace") + elif isinstance(val, (list, tuple)): + return [serialize_value(v) for v in val] + elif isinstance(val, dict): + return {k: serialize_value(v) for k, v in val.items()} + return val + + # Determine operation type from query + query_upper = query.upper() + is_copy_to = "TO" in query_upper and "STDOUT" in query_upper + is_copy_from = "FROM" in query_upper and "STDIN" in query_upper + + # Build input value + input_value = { + "query": query.strip(), + "operation": "COPY_TO" if is_copy_to else "COPY_FROM" if is_copy_from else "COPY", + } + + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Serialize the captured data + serialized_data = [serialize_value(d) for d in data_collected] + output_value = { + "data": serialized_data, + "chunk_count": len(data_collected), + } + + # Generate schemas and hashes + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + # Set span attributes + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Copy span finalized successfully") + + except Exception as e: + logger.error(f"Error finalizing copy span: {e}") + def _query_to_string(self, query: Any, cursor: Any) -> str: """Convert query to string.""" try: From 8feb8a721e66fc80337c7928d581d66d3d36c66f Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 13:23:08 -0800 Subject: [PATCH 05/37] fix multiple queries on same cursor --- .../psycopg/e2e-tests/src/app.py | 7 ------- .../psycopg/e2e-tests/src/test_requests.py | 2 ++ .../instrumentation/psycopg/instrumentation.py | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 26a2a4e..fa19cd2 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -199,13 +199,6 @@ def test_copy_to(): except Exception as e: return jsonify({"error": str(e)}), 500 -# ============================================================================= -# BUG HUNTING TEST ENDPOINTS -# These endpoints expose confirmed bugs in the psycopg instrumentation. -# See BUG_TRACKING.md for detailed analysis. -# ============================================================================= - - @app.route("/test/multiple-queries") def test_multiple_queries(): """Test multiple queries in same connection. diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index cafad71..7816a88 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -70,4 +70,6 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/copy-to") + make_request("GET", "/test/multiple-queries") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 5c37a18..cc61386 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -604,6 +604,17 @@ def _record_execute( kwargs: dict, ) -> Any: """Handle RECORD mode for execute - create span and execute query.""" + # Reset cursor state from any previous execute() on this cursor. + # This ensures fetch methods work correctly for multiple queries on the same cursor. + # We must restore original fetch methods so _finalize_query_span can call the real + # psycopg fetchall() method, not our patched version from a previous query. + if hasattr(cursor, '_tusk_original_fetchone'): + cursor.fetchone = cursor._tusk_original_fetchone + cursor.fetchmany = cursor._tusk_original_fetchmany + cursor.fetchall = cursor._tusk_original_fetchall + cursor._tusk_rows = None + cursor._tusk_index = 0 + span_info = SpanUtils.create_span( CreateSpanOptions( name="psycopg.query", @@ -1402,6 +1413,13 @@ def patched_fetchall(): cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] return result + # Save original fetch methods before patching (only if not already saved) + # These will be restored at the start of the next execute() call + if not hasattr(cursor, '_tusk_original_fetchone'): + cursor._tusk_original_fetchone = cursor.fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_original_fetchmany = cursor.fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_original_fetchall = cursor.fetchall # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue] cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] From 87237552d8e6005331f481c97532b6d63de00d1d Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 22:55:11 -0800 Subject: [PATCH 06/37] fix psycopg pipeline mode: defer result capture until sync() --- .../psycopg/e2e-tests/src/app.py | 99 +++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 11 ++ .../psycopg/instrumentation.py | 167 +++++++++++++++++- 3 files changed, 269 insertions(+), 8 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index fa19cd2..1c0eda9 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -224,6 +224,105 @@ def test_multiple_queries(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/pipeline-mode") +def test_pipeline_mode(): + """Test pipeline mode - batched operations. + + Pipeline mode allows sending multiple queries without waiting for results. + This tests whether the instrumentation handles pipeline mode correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Enter pipeline mode + with conn.pipeline() as p: + cur1 = conn.execute("SELECT id, name FROM users ORDER BY id LIMIT 3") + cur2 = conn.execute("SELECT COUNT(*) FROM users") + # Sync the pipeline to get results + p.sync() + + rows1 = cur1.fetchall() + count = cur2.fetchone()[0] + + return jsonify({ + "rows": [{"id": r[0], "name": r[1]} for r in rows1], + "count": count + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +# ========================================== +# Bug Hunt Test Endpoints +# ========================================== + +@app.route("/test/dict-row-factory") +def test_dict_row_factory(): + """Test dict_row row factory. + + Tests whether the instrumentation correctly handles dict row factories + which return dictionaries instead of tuples. + """ + try: + from psycopg.rows import dict_row + + with psycopg.connect(get_conn_string(), row_factory=dict_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": rows # Already dictionaries + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/namedtuple-row-factory") +def test_namedtuple_row_factory(): + """Test namedtuple_row row factory. + + Tests whether the instrumentation correctly handles namedtuple row factories. + """ + try: + from psycopg.rows import namedtuple_row + + with psycopg.connect(get_conn_string(), row_factory=namedtuple_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + # Convert named tuples to dicts for JSON serialization + return jsonify({ + "count": len(rows), + "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/cursor-iteration") +def test_cursor_iteration(): + """Test direct cursor iteration (for row in cursor). + + Tests whether iterating over cursor directly works correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 5") + + # Iterate directly over cursor + results = [] + for row in cur: + results.append({"id": row[0], "name": row[1]}) + + return jsonify({ + "count": len(results), + "data": results + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 7816a88..3af77ac 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -72,4 +72,15 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/multiple-queries") + make_request("GET", "/test/pipeline-mode") + + # BUG 2: Dict row factory - rows returned as column names + make_request("GET", "/test/dict-row-factory") + + # BUG 3: Namedtuple row factory - rows returned as plain tuples + make_request("GET", "/test/namedtuple-row-factory") + + # BUG 4: Cursor iteration - "no result available" in replay mode + make_request("GET", "/test/cursor-iteration") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index cc61386..b0c9ffb 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -2,6 +2,7 @@ import json import logging +import weakref from contextlib import contextmanager from types import ModuleType from typing import Any, Iterator @@ -184,6 +185,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.commit() return False + def pipeline(self): + """Return a mock pipeline context manager for REPLAY mode.""" + return MockPipeline(self) + class MockCursor: """Mock cursor for when we can't create a real cursor from base class. @@ -302,6 +307,27 @@ def set_types(self, types) -> None: pass +class MockPipeline: + """Mock Pipeline for REPLAY mode. + + In REPLAY mode, pipeline operations are no-ops since queries + return mocked data immediately. + """ + + def __init__(self, connection: "MockConnection"): + self._conn = connection + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def sync(self): + """No-op sync for mock pipeline.""" + pass + + class _TracedCopyWrapper: """Wrapper around psycopg's Copy object to capture data in RECORD mode. @@ -389,6 +415,8 @@ def __init__(self, enabled: bool = True) -> None: enabled=enabled, ) self._original_connect = None + # Track pending pipeline spans per connection for deferred finalization + self._pending_pipeline_spans: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() _instance = self def patch(self, module: ModuleType) -> None: @@ -445,6 +473,51 @@ def patched_connect(*args, **kwargs): module.connect = patched_connect # type: ignore[attr-defined] logger.debug("psycopg.connect instrumented") + # Patch Pipeline class for pipeline mode support + self._patch_pipeline_class(module) + + def _patch_pipeline_class(self, module: ModuleType) -> None: + """Patch psycopg.Pipeline to finalize spans on sync/exit.""" + try: + from psycopg import Pipeline + except ImportError: + logger.debug("psycopg.Pipeline not available, skipping pipeline instrumentation") + return + + instrumentation = self + + # Store originals for potential unpatch + self._original_pipeline_sync = getattr(Pipeline, 'sync', None) + self._original_pipeline_exit = getattr(Pipeline, '__exit__', None) + + if self._original_pipeline_sync: + def patched_sync(pipeline_self): + """Patched Pipeline.sync that finalizes pending spans.""" + result = instrumentation._original_pipeline_sync(pipeline_self) + # _conn is the connection associated with the pipeline + conn = getattr(pipeline_self, '_conn', None) + if conn: + instrumentation._finalize_pending_pipeline_spans(conn) + return result + + Pipeline.sync = patched_sync + logger.debug("psycopg.Pipeline.sync instrumented") + + if self._original_pipeline_exit: + def patched_exit(pipeline_self, exc_type, exc_val, exc_tb): + """Patched Pipeline.__exit__ that finalizes any remaining spans.""" + result = instrumentation._original_pipeline_exit( + pipeline_self, exc_type, exc_val, exc_tb + ) + # Finalize any remaining pending spans (handles implicit sync on exit) + conn = getattr(pipeline_self, '_conn', None) + if conn: + instrumentation._finalize_pending_pipeline_spans(conn) + return result + + Pipeline.__exit__ = patched_exit + logger.debug("psycopg.Pipeline.__exit__ instrumented") + def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): """Create a cursor factory that wraps cursors with instrumentation. @@ -638,6 +711,9 @@ def _record_execute( error = None result = None + # Check if we're in pipeline mode BEFORE executing + in_pipeline_mode = self._is_in_pipeline_mode(cursor) + with SpanUtils.with_span(span_info): try: result = original_execute(query, params, **kwargs) @@ -646,14 +722,24 @@ def _record_execute( error = e raise finally: - self._finalize_query_span( - span_info.span, - cursor, - query_str, - params, - error, - ) - span_info.span.end() + if error is not None: + # Always finalize immediately on error + self._finalize_query_span(span_info.span, cursor, query_str, params, error) + span_info.span.end() + elif in_pipeline_mode: + # Defer finalization until pipeline.sync() + connection = self._get_connection_from_cursor(cursor) + if connection: + self._add_pending_pipeline_span(connection, span_info, cursor, query_str, params) + # DON'T end span here - will be ended in _finalize_pending_pipeline_spans + else: + # Fallback: finalize immediately if we can't get connection + self._finalize_query_span(span_info.span, cursor, query_str, params, None) + span_info.span.end() + else: + # Normal mode: finalize immediately + self._finalize_query_span(span_info.span, cursor, query_str, params, None) + span_info.span.end() def _traced_executemany( self, cursor: Any, original_executemany: Any, sdk: TuskDrift, query: str, params_seq, **kwargs @@ -1218,6 +1304,71 @@ def _query_to_string(self, query: Any, cursor: Any) -> str: return str(query) if not isinstance(query, str) else query + def _is_in_pipeline_mode(self, cursor: Any) -> bool: + """Check if the cursor's connection is currently in pipeline mode. + + In psycopg3, when conn.pipeline() is active, connection._pipeline is set. + """ + try: + conn = getattr(cursor, 'connection', None) + if conn is None: + return False + # MockConnection doesn't have real pipeline mode + if isinstance(conn, MockConnection): + return False + pipeline = getattr(conn, '_pipeline', None) + return pipeline is not None + except Exception: + return False + + def _get_connection_from_cursor(self, cursor: Any) -> Any: + """Get the connection object from a cursor.""" + return getattr(cursor, 'connection', None) + + def _add_pending_pipeline_span( + self, + connection: Any, + span_info: Any, + cursor: Any, + query: str, + params: Any, + ) -> None: + """Add a pending span to be finalized when pipeline syncs.""" + if connection not in self._pending_pipeline_spans: + self._pending_pipeline_spans[connection] = [] + + self._pending_pipeline_spans[connection].append({ + 'span_info': span_info, + 'cursor': cursor, + 'query': query, + 'params': params, + }) + logger.debug(f"[PIPELINE] Deferred span for query: {query[:50]}...") + + def _finalize_pending_pipeline_spans(self, connection: Any) -> None: + """Finalize all pending spans for a connection after pipeline sync.""" + if connection not in self._pending_pipeline_spans: + return + + pending = self._pending_pipeline_spans.pop(connection, []) + logger.debug(f"[PIPELINE] Finalizing {len(pending)} pending pipeline spans") + + for item in pending: + span_info = item['span_info'] + cursor = item['cursor'] + query = item['query'] + params = item['params'] + + try: + self._finalize_query_span(span_info.span, cursor, query, params, error=None) + span_info.span.end() + except Exception as e: + logger.error(f"[PIPELINE] Error finalizing deferred span: {e}") + try: + span_info.span.end() + except Exception: + pass + def _try_get_mock( self, sdk: TuskDrift, From da7329de9b4dd8ad44202be171ddb1948935d64d Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Tue, 13 Jan 2026 23:27:06 -0800 Subject: [PATCH 07/37] fix cursor iteration --- .../psycopg/e2e-tests/src/app.py | 6 - .../psycopg/e2e-tests/src/test_requests.py | 3 - .../psycopg/instrumentation.py | 161 +++++++++++++++++- 3 files changed, 155 insertions(+), 15 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 1c0eda9..801517b 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -250,11 +250,6 @@ def test_pipeline_mode(): except Exception as e: return jsonify({"error": str(e)}), 500 - -# ========================================== -# Bug Hunt Test Endpoints -# ========================================== - @app.route("/test/dict-row-factory") def test_dict_row_factory(): """Test dict_row row factory. @@ -300,7 +295,6 @@ def test_namedtuple_row_factory(): except Exception as e: return jsonify({"error": str(e)}), 500 - @app.route("/test/cursor-iteration") def test_cursor_iteration(): """Test direct cursor iteration (for row in cursor). diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 3af77ac..05f8ddd 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -74,13 +74,10 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/pipeline-mode") - # BUG 2: Dict row factory - rows returned as column names make_request("GET", "/test/dict-row-factory") - # BUG 3: Namedtuple row factory - rows returned as plain tuples make_request("GET", "/test/namedtuple-row-factory") - # BUG 4: Cursor iteration - "no result available" in replay mode make_request("GET", "/test/cursor-iteration") print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index b0c9ffb..7c3ea4a 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -76,10 +76,11 @@ class MockConnection: All queries are mocked at the cursor.execute() level. """ - def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory): + def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory, row_factory=None): self.sdk = sdk self.instrumentation = instrumentation self.cursor_factory = cursor_factory + self.row_factory = row_factory # Store row_factory for cursor creation self.closed = False self.autocommit = False @@ -233,6 +234,18 @@ def stream(self, query, params=None, **kwargs): """Will be replaced by instrumentation.""" return iter([]) + def __iter__(self): + """Support direct cursor iteration (for row in cursor).""" + return self + + def __next__(self): + """Return next row for iteration.""" + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + return tuple(row) if isinstance(row, list) else row + raise StopIteration + def close(self): pass @@ -439,6 +452,7 @@ def patched_connect(*args, **kwargs): return original_connect(*args, **kwargs) user_cursor_factory = kwargs.pop("cursor_factory", None) + user_row_factory = kwargs.pop("row_factory", None) cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory) # Create server cursor factory for named cursors (conn.cursor(name="...")) @@ -448,6 +462,8 @@ def patched_connect(*args, **kwargs): if sdk.mode == TuskDriftMode.REPLAY: try: kwargs["cursor_factory"] = cursor_factory + if user_row_factory is not None: + kwargs["row_factory"] = user_row_factory connection = original_connect(*args, **kwargs) # Set server cursor factory on the connection for named cursors if server_cursor_factory: @@ -459,10 +475,12 @@ def patched_connect(*args, **kwargs): f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({e}), using mock connection (psycopg3)" ) # Return mock connection that doesn't require a real database - return MockConnection(sdk, instrumentation, cursor_factory) + return MockConnection(sdk, instrumentation, cursor_factory, row_factory=user_row_factory) # In RECORD mode, always require real connection kwargs["cursor_factory"] = cursor_factory + if user_row_factory is not None: + kwargs["row_factory"] = user_row_factory connection = original_connect(*args, **kwargs) # Set server cursor factory on the connection for named cursors if server_cursor_factory: @@ -558,6 +576,38 @@ def stream(self, query, params=None, **kwargs): def copy(self, query, params=None, **kwargs): return instrumentation._traced_copy(self, super().copy, sdk, query, params, **kwargs) + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, 'fetchone') and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + return InstrumentedCursor def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): @@ -594,6 +644,38 @@ def execute(self, query, params=None, **kwargs): # Note: ServerCursor doesn't support executemany() # Note: ServerCursor has stream-like iteration via fetchmany/itersize + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, 'fetchone') and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + return InstrumentedServerCursor def _traced_execute( @@ -1304,6 +1386,28 @@ def _query_to_string(self, query: Any, cursor: Any) -> str: return str(query) if not isinstance(query, str) else query + def _detect_row_factory_type(self, row_factory: Any) -> str: + """Detect the type of row factory for mock transformations. + + Returns: + "dict" for dict_row, "namedtuple" for namedtuple_row, "tuple" otherwise + """ + if row_factory is None: + return "tuple" + + # Check by function/class name + factory_name = getattr(row_factory, '__name__', '') + if not factory_name: + factory_name = str(type(row_factory).__name__) + + factory_name_lower = factory_name.lower() + if 'dict' in factory_name_lower: + return "dict" + elif 'namedtuple' in factory_name_lower: + return "namedtuple" + + return "tuple" + def _is_in_pipeline_mode(self, cursor: Any) -> bool: """Check if the cursor's connection is currently in pipeline mode. @@ -1443,6 +1547,36 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non except AttributeError: pass + # Get row_factory from cursor or connection for row transformation + row_factory = getattr(cursor, 'row_factory', None) + if row_factory is None: + conn = getattr(cursor, 'connection', None) + if conn: + row_factory = getattr(conn, 'row_factory', None) + + # Extract column names from description for row factory transformations + column_names = None + if description_data: + column_names = [col["name"] for col in description_data] + + # Detect row factory type for transformation + row_factory_type = self._detect_row_factory_type(row_factory) + + # Create namedtuple class once if needed (avoid recreating for each row) + RowClass = None + if row_factory_type == "namedtuple" and column_names: + from collections import namedtuple + RowClass = namedtuple('Row', column_names) + + def transform_row(row): + """Transform raw row data according to row factory type.""" + values = tuple(row) if isinstance(row, list) else row + if row_factory_type == "dict" and column_names: + return dict(zip(column_names, values)) + elif row_factory_type == "namedtuple" and RowClass is not None: + return RowClass(*values) + return values + mock_rows = actual_data.get("rows", []) # Deserialize datetime strings back to datetime objects for consistent Flask serialization mock_rows = [deserialize_db_value(row) for row in mock_rows] @@ -1453,7 +1587,7 @@ def mock_fetchone(): if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] - return tuple(row) if isinstance(row, list) else row + return transform_row(row) return None def mock_fetchmany(size=cursor.arraysize): @@ -1468,12 +1602,15 @@ def mock_fetchmany(size=cursor.arraysize): def mock_fetchall(): rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] - return [tuple(row) if isinstance(row, list) else row for row in rows] + return [transform_row(row) for row in rows] cursor.fetchone = mock_fetchone # pyright: ignore[reportAttributeAccessIssue] cursor.fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue] + # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor + # and MockCursor classes, as Python looks up special methods on the type, not instance + def _finalize_query_span( self, span: trace.Span, @@ -1538,8 +1675,20 @@ def serialize_value(val): # We need to capture these for replay mode try: all_rows = cursor.fetchall() - # Convert tuples to lists for JSON serialization - rows = [list(row) for row in all_rows] + # Convert rows to lists for JSON serialization + # Handle dict_row (returns dicts) and namedtuple_row (returns namedtuples) + column_names = [d["name"] for d in description] + rows = [] + for row in all_rows: + if isinstance(row, dict): + # dict_row: extract values in column order + rows.append([row.get(col) for col in column_names]) + elif hasattr(row, '_fields'): + # namedtuple: extract values in column order + rows.append([getattr(row, col, None) for col in column_names]) + else: + # tuple or list: convert directly + rows.append(list(row)) # CRITICAL: Re-populate cursor so user code can still fetch # We'll store the rows and patch fetch methods From 84210292ef2401694f92965ba4e71f05c01bdd32 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 10:41:08 -0800 Subject: [PATCH 08/37] fix executemany with returning=True instrumentation for psycopg --- .../psycopg/e2e-tests/src/app.py | 69 +++ .../psycopg/e2e-tests/src/test_requests.py | 2 + .../psycopg/instrumentation.py | 403 +++++++++++++++++- 3 files changed, 462 insertions(+), 12 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 801517b..da62585 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -317,6 +317,75 @@ def test_cursor_iteration(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/executemany-returning") +def test_executemany_returning(): + """Test executemany with returning=True. + + Tests whether the instrumentation correctly handles executemany with returning=True. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table + cur.execute("CREATE TEMP TABLE batch_test (id SERIAL, name TEXT)") + + # Use executemany with returning + params = [("Batch User 1",), ("Batch User 2",), ("Batch User 3",)] + cur.executemany( + "INSERT INTO batch_test (name) VALUES (%s) RETURNING id, name", + params, + returning=True + ) + + # Fetch results from each batch + results = [] + for result in cur.results(): + row = result.fetchone() + if row: + results.append({"id": row[0], "name": row[1]}) + + conn.commit() + + return jsonify({ + "count": len(results), + "data": results + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + +# ============================================================================ +# BUG HUNTING TEST ENDPOINTS +# These endpoints expose confirmed bugs in the instrumentation +# ============================================================================ + +@app.route("/test/cursor-scroll") +def test_cursor_scroll(): + """Test cursor.scroll() method. + + BUG: The MockCursor and InstrumentedCursor classes don't implement + the scroll() method. In replay mode, scroll() causes "no result available" + error on subsequent fetchone() calls. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id") + + # Fetch first row + first = cur.fetchone() + + # Scroll back to start + cur.scroll(0, mode='absolute') + + # Fetch first row again + first_again = cur.fetchone() + + return jsonify({ + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 05f8ddd..cd3ca33 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -80,4 +80,6 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/cursor-iteration") + make_request("GET", "/test/executemany-returning") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 7c3ea4a..3327592 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -204,6 +204,9 @@ def __init__(self, connection): self.arraysize = 1 self._mock_rows = [] self._mock_index = 0 + # Support for multiple result sets (executemany with returning=True) + self._mock_result_sets = [] + self._mock_result_set_index = 0 self.adapters = MockAdapters() # Django needs this logger.debug("[MOCK_CURSOR] Created fallback mock cursor (psycopg3)") @@ -230,6 +233,23 @@ def fetchmany(self, size=None): def fetchall(self): return [] + def results(self): + """Iterate over result sets for executemany with returning=True. + + This method is patched by _mock_executemany_returning_with_data + when replaying executemany with returning=True. + Default implementation yields self once for backward compatibility. + """ + yield self + + def nextset(self): + """Move to the next result set. + + Returns True if there is another result set, None otherwise. + This method is patched during replay for executemany with returning=True. + """ + return None + def stream(self, query, params=None, **kwargs): """Will be replaced by instrumentation.""" return iter([]) @@ -833,10 +853,14 @@ def _traced_executemany( query_str = self._query_to_string(query, cursor) # Convert to list BEFORE executing to avoid iterator exhaustion params_list = list(params_seq) + # Detect returning flag for executemany with RETURNING clause + returning = kwargs.get("returning", False) if sdk.mode == TuskDriftMode.REPLAY: return handle_replay_mode( - replay_mode_handler=lambda: self._replay_executemany(cursor, sdk, query_str, params_list), + replay_mode_handler=lambda: self._replay_executemany( + cursor, sdk, query_str, params_list, returning + ), no_op_request_handler=lambda: self._noop_execute(cursor), is_server_request=False, ) @@ -845,12 +869,14 @@ def _traced_executemany( return handle_record_mode( original_function_call=lambda: original_executemany(query, params_list, **kwargs), record_mode_handler=lambda is_pre_app_start: self._record_executemany( - cursor, original_executemany, sdk, query, query_str, params_list, is_pre_app_start, kwargs + cursor, original_executemany, sdk, query, query_str, params_list, is_pre_app_start, kwargs, returning ), span_kind=OTelSpanKind.CLIENT, ) - def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, params_list: list) -> Any: + def _replay_executemany( + self, cursor: Any, sdk: TuskDrift, query_str: str, params_list: list, returning: bool = False + ) -> Any: """Handle REPLAY mode for executemany - fetch mock from CLI.""" span_info = SpanUtils.create_span( CreateSpanOptions( @@ -872,8 +898,13 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): + # Include returning flag in parameters for mock matching + params_for_mock = {"_batch": params_list} + if returning: + params_for_mock["_returning"] = True + mock_result = self._try_get_mock( - sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id + sdk, query_str, params_for_mock, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -887,7 +918,13 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param f"Query: {query_str[:100]}..." ) - self._mock_execute_with_data(cursor, mock_result) + # Check if this is executemany_returning format (multiple result sets) + if mock_result.get("executemany_returning"): + self._mock_executemany_returning_with_data(cursor, mock_result) + else: + # Backward compatible: use existing single result set handling + self._mock_execute_with_data(cursor, mock_result) + span_info.span.end() return cursor @@ -901,6 +938,7 @@ def _record_executemany( params_list: list, is_pre_app_start: bool, kwargs: dict, + returning: bool = False, ) -> Any: """Handle RECORD mode for executemany - create span and execute query.""" span_info = SpanUtils.create_span( @@ -934,13 +972,24 @@ def _record_executemany( error = e raise finally: - self._finalize_query_span( - span_info.span, - cursor, - query_str, - {"_batch": params_list}, - error, - ) + if returning and error is None: + # Use specialized method for executemany with returning=True + self._finalize_executemany_returning_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list, "_returning": True}, + error, + ) + else: + # Existing behavior for executemany without returning + self._finalize_query_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list}, + error, + ) span_info.span.end() def _traced_stream( @@ -1611,6 +1660,167 @@ def mock_fetchall(): # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor # and MockCursor classes, as Python looks up special methods on the type, not instance + def _mock_executemany_returning_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> None: + """Mock cursor for executemany with returning=True - supports multiple result sets. + + This method sets up the cursor to replay multiple result sets captured during + executemany with returning=True. It patches the cursor's results() method to + yield the cursor for each result set, allowing iteration. + """ + result_sets = mock_data.get("result_sets", []) + + if not result_sets: + # Fallback to empty result + cursor._mock_rows = [] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + return + + # Get row_factory from cursor or connection + row_factory = getattr(cursor, "row_factory", None) + if row_factory is None: + conn = getattr(cursor, "connection", None) + if conn: + row_factory = getattr(conn, "row_factory", None) + + row_factory_type = self._detect_row_factory_type(row_factory) + + # Store all result sets for iteration + cursor._mock_result_sets = [] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + for result_set in result_sets: + description_data = result_set.get("description") + column_names = None + if description_data: + column_names = [col["name"] for col in description_data] + + # Deserialize rows + mock_rows = result_set.get("rows", []) + mock_rows = [deserialize_db_value(row) for row in mock_rows] + + cursor._mock_result_sets.append( # pyright: ignore[reportAttributeAccessIssue] + { + "description": description_data, + "column_names": column_names, + "rows": mock_rows, + "rowcount": result_set.get("rowcount", -1), + } + ) + + # Create row transformation helper + def create_row_class(col_names): + if row_factory_type == "namedtuple" and col_names: + from collections import namedtuple + + return namedtuple("Row", col_names) + return None + + def transform_row(row, col_names, RowClass): + """Transform raw row data according to row factory type.""" + values = tuple(row) if isinstance(row, list) else row + if row_factory_type == "dict" and col_names: + return dict(zip(col_names, values)) + elif row_factory_type == "namedtuple" and RowClass is not None: + return RowClass(*values) + return values + + def mock_results(): + """Generator that yields cursor for each result set.""" + while cursor._mock_result_set_index < len(cursor._mock_result_sets): # pyright: ignore[reportAttributeAccessIssue] + result_set = cursor._mock_result_sets[cursor._mock_result_set_index] # pyright: ignore[reportAttributeAccessIssue] + + # Set up cursor state for this result set + cursor._mock_rows = result_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Set description + description_data = result_set.get("description") + if description_data: + desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in description_data + ] + try: + cursor._tusk_description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + # Set rowcount + try: + cursor._rowcount = result_set.get("rowcount", -1) # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + column_names = result_set.get("column_names") + RowClass = create_row_class(column_names) + + # Create fetch methods for this result set with closures capturing current values + def make_fetchone(cn, RC): + def fetchone(): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return transform_row(row, cn, RC) + return None + + return fetchone + + def make_fetchmany(cn, RC): + def fetchmany(size=cursor.arraysize): + rows = [] + for _ in range(size): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + rows.append(transform_row(row, cn, RC)) + else: + break + return rows + + return fetchmany + + def make_fetchall(cn, RC): + def fetchall(): + rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] + return [transform_row(row, cn, RC) for row in rows] + + return fetchall + + cursor.fetchone = make_fetchone(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + + cursor._mock_result_set_index += 1 # pyright: ignore[reportAttributeAccessIssue] + yield cursor + + # Patch results() method onto cursor + cursor.results = mock_results # pyright: ignore[reportAttributeAccessIssue] + + # Also set up initial state for the first result set (in case user calls fetch without results()) + if cursor._mock_result_sets: # pyright: ignore[reportAttributeAccessIssue] + first_set = cursor._mock_result_sets[0] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = first_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Set description for first result set + description_data = first_set.get("description") + if description_data: + desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in description_data + ] + try: + cursor._tusk_description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + def _finalize_query_span( self, span: trace.Span, @@ -1764,3 +1974,172 @@ def patched_fetchall(): except Exception as e: logger.error(f"Error creating query span: {e}") + + def _finalize_executemany_returning_span( + self, + span: trace.Span, + cursor: Any, + query: str, + params: Any, + error: Exception | None, + ) -> None: + """Finalize span for executemany with returning=True - captures multiple result sets. + + This method iterates through cursor.results() to capture all result sets + from executemany with returning=True, storing them in a format that can + be replayed with multiple result set iteration. + """ + try: + import datetime + + def serialize_value(val): + """Convert non-JSON-serializable values to JSON-compatible types.""" + if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): + return val.isoformat() + elif isinstance(val, bytes): + return val.decode("utf-8", errors="replace") + elif isinstance(val, (list, tuple)): + return [serialize_value(v) for v in val] + elif isinstance(val, dict): + return {k: serialize_value(v) for k, v in val.items()} + return val + + # Build input value + input_value = { + "query": query.strip(), + } + if params is not None: + input_value["parameters"] = serialize_value(params) + + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Iterate through cursor.results() to capture all result sets + result_sets = [] + all_rows_collected = [] # For re-populating cursor + + try: + # cursor.results() yields the cursor itself for each result set + for result_cursor in cursor.results(): + result_set_data = {} + + # Capture description for this result set + if hasattr(result_cursor, "description") and result_cursor.description: + description = [ + { + "name": desc[0] if hasattr(desc, "__getitem__") else desc.name, + "type_code": desc[1] + if hasattr(desc, "__getitem__") and len(desc) > 1 + else getattr(desc, "type_code", None), + } + for desc in result_cursor.description + ] + result_set_data["description"] = description + column_names = [d["name"] for d in description] + else: + description = None + column_names = None + + # Fetch all rows for this result set + rows = [] + raw_rows = result_cursor.fetchall() + all_rows_collected.append(raw_rows) + + for row in raw_rows: + if isinstance(row, dict): + rows.append([row.get(col) for col in column_names] if column_names else list(row.values())) + elif hasattr(row, "_fields"): + rows.append( + [getattr(row, col, None) for col in column_names] if column_names else list(row) + ) + else: + rows.append(list(row)) + + result_set_data["rowcount"] = ( + result_cursor.rowcount if hasattr(result_cursor, "rowcount") else len(rows) + ) + result_set_data["rows"] = [[serialize_value(col) for col in row] for row in rows] + + result_sets.append(result_set_data) + + except Exception as results_error: + logger.debug(f"Could not iterate results(): {results_error}") + # Fallback: treat as single result set + result_sets = [] + + if result_sets: + output_value = { + "executemany_returning": True, + "result_sets": result_sets, + } + + # Re-populate cursor for user code + # Store all collected rows for replay via results() + cursor._tusk_result_sets = all_rows_collected # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Patch results() method to iterate stored result sets + def patched_results(): + while cursor._tusk_result_set_index < len(cursor._tusk_result_sets): # pyright: ignore[reportAttributeAccessIssue] + rows = cursor._tusk_result_sets[cursor._tusk_result_set_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = rows # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index += 1 # pyright: ignore[reportAttributeAccessIssue] + + # Patch fetch methods for this result set + def patched_fetchone(): + if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return row + return None + + def patched_fetchmany(size=cursor.arraysize): + result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] + return result + + def patched_fetchall(): + result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + return result + + cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] + + yield cursor + + cursor.results = patched_results # pyright: ignore[reportAttributeAccessIssue] + + else: + output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1} + + # Generate schemas and hashes + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + # Set span attributes + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Executemany returning span finalized successfully") + + except Exception as e: + logger.error(f"Error finalizing executemany returning span: {e}") From 9c4a03ee571a8c1277718aa5bc525dd694f0bb34 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 10:55:01 -0800 Subject: [PATCH 09/37] fix cursor.scroll() support in psycopg replay mode --- .../psycopg/instrumentation.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 3327592..d509e3f 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -266,6 +266,24 @@ def __next__(self): return tuple(row) if isinstance(row, list) else row raise StopIteration + def scroll(self, value: int, mode: str = "relative") -> None: + """Scroll the cursor to a new position in the result set.""" + if mode == "relative": + newpos = self._mock_index + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(self._mock_rows) + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + self._mock_index = newpos + def close(self): pass @@ -1657,6 +1675,26 @@ def mock_fetchall(): cursor.fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue] + def mock_scroll(value: int, mode: str = "relative") -> None: + """Scroll the cursor to a new position in the mock result set.""" + if mode == "relative": + newpos = cursor._mock_index + value # pyright: ignore[reportAttributeAccessIssue] + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + cursor._mock_index = newpos # pyright: ignore[reportAttributeAccessIssue] + + cursor.scroll = mock_scroll # pyright: ignore[reportAttributeAccessIssue] + # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor # and MockCursor classes, as Python looks up special methods on the type, not instance From 3f801cf582d776da50179d8f3708f046072c4d1e Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 11:53:06 -0800 Subject: [PATCH 10/37] fix cursor.rownumber property returning null during REPLAY mode --- .../psycopg/e2e-tests/src/app.py | 156 +++++++++++++++++- .../psycopg/e2e-tests/src/test_requests.py | 8 + .../psycopg/instrumentation.py | 29 ++++ 3 files changed, 188 insertions(+), 5 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index da62585..3fd1816 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -352,18 +352,56 @@ def test_executemany_returning(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/rownumber") +def test_rownumber(): + """Test cursor.rownumber property. + + Tests whether the rownumber property is properly tracked during replay mode. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 5") + + positions = [] + # Record rownumber at each fetch + positions.append({"before": cur.rownumber}) + + cur.fetchone() + positions.append({"after_fetchone_1": cur.rownumber}) + + cur.fetchone() + positions.append({"after_fetchone_2": cur.rownumber}) + + cur.fetchmany(2) + positions.append({"after_fetchmany_2": cur.rownumber}) + + return jsonify({ + "positions": positions + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + # ============================================================================ -# BUG HUNTING TEST ENDPOINTS -# These endpoints expose confirmed bugs in the instrumentation +# CONFIRMED BUG TEST ENDPOINTS +# These endpoints expose confirmed bugs in the psycopg instrumentation. +# See BUG_TRACKING.md for detailed documentation of each bug. +# +# Bug Summary: +# 1. /test/cursor-scroll - scroll() broken during RECORD mode +# 3. /test/statusmessage - statusmessage property returns null during REPLAY +# 4. /test/nextset - nextset() iteration broken during RECORD mode +# 5. /test/server-cursor-scroll - scroll() broken during RECORD mode # ============================================================================ @app.route("/test/cursor-scroll") def test_cursor_scroll(): """Test cursor.scroll() method. - BUG: The MockCursor and InstrumentedCursor classes don't implement - the scroll() method. In replay mode, scroll() causes "no result available" - error on subsequent fetchone() calls. + BUG: During RECORD mode, the instrumentation's _finalize_query_span calls + fetchall() which breaks the cursor position tracking. After fetchall(), + scroll(0, absolute) doesn't properly reset the cursor position because + the patched fetch methods use _tusk_index instead of _pos. """ try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: @@ -386,6 +424,114 @@ def test_cursor_scroll(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/statusmessage") +def test_statusmessage(): + """Test cursor.statusmessage property. + + BUG: The statusmessage property is not captured during RECORD mode + and not mocked during REPLAY mode. During RECORD, statusmessage + returns the command status (e.g., "SELECT 5", "INSERT 0 1"), but + during REPLAY it returns null because this property is not tracked. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # SELECT should return something like "SELECT 5" + cur.execute("SELECT id FROM users LIMIT 5") + select_status = cur.statusmessage + cur.fetchall() + + # INSERT should return something like "INSERT 0 1" + cur.execute( + "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", + ("StatusTest", "status@test.com") + ) + insert_status = cur.statusmessage + cur.fetchone() + + conn.rollback() # Don't actually insert + + return jsonify({ + "select_status": select_status, + "insert_status": insert_status + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/nextset") +def test_nextset(): + """Test cursor.nextset() for multiple result sets. + + BUG: During RECORD mode, the interaction between executemany with + returning=True and the fetch method patching breaks nextset() iteration. + The instrumentation's fetch patching may consume results before nextset() + can iterate through them, causing 0 results in RECORD but correct + results in REPLAY. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table + cur.execute("CREATE TEMP TABLE nextset_test (id SERIAL, val TEXT)") + + # Insert multiple rows with returning + cur.executemany( + "INSERT INTO nextset_test (val) VALUES (%s) RETURNING id, val", + [("First",), ("Second",), ("Third",)], + returning=True + ) + + # Use nextset to iterate through result sets + results = [] + while True: + row = cur.fetchone() + if row: + results.append({"id": row[0], "val": row[1]}) + if cur.nextset() is None: + break + + conn.commit() + + return jsonify({ + "count": len(results), + "data": results + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/server-cursor-scroll") +def test_server_cursor_scroll(): + """Test ServerCursor.scroll() method. + + BUG: Same root cause as /test/cursor-scroll. During RECORD mode, + the instrumentation breaks scroll() functionality by consuming all + rows via fetchall() in _finalize_query_span. ServerCursor.scroll() + sends MOVE commands to the server, but the position tracking is + inconsistent after the instrumentation patches the fetch methods. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Named cursor with scrollable=True + with conn.cursor(name="scroll_test", scrollable=True) as cur: + cur.execute("SELECT id, name FROM users ORDER BY id") + + # Fetch first row + first = cur.fetchone() + + # Scroll back to start + cur.scroll(0, mode='absolute') + + # Fetch first row again + first_again = cur.fetchone() + + return jsonify({ + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index cd3ca33..d7ba79f 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -82,4 +82,12 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/executemany-returning") + make_request("GET", "/test/rownumber") + + # Bug-exposing test endpoints + make_request("GET", "/test/statusmessage") + make_request("GET", "/test/nextset") + make_request("GET", "/test/server-cursor-scroll") + make_request("GET", "/test/cursor-scroll") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index d509e3f..67dda36 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -214,6 +214,13 @@ def __init__(self, connection): def description(self): return self._tusk_description + @property + def rownumber(self): + """Return the index of the next row to fetch, or None if no result.""" + if self._mock_rows: + return self._mock_index + return None + def execute(self, query, params=None, **kwargs): """Will be replaced by instrumentation.""" logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") @@ -602,6 +609,17 @@ def description(self): return self._tusk_description return super().description + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + def execute(self, query, params=None, **kwargs): return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -675,6 +693,17 @@ def description(self): return self._tusk_description return super().description + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + def execute(self, query, params=None, **kwargs): # Note: ServerCursor.execute() doesn't support 'prepare' parameter return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) From b7c850250d44de19c224ec6c80b0ce8768866df5 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 12:07:12 -0800 Subject: [PATCH 11/37] fix psycopg cursor.statusmessage capture and replay --- .../psycopg/e2e-tests/src/app.py | 69 +++++++++---------- .../psycopg/e2e-tests/src/test_requests.py | 3 +- .../psycopg/instrumentation.py | 30 ++++++++ 3 files changed, 66 insertions(+), 36 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 3fd1816..eaab0c7 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -381,6 +381,40 @@ def test_rownumber(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/statusmessage") +def test_statusmessage(): + """Test cursor.statusmessage property. + + BUG: The statusmessage property is not captured during RECORD mode + and not mocked during REPLAY mode. During RECORD, statusmessage + returns the command status (e.g., "SELECT 5", "INSERT 0 1"), but + during REPLAY it returns null because this property is not tracked. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # SELECT should return something like "SELECT 5" + cur.execute("SELECT id FROM users LIMIT 5") + select_status = cur.statusmessage + cur.fetchall() + + # INSERT should return something like "INSERT 0 1" + cur.execute( + "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", + ("StatusTest", "status@test.com") + ) + insert_status = cur.statusmessage + cur.fetchone() + + conn.rollback() # Don't actually insert + + return jsonify({ + "select_status": select_status, + "insert_status": insert_status + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + # ============================================================================ # CONFIRMED BUG TEST ENDPOINTS @@ -389,7 +423,6 @@ def test_rownumber(): # # Bug Summary: # 1. /test/cursor-scroll - scroll() broken during RECORD mode -# 3. /test/statusmessage - statusmessage property returns null during REPLAY # 4. /test/nextset - nextset() iteration broken during RECORD mode # 5. /test/server-cursor-scroll - scroll() broken during RECORD mode # ============================================================================ @@ -424,40 +457,6 @@ def test_cursor_scroll(): except Exception as e: return jsonify({"error": str(e)}), 500 -@app.route("/test/statusmessage") -def test_statusmessage(): - """Test cursor.statusmessage property. - - BUG: The statusmessage property is not captured during RECORD mode - and not mocked during REPLAY mode. During RECORD, statusmessage - returns the command status (e.g., "SELECT 5", "INSERT 0 1"), but - during REPLAY it returns null because this property is not tracked. - """ - try: - with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: - # SELECT should return something like "SELECT 5" - cur.execute("SELECT id FROM users LIMIT 5") - select_status = cur.statusmessage - cur.fetchall() - - # INSERT should return something like "INSERT 0 1" - cur.execute( - "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", - ("StatusTest", "status@test.com") - ) - insert_status = cur.statusmessage - cur.fetchone() - - conn.rollback() # Don't actually insert - - return jsonify({ - "select_status": select_status, - "insert_status": insert_status - }) - except Exception as e: - return jsonify({"error": str(e)}), 500 - - @app.route("/test/nextset") def test_nextset(): """Test cursor.nextset() for multiple result sets. diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index d7ba79f..e54bbbc 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -84,8 +84,9 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/rownumber") - # Bug-exposing test endpoints make_request("GET", "/test/statusmessage") + + # Bug-exposing test endpoints make_request("GET", "/test/nextset") make_request("GET", "/test/server-cursor-scroll") make_request("GET", "/test/cursor-scroll") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 67dda36..dd6e1fb 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -221,6 +221,11 @@ def rownumber(self): return self._mock_index return None + @property + def statusmessage(self): + """Return the mock status message if set, otherwise None.""" + return getattr(self, '_mock_statusmessage', None) + def execute(self, query, params=None, **kwargs): """Will be replaced by instrumentation.""" logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") @@ -620,6 +625,14 @@ def rownumber(self): # Otherwise, return real cursor's rownumber return super().rownumber + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, '_mock_statusmessage'): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + def execute(self, query, params=None, **kwargs): return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -704,6 +717,14 @@ def rownumber(self): # Otherwise, return real cursor's rownumber return super().rownumber + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, '_mock_statusmessage'): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + def execute(self, query, params=None, **kwargs): # Note: ServerCursor.execute() doesn't support 'prepare' parameter return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -1643,6 +1664,11 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non except AttributeError: pass + # Set mock statusmessage for replay + statusmessage = actual_data.get("statusmessage") + if statusmessage is not None: + cursor._mock_statusmessage = statusmessage + # Get row_factory from cursor or connection for row transformation row_factory = getattr(cursor, 'row_factory', None) if row_factory is None: @@ -2017,6 +2043,10 @@ def patched_fetchall(): serialized_rows = [[serialize_value(col) for col in row] for row in rows] output_value["rows"] = serialized_rows + # Capture statusmessage for replay + if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: + output_value["statusmessage"] = cursor.statusmessage + except Exception as e: logger.debug(f"Error getting query metadata: {e}") From 15a32a8653ca833002a545e468f3c8e428fc2c7f Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 12:28:08 -0800 Subject: [PATCH 12/37] fix nextset() iteration for executemany with returning=True --- .../psycopg/instrumentation.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index dd6e1fb..9a3b654 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1914,6 +1914,80 @@ def fetchall(): except AttributeError: pass + # Set up initial fetch methods for the first result set (for code that uses nextset() instead of results()) + first_column_names = first_set.get("column_names") + FirstRowClass = create_row_class(first_column_names) + + def make_fetchone_replay(cn, RC): + def fetchone(): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return transform_row(row, cn, RC) + return None + return fetchone + + def make_fetchmany_replay(cn, RC): + def fetchmany(size=cursor.arraysize): + rows = [] + for _ in range(size): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + rows.append(transform_row(row, cn, RC)) + else: + break + return rows + return fetchmany + + def make_fetchall_replay(cn, RC): + def fetchall(): + rows = cursor._mock_rows[cursor._mock_index:] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] + return [transform_row(row, cn, RC) for row in rows] + return fetchall + + cursor.fetchone = make_fetchone_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Patch nextset() to work with _mock_result_sets + def patched_nextset(): + """Move to the next result set in _mock_result_sets.""" + next_index = cursor._mock_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue] + if next_index < len(cursor._mock_result_sets): # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue] + next_set = cursor._mock_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = next_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the new result set + next_column_names = next_set.get("column_names") + NextRowClass = create_row_class(next_column_names) + cursor.fetchone = make_fetchone_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Update description for next result set + next_description_data = next_set.get("description") + if next_description_data: + next_desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in next_description_data + ] + try: + cursor._tusk_description = next_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = next_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + return True + return None + + cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] + def _finalize_query_span( self, span: trace.Span, @@ -2216,6 +2290,58 @@ def patched_fetchall(): cursor.results = patched_results # pyright: ignore[reportAttributeAccessIssue] + # Set up the first result set immediately for user code that uses nextset() instead of results() + if all_rows_collected: + cursor._tusk_rows = all_rows_collected[0] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Create initial fetch methods for the first result set + def make_patched_fetchone_record(): + def patched_fetchone(): + if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return row + return None + return patched_fetchone + + def make_patched_fetchmany_record(): + def patched_fetchmany(size=cursor.arraysize): + result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] + return result + return patched_fetchmany + + def make_patched_fetchall_record(): + def patched_fetchall(): + result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + return result + return patched_fetchall + + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + + # Patch nextset() to work with _tusk_result_sets + def patched_nextset(): + """Move to the next result set in _tusk_result_sets.""" + next_index = cursor._tusk_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue] + if next_index < len(cursor._tusk_result_sets): # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = cursor._tusk_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the new result set + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + return True + return None + + cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] + else: output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1} From 02a5b4569eb88a87a7d044ae08aa3ada6cc845a9 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 12:41:08 -0800 Subject: [PATCH 13/37] fix cursor.scroll() position tracking in RECORD mode --- .../psycopg/e2e-tests/src/app.py | 81 +++++++------------ .../psycopg/e2e-tests/src/test_requests.py | 3 +- .../psycopg/instrumentation.py | 22 +++++ 3 files changed, 52 insertions(+), 54 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index eaab0c7..cc5d6b7 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -414,58 +414,11 @@ def test_statusmessage(): except Exception as e: return jsonify({"error": str(e)}), 500 - - -# ============================================================================ -# CONFIRMED BUG TEST ENDPOINTS -# These endpoints expose confirmed bugs in the psycopg instrumentation. -# See BUG_TRACKING.md for detailed documentation of each bug. -# -# Bug Summary: -# 1. /test/cursor-scroll - scroll() broken during RECORD mode -# 4. /test/nextset - nextset() iteration broken during RECORD mode -# 5. /test/server-cursor-scroll - scroll() broken during RECORD mode -# ============================================================================ - -@app.route("/test/cursor-scroll") -def test_cursor_scroll(): - """Test cursor.scroll() method. - - BUG: During RECORD mode, the instrumentation's _finalize_query_span calls - fetchall() which breaks the cursor position tracking. After fetchall(), - scroll(0, absolute) doesn't properly reset the cursor position because - the patched fetch methods use _tusk_index instead of _pos. - """ - try: - with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: - cur.execute("SELECT id, name FROM users ORDER BY id") - - # Fetch first row - first = cur.fetchone() - - # Scroll back to start - cur.scroll(0, mode='absolute') - - # Fetch first row again - first_again = cur.fetchone() - - return jsonify({ - "first": {"id": first[0], "name": first[1]} if first else None, - "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, - "match": first == first_again - }) - except Exception as e: - return jsonify({"error": str(e)}), 500 - @app.route("/test/nextset") def test_nextset(): """Test cursor.nextset() for multiple result sets. - BUG: During RECORD mode, the interaction between executemany with - returning=True and the fetch method patching breaks nextset() iteration. - The instrumentation's fetch patching may consume results before nextset() - can iterate through them, causing 0 results in RECORD but correct - results in REPLAY. + Tests whether the instrumentation correctly handles nextset() for multiple result sets. """ try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: @@ -497,16 +450,38 @@ def test_nextset(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/cursor-scroll") +def test_cursor_scroll(): + """Test cursor.scroll() method. + + Tests whether the instrumentation correctly handles scroll() for cursor position tracking. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id") + + # Fetch first row + first = cur.fetchone() + + # Scroll back to start + cur.scroll(0, mode='absolute') + + # Fetch first row again + first_again = cur.fetchone() + + return jsonify({ + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 @app.route("/test/server-cursor-scroll") def test_server_cursor_scroll(): """Test ServerCursor.scroll() method. - BUG: Same root cause as /test/cursor-scroll. During RECORD mode, - the instrumentation breaks scroll() functionality by consuming all - rows via fetchall() in _finalize_query_span. ServerCursor.scroll() - sends MOVE commands to the server, but the position tracking is - inconsistent after the instrumentation patches the fetch methods. + Tests whether the instrumentation correctly handles scroll() for server-side cursors. """ try: with psycopg.connect(get_conn_string()) as conn: diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index e54bbbc..b0e9cc2 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -86,9 +86,10 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/statusmessage") - # Bug-exposing test endpoints make_request("GET", "/test/nextset") + make_request("GET", "/test/server-cursor-scroll") + make_request("GET", "/test/cursor-scroll") print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 9a3b654..9a2ef76 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -857,6 +857,8 @@ def _record_execute( cursor.fetchall = cursor._tusk_original_fetchall cursor._tusk_rows = None cursor._tusk_index = 0 + if hasattr(cursor, '_tusk_original_scroll'): + cursor.scroll = cursor._tusk_original_scroll span_info = SpanUtils.create_span( CreateSpanOptions( @@ -2090,16 +2092,36 @@ def patched_fetchall(): cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] return result + def patched_scroll(value: int, mode: str = "relative") -> None: + """Scroll the cursor to a new position in the captured result set.""" + if mode == "relative": + newpos = cursor._tusk_index + value # pyright: ignore[reportAttributeAccessIssue] + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("cursor position out of range") + elif newpos != 0: + raise IndexError("cursor position out of range") + + cursor._tusk_index = newpos # pyright: ignore[reportAttributeAccessIssue] + # Save original fetch methods before patching (only if not already saved) # These will be restored at the start of the next execute() call if not hasattr(cursor, '_tusk_original_fetchone'): cursor._tusk_original_fetchone = cursor.fetchone # pyright: ignore[reportAttributeAccessIssue] cursor._tusk_original_fetchmany = cursor.fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor._tusk_original_fetchall = cursor.fetchall # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_original_scroll = cursor.scroll # pyright: ignore[reportAttributeAccessIssue] cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue] cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] + cursor.scroll = patched_scroll # pyright: ignore[reportAttributeAccessIssue] except Exception as fetch_error: logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}") From acbb31bd3bb55ad1e913437c4823d02b08580fe8 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 13:00:48 -0800 Subject: [PATCH 14/37] refactor --- .../psycopg/instrumentation.py | 498 +----------------- drift/instrumentation/psycopg/mocks.py | 384 ++++++++++++++ drift/instrumentation/psycopg/wrappers.py | 80 +++ .../psycopg2/instrumentation.py | 16 +- drift/instrumentation/utils/serialization.py | 30 ++ 5 files changed, 500 insertions(+), 508 deletions(-) create mode 100644 drift/instrumentation/psycopg/mocks.py create mode 100644 drift/instrumentation/psycopg/wrappers.py create mode 100644 drift/instrumentation/utils/serialization.py diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 9a2ef76..298b0d0 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -24,445 +24,15 @@ ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value +from ..utils.serialization import serialize_value +from .mocks import MockConnection, MockCopy +from .wrappers import TracedCopyWrapper logger = logging.getLogger(__name__) _instance: PsycopgInstrumentation | None = None -class MockLoader: - """Mock loader for psycopg3.""" - - def __init__(self): - self.timezone = None # Django expects this attribute - - def __call__(self, data): - """No-op load function.""" - return data - - -class MockDumper: - """Mock dumper for psycopg3.""" - - def __call__(self, obj): - """No-op dump function.""" - return str(obj).encode("utf-8") - - -class MockAdapters: - """Mock adapters for psycopg3 connection.""" - - def get_loader(self, oid, format): - """Return a mock loader.""" - return MockLoader() - - def get_dumper(self, obj, format): - """Return a mock dumper.""" - return MockDumper() - - def register_loader(self, oid, loader): - """No-op register loader for Django compatibility.""" - pass - - def register_dumper(self, oid, dumper): - """No-op register dumper for Django compatibility.""" - pass - - -class MockConnection: - """Mock database connection for REPLAY mode when postgres is not available. - - Provides minimal interface for Django/Flask to work without a real database. - All queries are mocked at the cursor.execute() level. - """ - - def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory, row_factory=None): - self.sdk = sdk - self.instrumentation = instrumentation - self.cursor_factory = cursor_factory - self.row_factory = row_factory # Store row_factory for cursor creation - self.closed = False - self.autocommit = False - - # Django/psycopg3 requires these for connection initialization - self.isolation_level = None - self.encoding = "UTF8" - self.adapters = MockAdapters() - self.pgconn = None # Mock pg connection object - - # Create a comprehensive mock info object for Django - class MockInfo: - vendor = "postgresql" - server_version = 150000 # PostgreSQL 15.0 as integer - encoding = "UTF8" - - def parameter_status(self, param): - """Return mock parameter status.""" - if param == "TimeZone": - return "UTC" - elif param == "server_version": - return "15.0" - return None - - self.info = MockInfo() - - logger.debug("[MOCK_CONNECTION] Created mock connection for REPLAY mode (psycopg3)") - - def cursor(self, name=None, *, cursor_factory=None, **kwargs): - """Create a cursor using the instrumented cursor factory. - - Accepts the same parameters as psycopg's Connection.cursor(), including - server cursor parameters like scrollable and withhold. - """ - # For mock connections, we create a MockCursor directly - # The name parameter is accepted but not used since mock cursors - # behave the same for both regular and server cursors - cursor = MockCursor(self) - - # Wrap execute/executemany for mock cursor - instrumentation = self.instrumentation - sdk = self.sdk - - def mock_execute(query, params=None, **kwargs): - # For mock cursor, original_execute is just a no-op - def noop_execute(q, p, **kw): - return cursor - - return instrumentation._traced_execute(cursor, noop_execute, sdk, query, params, **kwargs) - - def mock_executemany(query, params_seq, **kwargs): - # For mock cursor, original_executemany is just a no-op - def noop_executemany(q, ps, **kw): - return cursor - - return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) - - def mock_stream(query, params=None, **kwargs): - # For mock cursor, original_stream is just a no-op generator - def noop_stream(q, p, **kw): - return iter([]) - - return instrumentation._traced_stream(cursor, noop_stream, sdk, query, params, **kwargs) - - def mock_copy(query, params=None, **kwargs): - # For mock cursor, original_copy is a no-op context manager - @contextmanager - def noop_copy(q, p=None, **kw): - yield MockCopy([]) - - return instrumentation._traced_copy(cursor, noop_copy, sdk, query, params, **kwargs) - - # Monkey-patch mock functions onto cursor - cursor.execute = mock_execute # type: ignore[method-assign] - cursor.executemany = mock_executemany # type: ignore[method-assign] - cursor.stream = mock_stream # type: ignore[method-assign] - cursor.copy = mock_copy # type: ignore[method-assign] - - logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") - return cursor - - def commit(self): - """Mock commit - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] commit() called (no-op)") - pass - - def rollback(self): - """Mock rollback - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] rollback() called (no-op)") - pass - - def close(self): - """Mock close - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] close() called (no-op)") - self.closed = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None: - self.rollback() - else: - self.commit() - return False - - def pipeline(self): - """Return a mock pipeline context manager for REPLAY mode.""" - return MockPipeline(self) - - -class MockCursor: - """Mock cursor for when we can't create a real cursor from base class. - - This is a fallback when the connection is completely mocked. - """ - - def __init__(self, connection): - self.connection = connection - self.rowcount = -1 - self._tusk_description = None # Store mock description - self.arraysize = 1 - self._mock_rows = [] - self._mock_index = 0 - # Support for multiple result sets (executemany with returning=True) - self._mock_result_sets = [] - self._mock_result_set_index = 0 - self.adapters = MockAdapters() # Django needs this - logger.debug("[MOCK_CURSOR] Created fallback mock cursor (psycopg3)") - - @property - def description(self): - return self._tusk_description - - @property - def rownumber(self): - """Return the index of the next row to fetch, or None if no result.""" - if self._mock_rows: - return self._mock_index - return None - - @property - def statusmessage(self): - """Return the mock status message if set, otherwise None.""" - return getattr(self, '_mock_statusmessage', None) - - def execute(self, query, params=None, **kwargs): - """Will be replaced by instrumentation.""" - logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") - return self - - def executemany(self, query, params_seq, **kwargs): - """Will be replaced by instrumentation.""" - logger.debug(f"[MOCK_CURSOR] executemany() called: {query[:100]}") - return self - - def fetchone(self): - return None - - def fetchmany(self, size=None): - return [] - - def fetchall(self): - return [] - - def results(self): - """Iterate over result sets for executemany with returning=True. - - This method is patched by _mock_executemany_returning_with_data - when replaying executemany with returning=True. - Default implementation yields self once for backward compatibility. - """ - yield self - - def nextset(self): - """Move to the next result set. - - Returns True if there is another result set, None otherwise. - This method is patched during replay for executemany with returning=True. - """ - return None - - def stream(self, query, params=None, **kwargs): - """Will be replaced by instrumentation.""" - return iter([]) - - def __iter__(self): - """Support direct cursor iteration (for row in cursor).""" - return self - - def __next__(self): - """Return next row for iteration.""" - if self._mock_index < len(self._mock_rows): - row = self._mock_rows[self._mock_index] - self._mock_index += 1 - return tuple(row) if isinstance(row, list) else row - raise StopIteration - - def scroll(self, value: int, mode: str = "relative") -> None: - """Scroll the cursor to a new position in the result set.""" - if mode == "relative": - newpos = self._mock_index + value - elif mode == "absolute": - newpos = value - else: - raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") - - num_rows = len(self._mock_rows) - if num_rows > 0: - if not (0 <= newpos < num_rows): - raise IndexError("position out of bound") - elif newpos != 0: - raise IndexError("position out of bound") - - self._mock_index = newpos - - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - -class MockCopy: - """Mock Copy object for REPLAY mode. - - Provides a minimal interface compatible with psycopg's Copy object - for COPY TO operations (iteration) and COPY FROM operations (write). - """ - - def __init__(self, data: list): - """Initialize MockCopy with recorded data. - - Args: - data: For COPY TO - list of data chunks (as strings from JSON, will be encoded to bytes) - """ - self._data = data - self._index = 0 - - def __iter__(self) -> Iterator[bytes]: - """Iterate over COPY TO data chunks.""" - for item in self._data: - # Data was stored as string in JSON, convert back to bytes - if isinstance(item, str): - yield item.encode("utf-8") - elif isinstance(item, bytes): - yield item - else: - yield str(item).encode("utf-8") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def read(self) -> bytes: - """Read next data chunk for COPY TO.""" - if self._index < len(self._data): - item = self._data[self._index] - self._index += 1 - if isinstance(item, str): - return item.encode("utf-8") - elif isinstance(item, bytes): - return item - return str(item).encode("utf-8") - return b"" - - def rows(self) -> Iterator[tuple]: - """Iterate over rows for COPY TO (parsed format).""" - for item in self._data: - yield tuple(item) if isinstance(item, list) else item - - def write(self, buffer) -> None: - """No-op for COPY FROM in replay mode.""" - pass - - def write_row(self, row) -> None: - """No-op for COPY FROM in replay mode.""" - pass - - def set_types(self, types) -> None: - """No-op for replay mode.""" - pass - - -class MockPipeline: - """Mock Pipeline for REPLAY mode. - - In REPLAY mode, pipeline operations are no-ops since queries - return mocked data immediately. - """ - - def __init__(self, connection: "MockConnection"): - self._conn = connection - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def sync(self): - """No-op sync for mock pipeline.""" - pass - - -class _TracedCopyWrapper: - """Wrapper around psycopg's Copy object to capture data in RECORD mode. - - Intercepts all data operations to record them for replay. - """ - - def __init__(self, copy: Any, data_collected: list): - """Initialize wrapper. - - Args: - copy: The real psycopg Copy object - data_collected: List to append captured data chunks to - """ - self._copy = copy - self._data_collected = data_collected - - def __iter__(self) -> Iterator[bytes]: - """Iterate over COPY TO data, capturing each chunk.""" - for data in self._copy: - # Handle both bytes and memoryview - if isinstance(data, memoryview): - data = bytes(data) - self._data_collected.append(data) - yield data - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def read(self) -> bytes: - """Read raw data from COPY TO, capturing it.""" - data = self._copy.read() - if data: - if isinstance(data, memoryview): - data = bytes(data) - self._data_collected.append(data) - return data - - def read_row(self): - """Read a parsed row from COPY TO.""" - row = self._copy.read_row() - if row is not None: - self._data_collected.append(row) - return row - - def rows(self): - """Iterate over parsed rows from COPY TO.""" - for row in self._copy.rows(): - self._data_collected.append(row) - yield row - - def write(self, buffer): - """Write raw data for COPY FROM.""" - self._data_collected.append(buffer) - return self._copy.write(buffer) - - def write_row(self, row): - """Write a row for COPY FROM.""" - self._data_collected.append(row) - return self._copy.write_row(row) - - def set_types(self, types): - """Proxy set_types to the underlying Copy object.""" - return self._copy.set_types(types) - - def __getattr__(self, name): - """Proxy any other attributes to the underlying copy object.""" - return getattr(self._copy, name) - - class PsycopgInstrumentation(InstrumentationBase): """Instrumentation for psycopg (psycopg3) PostgreSQL client library. @@ -1187,20 +757,6 @@ def _finalize_stream_span( ) -> None: """Finalize span for stream operation with collected rows.""" try: - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val - # Build input value input_value = { "query": query.strip(), @@ -1284,7 +840,7 @@ def _record_copy( query_str: str, params: Any, kwargs: dict, - ) -> Iterator[_TracedCopyWrapper]: + ) -> Iterator[TracedCopyWrapper]: """Handle RECORD mode for copy - wrap Copy object with tracing.""" is_pre_app_start = not sdk.app_ready @@ -1317,7 +873,7 @@ def _record_copy( with SpanUtils.with_span(span_info): with original_copy(query, params, **kwargs) as copy: # Wrap the Copy object to capture data - wrapped_copy = _TracedCopyWrapper(copy, data_collected) + wrapped_copy = TracedCopyWrapper(copy, data_collected) yield wrapped_copy except Exception as e: error = e @@ -1427,22 +983,6 @@ def _finalize_copy_span( ) -> None: """Finalize span for copy operation.""" try: - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, memoryview): - return bytes(val).decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val - # Determine operation type from query query_upper = query.upper() is_copy_to = "TO" in query_upper and "STDOUT" in query_upper @@ -2000,21 +1540,6 @@ def _finalize_query_span( ) -> None: """Finalize span with query data.""" try: - # Helper function to serialize non-JSON types - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val - # Build input value input_value = { "query": query.strip(), @@ -2183,19 +1708,6 @@ def _finalize_executemany_returning_span( be replayed with multiple result set iteration. """ try: - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val # Build input value input_value = { diff --git a/drift/instrumentation/psycopg/mocks.py b/drift/instrumentation/psycopg/mocks.py new file mode 100644 index 0000000..b5e672d --- /dev/null +++ b/drift/instrumentation/psycopg/mocks.py @@ -0,0 +1,384 @@ +"""Mock classes for psycopg3 REPLAY mode. + +These mock classes provide a minimal interface for Django/Flask to work +without a real PostgreSQL database connection during replay. +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterator + +if TYPE_CHECKING: + from ...core.drift_sdk import TuskDrift + from .instrumentation import PsycopgInstrumentation + +logger = logging.getLogger(__name__) + + +class MockLoader: + """Mock loader for psycopg3.""" + + def __init__(self): + self.timezone = None # Django expects this attribute + + def __call__(self, data): + """No-op load function.""" + return data + + +class MockDumper: + """Mock dumper for psycopg3.""" + + def __call__(self, obj): + """No-op dump function.""" + return str(obj).encode("utf-8") + + +class MockAdapters: + """Mock adapters for psycopg3 connection.""" + + def get_loader(self, oid, format): + """Return a mock loader.""" + return MockLoader() + + def get_dumper(self, obj, format): + """Return a mock dumper.""" + return MockDumper() + + def register_loader(self, oid, loader): + """No-op register loader for Django compatibility.""" + pass + + def register_dumper(self, oid, dumper): + """No-op register dumper for Django compatibility.""" + pass + + +class MockConnection: + """Mock database connection for REPLAY mode when postgres is not available. + + Provides minimal interface for Django/Flask to work without a real database. + All queries are mocked at the cursor.execute() level. + """ + + def __init__( + self, + sdk: TuskDrift, + instrumentation: PsycopgInstrumentation, + cursor_factory, + row_factory=None, + ): + self.sdk = sdk + self.instrumentation = instrumentation + self.cursor_factory = cursor_factory + self.row_factory = row_factory # Store row_factory for cursor creation + self.closed = False + self.autocommit = False + + # Django/psycopg3 requires these for connection initialization + self.isolation_level = None + self.encoding = "UTF8" + self.adapters = MockAdapters() + self.pgconn = None # Mock pg connection object + + # Create a comprehensive mock info object for Django + class MockInfo: + vendor = "postgresql" + server_version = 150000 # PostgreSQL 15.0 as integer + encoding = "UTF8" + + def parameter_status(self, param): + """Return mock parameter status.""" + if param == "TimeZone": + return "UTC" + elif param == "server_version": + return "15.0" + return None + + self.info = MockInfo() + + logger.debug("[MOCK_CONNECTION] Created mock connection for REPLAY mode (psycopg3)") + + def cursor(self, name=None, *, cursor_factory=None, **kwargs): + """Create a cursor using the instrumented cursor factory. + + Accepts the same parameters as psycopg's Connection.cursor(), including + server cursor parameters like scrollable and withhold. + """ + # For mock connections, we create a MockCursor directly + # The name parameter is accepted but not used since mock cursors + # behave the same for both regular and server cursors + cursor = MockCursor(self) + + # Wrap execute/executemany for mock cursor + instrumentation = self.instrumentation + sdk = self.sdk + + def mock_execute(query, params=None, **kwargs): + # For mock cursor, original_execute is just a no-op + def noop_execute(q, p, **kw): + return cursor + + return instrumentation._traced_execute(cursor, noop_execute, sdk, query, params, **kwargs) + + def mock_executemany(query, params_seq, **kwargs): + # For mock cursor, original_executemany is just a no-op + def noop_executemany(q, ps, **kw): + return cursor + + return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) + + def mock_stream(query, params=None, **kwargs): + # For mock cursor, original_stream is just a no-op generator + def noop_stream(q, p, **kw): + return iter([]) + + return instrumentation._traced_stream(cursor, noop_stream, sdk, query, params, **kwargs) + + def mock_copy(query, params=None, **kwargs): + # For mock cursor, original_copy is a no-op context manager + @contextmanager + def noop_copy(q, p=None, **kw): + yield MockCopy([]) + + return instrumentation._traced_copy(cursor, noop_copy, sdk, query, params, **kwargs) + + # Monkey-patch mock functions onto cursor + cursor.execute = mock_execute # type: ignore[method-assign] + cursor.executemany = mock_executemany # type: ignore[method-assign] + cursor.stream = mock_stream # type: ignore[method-assign] + cursor.copy = mock_copy # type: ignore[method-assign] + + logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") + return cursor + + def commit(self): + """Mock commit - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] commit() called (no-op)") + pass + + def rollback(self): + """Mock rollback - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] rollback() called (no-op)") + pass + + def close(self): + """Mock close - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] close() called (no-op)") + self.closed = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.rollback() + else: + self.commit() + return False + + def pipeline(self): + """Return a mock pipeline context manager for REPLAY mode.""" + return MockPipeline(self) + + +class MockCursor: + """Mock cursor for when we can't create a real cursor from base class. + + This is a fallback when the connection is completely mocked. + """ + + def __init__(self, connection): + self.connection = connection + self.rowcount = -1 + self._tusk_description = None # Store mock description + self.arraysize = 1 + self._mock_rows = [] + self._mock_index = 0 + # Support for multiple result sets (executemany with returning=True) + self._mock_result_sets = [] + self._mock_result_set_index = 0 + self.adapters = MockAdapters() # Django needs this + logger.debug("[MOCK_CURSOR] Created fallback mock cursor (psycopg3)") + + @property + def description(self): + return self._tusk_description + + @property + def rownumber(self): + """Return the index of the next row to fetch, or None if no result.""" + if self._mock_rows: + return self._mock_index + return None + + @property + def statusmessage(self): + """Return the mock status message if set, otherwise None.""" + return getattr(self, '_mock_statusmessage', None) + + def execute(self, query, params=None, **kwargs): + """Will be replaced by instrumentation.""" + logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") + return self + + def executemany(self, query, params_seq, **kwargs): + """Will be replaced by instrumentation.""" + logger.debug(f"[MOCK_CURSOR] executemany() called: {query[:100]}") + return self + + def fetchone(self): + return None + + def fetchmany(self, size=None): + return [] + + def fetchall(self): + return [] + + def results(self): + """Iterate over result sets for executemany with returning=True. + + This method is patched by _mock_executemany_returning_with_data + when replaying executemany with returning=True. + Default implementation yields self once for backward compatibility. + """ + yield self + + def nextset(self): + """Move to the next result set. + + Returns True if there is another result set, None otherwise. + This method is patched during replay for executemany with returning=True. + """ + return None + + def stream(self, query, params=None, **kwargs): + """Will be replaced by instrumentation.""" + return iter([]) + + def __iter__(self): + """Support direct cursor iteration (for row in cursor).""" + return self + + def __next__(self): + """Return next row for iteration.""" + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + return tuple(row) if isinstance(row, list) else row + raise StopIteration + + def scroll(self, value: int, mode: str = "relative") -> None: + """Scroll the cursor to a new position in the result set.""" + if mode == "relative": + newpos = self._mock_index + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(self._mock_rows) + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + self._mock_index = newpos + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + +class MockCopy: + """Mock Copy object for REPLAY mode. + + Provides a minimal interface compatible with psycopg's Copy object + for COPY TO operations (iteration) and COPY FROM operations (write). + """ + + def __init__(self, data: list): + """Initialize MockCopy with recorded data. + + Args: + data: For COPY TO - list of data chunks (as strings from JSON, will be encoded to bytes) + """ + self._data = data + self._index = 0 + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data chunks.""" + for item in self._data: + # Data was stored as string in JSON, convert back to bytes + if isinstance(item, str): + yield item.encode("utf-8") + elif isinstance(item, bytes): + yield item + else: + yield str(item).encode("utf-8") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read next data chunk for COPY TO.""" + if self._index < len(self._data): + item = self._data[self._index] + self._index += 1 + if isinstance(item, str): + return item.encode("utf-8") + elif isinstance(item, bytes): + return item + return str(item).encode("utf-8") + return b"" + + def rows(self) -> Iterator[tuple]: + """Iterate over rows for COPY TO (parsed format).""" + for item in self._data: + yield tuple(item) if isinstance(item, list) else item + + def write(self, buffer) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def write_row(self, row) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def set_types(self, types) -> None: + """No-op for replay mode.""" + pass + + +class MockPipeline: + """Mock Pipeline for REPLAY mode. + + In REPLAY mode, pipeline operations are no-ops since queries + return mocked data immediately. + """ + + def __init__(self, connection: MockConnection): + self._conn = connection + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def sync(self): + """No-op sync for mock pipeline.""" + pass diff --git a/drift/instrumentation/psycopg/wrappers.py b/drift/instrumentation/psycopg/wrappers.py new file mode 100644 index 0000000..333c988 --- /dev/null +++ b/drift/instrumentation/psycopg/wrappers.py @@ -0,0 +1,80 @@ +"""Wrapper classes for psycopg3 instrumentation. + +These wrappers intercept operations to capture data for recording. +""" + +from __future__ import annotations + +from typing import Any, Iterator + + +class TracedCopyWrapper: + """Wrapper around psycopg's Copy object to capture data in RECORD mode. + + Intercepts all data operations to record them for replay. + """ + + def __init__(self, copy: Any, data_collected: list): + """Initialize wrapper. + + Args: + copy: The real psycopg Copy object + data_collected: List to append captured data chunks to + """ + self._copy = copy + self._data_collected = data_collected + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data, capturing each chunk.""" + for data in self._copy: + # Handle both bytes and memoryview + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + yield data + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read raw data from COPY TO, capturing it.""" + data = self._copy.read() + if data: + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + return data + + def read_row(self): + """Read a parsed row from COPY TO.""" + row = self._copy.read_row() + if row is not None: + self._data_collected.append(row) + return row + + def rows(self): + """Iterate over parsed rows from COPY TO.""" + for row in self._copy.rows(): + self._data_collected.append(row) + yield row + + def write(self, buffer): + """Write raw data for COPY FROM.""" + self._data_collected.append(buffer) + return self._copy.write(buffer) + + def write_row(self, row): + """Write a row for COPY FROM.""" + self._data_collected.append(row) + return self._copy.write_row(row) + + def set_types(self, types): + """Proxy set_types to the underlying Copy object.""" + return self._copy.set_types(types) + + def __getattr__(self, name): + """Proxy any other attributes to the underlying copy object.""" + return getattr(self._copy, name) diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index 4030304..5da1aac 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -30,6 +30,7 @@ ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value +from ..utils.serialization import serialize_value logger = logging.getLogger(__name__) @@ -814,21 +815,6 @@ def _finalize_query_span( ) -> None: """Finalize span with query data.""" try: - # Helper function to serialize non-JSON types - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val - # Build input value query_str = _query_to_str(query) input_value = { diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py new file mode 100644 index 0000000..0cae8b6 --- /dev/null +++ b/drift/instrumentation/utils/serialization.py @@ -0,0 +1,30 @@ +"""Serialization utilities for instrumentation modules.""" + +from __future__ import annotations + +import datetime +from typing import Any + + +def serialize_value(val: Any) -> Any: + """Convert non-JSON-serializable values to JSON-compatible types. + + Handles datetime objects, bytes, and nested structures (lists, tuples, dicts). + + Args: + val: The value to serialize. + + Returns: + A JSON-serializable version of the value. + """ + if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): + return val.isoformat() + elif isinstance(val, bytes): + return val.decode("utf-8", errors="replace") + elif isinstance(val, memoryview): + return bytes(val).decode("utf-8", errors="replace") + elif isinstance(val, (list, tuple)): + return [serialize_value(v) for v in val] + elif isinstance(val, dict): + return {k: serialize_value(v) for k, v in val.items()} + return val From 35b3765de5b3dd388e1399d6961c85f1387b04d6 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 14:07:10 -0800 Subject: [PATCH 15/37] resolve cursor reuse hang in RECORD mode --- .../psycopg/e2e-tests/src/app.py | 127 ++++++++++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 14 ++ .../psycopg/instrumentation.py | 27 ++-- 3 files changed, 151 insertions(+), 17 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index cc5d6b7..7110a04 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -506,6 +506,133 @@ def test_server_cursor_scroll(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/cursor-reuse") +def test_cursor_reuse(): + """Test reusing cursor for multiple queries. + + Tests whether the instrumentation correctly handles reusing a cursor for multiple execute() calls. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # First query + cur.execute("SELECT id, name FROM users WHERE id = 1") + row1 = cur.fetchone() + + # Second query on same cursor + cur.execute("SELECT id, name FROM users WHERE id = 2") + row2 = cur.fetchone() + + # Third query + cur.execute("SELECT COUNT(*) FROM users") + count = cur.fetchone()[0] + + return jsonify({ + "row1": {"id": row1[0], "name": row1[1]} if row1 else None, + "row2": {"id": row2[0], "name": row2[1]} if row2 else None, + "count": count + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + +@app.route("/test/sql-composed") +def test_sql_composed(): + """Test psycopg.sql.SQL() composed queries.""" + try: + from psycopg import sql + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + table = sql.Identifier("users") + columns = sql.SQL(", ").join([ + sql.Identifier("id"), + sql.Identifier("name"), + sql.Identifier("email") + ]) + + query = sql.SQL("SELECT {} FROM {} ORDER BY id LIMIT 3").format(columns, table) + cur.execute(query) + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": [{"id": r[0], "name": r[1], "email": r[2]} for r in rows] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + + + +# ===== BUG HUNTING TEST ENDPOINTS ===== +# These endpoints expose confirmed bugs in the psycopg instrumentation +# Endpoints that passed tests have been removed + +@app.route("/test/binary-uuid") +def test_binary_uuid(): + """Test binary UUID data type. + + BUG HYPOTHESIS: UUID types may not serialize/deserialize correctly + during RECORD/REPLAY because they are binary. + """ + try: + import uuid + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create a temp table with UUID column + cur.execute("CREATE TEMP TABLE uuid_test (id UUID PRIMARY KEY, name TEXT)") + + # Insert a UUID + test_uuid = uuid.uuid4() + cur.execute( + "INSERT INTO uuid_test (id, name) VALUES (%s, %s) RETURNING id, name", + (test_uuid, "UUID Test") + ) + inserted = cur.fetchone() + + # Query it back + cur.execute("SELECT id, name FROM uuid_test WHERE id = %s", (test_uuid,)) + queried = cur.fetchone() + + conn.commit() + + return jsonify({ + "inserted_uuid": str(inserted[0]) if inserted else None, + "queried_uuid": str(queried[0]) if queried else None, + "match": str(inserted[0]) == str(queried[0]) if inserted and queried else False + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/binary-bytea") +def test_binary_bytea(): + """Test binary bytea data type. + + BUG HYPOTHESIS: Binary data (bytea) may not serialize/deserialize + correctly during RECORD/REPLAY. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create a temp table with bytea column + cur.execute("CREATE TEMP TABLE bytea_test (id SERIAL, data BYTEA)") + + # Insert binary data + test_data = b'\x00\x01\x02\x03\xff\xfe\xfd' + cur.execute( + "INSERT INTO bytea_test (data) VALUES (%s) RETURNING id, data", + (test_data,) + ) + inserted = cur.fetchone() + + conn.commit() + + # Convert bytes to hex for JSON serialization + return jsonify({ + "inserted_id": inserted[0] if inserted else None, + "data_hex": inserted[1].hex() if inserted and inserted[1] else None, + "data_length": len(inserted[1]) if inserted and inserted[1] else 0 + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index b0e9cc2..93754a5 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -92,4 +92,18 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/cursor-scroll") + make_request("GET", "/test/cursor-reuse") + + make_request("GET", "/test/sql-composed") + + # ===== BUG HUNTING TEST ENDPOINTS ===== + # These tests expose confirmed bugs in the psycopg instrumentation + # See BUG_TRACKING.md for detailed information about each bug + print("\n--- Bug Hunting Tests (REPLAY mode bugs - pass RECORD but fail REPLAY) ---\n") + + # Bug 8: UUID parameter serialization issue during REPLAY + make_request("GET", "/test/binary-uuid") + # Bug 9: bytea data deserialization returns string instead of bytes + make_request("GET", "/test/binary-bytea") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 298b0d0..3f7f67d 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -418,17 +418,16 @@ def _record_execute( ) -> Any: """Handle RECORD mode for execute - create span and execute query.""" # Reset cursor state from any previous execute() on this cursor. - # This ensures fetch methods work correctly for multiple queries on the same cursor. - # We must restore original fetch methods so _finalize_query_span can call the real - # psycopg fetchall() method, not our patched version from a previous query. - if hasattr(cursor, '_tusk_original_fetchone'): - cursor.fetchone = cursor._tusk_original_fetchone - cursor.fetchmany = cursor._tusk_original_fetchmany - cursor.fetchall = cursor._tusk_original_fetchall + # Delete instance attribute overrides to expose original class methods. + # This is safer than saving/restoring bound methods which can become stale. + if hasattr(cursor, '_tusk_patched'): + # Remove patched instance attributes to expose class methods + for attr in ('fetchone', 'fetchmany', 'fetchall', 'scroll'): + if attr in cursor.__dict__: + delattr(cursor, attr) cursor._tusk_rows = None cursor._tusk_index = 0 - if hasattr(cursor, '_tusk_original_scroll'): - cursor.scroll = cursor._tusk_original_scroll + del cursor._tusk_patched span_info = SpanUtils.create_span( CreateSpanOptions( @@ -1635,18 +1634,12 @@ def patched_scroll(value: int, mode: str = "relative") -> None: cursor._tusk_index = newpos # pyright: ignore[reportAttributeAccessIssue] - # Save original fetch methods before patching (only if not already saved) - # These will be restored at the start of the next execute() call - if not hasattr(cursor, '_tusk_original_fetchone'): - cursor._tusk_original_fetchone = cursor.fetchone # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_original_fetchmany = cursor.fetchmany # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_original_fetchall = cursor.fetchall # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_original_scroll = cursor.scroll # pyright: ignore[reportAttributeAccessIssue] - + # Patch fetch methods with our versions that return stored rows cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue] cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] cursor.scroll = patched_scroll # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_patched = True # pyright: ignore[reportAttributeAccessIssue] except Exception as fetch_error: logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}") From d22ed0aebf33b9dbec7d91868d02ba29785a4c71 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 14:30:14 -0800 Subject: [PATCH 16/37] Fix UUID parameter serialization mismatch in psycopg REPLAY mode --- drift/instrumentation/psycopg/instrumentation.py | 3 ++- drift/instrumentation/utils/serialization.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 3f7f67d..b9d370b 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1150,7 +1150,8 @@ def _try_get_mock( "query": query.strip(), } if params is not None: - input_value["parameters"] = params + # Serialize parameters to ensure consistent hashing with RECORD mode + input_value["parameters"] = serialize_value(params) # Use centralized mock finding utility from ...core.mock_utils import find_mock_response_sync diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index 0cae8b6..08fd8d8 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime +import uuid from typing import Any @@ -19,6 +20,8 @@ def serialize_value(val: Any) -> Any: """ if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): return val.isoformat() + elif isinstance(val, uuid.UUID): + return str(val) elif isinstance(val, bytes): return val.decode("utf-8", errors="replace") elif isinstance(val, memoryview): From 765fd9b28ba3236d039196e98738dbd09a2fe866 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 14:50:58 -0800 Subject: [PATCH 17/37] fix bytea serialization to preserve binary data during record/replay --- .../psycopg/e2e-tests/src/app.py | 14 ++------- .../psycopg/e2e-tests/src/test_requests.py | 8 +---- drift/instrumentation/utils/psycopg_utils.py | 26 ++++++++++------- drift/instrumentation/utils/serialization.py | 29 +++++++++++++++++-- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 7110a04..e4503db 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -559,19 +559,11 @@ def test_sql_composed(): except Exception as e: return jsonify({"error": str(e)}), 500 - - - -# ===== BUG HUNTING TEST ENDPOINTS ===== -# These endpoints expose confirmed bugs in the psycopg instrumentation -# Endpoints that passed tests have been removed - @app.route("/test/binary-uuid") def test_binary_uuid(): """Test binary UUID data type. - BUG HYPOTHESIS: UUID types may not serialize/deserialize correctly - during RECORD/REPLAY because they are binary. + Tests whether the instrumentation correctly handles binary UUID data types. """ try: import uuid @@ -602,13 +594,11 @@ def test_binary_uuid(): except Exception as e: return jsonify({"error": str(e)}), 500 - @app.route("/test/binary-bytea") def test_binary_bytea(): """Test binary bytea data type. - BUG HYPOTHESIS: Binary data (bytea) may not serialize/deserialize - correctly during RECORD/REPLAY. + Tests whether the instrumentation correctly handles binary bytea data types. """ try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 93754a5..d3cce98 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -96,14 +96,8 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/sql-composed") - # ===== BUG HUNTING TEST ENDPOINTS ===== - # These tests expose confirmed bugs in the psycopg instrumentation - # See BUG_TRACKING.md for detailed information about each bug - print("\n--- Bug Hunting Tests (REPLAY mode bugs - pass RECORD but fail REPLAY) ---\n") - - # Bug 8: UUID parameter serialization issue during REPLAY make_request("GET", "/test/binary-uuid") - # Bug 9: bytea data deserialization returns string instead of bytes + make_request("GET", "/test/binary-bytea") print("\nAll requests completed successfully") diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index 1dc3859..99738d5 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -2,27 +2,35 @@ from __future__ import annotations +import base64 import datetime as dt from typing import Any def deserialize_db_value(val: Any) -> Any: - """Convert ISO datetime strings back to datetime objects for consistent serialization. + """Convert serialized values back to their original Python types. - During recording, datetime objects from the database are serialized to ISO format strings. - During replay, we need to convert them back to datetime objects so that Flask/Django - serializes them the same way (e.g., RFC 2822 vs ISO 8601 format). + During recording, database values are serialized for JSON storage: + - datetime objects -> ISO format strings + - bytes/memoryview -> {"__bytes__": ""} - Only parses strings that contain a time component (T or space separator with :) to avoid - incorrectly converting date-only strings or text that happens to look like dates. + During replay, we need to convert them back to their original types so that + application code (Flask/Django) handles them the same way. Args: val: A value from the mocked database rows. Can be a string, list, dict, or any other type. Returns: - The value with ISO datetime strings converted back to datetime objects. + The value with serialized types converted back to their original Python types. """ - if isinstance(val, str): + if isinstance(val, dict): + # Check for bytes tagged structure + if "__bytes__" in val and len(val) == 1: + # Decode base64 back to bytes + return base64.b64decode(val["__bytes__"]) + # Recursively deserialize dict values + return {k: deserialize_db_value(v) for k, v in val.items()} + elif isinstance(val, str): # Only parse strings that look like full datetime (must have time component) # This avoids converting date-only strings like "2024-01-15" or text columns # that happen to match date patterns @@ -35,6 +43,4 @@ def deserialize_db_value(val: Any) -> Any: pass elif isinstance(val, list): return [deserialize_db_value(v) for v in val] - elif isinstance(val, dict): - return {k: deserialize_db_value(v) for k, v in val.items()} return val diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index 08fd8d8..a3c2933 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -2,11 +2,33 @@ from __future__ import annotations +import base64 import datetime import uuid from typing import Any +def _serialize_bytes(val: bytes) -> Any: + """Serialize bytes to a JSON-compatible format. + + Attempts UTF-8 decode first for text data (like COPY output). + Falls back to base64 encoding with tagged structure for binary data + that contains invalid UTF-8 sequences (like bytea columns). + + Args: + val: The bytes value to serialize. + + Returns: + Either a string (if valid UTF-8) or a dict {"__bytes__": "base64_data"}. + """ + try: + # Try UTF-8 decode first - works for text data like COPY output + return val.decode("utf-8") + except UnicodeDecodeError: + # Fall back to base64 for binary data with invalid UTF-8 sequences + return {"__bytes__": base64.b64encode(val).decode("ascii")} + + def serialize_value(val: Any) -> Any: """Convert non-JSON-serializable values to JSON-compatible types. @@ -22,10 +44,11 @@ def serialize_value(val: Any) -> Any: return val.isoformat() elif isinstance(val, uuid.UUID): return str(val) - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") elif isinstance(val, memoryview): - return bytes(val).decode("utf-8", errors="replace") + # Convert memoryview to bytes first, then serialize + return _serialize_bytes(bytes(val)) + elif isinstance(val, bytes): + return _serialize_bytes(val) elif isinstance(val, (list, tuple)): return [serialize_value(v) for v in val] elif isinstance(val, dict): From c95337f7195e3a4515164619185c8e2cdd2e11e9 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 15:30:57 -0800 Subject: [PATCH 18/37] Fix kwargs_row row factory handling in psycopg instrumentation --- .../psycopg/e2e-tests/src/app.py | 80 +++++++++++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 20 +---- .../psycopg/instrumentation.py | 31 ++++++- drift/instrumentation/psycopg/wrappers.py | 3 + 4 files changed, 115 insertions(+), 19 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index e4503db..786928a 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -624,6 +624,86 @@ def test_binary_bytea(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/class-row-factory") +def test_class_row_factory(): + """Test class_row row factory. + + Tests whether the instrumentation correctly handles class_row factories + which return instances of a custom class. + """ + try: + from psycopg.rows import class_row + from dataclasses import dataclass + + @dataclass + class User: + id: int + name: str + email: str + + with psycopg.connect(get_conn_string(), row_factory=class_row(User)) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/kwargs-row-factory") +def test_kwargs_row_factory(): + """Test kwargs_row row factory. + + Tests whether the instrumentation correctly handles kwargs_row factories + which call a function with keyword arguments. + """ + try: + from psycopg.rows import kwargs_row + + def make_user_dict(**kwargs): + return {"user_data": kwargs, "processed": True} + + with psycopg.connect(get_conn_string(), row_factory=kwargs_row(make_user_dict)) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": rows # Already processed dictionaries + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + +# BUG HUNTING TEST ENDPOINTS + +@app.route("/test/scalar-row-factory") +def test_scalar_row_factory(): + """Test scalar_row row factory. + + Tests whether the instrumentation correctly handles scalar_row factories + which return just the first column value. + """ + try: + from psycopg.rows import scalar_row + + with psycopg.connect(get_conn_string(), row_factory=scalar_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT name FROM users ORDER BY id LIMIT 5") + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": rows # Just names as scalars + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index d3cce98..b879f50 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -65,39 +65,25 @@ def make_request(method, endpoint, **kwargs): make_request("DELETE", f"/db/delete/{user_id}") make_request("GET", "/test/cursor-stream") - make_request("GET", "/test/server-cursor") - make_request("GET", "/test/copy-to") - make_request("GET", "/test/multiple-queries") - make_request("GET", "/test/pipeline-mode") - make_request("GET", "/test/dict-row-factory") - make_request("GET", "/test/namedtuple-row-factory") - make_request("GET", "/test/cursor-iteration") - make_request("GET", "/test/executemany-returning") - make_request("GET", "/test/rownumber") - make_request("GET", "/test/statusmessage") - make_request("GET", "/test/nextset") - make_request("GET", "/test/server-cursor-scroll") - make_request("GET", "/test/cursor-scroll") - make_request("GET", "/test/cursor-reuse") - make_request("GET", "/test/sql-composed") - make_request("GET", "/test/binary-uuid") - make_request("GET", "/test/binary-bytea") + make_request("GET", "/test/class-row-factory") + make_request("GET", "/test/kwargs-row-factory") + make_request("GET", "/test/scalar-row-factory") print("\nAll requests completed successfully") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index b9d370b..cc506d6 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1063,6 +1063,8 @@ def _detect_row_factory_type(self, row_factory: Any) -> str: return "dict" elif 'namedtuple' in factory_name_lower: return "namedtuple" + elif 'kwargs' in factory_name_lower: + return "kwargs" return "tuple" @@ -1234,6 +1236,9 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non def transform_row(row): """Transform raw row data according to row factory type.""" + if row_factory_type == "kwargs": + # kwargs_row: return stored dict as-is (already in correct format) + return row values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and column_names: return dict(zip(column_names, values)) @@ -1352,6 +1357,9 @@ def create_row_class(col_names): def transform_row(row, col_names, RowClass): """Transform raw row data according to row factory type.""" + if row_factory_type == "kwargs": + # kwargs_row: return stored dict as-is + return row values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and col_names: return dict(zip(col_names, values)) @@ -1582,9 +1590,23 @@ def _finalize_query_span( # Convert rows to lists for JSON serialization # Handle dict_row (returns dicts) and namedtuple_row (returns namedtuples) column_names = [d["name"] for d in description] + + # Get row factory from cursor or connection + row_factory = getattr(cursor, 'row_factory', None) + if row_factory is None: + conn = getattr(cursor, 'connection', None) + if conn: + row_factory = getattr(conn, 'row_factory', None) + + # Detect row factory type BEFORE processing rows + row_factory_type = self._detect_row_factory_type(row_factory) + rows = [] for row in all_rows: - if isinstance(row, dict): + if row_factory_type == "kwargs": + # kwargs_row: store the entire dict as-is (it has custom keys, not column names) + rows.append(row) + elif isinstance(row, dict): # dict_row: extract values in column order rows.append([row.get(col) for col in column_names]) elif hasattr(row, '_fields'): @@ -1655,7 +1677,12 @@ def patched_scroll(value: int, mode: str = "relative") -> None: if rows: # Convert rows to JSON-serializable format (handle datetime objects, etc.) - serialized_rows = [[serialize_value(col) for col in row] for row in rows] + # For kwargs_row, rows are custom dicts - serialize each row as a complete dict + # For other types, rows are lists of column values + if row_factory_type == "kwargs": + serialized_rows = [serialize_value(row) for row in rows] + else: + serialized_rows = [[serialize_value(col) for col in row] for row in rows] output_value["rows"] = serialized_rows # Capture statusmessage for replay diff --git a/drift/instrumentation/psycopg/wrappers.py b/drift/instrumentation/psycopg/wrappers.py index 333c988..12efc1f 100644 --- a/drift/instrumentation/psycopg/wrappers.py +++ b/drift/instrumentation/psycopg/wrappers.py @@ -63,6 +63,9 @@ def rows(self): def write(self, buffer): """Write raw data for COPY FROM.""" + # Convert memoryview to bytes to avoid mutation if buffer is reused + if isinstance(buffer, memoryview): + buffer = bytes(buffer) self._data_collected.append(buffer) return self._copy.write(buffer) From 9a407ab75e38d5cbec13cc087014dea5b68480f5 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 15:45:49 -0800 Subject: [PATCH 19/37] check number of tests replayed in e2e tests --- .../django/e2e-tests/src/test_requests.py | 24 +----------- drift/instrumentation/e2e_common/__init__.py | 9 ++++- .../instrumentation/e2e_common/base_runner.py | 36 +++++++++++++++++- .../instrumentation/e2e_common/test_utils.py | 37 +++++++++++++++++++ .../fastapi/e2e-tests/src/test_requests.py | 22 ++--------- .../flask/e2e-tests/src/test_requests.py | 22 +---------- .../httpx/e2e-tests/src/test_requests.py | 22 +---------- .../psycopg/e2e-tests/src/test_requests.py | 24 ++---------- .../psycopg2/e2e-tests/src/test_requests.py | 22 +---------- .../redis/e2e-tests/src/test_requests.py | 22 +---------- .../requests/e2e-tests/src/test_requests.py | 22 +---------- 11 files changed, 98 insertions(+), 164 deletions(-) create mode 100644 drift/instrumentation/e2e_common/test_utils.py diff --git a/drift/instrumentation/django/e2e-tests/src/test_requests.py b/drift/instrumentation/django/e2e-tests/src/test_requests.py index 4e0301a..9d4354c 100644 --- a/drift/instrumentation/django/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/django/e2e-tests/src/test_requests.py @@ -1,26 +1,6 @@ """Execute test requests against the Django app.""" -import os -import time - -import requests - -PORT = os.getenv("PORT", "8000") -BASE_URL = f"http://localhost:{PORT}" - - -def make_request(method: str, endpoint: str, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Django test request sequence...\n") @@ -42,4 +22,4 @@ def make_request(method: str, endpoint: str, **kwargs): ) make_request("DELETE", "/api/post/1/delete") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/e2e_common/__init__.py b/drift/instrumentation/e2e_common/__init__.py index 2585a76..283be03 100644 --- a/drift/instrumentation/e2e_common/__init__.py +++ b/drift/instrumentation/e2e_common/__init__.py @@ -1,5 +1,12 @@ """Common utilities for Python SDK e2e tests.""" from .base_runner import Colors, E2ETestRunnerBase +from .test_utils import get_request_count, make_request, print_request_summary -__all__ = ["Colors", "E2ETestRunnerBase"] +__all__ = [ + "Colors", + "E2ETestRunnerBase", + "get_request_count", + "make_request", + "print_request_summary", +] diff --git a/drift/instrumentation/e2e_common/base_runner.py b/drift/instrumentation/e2e_common/base_runner.py index 17890aa..00556a9 100644 --- a/drift/instrumentation/e2e_common/base_runner.py +++ b/drift/instrumentation/e2e_common/base_runner.py @@ -45,6 +45,7 @@ def __init__(self, app_port: int = 8000): self.app_port = app_port self.app_process: subprocess.Popen | None = None self.exit_code = 0 + self.expected_request_count: int | None = None # Register signal handlers for cleanup signal.signal(signal.SIGTERM, self._signal_handler) @@ -76,6 +77,17 @@ def run_command(self, cmd: list[str], env: dict | None = None, check: bool = Tru return result + def _parse_request_count(self, output: str): + """Parse the request count from test_requests.py output.""" + for line in output.split("\n"): + if line.startswith("TOTAL_REQUESTS_SENT:"): + try: + count = int(line.split(":")[1]) + self.expected_request_count = count + self.log(f"Captured request count: {count}", Colors.GREEN) + except (ValueError, IndexError): + self.log(f"Failed to parse request count from: {line}", Colors.YELLOW) + def wait_for_service(self, check_cmd: list[str], timeout: int = 30, interval: int = 1) -> bool: """Wait for a service to become ready.""" elapsed = 0 @@ -164,7 +176,13 @@ def record_traces(self) -> bool: # Execute test requests self.log("Executing test requests...", Colors.GREEN) try: - self.run_command(["python", "src/test_requests.py"]) + # Pass PYTHONPATH so test_requests.py can import from e2e_common + result = self.run_command( + ["python", "src/test_requests.py"], + env={"PYTHONPATH": "/sdk"}, + ) + # Parse request count from output + self._parse_request_count(result.stdout) except subprocess.CalledProcessError: self.log("Test requests failed", Colors.RED) self.exit_code = 1 @@ -257,6 +275,7 @@ def parse_test_results(self, output: str): idx += 1 all_passed = True + passed_count = 0 for result in results: test_id = result.get("test_id", "unknown") passed = result.get("passed", False) @@ -264,6 +283,7 @@ def parse_test_results(self, output: str): if passed: self.log(f"✓ Test ID: {test_id} (Duration: {duration}ms)", Colors.GREEN) + passed_count += 1 else: self.log(f"✗ Test ID: {test_id} (Duration: {duration}ms)", Colors.RED) all_passed = False @@ -276,6 +296,20 @@ def parse_test_results(self, output: str): self.log("Some tests failed!", Colors.RED) self.exit_code = 1 + # Validate request count matches passed tests + if self.expected_request_count is not None: + if passed_count < self.expected_request_count: + self.log( + f"✗ Request count mismatch: {passed_count} passed tests != {self.expected_request_count} requests sent", + Colors.RED, + ) + self.exit_code = 1 + else: + self.log( + f"✓ Request count validation: {passed_count} passed tests >= {self.expected_request_count} requests sent", + Colors.GREEN, + ) + except Exception as e: self.log(f"Failed to parse test results: {e}", Colors.RED) self.log(f"Raw output:\n{output}", Colors.YELLOW) diff --git a/drift/instrumentation/e2e_common/test_utils.py b/drift/instrumentation/e2e_common/test_utils.py new file mode 100644 index 0000000..951a774 --- /dev/null +++ b/drift/instrumentation/e2e_common/test_utils.py @@ -0,0 +1,37 @@ +""" +Shared test utilities for e2e tests. + +This module provides common functions used across all instrumentation e2e tests, +including request counting for validation. +""" + +import time + +import requests + +BASE_URL = "http://localhost:8000" +_request_count = 0 + + +def make_request(method, endpoint, **kwargs): + """Make HTTP request, log result, and track count.""" + global _request_count + _request_count += 1 + + url = f"{BASE_URL}{endpoint}" + print(f"→ {method} {endpoint}") + kwargs.setdefault("timeout", 30) + response = requests.request(method, url, **kwargs) + print(f" Status: {response.status_code}") + time.sleep(0.5) + return response + + +def get_request_count(): + """Return the current request count.""" + return _request_count + + +def print_request_summary(): + """Print the total request count in a parseable format.""" + print(f"\nTOTAL_REQUESTS_SENT:{_request_count}") diff --git a/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py b/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py index 743561a..e684b8b 100644 --- a/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py @@ -1,27 +1,9 @@ """Execute test requests against the FastAPI app.""" import json -import os -import time from pathlib import Path -import requests - -PORT = os.getenv("PORT", "8000") -BASE_URL = f"http://localhost:{PORT}" - - -def make_request(method: str, endpoint: str, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary def verify_stack_traces(): @@ -180,3 +162,5 @@ def verify_context_propagation(): print("\n" + "=" * 50) print("All requests completed successfully") print("=" * 50) + + print_request_summary() diff --git a/drift/instrumentation/flask/e2e-tests/src/test_requests.py b/drift/instrumentation/flask/e2e-tests/src/test_requests.py index 6ab2ff8..8f0861b 100644 --- a/drift/instrumentation/flask/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/flask/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence...\n") @@ -32,4 +14,4 @@ def make_request(method, endpoint, **kwargs): make_request("POST", "/api/post", json={"title": "Test Post", "body": "This is a test post", "userId": 1}) make_request("DELETE", "/api/post/1") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/httpx/e2e-tests/src/test_requests.py b/drift/instrumentation/httpx/e2e-tests/src/test_requests.py index ebcbebd..a5b7b53 100644 --- a/drift/instrumentation/httpx/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/httpx/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app to exercise the HTTPX instrumentation.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"-> {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence for HTTPX instrumentation...\n") @@ -124,4 +106,4 @@ def make_request(method, endpoint, **kwargs): make_request("POST", "/test/file-like-body") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index b879f50..497cf7d 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Psycopg Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Psycopg test request sequence...\n") @@ -84,6 +66,6 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/binary-bytea") make_request("GET", "/test/class-row-factory") make_request("GET", "/test/kwargs-row-factory") - make_request("GET", "/test/scalar-row-factory") + # make_request("GET", "/test/scalar-row-factory") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py index 629647c..287afd3 100644 --- a/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Psycopg Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Psycopg test request sequence...\n") @@ -64,4 +46,4 @@ def make_request(method, endpoint, **kwargs): if user_id: make_request("DELETE", f"/db/delete/{user_id}") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/redis/e2e-tests/src/test_requests.py b/drift/instrumentation/redis/e2e-tests/src/test_requests.py index 9cc893c..b923668 100644 --- a/drift/instrumentation/redis/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/redis/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Redis Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Redis test request sequence...\n") @@ -62,4 +44,4 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/transaction-watch") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/requests/e2e-tests/src/test_requests.py b/drift/instrumentation/requests/e2e-tests/src/test_requests.py index 5976ff4..67e9904 100644 --- a/drift/instrumentation/requests/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/requests/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app to exercise the requests instrumentation.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"-> {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence for requests instrumentation...\n") @@ -76,4 +58,4 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/response-hooks") - print("\nAll requests completed successfully") + print_request_summary() From be8de290bae77f85e383be66c93a302097b6be13 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 16:05:08 -0800 Subject: [PATCH 20/37] Fix scalar_row factory handling in psycopg instrumentation --- .../psycopg/e2e-tests/src/test_requests.py | 2 +- drift/instrumentation/psycopg/instrumentation.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 497cf7d..d2a2be8 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -66,6 +66,6 @@ make_request("GET", "/test/binary-bytea") make_request("GET", "/test/class-row-factory") make_request("GET", "/test/kwargs-row-factory") - # make_request("GET", "/test/scalar-row-factory") + make_request("GET", "/test/scalar-row-factory") print_request_summary() diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index cc506d6..b3df125 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1065,6 +1065,8 @@ def _detect_row_factory_type(self, row_factory: Any) -> str: return "namedtuple" elif 'kwargs' in factory_name_lower: return "kwargs" + elif 'scalar' in factory_name_lower: + return "scalar" return "tuple" @@ -1239,6 +1241,9 @@ def transform_row(row): if row_factory_type == "kwargs": # kwargs_row: return stored dict as-is (already in correct format) return row + if row_factory_type == "scalar": + # scalar_row: unwrap the single-element list to get the scalar value + return row[0] if isinstance(row, list) and len(row) > 0 else row values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and column_names: return dict(zip(column_names, values)) @@ -1606,6 +1611,9 @@ def _finalize_query_span( if row_factory_type == "kwargs": # kwargs_row: store the entire dict as-is (it has custom keys, not column names) rows.append(row) + elif row_factory_type == "scalar": + # scalar_row: returns single values - wrap in list for consistent storage + rows.append([row]) elif isinstance(row, dict): # dict_row: extract values in column order rows.append([row.get(col) for col in column_names]) From 0a2a1c90cb703a6b09a892c3db13ac97f2a8ad08 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 16:12:01 -0800 Subject: [PATCH 21/37] handle replaying uuid properly --- drift/instrumentation/utils/psycopg_utils.py | 5 +++++ drift/instrumentation/utils/serialization.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index 99738d5..2aeec45 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -4,6 +4,7 @@ import base64 import datetime as dt +import uuid from typing import Any @@ -13,6 +14,7 @@ def deserialize_db_value(val: Any) -> Any: During recording, database values are serialized for JSON storage: - datetime objects -> ISO format strings - bytes/memoryview -> {"__bytes__": ""} + - uuid.UUID -> {"__uuid__": ""} During replay, we need to convert them back to their original types so that application code (Flask/Django) handles them the same way. @@ -28,6 +30,9 @@ def deserialize_db_value(val: Any) -> Any: if "__bytes__" in val and len(val) == 1: # Decode base64 back to bytes return base64.b64decode(val["__bytes__"]) + # Check for UUID tagged structure + if "__uuid__" in val and len(val) == 1: + return uuid.UUID(val["__uuid__"]) # Recursively deserialize dict values return {k: deserialize_db_value(v) for k, v in val.items()} elif isinstance(val, str): diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index a3c2933..038653b 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -43,7 +43,7 @@ def serialize_value(val: Any) -> Any: if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): return val.isoformat() elif isinstance(val, uuid.UUID): - return str(val) + return {"__uuid__": str(val)} elif isinstance(val, memoryview): # Convert memoryview to bytes first, then serialize return _serialize_bytes(bytes(val)) From f6d4d329f48cae91d031d306eab9e31e70eea030 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Wed, 14 Jan 2026 17:24:58 -0800 Subject: [PATCH 22/37] Fix binary format hang by deferring result capture until fetch --- .../psycopg/e2e-tests/src/app.py | 182 +++++++++- .../psycopg/e2e-tests/src/test_requests.py | 12 + .../psycopg/instrumentation.py | 337 ++++++++++++------ 3 files changed, 416 insertions(+), 115 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 786928a..ca34230 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -679,8 +679,6 @@ def make_user_dict(**kwargs): except Exception as e: return jsonify({"error": str(e)}), 500 -# BUG HUNTING TEST ENDPOINTS - @app.route("/test/scalar-row-factory") def test_scalar_row_factory(): """Test scalar_row row factory. @@ -703,6 +701,186 @@ def test_scalar_row_factory(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/binary-format") +def test_binary_format(): + """Test execute with binary=True parameter. + + Tests whether the instrumentation correctly handles binary format transfers. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Execute with binary=True + cur.execute( + "SELECT id, name FROM users ORDER BY id LIMIT 3", + binary=True + ) + rows = cur.fetchall() + + return jsonify({ + "count": len(rows), + "data": [{"id": r[0], "name": r[1]} for r in rows] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +# ===================================================================== +# BUG-EXPOSING TEST ENDPOINTS +# These endpoints expose confirmed bugs in the psycopg instrumentation. +# See BUG_TRACKING.md for detailed analysis. +# ===================================================================== + + +@app.route("/test/null-values") +def test_null_values(): + """Test handling of NULL values in results. + + BUG INVESTIGATION: NULL value serialization/deserialization may have issues. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with nullable columns + cur.execute(""" + CREATE TEMP TABLE null_test ( + id INT, + nullable_text TEXT, + nullable_int INT, + nullable_bool BOOLEAN + ) + """) + + # Insert rows with NULL values + cur.execute(""" + INSERT INTO null_test VALUES + (1, 'has_value', 42, TRUE), + (2, NULL, NULL, NULL), + (3, 'another', NULL, FALSE) + """) + + # Query rows + cur.execute("SELECT * FROM null_test ORDER BY id") + rows = cur.fetchall() + conn.commit() + + return jsonify({ + "count": len(rows), + "data": [ + { + "id": r[0], + "nullable_text": r[1], + "nullable_int": r[2], + "nullable_bool": r[3] + } + for r in rows + ] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/transaction-context") +def test_transaction_context(): + """Test conn.transaction() context manager. + + BUG INVESTIGATION: Explicit transaction context manager may not work correctly. + """ + try: + results = [] + with psycopg.connect(get_conn_string()) as conn: + # Use explicit transaction + with conn.transaction(): + with conn.cursor() as cur: + cur.execute("CREATE TEMP TABLE tx_test (id INT, val TEXT)") + cur.execute("INSERT INTO tx_test VALUES (1, 'first')") + cur.execute("SELECT * FROM tx_test") + rows = cur.fetchall() + results.append({"phase": "inside_transaction", "rows": [list(r) for r in rows]}) + + # After transaction commit + with conn.cursor() as cur: + cur.execute("SELECT * FROM tx_test") + rows = cur.fetchall() + results.append({"phase": "after_commit", "rows": [list(r) for r in rows]}) + + return jsonify({"results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/json-jsonb") +def test_json_jsonb(): + """Test JSON and JSONB data types. + + BUG INVESTIGATION: JSON types may have serialization issues. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with JSON columns + cur.execute(""" + CREATE TEMP TABLE json_test ( + id INT, + json_col JSON, + jsonb_col JSONB + ) + """) + + # Insert JSON data + import json + test_json = {"name": "test", "values": [1, 2, 3], "nested": {"key": "value"}} + cur.execute( + "INSERT INTO json_test VALUES (%s, %s, %s)", + (1, json.dumps(test_json), json.dumps(test_json)) + ) + + # Query back + cur.execute("SELECT * FROM json_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "json_col": row[1], + "jsonb_col": row[2] + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/array-types") +def test_array_types(): + """Test PostgreSQL array types. + + BUG INVESTIGATION: Array types may have serialization issues. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with array columns + cur.execute(""" + CREATE TEMP TABLE array_test ( + id INT, + int_array INTEGER[], + text_array TEXT[] + ) + """) + + # Insert array data + cur.execute( + "INSERT INTO array_test VALUES (%s, %s, %s)", + (1, [10, 20, 30], ["a", "b", "c"]) + ) + + # Query back + cur.execute("SELECT * FROM array_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "int_array": list(row[1]) if row[1] else None, + "text_array": list(row[2]) if row[2] else None + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 if __name__ == "__main__": sdk.mark_app_as_ready() diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index d2a2be8..3c6ce49 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -67,5 +67,17 @@ make_request("GET", "/test/class-row-factory") make_request("GET", "/test/kwargs-row-factory") make_request("GET", "/test/scalar-row-factory") + make_request("GET", "/test/binary-format") + + # ===================================================================== + # BUG-EXPOSING TEST ENDPOINTS (SKIPPED) + # These endpoints expose confirmed bugs and cause the app to hang. + # See BUG_TRACKING.md for detailed analysis. + # ===================================================================== + + # make_request("GET", "/test/null-values") + # make_request("GET", "/test/transaction-context") + # make_request("GET", "/test/json-jsonb") + # make_request("GET", "/test/array-types") print_request_summary() diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index b3df125..65ce2a2 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -478,9 +478,13 @@ def _record_execute( self._finalize_query_span(span_info.span, cursor, query_str, params, None) span_info.span.end() else: - # Normal mode: finalize immediately - self._finalize_query_span(span_info.span, cursor, query_str, params, None) - span_info.span.end() + # Normal mode: finalize immediately (unless lazy capture was set up) + span_finalized = self._finalize_query_span(span_info.span, cursor, query_str, params, None) + if span_finalized: + # Span was fully finalized, end it now + span_info.span.end() + # If span_finalized is False, lazy capture was set up and span will be + # ended when user code calls a fetch method def _traced_executemany( self, cursor: Any, original_executemany: Any, sdk: TuskDrift, query: str, params_seq, **kwargs @@ -620,16 +624,20 @@ def _record_executemany( {"_batch": params_list, "_returning": True}, error, ) + span_info.span.end() else: # Existing behavior for executemany without returning - self._finalize_query_span( + span_finalized = self._finalize_query_span( span_info.span, cursor, query_str, {"_batch": params_list}, error, ) - span_info.span.end() + if span_finalized: + span_info.span.end() + # Note: executemany without returning typically has no results, + # so lazy capture is unlikely but we handle it for safety def _traced_stream( self, cursor: Any, original_stream: Any, sdk: TuskDrift, query: str, params=None, **kwargs @@ -1048,7 +1056,8 @@ def _detect_row_factory_type(self, row_factory: Any) -> str: """Detect the type of row factory for mock transformations. Returns: - "dict" for dict_row, "namedtuple" for namedtuple_row, "tuple" otherwise + "dict" for dict_row, "namedtuple" for namedtuple_row, + "class" for class_row, "tuple" otherwise """ if row_factory is None: return "tuple" @@ -1067,6 +1076,8 @@ def _detect_row_factory_type(self, row_factory: Any) -> str: return "kwargs" elif 'scalar' in factory_name_lower: return "scalar" + elif 'class' in factory_name_lower: + return "class" return "tuple" @@ -1126,8 +1137,10 @@ def _finalize_pending_pipeline_spans(self, connection: Any) -> None: params = item['params'] try: - self._finalize_query_span(span_info.span, cursor, query, params, error=None) - span_info.span.end() + span_finalized = self._finalize_query_span(span_info.span, cursor, query, params, error=None) + if span_finalized: + span_info.span.end() + # If lazy capture was set up, span will be ended when user fetches except Exception as e: logger.error(f"[PIPELINE] Error finalizing deferred span: {e}") try: @@ -1231,8 +1244,10 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non row_factory_type = self._detect_row_factory_type(row_factory) # Create namedtuple class once if needed (avoid recreating for each row) + # Used for both namedtuple_row and class_row (class_row returns dataclass instances, + # but in replay we can't recreate the exact class, so we use namedtuple as a compatible substitute) RowClass = None - if row_factory_type == "namedtuple" and column_names: + if row_factory_type in ("namedtuple", "class") and column_names: from collections import namedtuple RowClass = namedtuple('Row', column_names) @@ -1247,7 +1262,9 @@ def transform_row(row): values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and column_names: return dict(zip(column_names, values)) - elif row_factory_type == "namedtuple" and RowClass is not None: + elif row_factory_type in ("namedtuple", "class") and RowClass is not None: + # For class_row, we use namedtuple as a compatible substitute that supports + # attribute access (row.id, row.name, etc.) return RowClass(*values) return values @@ -1550,8 +1567,12 @@ def _finalize_query_span( query: str, params: Any, error: Exception | None, - ) -> None: - """Finalize span with query data.""" + ) -> bool: + """Finalize span with query data. + + Returns True if span was fully finalized, False if lazy capture was set up + (meaning caller should NOT end the span - it will be ended by lazy fetch). + """ try: # Build input value input_value = { @@ -1575,6 +1596,7 @@ def _finalize_query_span( try: rows = [] description = None + row_factory_type = "tuple" # default # Try to fetch results if available if hasattr(cursor, "description") and cursor.description: @@ -1588,111 +1610,39 @@ def _finalize_query_span( for desc in cursor.description ] - # Fetch all rows for recording - # We need to capture these for replay mode - try: - all_rows = cursor.fetchall() - # Convert rows to lists for JSON serialization - # Handle dict_row (returns dicts) and namedtuple_row (returns namedtuples) - column_names = [d["name"] for d in description] - - # Get row factory from cursor or connection - row_factory = getattr(cursor, 'row_factory', None) - if row_factory is None: - conn = getattr(cursor, 'connection', None) - if conn: - row_factory = getattr(conn, 'row_factory', None) - - # Detect row factory type BEFORE processing rows - row_factory_type = self._detect_row_factory_type(row_factory) - - rows = [] - for row in all_rows: - if row_factory_type == "kwargs": - # kwargs_row: store the entire dict as-is (it has custom keys, not column names) - rows.append(row) - elif row_factory_type == "scalar": - # scalar_row: returns single values - wrap in list for consistent storage - rows.append([row]) - elif isinstance(row, dict): - # dict_row: extract values in column order - rows.append([row.get(col) for col in column_names]) - elif hasattr(row, '_fields'): - # namedtuple: extract values in column order - rows.append([getattr(row, col, None) for col in column_names]) - else: - # tuple or list: convert directly - rows.append(list(row)) - - # CRITICAL: Re-populate cursor so user code can still fetch - # We'll store the rows and patch fetch methods - cursor._tusk_rows = all_rows # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] - - # Replace with our versions that return stored rows - def patched_fetchone(): - if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] - row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] - return row - return None - - def patched_fetchmany(size=cursor.arraysize): - result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] - return result - - def patched_fetchall(): - result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] - return result - - def patched_scroll(value: int, mode: str = "relative") -> None: - """Scroll the cursor to a new position in the captured result set.""" - if mode == "relative": - newpos = cursor._tusk_index + value # pyright: ignore[reportAttributeAccessIssue] - elif mode == "absolute": - newpos = value - else: - raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") - - num_rows = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] - if num_rows > 0: - if not (0 <= newpos < num_rows): - raise IndexError("cursor position out of range") - elif newpos != 0: - raise IndexError("cursor position out of range") - - cursor._tusk_index = newpos # pyright: ignore[reportAttributeAccessIssue] - - # Patch fetch methods with our versions that return stored rows - cursor.fetchone = patched_fetchone # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] - cursor.scroll = patched_scroll # pyright: ignore[reportAttributeAccessIssue] - cursor._tusk_patched = True # pyright: ignore[reportAttributeAccessIssue] - - except Exception as fetch_error: - logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}") - rows = [] - + # Get row factory from cursor or connection + row_factory = getattr(cursor, 'row_factory', None) + if row_factory is None: + conn = getattr(cursor, 'connection', None) + if conn: + row_factory = getattr(conn, 'row_factory', None) + + # Detect row factory type BEFORE processing rows + row_factory_type = self._detect_row_factory_type(row_factory) + column_names = [d["name"] for d in description] + + # Use LAZY CAPTURE to avoid hanging with binary=True and other edge cases. + # Instead of calling fetchall() immediately (which can hang), we set up + # wrappers that capture results when the user first calls a fetch method. + # Store context needed for lazy capture + cursor._tusk_lazy_span = span # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_input_value = input_value # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_description = description # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_row_factory_type = row_factory_type # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_column_names = column_names # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_instrumentation = self # pyright: ignore[reportAttributeAccessIssue] + + # Set up lazy capture wrappers + self._setup_lazy_capture(cursor) + + logger.debug("[PSYCOPG] Lazy capture set up, deferring span finalization") + return False # Signal caller NOT to end span + + # No description means no results expected (e.g., INSERT without RETURNING) output_value = { "rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1, } - if description: - output_value["description"] = description - - if rows: - # Convert rows to JSON-serializable format (handle datetime objects, etc.) - # For kwargs_row, rows are custom dicts - serialize each row as a complete dict - # For other types, rows are lists of column values - if row_factory_type == "kwargs": - serialized_rows = [serialize_value(row) for row in rows] - else: - serialized_rows = [[serialize_value(col) for col in row] for row in rows] - output_value["rows"] = serialized_rows - # Capture statusmessage for replay if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: output_value["statusmessage"] = cursor.statusmessage @@ -1718,9 +1668,170 @@ def patched_scroll(value: int, mode: str = "relative") -> None: span.set_status(Status(OTelStatusCode.OK)) logger.debug("[PSYCOPG] Span finalized successfully") + return True # Span fully finalized except Exception as e: logger.error(f"Error creating query span: {e}") + return True # Return True to end span on error + + def _setup_lazy_capture(self, cursor: Any) -> None: + """Set up lazy capture wrappers on cursor fetch methods. + + These wrappers defer the actual fetchall() call until the user's code + requests results. This avoids issues with binary format and other cases + where calling fetchall() immediately after execute() can hang. + """ + # Get references to original fetch methods from the cursor's class + # (not instance methods which might already be patched) + cursor_class = type(cursor) + original_fetchall = cursor_class.fetchall + original_fetchone = cursor_class.fetchone + original_fetchmany = cursor_class.fetchmany + original_scroll = cursor_class.scroll if hasattr(cursor_class, 'scroll') else None + + def do_lazy_capture(): + """Perform the actual capture - called on first fetch.""" + if hasattr(cursor, '_tusk_rows') and cursor._tusk_rows is not None: + return # Already captured + + try: + # Get the actual rows from psycopg + all_rows = original_fetchall(cursor) + + # Store for subsequent fetch calls + cursor._tusk_rows = all_rows # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Process rows for trace capture + description = cursor._tusk_lazy_description # pyright: ignore[reportAttributeAccessIssue] + row_factory_type = cursor._tusk_lazy_row_factory_type # pyright: ignore[reportAttributeAccessIssue] + column_names = cursor._tusk_lazy_column_names # pyright: ignore[reportAttributeAccessIssue] + + rows = [] + for row in all_rows: + if row_factory_type == "kwargs": + rows.append(row) + elif row_factory_type == "scalar": + rows.append([row]) + elif row_factory_type == "class" or hasattr(row, '__dataclass_fields__'): + # dataclass (from class_row): extract values by attribute name + rows.append([getattr(row, col, None) for col in column_names]) + elif isinstance(row, dict): + rows.append([row.get(col) for col in column_names]) + elif hasattr(row, '_fields'): + # namedtuple: extract values by field name + rows.append([getattr(row, col, None) for col in column_names]) + else: + rows.append(list(row)) + + # Finalize the span with captured data + span = cursor._tusk_lazy_span # pyright: ignore[reportAttributeAccessIssue] + input_value = cursor._tusk_lazy_input_value # pyright: ignore[reportAttributeAccessIssue] + instrumentation = cursor._tusk_lazy_instrumentation # pyright: ignore[reportAttributeAccessIssue] + + output_value = { + "rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1, + } + + if description: + output_value["description"] = description + + if rows: + if row_factory_type == "kwargs": + serialized_rows = [serialize_value(row) for row in rows] + else: + serialized_rows = [[serialize_value(col) for col in row] for row in rows] + output_value["rows"] = serialized_rows + + if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: + output_value["statusmessage"] = cursor.statusmessage + + # Generate schemas and hashes + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + # Set span attributes + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + + span.set_status(Status(OTelStatusCode.OK)) + span.end() + + logger.debug("[PSYCOPG] Lazy capture completed, span finalized") + + except Exception as e: + logger.error(f"Error in lazy capture: {e}") + # Try to end span even on error + try: + span = cursor._tusk_lazy_span # pyright: ignore[reportAttributeAccessIssue] + span.set_status(Status(OTelStatusCode.ERROR, str(e))) + span.end() + except Exception: + pass + + finally: + # Clean up lazy capture attributes + for attr in ('_tusk_lazy_span', '_tusk_lazy_input_value', '_tusk_lazy_description', + '_tusk_lazy_row_factory_type', '_tusk_lazy_column_names', '_tusk_lazy_instrumentation'): + if hasattr(cursor, attr): + try: + delattr(cursor, attr) + except AttributeError: + pass + + def lazy_fetchone(): + do_lazy_capture() + if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return row + return None + + def lazy_fetchmany(size=None): + do_lazy_capture() + if size is None: + size = cursor.arraysize + result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] + return result + + def lazy_fetchall(): + do_lazy_capture() + result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + return result + + def lazy_scroll(value: int, mode: str = "relative") -> None: + do_lazy_capture() + if mode == "relative": + newpos = cursor._tusk_index + value # pyright: ignore[reportAttributeAccessIssue] + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("cursor position out of range") + elif newpos != 0: + raise IndexError("cursor position out of range") + + cursor._tusk_index = newpos # pyright: ignore[reportAttributeAccessIssue] + + # Patch the cursor with lazy wrappers + cursor.fetchone = lazy_fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = lazy_fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = lazy_fetchall # pyright: ignore[reportAttributeAccessIssue] + if original_scroll: + cursor.scroll = lazy_scroll # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_patched = True # pyright: ignore[reportAttributeAccessIssue] def _finalize_executemany_returning_span( self, From 8ad42acbc6e610bc3b09d78134dc46e129fc6cd1 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 11:01:02 -0800 Subject: [PATCH 23/37] Enable null-values test for psycopg instrumentation NULL values handling now passes thanks to the lazy capture mechanism that defers fetchall() until user code actually calls fetch. --- .../psycopg/e2e-tests/src/test_requests.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 3c6ce49..b545f91 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -70,12 +70,14 @@ make_request("GET", "/test/binary-format") # ===================================================================== - # BUG-EXPOSING TEST ENDPOINTS (SKIPPED) - # These endpoints expose confirmed bugs and cause the app to hang. + # PREVIOUSLY BUG-EXPOSING TEST ENDPOINTS # See BUG_TRACKING.md for detailed analysis. # ===================================================================== - # make_request("GET", "/test/null-values") + # Test 3: NULL values handling - FIXED (lazy capture mechanism) + make_request("GET", "/test/null-values") + + # These still need investigation: # make_request("GET", "/test/transaction-context") # make_request("GET", "/test/json-jsonb") # make_request("GET", "/test/array-types") From 0526956f93c590c9f62be58559efec8e558719ee Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 11:03:59 -0800 Subject: [PATCH 24/37] Clean up null-values test comments Remove bug investigation comments from the null-values endpoint since the issue is now resolved and the test is part of the E2E suite. --- drift/instrumentation/psycopg/e2e-tests/src/app.py | 5 +---- .../instrumentation/psycopg/e2e-tests/src/test_requests.py | 7 +------ 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index ca34230..5542524 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -733,10 +733,7 @@ def test_binary_format(): @app.route("/test/null-values") def test_null_values(): - """Test handling of NULL values in results. - - BUG INVESTIGATION: NULL value serialization/deserialization may have issues. - """ + """Test handling of NULL values in results.""" try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: # Create temp table with nullable columns diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index b545f91..d398dab 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -69,12 +69,7 @@ make_request("GET", "/test/scalar-row-factory") make_request("GET", "/test/binary-format") - # ===================================================================== - # PREVIOUSLY BUG-EXPOSING TEST ENDPOINTS - # See BUG_TRACKING.md for detailed analysis. - # ===================================================================== - - # Test 3: NULL values handling - FIXED (lazy capture mechanism) + # Test: NULL values handling (integrated into E2E suite) make_request("GET", "/test/null-values") # These still need investigation: From 2449d4005fb9e50574f2da5a4b94d39787b973f7 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 11:15:11 -0800 Subject: [PATCH 25/37] Enable transaction context manager test for psycopg instrumentation The conn.transaction() context manager now works correctly thanks to the lazy capture mechanism that defers fetchall() until fetch is called. --- drift/instrumentation/psycopg/e2e-tests/src/test_requests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index d398dab..b546825 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -72,8 +72,10 @@ # Test: NULL values handling (integrated into E2E suite) make_request("GET", "/test/null-values") + # Test: Transaction context manager + make_request("GET", "/test/transaction-context") + # These still need investigation: - # make_request("GET", "/test/transaction-context") # make_request("GET", "/test/json-jsonb") # make_request("GET", "/test/array-types") From a9ab04f24947ad1963f8a12c068bee5a8c48be62 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 11:16:20 -0800 Subject: [PATCH 26/37] Clean up transaction context manager test comments Remove bug investigation comments from the endpoint since the issue is now resolved and the test is part of the E2E suite. --- drift/instrumentation/psycopg/e2e-tests/src/app.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 5542524..ecca9e2 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -777,10 +777,7 @@ def test_null_values(): @app.route("/test/transaction-context") def test_transaction_context(): - """Test conn.transaction() context manager. - - BUG INVESTIGATION: Explicit transaction context manager may not work correctly. - """ + """Test conn.transaction() context manager.""" try: results = [] with psycopg.connect(get_conn_string()) as conn: From 7dbf9eb85d32e302a2dfa56f18f38db855de60af Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 11:55:48 -0800 Subject: [PATCH 27/37] fix logging OOM issue during e2e tests --- drift/instrumentation/e2e_common/base_runner.py | 4 ++-- drift/instrumentation/psycopg/e2e-tests/src/app.py | 8 -------- .../psycopg/e2e-tests/src/test_requests.py | 6 +++--- drift/instrumentation/psycopg/instrumentation.py | 1 - 4 files changed, 5 insertions(+), 14 deletions(-) diff --git a/drift/instrumentation/e2e_common/base_runner.py b/drift/instrumentation/e2e_common/base_runner.py index 00556a9..1b089e0 100644 --- a/drift/instrumentation/e2e_common/base_runner.py +++ b/drift/instrumentation/e2e_common/base_runner.py @@ -150,8 +150,8 @@ def record_traces(self) -> bool: self.app_process = subprocess.Popen( ["python", "src/app.py"], env={**os.environ, **env}, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, text=True, ) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index ecca9e2..607c4aa 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -723,14 +723,6 @@ def test_binary_format(): except Exception as e: return jsonify({"error": str(e)}), 500 - -# ===================================================================== -# BUG-EXPOSING TEST ENDPOINTS -# These endpoints expose confirmed bugs in the psycopg instrumentation. -# See BUG_TRACKING.md for detailed analysis. -# ===================================================================== - - @app.route("/test/null-values") def test_null_values(): """Test handling of NULL values in results.""" diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index b546825..9de4aa6 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -75,8 +75,8 @@ # Test: Transaction context manager make_request("GET", "/test/transaction-context") - # These still need investigation: - # make_request("GET", "/test/json-jsonb") - # make_request("GET", "/test/array-types") + # JSON/JSONB and array types tests + make_request("GET", "/test/json-jsonb") + make_request("GET", "/test/array-types") print_request_summary() diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 65ce2a2..33b0ac9 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1202,7 +1202,6 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non # The SDK communicator already extracts response.body from the CLI's MockInteraction structure # So mock_data should already contain: {"rowcount": ..., "description": [...], "rows": [...]} actual_data = mock_data - logger.debug(f"[MOCK_DATA] mock_data: {mock_data}") try: cursor._rowcount = actual_data.get("rowcount", -1) From 5fc28b2fbb677974bc3bcd1e0e4a94d1313d911e Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 12:26:46 -0800 Subject: [PATCH 28/37] Fix cursor.set_result() not mocked in REPLAY mode for executemany with returning=True The instrumentation was patching results() and nextset() methods for navigating result sets, but set_result(index) was missing. This caused REPLAY mode to fail with "index out of range" errors when user code called set_result() to jump to a specific result set. Added patched_set_result() for both REPLAY mode (in _mock_executemany_returning_with_data) and RECORD mode (in _finalize_executemany_returning_span) that: - Validates index bounds and supports negative indices - Updates the current result set state and fetch methods - Returns the cursor as the real implementation does --- .../psycopg/e2e-tests/src/app.py | 38 +++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 3 + .../psycopg/instrumentation.py | 66 +++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 607c4aa..7efe02d 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -868,6 +868,44 @@ def test_array_types(): except Exception as e: return jsonify({"error": str(e)}), 500 +@app.route("/test/cursor-set-result") +def test_cursor_set_result(): + """Test cursor.set_result() method. + + This tests whether the instrumentation correctly handles + set_result() for navigating between result sets. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table and insert with returning + cur.execute("CREATE TEMP TABLE setresult_test (id SERIAL, val TEXT)") + + cur.executemany( + "INSERT INTO setresult_test (val) VALUES (%s) RETURNING id, val", + [("First",), ("Second",), ("Third",)], + returning=True + ) + + # Use set_result to navigate to specific result sets + results = [] + + # Go to the last result set + cur.set_result(-1) + row = cur.fetchone() + results.append({"set": "last", "data": {"id": row[0], "val": row[1]} if row else None}) + + # Go to the first result set + cur.set_result(0) + row = cur.fetchone() + results.append({"set": "first", "data": {"id": row[0], "val": row[1]} if row else None}) + + conn.commit() + + return jsonify({"results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 9de4aa6..9436593 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -79,4 +79,7 @@ make_request("GET", "/test/json-jsonb") make_request("GET", "/test/array-types") + # Bug-exposing tests - kept for regression testing + make_request("GET", "/test/cursor-set-result") + print_request_summary() diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 33b0ac9..f42743c 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1559,6 +1559,48 @@ def patched_nextset(): cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] + # Patch set_result() to work with _mock_result_sets + def patched_set_result(index: int): + """Navigate to a specific result set by index (supports negative indices).""" + num_results = len(cursor._mock_result_sets) # pyright: ignore[reportAttributeAccessIssue] + if not -num_results <= index < num_results: + raise IndexError( + f"index {index} out of range: {num_results} result(s) available" + ) + if index < 0: + index = num_results + index + + cursor._mock_result_set_index = index # pyright: ignore[reportAttributeAccessIssue] + target_set = cursor._mock_result_sets[index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = target_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the target result set + target_column_names = target_set.get("column_names") + TargetRowClass = create_row_class(target_column_names) + cursor.fetchone = make_fetchone_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Update description for target result set + target_description_data = target_set.get("description") + if target_description_data: + target_desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in target_description_data + ] + try: + cursor._tusk_description = target_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = target_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + return cursor + + cursor.set_result = patched_set_result # pyright: ignore[reportAttributeAccessIssue] + def _finalize_query_span( self, span: trace.Span, @@ -2015,6 +2057,30 @@ def patched_nextset(): cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] + # Patch set_result() to work with _tusk_result_sets + def patched_set_result_record(index: int): + """Navigate to a specific result set by index (supports negative indices).""" + num_results = len(cursor._tusk_result_sets) # pyright: ignore[reportAttributeAccessIssue] + if not -num_results <= index < num_results: + raise IndexError( + f"index {index} out of range: {num_results} result(s) available" + ) + if index < 0: + index = num_results + index + + cursor._tusk_result_set_index = index # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = cursor._tusk_result_sets[index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the target result set + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + + return cursor + + cursor.set_result = patched_set_result_record # pyright: ignore[reportAttributeAccessIssue] + else: output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1} From 6903b87d3339a3203c6bc81e054a2fb1ca930cb6 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 12:34:43 -0800 Subject: [PATCH 29/37] Refactor psycopg instrumentation to reduce code duplication - Created _CursorInstrumentationMixin class to share common cursor properties (description, rownumber, statusmessage) and iteration methods (__iter__, __next__) between InstrumentedCursor and InstrumentedServerCursor - Created _set_span_attributes() helper method to consolidate the repeated pattern of setting span attributes (input/output values, schemas, and hashes) that was duplicated in 5 different places This reduces code duplication by ~110 lines while maintaining the same functionality (all 37 E2E tests pass). --- .../psycopg/instrumentation.py | 305 +++++++----------- 1 file changed, 117 insertions(+), 188 deletions(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index f42743c..252b8e7 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -33,6 +33,74 @@ _instance: PsycopgInstrumentation | None = None +class _CursorInstrumentationMixin: + """Mixin providing common functionality for instrumented cursor classes. + + This mixin contains shared properties and methods used by both + InstrumentedCursor and InstrumentedServerCursor to avoid code duplication. + """ + + _tusk_description = None # Store mock description for replay mode + + @property + def description(self): + # In replay mode, return mock description if set; otherwise use base + if self._tusk_description is not None: + return self._tusk_description + return super().description + + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, '_mock_statusmessage'): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + return self + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, 'fetchone') and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + + class PsycopgInstrumentation(InstrumentationBase): """Instrumentation for psycopg (psycopg3) PostgreSQL client library. @@ -174,34 +242,12 @@ def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): base = base_factory or BaseCursor - class InstrumentedCursor(base): # type: ignore - _tusk_description = None # Store mock description for replay mode - - @property - def description(self): - # In replay mode, return mock description if set; otherwise use base - if self._tusk_description is not None: - return self._tusk_description - return super().description - - @property - def rownumber(self): - # In captured mode (after fetchall in _finalize_query_span), return tracked index - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - return self._tusk_index - # In replay mode with mock data, return mock index - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - return self._mock_index - # Otherwise, return real cursor's rownumber - return super().rownumber - - @property - def statusmessage(self): - # In replay mode with mock data, return mock statusmessage - if hasattr(self, '_mock_statusmessage'): - return self._mock_statusmessage - # Otherwise, return real cursor's statusmessage - return super().statusmessage + class InstrumentedCursor(_CursorInstrumentationMixin, base): # type: ignore + """Instrumented cursor with tracing support. + + Inherits common properties (description, rownumber, statusmessage) + and iteration methods (__iter__, __next__) from _CursorInstrumentationMixin. + """ def execute(self, query, params=None, **kwargs): return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -215,38 +261,6 @@ def stream(self, query, params=None, **kwargs): def copy(self, query, params=None, **kwargs): return instrumentation._traced_copy(self, super().copy, sdk, query, params, **kwargs) - def __iter__(self): - # Support direct cursor iteration (for row in cursor) - # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - return self - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - return self - return super().__iter__() - - def __next__(self): - # In replay mode with mock data, iterate over mock rows - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - if self._mock_index < len(self._mock_rows): - row = self._mock_rows[self._mock_index] - self._mock_index += 1 - # Apply row transformation if fetchone is patched - if hasattr(self, 'fetchone') and callable(self.fetchone): - # Reset index, get transformed row, restore index - self._mock_index -= 1 - result = self.fetchone() - return result - return tuple(row) if isinstance(row, list) else row - raise StopIteration - # In record mode with captured data, iterate over stored rows - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - if self._tusk_index < len(self._tusk_rows): - row = self._tusk_rows[self._tusk_index] - self._tusk_index += 1 - return row - raise StopIteration - return super().__next__() - return InstrumentedCursor def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): @@ -266,74 +280,20 @@ def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): base = base_factory or BaseServerCursor - class InstrumentedServerCursor(base): # type: ignore - _tusk_description = None # Store mock description for replay mode - - @property - def description(self): - # In replay mode, return mock description if set; otherwise use base - if self._tusk_description is not None: - return self._tusk_description - return super().description - - @property - def rownumber(self): - # In captured mode (after fetchall in _finalize_query_span), return tracked index - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - return self._tusk_index - # In replay mode with mock data, return mock index - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - return self._mock_index - # Otherwise, return real cursor's rownumber - return super().rownumber - - @property - def statusmessage(self): - # In replay mode with mock data, return mock statusmessage - if hasattr(self, '_mock_statusmessage'): - return self._mock_statusmessage - # Otherwise, return real cursor's statusmessage - return super().statusmessage + class InstrumentedServerCursor(_CursorInstrumentationMixin, base): # type: ignore + """Instrumented server cursor with tracing support. + + Inherits common properties (description, rownumber, statusmessage) + and iteration methods (__iter__, __next__) from _CursorInstrumentationMixin. + + Note: ServerCursor doesn't support executemany(). + Note: ServerCursor has stream-like iteration via fetchmany/itersize. + """ def execute(self, query, params=None, **kwargs): # Note: ServerCursor.execute() doesn't support 'prepare' parameter return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) - # Note: ServerCursor doesn't support executemany() - # Note: ServerCursor has stream-like iteration via fetchmany/itersize - - def __iter__(self): - # Support direct cursor iteration (for row in cursor) - # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - return self - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - return self - return super().__iter__() - - def __next__(self): - # In replay mode with mock data, iterate over mock rows - if hasattr(self, '_mock_rows') and self._mock_rows is not None: - if self._mock_index < len(self._mock_rows): - row = self._mock_rows[self._mock_index] - self._mock_index += 1 - # Apply row transformation if fetchone is patched - if hasattr(self, 'fetchone') and callable(self.fetchone): - # Reset index, get transformed row, restore index - self._mock_index -= 1 - result = self.fetchone() - return result - return tuple(row) if isinstance(row, list) else row - raise StopIteration - # In record mode with captured data, iterate over stored rows - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: - if self._tusk_index < len(self._tusk_rows): - row = self._tusk_rows[self._tusk_index] - self._tusk_index += 1 - return row - raise StopIteration - return super().__next__() - return InstrumentedServerCursor def _traced_execute( @@ -791,19 +751,7 @@ def _finalize_stream_span( if serialized_rows: output_value["rows"] = serialized_rows - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) - - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + self._set_span_attributes(span, input_value, output_value) if not error: span.set_status(Status(OTelStatusCode.OK)) @@ -1018,19 +966,7 @@ def _finalize_copy_span( "chunk_count": len(data_collected), } - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) - - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + self._set_span_attributes(span, input_value, output_value) if not error: span.set_status(Status(OTelStatusCode.OK)) @@ -1052,6 +988,35 @@ def _query_to_string(self, query: Any, cursor: Any) -> str: return str(query) if not isinstance(query, str) else query + def _set_span_attributes( + self, + span: trace.Span, + input_value: dict, + output_value: dict, + ) -> None: + """Set span attributes for input/output values with schemas and hashes. + + This helper method centralizes the repeated pattern of: + 1. Generating schemas and hashes for input/output values + 2. Setting all span attributes (INPUT_VALUE, OUTPUT_VALUE, schemas, hashes) + + Args: + span: The OpenTelemetry span to set attributes on + input_value: The input data dictionary (query, parameters, etc.) + output_value: The output data dictionary (rows, rowcount, error, etc.) + """ + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + def _detect_row_factory_type(self, row_factory: Any) -> str: """Detect the type of row factory for mock transformations. @@ -1691,19 +1656,7 @@ def _finalize_query_span( except Exception as e: logger.debug(f"Error getting query metadata: {e}") - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) - - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + self._set_span_attributes(span, input_value, output_value) if not error: span.set_status(Status(OTelStatusCode.OK)) @@ -1787,19 +1740,7 @@ def do_lazy_capture(): if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: output_value["statusmessage"] = cursor.statusmessage - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) - - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + instrumentation._set_span_attributes(span, input_value, output_value) span.set_status(Status(OTelStatusCode.OK)) span.end() @@ -2084,19 +2025,7 @@ def patched_set_result_record(index: int): else: output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1} - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) - - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + self._set_span_attributes(span, input_value, output_value) if not error: span.set_status(Status(OTelStatusCode.OK)) From f04b88bb97b37b58c1d144d332db89d6684ed196 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 12:59:19 -0800 Subject: [PATCH 30/37] Fix Decimal and timedelta serialization for consistent RECORD/REPLAY hashing Added proper serialization for Decimal and timedelta types in serialize_value() to ensure consistent hashing between RECORD and REPLAY modes: - Decimal: Serialized as {"__decimal__": str(val)} to preserve precision - timedelta: Serialized as {"__timedelta__": total_seconds} for consistency Also added corresponding deserialization in deserialize_db_value() to reconstruct the original Python types when reading from traces. This fixes REPLAY mismatch errors ("No mock found") when queries use Decimal or timedelta parameters. --- .../psycopg/e2e-tests/src/app.py | 80 +++++++++++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 5 ++ drift/instrumentation/utils/psycopg_utils.py | 7 ++ drift/instrumentation/utils/serialization.py | 9 ++- 4 files changed, 100 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 7efe02d..0d6630a 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -906,6 +906,86 @@ def test_cursor_set_result(): return jsonify({"error": str(e)}), 500 +@app.route("/test/decimal-types") +def test_decimal_types(): + """Test Decimal/numeric types. + + BUG INVESTIGATION: Decimal types may have serialization/precision issues. + """ + try: + from decimal import Decimal + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with numeric columns + cur.execute(""" + CREATE TEMP TABLE decimal_test ( + id INT, + price DECIMAL(10, 2), + rate DECIMAL(18, 8) + ) + """) + + # Insert decimal data + cur.execute( + "INSERT INTO decimal_test VALUES (%s, %s, %s)", + (1, Decimal("123.45"), Decimal("0.00000001")) + ) + + # Query back + cur.execute("SELECT * FROM decimal_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "price": str(row[1]) if row[1] else None, + "rate": str(row[2]) if row[2] else None + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/date-time-types") +def test_date_time_types(): + """Test date/time types. + + BUG INVESTIGATION: Date, time, and interval types may have issues. + """ + try: + from datetime import date, time, timedelta + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with date/time columns + cur.execute(""" + CREATE TEMP TABLE datetime_test ( + id INT, + birth_date DATE, + wake_time TIME, + duration INTERVAL + ) + """) + + # Insert date/time data + cur.execute( + "INSERT INTO datetime_test VALUES (%s, %s, %s, %s)", + (1, date(1990, 5, 15), time(8, 30, 0), timedelta(hours=2, minutes=30)) + ) + + # Query back + cur.execute("SELECT * FROM datetime_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "birth_date": str(row[1]) if row[1] else None, + "wake_time": str(row[2]) if row[2] else None, + "duration": str(row[3]) if row[3] else None + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 9436593..9aab4a4 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -82,4 +82,9 @@ # Bug-exposing tests - kept for regression testing make_request("GET", "/test/cursor-set-result") + # Bug-exposing tests for parameter serialization issues + # These tests expose hash mismatch bugs with Decimal and date/time types + make_request("GET", "/test/decimal-types") + make_request("GET", "/test/date-time-types") + print_request_summary() diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index 2aeec45..38c1dc2 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -5,6 +5,7 @@ import base64 import datetime as dt import uuid +from decimal import Decimal from typing import Any @@ -33,6 +34,12 @@ def deserialize_db_value(val: Any) -> Any: # Check for UUID tagged structure if "__uuid__" in val and len(val) == 1: return uuid.UUID(val["__uuid__"]) + # Check for Decimal tagged structure + if "__decimal__" in val and len(val) == 1: + return Decimal(val["__decimal__"]) + # Check for timedelta tagged structure + if "__timedelta__" in val and len(val) == 1: + return dt.timedelta(seconds=val["__timedelta__"]) # Recursively deserialize dict values return {k: deserialize_db_value(v) for k, v in val.items()} elif isinstance(val, str): diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index 038653b..d2d8bcb 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -5,6 +5,7 @@ import base64 import datetime import uuid +from decimal import Decimal from typing import Any @@ -32,7 +33,7 @@ def _serialize_bytes(val: bytes) -> Any: def serialize_value(val: Any) -> Any: """Convert non-JSON-serializable values to JSON-compatible types. - Handles datetime objects, bytes, and nested structures (lists, tuples, dicts). + Handles datetime objects, bytes, Decimal, and nested structures (lists, tuples, dicts). Args: val: The value to serialize. @@ -42,6 +43,12 @@ def serialize_value(val: Any) -> Any: """ if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): return val.isoformat() + elif isinstance(val, datetime.timedelta): + # Serialize timedelta as total seconds for consistent hashing + return {"__timedelta__": val.total_seconds()} + elif isinstance(val, Decimal): + # Serialize Decimal as string to preserve precision and ensure consistent hashing + return {"__decimal__": str(val)} elif isinstance(val, uuid.UUID): return {"__uuid__": str(val)} elif isinstance(val, memoryview): From 017074772b3603efd805aa2f59812f535a5c5b65 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 13:11:19 -0800 Subject: [PATCH 31/37] Refactor psycopg instrumentation: extract common helper methods Created helper methods to reduce code duplication: - _create_query_span(): Centralizes span creation (used 8 times) - _create_fetch_methods(): Creates fetch closures for mock cursors - _create_scroll_method(): Creates scroll method for cursor mocking - _get_row_factory_from_cursor(): Extracts row factory consistently - _set_cursor_description(): Sets cursor description with error handling - _create_row_transformer(): Creates row transform functions This reduces code duplication and improves maintainability while maintaining all existing functionality (39 E2E tests pass). --- .../psycopg/e2e-tests/src/app.py | 10 +- .../psycopg/instrumentation.py | 435 ++++++++---------- 2 files changed, 203 insertions(+), 242 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 0d6630a..5636bb8 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -908,10 +908,7 @@ def test_cursor_set_result(): @app.route("/test/decimal-types") def test_decimal_types(): - """Test Decimal/numeric types. - - BUG INVESTIGATION: Decimal types may have serialization/precision issues. - """ + """Test Decimal/numeric types.""" try: from decimal import Decimal @@ -947,10 +944,7 @@ def test_decimal_types(): @app.route("/test/date-time-types") def test_date_time_types(): - """Test date/time types. - - BUG INVESTIGATION: Date, time, and interval types may have issues. - """ + """Test date/time types.""" try: from datetime import date, time, timedelta diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 252b8e7..9270f50 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -329,21 +329,7 @@ def _noop_execute(self, cursor: Any) -> Any: def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any) -> Any: """Handle REPLAY mode for execute - fetch mock from CLI.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "query") if not span_info: raise RuntimeError("Error creating span in replay mode") @@ -389,21 +375,7 @@ def _record_execute( cursor._tusk_index = 0 del cursor._tusk_patched - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + span_info = self._create_query_span(sdk, "query", is_pre_app_start) if not span_info: # Fallback to original call if span creation fails @@ -481,21 +453,7 @@ def _replay_executemany( self, cursor: Any, sdk: TuskDrift, query_str: str, params_list: list, returning: bool = False ) -> Any: """Handle REPLAY mode for executemany - fetch mock from CLI.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "query") if not span_info: raise RuntimeError("Error creating span in replay mode") @@ -544,21 +502,7 @@ def _record_executemany( returning: bool = False, ) -> Any: """Handle RECORD mode for executemany - create span and execute query.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + span_info = self._create_query_span(sdk, "query", is_pre_app_start) if not span_info: # Fallback to original call if span creation fails @@ -636,21 +580,7 @@ def _record_stream( kwargs: dict, ): """Handle RECORD mode for stream - wrap generator with tracing.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + span_info = self._create_query_span(sdk, "query", is_pre_app_start) if not span_info: yield from original_stream(query, params, **kwargs) @@ -673,21 +603,7 @@ def _record_stream( def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any): """Handle REPLAY mode for stream - return mock generator.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "query") if not span_info: raise RuntimeError("Error creating span in replay mode") @@ -797,23 +713,7 @@ def _record_copy( kwargs: dict, ) -> Iterator[TracedCopyWrapper]: """Handle RECORD mode for copy - wrap Copy object with tracing.""" - is_pre_app_start = not sdk.app_ready - - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.copy", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.copy", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "copy", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + span_info = self._create_query_span(sdk, "copy") if not span_info: # Fallback to original if span creation fails @@ -845,21 +745,7 @@ def _record_copy( @contextmanager def _replay_copy(self, cursor: Any, sdk: TuskDrift, query_str: str) -> Iterator[MockCopy]: """Handle REPLAY mode for copy - return mock Copy object.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.copy", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.copy", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "copy", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "copy") if not span_info: raise RuntimeError("Error creating span in replay mode") @@ -988,6 +874,180 @@ def _query_to_string(self, query: Any, cursor: Any) -> str: return str(query) if not isinstance(query, str) else query + def _create_query_span(self, sdk: TuskDrift, submodule: str = "query", is_pre_app_start: bool | None = None): + """Create a span for psycopg operations. + + This helper reduces code duplication across replay/record methods. + + Args: + sdk: The TuskDrift instance + submodule: The submodule name ("query" or "copy") + is_pre_app_start: Override for pre-app-start flag; if None, derived from sdk.app_ready + + Returns: + SpanInfo object or None if span creation fails + """ + if is_pre_app_start is None: + is_pre_app_start = not sdk.app_ready + span_name = f"psycopg.{submodule}" + return SpanUtils.create_span( + CreateSpanOptions( + name=span_name, + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: span_name, + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: submodule, + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + def _create_fetch_methods(self, cursor: Any, rows_attr: str, index_attr: str, transform_row=None): + """Create fetch method closures for cursor mocking. + + This helper reduces code duplication in mock/replay cursor setup. + + Args: + cursor: The cursor object to operate on + rows_attr: Attribute name for stored rows (e.g., '_mock_rows', '_tusk_rows') + index_attr: Attribute name for current index (e.g., '_mock_index', '_tusk_index') + transform_row: Optional function to transform each row before returning + + Returns: + Tuple of (fetchone, fetchmany, fetchall) functions + """ + def fetchone(): + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + if idx < len(rows): + row = rows[idx] + setattr(cursor, index_attr, idx + 1) + return transform_row(row) if transform_row else row + return None + + def fetchmany(size=None): + if size is None: + size = cursor.arraysize + result = [] + for _ in range(size): + row = fetchone() + if row is None: + break + result.append(row) + return result + + def fetchall(): + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + remaining = rows[idx:] + setattr(cursor, index_attr, len(rows)) + if transform_row: + return [transform_row(row) for row in remaining] + return list(remaining) + + return fetchone, fetchmany, fetchall + + def _create_scroll_method(self, cursor: Any, rows_attr: str, index_attr: str): + """Create scroll method closure for cursor mocking. + + Args: + cursor: The cursor object to operate on + rows_attr: Attribute name for stored rows + index_attr: Attribute name for current index + + Returns: + scroll function + """ + def scroll(value: int, mode: str = "relative") -> None: + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + if mode == "relative": + newpos = idx + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(rows) + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + setattr(cursor, index_attr, newpos) + + return scroll + + def _get_row_factory_from_cursor(self, cursor: Any): + """Get row_factory from cursor or its connection. + + Args: + cursor: The cursor object + + Returns: + The row_factory or None if not found + """ + row_factory = getattr(cursor, 'row_factory', None) + if row_factory is None: + conn = getattr(cursor, 'connection', None) + if conn: + row_factory = getattr(conn, 'row_factory', None) + return row_factory + + def _set_cursor_description(self, cursor: Any, description_data: list | None) -> None: + """Set description on cursor from description data. + + Args: + cursor: The cursor object + description_data: List of column description dicts with 'name' and 'type_code' keys + """ + if not description_data: + return + + desc = [(col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data] + try: + cursor._tusk_description = desc + except AttributeError: + try: + cursor.description = desc + except AttributeError: + pass + + def _create_row_transformer(self, row_factory_type: str, column_names: list | None): + """Create a row transformation function based on row factory type. + + Args: + row_factory_type: The detected row factory type ('dict', 'namedtuple', etc.) + column_names: List of column names for the result set + + Returns: + A function that transforms a raw row into the appropriate format + """ + RowClass = None + if row_factory_type in ("namedtuple", "class") and column_names: + from collections import namedtuple + RowClass = namedtuple('Row', column_names) + + def transform_row(row): + """Transform raw row data according to row factory type.""" + if row_factory_type == "kwargs": + return row + if row_factory_type == "scalar": + return row[0] if isinstance(row, list) and len(row) > 0 else row + values = tuple(row) if isinstance(row, list) else row + if row_factory_type == "dict" and column_names: + return dict(zip(column_names, values)) + elif row_factory_type in ("namedtuple", "class") and RowClass is not None: + return RowClass(*values) + return values + + return transform_row + def _set_span_attributes( self, span: trace.Span, @@ -1174,63 +1234,22 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non object.__setattr__(cursor, "rowcount", actual_data.get("rowcount", -1)) description_data = actual_data.get("description") - if description_data: - desc = [(col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data] - # Set mock description - InstrumentedCursor has _tusk_description property - # MockCursor uses regular description attribute - try: - cursor._tusk_description = desc - except AttributeError: - # For MockCursor, set description directly - try: - cursor.description = desc - except AttributeError: - pass + self._set_cursor_description(cursor, description_data) # Set mock statusmessage for replay statusmessage = actual_data.get("statusmessage") if statusmessage is not None: cursor._mock_statusmessage = statusmessage - # Get row_factory from cursor or connection for row transformation - row_factory = getattr(cursor, 'row_factory', None) - if row_factory is None: - conn = getattr(cursor, 'connection', None) - if conn: - row_factory = getattr(conn, 'row_factory', None) - - # Extract column names from description for row factory transformations - column_names = None - if description_data: - column_names = [col["name"] for col in description_data] - - # Detect row factory type for transformation + # Get row_factory and detect type for row transformation + row_factory = self._get_row_factory_from_cursor(cursor) row_factory_type = self._detect_row_factory_type(row_factory) - # Create namedtuple class once if needed (avoid recreating for each row) - # Used for both namedtuple_row and class_row (class_row returns dataclass instances, - # but in replay we can't recreate the exact class, so we use namedtuple as a compatible substitute) - RowClass = None - if row_factory_type in ("namedtuple", "class") and column_names: - from collections import namedtuple - RowClass = namedtuple('Row', column_names) + # Extract column names from description for row factory transformations + column_names = [col["name"] for col in description_data] if description_data else None - def transform_row(row): - """Transform raw row data according to row factory type.""" - if row_factory_type == "kwargs": - # kwargs_row: return stored dict as-is (already in correct format) - return row - if row_factory_type == "scalar": - # scalar_row: unwrap the single-element list to get the scalar value - return row[0] if isinstance(row, list) and len(row) > 0 else row - values = tuple(row) if isinstance(row, list) else row - if row_factory_type == "dict" and column_names: - return dict(zip(column_names, values)) - elif row_factory_type in ("namedtuple", "class") and RowClass is not None: - # For class_row, we use namedtuple as a compatible substitute that supports - # attribute access (row.id, row.name, etc.) - return RowClass(*values) - return values + # Create row transformer using helper + transform_row = self._create_row_transformer(row_factory_type, column_names) mock_rows = actual_data.get("rows", []) # Deserialize datetime strings back to datetime objects for consistent Flask serialization @@ -1238,50 +1257,15 @@ def transform_row(row): cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] - def mock_fetchone(): - if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] - row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] - cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] - return transform_row(row) - return None - - def mock_fetchmany(size=cursor.arraysize): - rows = [] - for _ in range(size): - row = mock_fetchone() - if row is None: - break - rows.append(row) - return rows - - def mock_fetchall(): - rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] - cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] - return [transform_row(row) for row in rows] - - cursor.fetchone = mock_fetchone # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue] - - def mock_scroll(value: int, mode: str = "relative") -> None: - """Scroll the cursor to a new position in the mock result set.""" - if mode == "relative": - newpos = cursor._mock_index + value # pyright: ignore[reportAttributeAccessIssue] - elif mode == "absolute": - newpos = value - else: - raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") - - num_rows = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] - if num_rows > 0: - if not (0 <= newpos < num_rows): - raise IndexError("position out of bound") - elif newpos != 0: - raise IndexError("position out of bound") - - cursor._mock_index = newpos # pyright: ignore[reportAttributeAccessIssue] + # Use helper methods to create fetch and scroll methods + fetchone, fetchmany, fetchall = self._create_fetch_methods( + cursor, '_mock_rows', '_mock_index', transform_row + ) + cursor.fetchone = fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = fetchall # pyright: ignore[reportAttributeAccessIssue] - cursor.scroll = mock_scroll # pyright: ignore[reportAttributeAccessIssue] + cursor.scroll = self._create_scroll_method(cursor, '_mock_rows', '_mock_index') # pyright: ignore[reportAttributeAccessIssue] # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor # and MockCursor classes, as Python looks up special methods on the type, not instance @@ -1301,13 +1285,8 @@ def _mock_executemany_returning_with_data(self, cursor: Any, mock_data: dict[str cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] return - # Get row_factory from cursor or connection - row_factory = getattr(cursor, "row_factory", None) - if row_factory is None: - conn = getattr(cursor, "connection", None) - if conn: - row_factory = getattr(conn, "row_factory", None) - + # Get row_factory and detect type using helpers + row_factory = self._get_row_factory_from_cursor(cursor) row_factory_type = self._detect_row_factory_type(row_factory) # Store all result sets for iteration @@ -1435,20 +1414,8 @@ def fetchall(): cursor._mock_rows = first_set["rows"] # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] - # Set description for first result set - description_data = first_set.get("description") - if description_data: - desc = [ - (col["name"], col.get("type_code"), None, None, None, None, None) - for col in description_data - ] - try: - cursor._tusk_description = desc # pyright: ignore[reportAttributeAccessIssue] - except AttributeError: - try: - cursor.description = desc # pyright: ignore[reportAttributeAccessIssue] - except AttributeError: - pass + # Set description for first result set using helper + self._set_cursor_description(cursor, first_set.get("description")) # Set up initial fetch methods for the first result set (for code that uses nextset() instead of results()) first_column_names = first_set.get("column_names") From 66453ad5a8f480f7d351048f92f928d130d40e8b Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 13:40:57 -0800 Subject: [PATCH 32/37] Fix inet/cidr network type serialization for REPLAY mode Added support for Python ipaddress module types in serialize_value(): - IPv4Address, IPv6Address - IPv4Interface, IPv6Interface - IPv4Network, IPv6Network These types are returned by psycopg when querying PostgreSQL inet and cidr columns. Without proper serialization, REPLAY mode failed to match recorded mocks. --- .../psycopg/e2e-tests/src/app.py | 82 +++++++++++++++++++ .../psycopg/e2e-tests/src/test_requests.py | 5 ++ drift/instrumentation/utils/serialization.py | 15 ++++ 3 files changed, 102 insertions(+) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 5636bb8..ca98bf6 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -980,6 +980,88 @@ def test_date_time_types(): return jsonify({"error": str(e)}), 500 +@app.route("/test/inet-cidr-types") +def test_inet_cidr_types(): + """Test PostgreSQL inet/cidr network types. + + BUG INVESTIGATION: Network types may have serialization issues. + """ + try: + from ipaddress import IPv4Address, IPv4Network + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with network columns + cur.execute(""" + CREATE TEMP TABLE network_test ( + id INT, + ip_addr INET, + network CIDR + ) + """) + + # Insert network data + cur.execute( + "INSERT INTO network_test VALUES (%s, %s, %s)", + (1, "192.168.1.100", "10.0.0.0/8") + ) + + # Query back + cur.execute("SELECT * FROM network_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "ip_addr": str(row[1]) if row[1] else None, + "network": str(row[2]) if row[2] else None + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/range-types") +def test_range_types(): + """Test PostgreSQL range types. + + BUG INVESTIGATION: Range types may have serialization issues. + """ + try: + from psycopg.types.range import Range + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with range columns + cur.execute(""" + CREATE TEMP TABLE range_test ( + id INT, + int_range INT4RANGE, + ts_range TSRANGE + ) + """) + + # Insert range data + from datetime import datetime + int_range = Range(1, 10) + ts_range = Range(datetime(2024, 1, 1, 0, 0), datetime(2024, 12, 31, 23, 59)) + + cur.execute( + "INSERT INTO range_test VALUES (%s, %s, %s)", + (1, int_range, ts_range) + ) + + # Query back + cur.execute("SELECT * FROM range_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({ + "id": row[0], + "int_range": str(row[1]) if row[1] else None, + "ts_range": str(row[2]) if row[2] else None + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 9aab4a4..3ad5392 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -87,4 +87,9 @@ make_request("GET", "/test/decimal-types") make_request("GET", "/test/date-time-types") + # Bug-exposing tests for network and range type serialization issues + # These tests expose serialization bugs with inet/cidr and range types + make_request("GET", "/test/inet-cidr-types") + make_request("GET", "/test/range-types") + print_request_summary() diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index d2d8bcb..e492f80 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -4,6 +4,7 @@ import base64 import datetime +import ipaddress import uuid from decimal import Decimal from typing import Any @@ -51,6 +52,20 @@ def serialize_value(val: Any) -> Any: return {"__decimal__": str(val)} elif isinstance(val, uuid.UUID): return {"__uuid__": str(val)} + elif isinstance( + val, + ( + ipaddress.IPv4Address, + ipaddress.IPv6Address, + ipaddress.IPv4Interface, + ipaddress.IPv6Interface, + ipaddress.IPv4Network, + ipaddress.IPv6Network, + ), + ): + # Serialize ipaddress types to string for inet/cidr PostgreSQL columns + # These are returned by psycopg when querying inet and cidr columns + return str(val) elif isinstance(val, memoryview): # Convert memoryview to bytes first, then serialize return _serialize_bytes(bytes(val)) From 34186738c5db9d94402495b93b90335ecb32f88c Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 14:02:35 -0800 Subject: [PATCH 33/37] Fix psycopg Range type serialization for REPLAY mode Added serialization and deserialization support for psycopg Range objects (INT4RANGE, TSRANGE, etc.): Serialization (in serialize_value): - Serializes to {"__range__": {"lower": ..., "upper": ..., "bounds": ...}} - Handles empty ranges with {"__range__": {"empty": True}} - Recursively serializes bounds for datetime/nested types Deserialization (in deserialize_db_value): - Reconstructs Range objects from tagged dict structure - Recursively deserializes bounds - Converts JSON floats back to ints when appropriate This fixes REPLAY mismatch when Range types are used as query parameters or returned in results. --- .../psycopg/e2e-tests/src/app.py | 5 +--- drift/instrumentation/utils/psycopg_utils.py | 30 +++++++++++++++++++ drift/instrumentation/utils/serialization.py | 21 +++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index ca98bf6..463b039 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -982,10 +982,7 @@ def test_date_time_types(): @app.route("/test/inet-cidr-types") def test_inet_cidr_types(): - """Test PostgreSQL inet/cidr network types. - - BUG INVESTIGATION: Network types may have serialization issues. - """ + """Test PostgreSQL inet/cidr network types.""" try: from ipaddress import IPv4Address, IPv4Network diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index 38c1dc2..ea5e3ea 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -8,6 +8,15 @@ from decimal import Decimal from typing import Any +# Try to import psycopg Range type for deserialization support +try: + from psycopg.types.range import Range as PsycopgRange + + HAS_PSYCOPG_RANGE = True +except ImportError: + HAS_PSYCOPG_RANGE = False + PsycopgRange = None # type: ignore[misc, assignment] + def deserialize_db_value(val: Any) -> Any: """Convert serialized values back to their original Python types. @@ -40,6 +49,27 @@ def deserialize_db_value(val: Any) -> Any: # Check for timedelta tagged structure if "__timedelta__" in val and len(val) == 1: return dt.timedelta(seconds=val["__timedelta__"]) + # Check for Range tagged structure (psycopg Range types) + if "__range__" in val and len(val) == 1: + range_data = val["__range__"] + if HAS_PSYCOPG_RANGE and PsycopgRange is not None: + if range_data.get("empty"): + return PsycopgRange(empty=True) + # Recursively deserialize the lower and upper bounds + # (they may contain datetime or other serialized types) + lower = deserialize_db_value(range_data.get("lower")) + upper = deserialize_db_value(range_data.get("upper")) + bounds = range_data.get("bounds", "[)") + # Convert floats back to ints if they represent whole numbers + # This is needed because JSON doesn't distinguish int/float + if isinstance(lower, float) and lower.is_integer(): + lower = int(lower) + if isinstance(upper, float) and upper.is_integer(): + upper = int(upper) + return PsycopgRange(lower, upper, bounds) + else: + # If psycopg is not available, return the dict as-is + return range_data # Recursively deserialize dict values return {k: deserialize_db_value(v) for k, v in val.items()} elif isinstance(val, str): diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index e492f80..9ea331f 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -9,6 +9,15 @@ from decimal import Decimal from typing import Any +# Try to import psycopg Range type for serialization support +try: + from psycopg.types.range import Range as PsycopgRange + + HAS_PSYCOPG_RANGE = True +except ImportError: + HAS_PSYCOPG_RANGE = False + PsycopgRange = None # type: ignore[misc, assignment] + def _serialize_bytes(val: bytes) -> Any: """Serialize bytes to a JSON-compatible format. @@ -52,6 +61,18 @@ def serialize_value(val: Any) -> Any: return {"__decimal__": str(val)} elif isinstance(val, uuid.UUID): return {"__uuid__": str(val)} + elif HAS_PSYCOPG_RANGE and isinstance(val, PsycopgRange): + # Serialize psycopg Range objects to a deterministic dict format + # This handles INT4RANGE, TSRANGE, and other PostgreSQL range types + if val.isempty: + return {"__range__": {"empty": True}} + return { + "__range__": { + "lower": serialize_value(val.lower), + "upper": serialize_value(val.upper), + "bounds": val.bounds, + } + } elif isinstance( val, ( From 08f5ce75f725a4d923c4ef2b397c5becca7a2558 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 15:21:59 -0800 Subject: [PATCH 34/37] fix format + lint issues --- .../psycopg/e2e-tests/requirements.txt | 1 + .../psycopg/e2e-tests/src/app.py | 335 ++++++++---------- .../psycopg/e2e-tests/src/test_requests.py | 4 - .../psycopg/instrumentation.py | 179 +++++----- drift/instrumentation/psycopg/mocks.py | 5 +- drift/instrumentation/psycopg/wrappers.py | 3 +- .../psycopg2/instrumentation.py | 4 +- 7 files changed, 240 insertions(+), 291 deletions(-) diff --git a/drift/instrumentation/psycopg/e2e-tests/requirements.txt b/drift/instrumentation/psycopg/e2e-tests/requirements.txt index 5495053..67be944 100644 --- a/drift/instrumentation/psycopg/e2e-tests/requirements.txt +++ b/drift/instrumentation/psycopg/e2e-tests/requirements.txt @@ -1,4 +1,5 @@ -e /sdk # Mount point for drift SDK Flask>=3.1.2 psycopg[binary]>=3.2.1 +psycopg-pool>=3.2.0 requests>=2.32.5 diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index 463b039..3be2bcf 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -1,6 +1,7 @@ """Flask app with Psycopg (v3) operations for e2e testing.""" import os +from datetime import UTC import psycopg from flask import Flask, jsonify, request @@ -143,6 +144,7 @@ def db_transaction(): except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/cursor-stream") def test_cursor_stream(): """Test cursor.stream() - generator-based result streaming. @@ -160,6 +162,7 @@ def test_cursor_stream(): except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/server-cursor") def test_server_cursor(): """Test ServerCursor (named cursor) - server-side cursor. @@ -179,6 +182,7 @@ def test_server_cursor(): except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/copy-to") def test_copy_to(): """Test cursor.copy() with COPY TO - bulk data export. @@ -194,11 +198,12 @@ def test_copy_to(): # Handle both bytes and memoryview if isinstance(row, memoryview): row = bytes(row) - output.append(row.decode('utf-8').strip()) + output.append(row.decode("utf-8").strip()) return jsonify({"count": len(output), "data": output}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/multiple-queries") def test_multiple_queries(): """Test multiple queries in same connection. @@ -224,6 +229,7 @@ def test_multiple_queries(): except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/pipeline-mode") def test_pipeline_mode(): """Test pipeline mode - batched operations. @@ -243,13 +249,11 @@ def test_pipeline_mode(): rows1 = cur1.fetchall() count = cur2.fetchone()[0] - return jsonify({ - "rows": [{"id": r[0], "name": r[1]} for r in rows1], - "count": count - }) + return jsonify({"rows": [{"id": r[0], "name": r[1]} for r in rows1], "count": count}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/dict-row-factory") def test_dict_row_factory(): """Test dict_row row factory. @@ -265,10 +269,12 @@ def test_dict_row_factory(): cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": rows # Already dictionaries - }) + return jsonify( + { + "count": len(rows), + "data": rows, # Already dictionaries + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -288,13 +294,11 @@ def test_namedtuple_row_factory(): rows = cur.fetchall() # Convert named tuples to dicts for JSON serialization - return jsonify({ - "count": len(rows), - "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows] - }) + return jsonify({"count": len(rows), "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows]}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/cursor-iteration") def test_cursor_iteration(): """Test direct cursor iteration (for row in cursor). @@ -310,13 +314,11 @@ def test_cursor_iteration(): for row in cur: results.append({"id": row[0], "name": row[1]}) - return jsonify({ - "count": len(results), - "data": results - }) + return jsonify({"count": len(results), "data": results}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/executemany-returning") def test_executemany_returning(): """Test executemany with returning=True. @@ -330,11 +332,7 @@ def test_executemany_returning(): # Use executemany with returning params = [("Batch User 1",), ("Batch User 2",), ("Batch User 3",)] - cur.executemany( - "INSERT INTO batch_test (name) VALUES (%s) RETURNING id, name", - params, - returning=True - ) + cur.executemany("INSERT INTO batch_test (name) VALUES (%s) RETURNING id, name", params, returning=True) # Fetch results from each batch results = [] @@ -345,13 +343,11 @@ def test_executemany_returning(): conn.commit() - return jsonify({ - "count": len(results), - "data": results - }) + return jsonify({"count": len(results), "data": results}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/rownumber") def test_rownumber(): """Test cursor.rownumber property. @@ -375,21 +371,14 @@ def test_rownumber(): cur.fetchmany(2) positions.append({"after_fetchmany_2": cur.rownumber}) - return jsonify({ - "positions": positions - }) + return jsonify({"positions": positions}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/statusmessage") def test_statusmessage(): - """Test cursor.statusmessage property. - - BUG: The statusmessage property is not captured during RECORD mode - and not mocked during REPLAY mode. During RECORD, statusmessage - returns the command status (e.g., "SELECT 5", "INSERT 0 1"), but - during REPLAY it returns null because this property is not tracked. - """ + """Test cursor.statusmessage property.""" try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: # SELECT should return something like "SELECT 5" @@ -399,21 +388,18 @@ def test_statusmessage(): # INSERT should return something like "INSERT 0 1" cur.execute( - "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", - ("StatusTest", "status@test.com") + "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", ("StatusTest", "status@test.com") ) insert_status = cur.statusmessage cur.fetchone() conn.rollback() # Don't actually insert - return jsonify({ - "select_status": select_status, - "insert_status": insert_status - }) + return jsonify({"select_status": select_status, "insert_status": insert_status}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/nextset") def test_nextset(): """Test cursor.nextset() for multiple result sets. @@ -429,7 +415,7 @@ def test_nextset(): cur.executemany( "INSERT INTO nextset_test (val) VALUES (%s) RETURNING id, val", [("First",), ("Second",), ("Third",)], - returning=True + returning=True, ) # Use nextset to iterate through result sets @@ -443,13 +429,11 @@ def test_nextset(): conn.commit() - return jsonify({ - "count": len(results), - "data": results - }) + return jsonify({"count": len(results), "data": results}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/cursor-scroll") def test_cursor_scroll(): """Test cursor.scroll() method. @@ -464,19 +448,22 @@ def test_cursor_scroll(): first = cur.fetchone() # Scroll back to start - cur.scroll(0, mode='absolute') + cur.scroll(0, mode="absolute") # Fetch first row again first_again = cur.fetchone() - return jsonify({ - "first": {"id": first[0], "name": first[1]} if first else None, - "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, - "match": first == first_again - }) + return jsonify( + { + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/server-cursor-scroll") def test_server_cursor_scroll(): """Test ServerCursor.scroll() method. @@ -493,19 +480,22 @@ def test_server_cursor_scroll(): first = cur.fetchone() # Scroll back to start - cur.scroll(0, mode='absolute') + cur.scroll(0, mode="absolute") # Fetch first row again first_again = cur.fetchone() - return jsonify({ - "first": {"id": first[0], "name": first[1]} if first else None, - "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, - "match": first == first_again - }) + return jsonify( + { + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/cursor-reuse") def test_cursor_reuse(): """Test reusing cursor for multiple queries. @@ -526,14 +516,17 @@ def test_cursor_reuse(): cur.execute("SELECT COUNT(*) FROM users") count = cur.fetchone()[0] - return jsonify({ - "row1": {"id": row1[0], "name": row1[1]} if row1 else None, - "row2": {"id": row2[0], "name": row2[1]} if row2 else None, - "count": count - }) + return jsonify( + { + "row1": {"id": row1[0], "name": row1[1]} if row1 else None, + "row2": {"id": row2[0], "name": row2[1]} if row2 else None, + "count": count, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/sql-composed") def test_sql_composed(): """Test psycopg.sql.SQL() composed queries.""" @@ -542,23 +535,17 @@ def test_sql_composed(): with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: table = sql.Identifier("users") - columns = sql.SQL(", ").join([ - sql.Identifier("id"), - sql.Identifier("name"), - sql.Identifier("email") - ]) + columns = sql.SQL(", ").join([sql.Identifier("id"), sql.Identifier("name"), sql.Identifier("email")]) query = sql.SQL("SELECT {} FROM {} ORDER BY id LIMIT 3").format(columns, table) cur.execute(query) rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": [{"id": r[0], "name": r[1], "email": r[2]} for r in rows] - }) + return jsonify({"count": len(rows), "data": [{"id": r[0], "name": r[1], "email": r[2]} for r in rows]}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/binary-uuid") def test_binary_uuid(): """Test binary UUID data type. @@ -574,10 +561,7 @@ def test_binary_uuid(): # Insert a UUID test_uuid = uuid.uuid4() - cur.execute( - "INSERT INTO uuid_test (id, name) VALUES (%s, %s) RETURNING id, name", - (test_uuid, "UUID Test") - ) + cur.execute("INSERT INTO uuid_test (id, name) VALUES (%s, %s) RETURNING id, name", (test_uuid, "UUID Test")) inserted = cur.fetchone() # Query it back @@ -586,14 +570,17 @@ def test_binary_uuid(): conn.commit() - return jsonify({ - "inserted_uuid": str(inserted[0]) if inserted else None, - "queried_uuid": str(queried[0]) if queried else None, - "match": str(inserted[0]) == str(queried[0]) if inserted and queried else False - }) + return jsonify( + { + "inserted_uuid": str(inserted[0]) if inserted else None, + "queried_uuid": str(queried[0]) if queried else None, + "match": str(inserted[0]) == str(queried[0]) if inserted and queried else False, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/binary-bytea") def test_binary_bytea(): """Test binary bytea data type. @@ -606,24 +593,24 @@ def test_binary_bytea(): cur.execute("CREATE TEMP TABLE bytea_test (id SERIAL, data BYTEA)") # Insert binary data - test_data = b'\x00\x01\x02\x03\xff\xfe\xfd' - cur.execute( - "INSERT INTO bytea_test (data) VALUES (%s) RETURNING id, data", - (test_data,) - ) + test_data = b"\x00\x01\x02\x03\xff\xfe\xfd" + cur.execute("INSERT INTO bytea_test (data) VALUES (%s) RETURNING id, data", (test_data,)) inserted = cur.fetchone() conn.commit() # Convert bytes to hex for JSON serialization - return jsonify({ - "inserted_id": inserted[0] if inserted else None, - "data_hex": inserted[1].hex() if inserted and inserted[1] else None, - "data_length": len(inserted[1]) if inserted and inserted[1] else 0 - }) + return jsonify( + { + "inserted_id": inserted[0] if inserted else None, + "data_hex": inserted[1].hex() if inserted and inserted[1] else None, + "data_length": len(inserted[1]) if inserted and inserted[1] else 0, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/class-row-factory") def test_class_row_factory(): """Test class_row row factory. @@ -632,9 +619,10 @@ def test_class_row_factory(): which return instances of a custom class. """ try: - from psycopg.rows import class_row from dataclasses import dataclass + from psycopg.rows import class_row + @dataclass class User: id: int @@ -646,10 +634,7 @@ class User: cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows] - }) + return jsonify({"count": len(rows), "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows]}) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -672,13 +657,16 @@ def make_user_dict(**kwargs): cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": rows # Already processed dictionaries - }) + return jsonify( + { + "count": len(rows), + "data": rows, # Already processed dictionaries + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/scalar-row-factory") def test_scalar_row_factory(): """Test scalar_row row factory. @@ -694,13 +682,16 @@ def test_scalar_row_factory(): cur.execute("SELECT name FROM users ORDER BY id LIMIT 5") rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": rows # Just names as scalars - }) + return jsonify( + { + "count": len(rows), + "data": rows, # Just names as scalars + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/binary-format") def test_binary_format(): """Test execute with binary=True parameter. @@ -710,19 +701,14 @@ def test_binary_format(): try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: # Execute with binary=True - cur.execute( - "SELECT id, name FROM users ORDER BY id LIMIT 3", - binary=True - ) + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 3", binary=True) rows = cur.fetchall() - return jsonify({ - "count": len(rows), - "data": [{"id": r[0], "name": r[1]} for r in rows] - }) + return jsonify({"count": len(rows), "data": [{"id": r[0], "name": r[1]} for r in rows]}) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/null-values") def test_null_values(): """Test handling of NULL values in results.""" @@ -751,18 +737,14 @@ def test_null_values(): rows = cur.fetchall() conn.commit() - return jsonify({ - "count": len(rows), - "data": [ - { - "id": r[0], - "nullable_text": r[1], - "nullable_int": r[2], - "nullable_bool": r[3] - } - for r in rows - ] - }) + return jsonify( + { + "count": len(rows), + "data": [ + {"id": r[0], "nullable_text": r[1], "nullable_int": r[2], "nullable_bool": r[3]} for r in rows + ], + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -795,10 +777,7 @@ def test_transaction_context(): @app.route("/test/json-jsonb") def test_json_jsonb(): - """Test JSON and JSONB data types. - - BUG INVESTIGATION: JSON types may have serialization issues. - """ + """Test JSON and JSONB data types.""" try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: # Create temp table with JSON columns @@ -812,32 +791,23 @@ def test_json_jsonb(): # Insert JSON data import json + test_json = {"name": "test", "values": [1, 2, 3], "nested": {"key": "value"}} - cur.execute( - "INSERT INTO json_test VALUES (%s, %s, %s)", - (1, json.dumps(test_json), json.dumps(test_json)) - ) + cur.execute("INSERT INTO json_test VALUES (%s, %s, %s)", (1, json.dumps(test_json), json.dumps(test_json))) # Query back cur.execute("SELECT * FROM json_test WHERE id = 1") row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "json_col": row[1], - "jsonb_col": row[2] - }) + return jsonify({"id": row[0], "json_col": row[1], "jsonb_col": row[2]}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/test/array-types") def test_array_types(): - """Test PostgreSQL array types. - - BUG INVESTIGATION: Array types may have serialization issues. - """ + """Test PostgreSQL array types.""" try: with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: # Create temp table with array columns @@ -850,24 +820,24 @@ def test_array_types(): """) # Insert array data - cur.execute( - "INSERT INTO array_test VALUES (%s, %s, %s)", - (1, [10, 20, 30], ["a", "b", "c"]) - ) + cur.execute("INSERT INTO array_test VALUES (%s, %s, %s)", (1, [10, 20, 30], ["a", "b", "c"])) # Query back cur.execute("SELECT * FROM array_test WHERE id = 1") row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "int_array": list(row[1]) if row[1] else None, - "text_array": list(row[2]) if row[2] else None - }) + return jsonify( + { + "id": row[0], + "int_array": list(row[1]) if row[1] else None, + "text_array": list(row[2]) if row[2] else None, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/test/cursor-set-result") def test_cursor_set_result(): """Test cursor.set_result() method. @@ -883,7 +853,7 @@ def test_cursor_set_result(): cur.executemany( "INSERT INTO setresult_test (val) VALUES (%s) RETURNING id, val", [("First",), ("Second",), ("Third",)], - returning=True + returning=True, ) # Use set_result to navigate to specific result sets @@ -923,21 +893,16 @@ def test_decimal_types(): """) # Insert decimal data - cur.execute( - "INSERT INTO decimal_test VALUES (%s, %s, %s)", - (1, Decimal("123.45"), Decimal("0.00000001")) - ) + cur.execute("INSERT INTO decimal_test VALUES (%s, %s, %s)", (1, Decimal("123.45"), Decimal("0.00000001"))) # Query back cur.execute("SELECT * FROM decimal_test WHERE id = 1") row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "price": str(row[1]) if row[1] else None, - "rate": str(row[2]) if row[2] else None - }) + return jsonify( + {"id": row[0], "price": str(row[1]) if row[1] else None, "rate": str(row[2]) if row[2] else None} + ) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -962,7 +927,7 @@ def test_date_time_types(): # Insert date/time data cur.execute( "INSERT INTO datetime_test VALUES (%s, %s, %s, %s)", - (1, date(1990, 5, 15), time(8, 30, 0), timedelta(hours=2, minutes=30)) + (1, date(1990, 5, 15), time(8, 30, 0), timedelta(hours=2, minutes=30)), ) # Query back @@ -970,12 +935,14 @@ def test_date_time_types(): row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "birth_date": str(row[1]) if row[1] else None, - "wake_time": str(row[2]) if row[2] else None, - "duration": str(row[3]) if row[3] else None - }) + return jsonify( + { + "id": row[0], + "birth_date": str(row[1]) if row[1] else None, + "wake_time": str(row[2]) if row[2] else None, + "duration": str(row[3]) if row[3] else None, + } + ) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -997,31 +964,23 @@ def test_inet_cidr_types(): """) # Insert network data - cur.execute( - "INSERT INTO network_test VALUES (%s, %s, %s)", - (1, "192.168.1.100", "10.0.0.0/8") - ) + cur.execute("INSERT INTO network_test VALUES (%s, %s, %s)", (1, "192.168.1.100", "10.0.0.0/8")) # Query back cur.execute("SELECT * FROM network_test WHERE id = 1") row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "ip_addr": str(row[1]) if row[1] else None, - "network": str(row[2]) if row[2] else None - }) + return jsonify( + {"id": row[0], "ip_addr": str(row[1]) if row[1] else None, "network": str(row[2]) if row[2] else None} + ) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/test/range-types") def test_range_types(): - """Test PostgreSQL range types. - - BUG INVESTIGATION: Range types may have serialization issues. - """ + """Test PostgreSQL range types.""" try: from psycopg.types.range import Range @@ -1037,24 +996,20 @@ def test_range_types(): # Insert range data from datetime import datetime + int_range = Range(1, 10) ts_range = Range(datetime(2024, 1, 1, 0, 0), datetime(2024, 12, 31, 23, 59)) - cur.execute( - "INSERT INTO range_test VALUES (%s, %s, %s)", - (1, int_range, ts_range) - ) + cur.execute("INSERT INTO range_test VALUES (%s, %s, %s)", (1, int_range, ts_range)) # Query back cur.execute("SELECT * FROM range_test WHERE id = 1") row = cur.fetchone() conn.commit() - return jsonify({ - "id": row[0], - "int_range": str(row[1]) if row[1] else None, - "ts_range": str(row[2]) if row[2] else None - }) + return jsonify( + {"id": row[0], "int_range": str(row[1]) if row[1] else None, "ts_range": str(row[2]) if row[2] else None} + ) except Exception as e: return jsonify({"error": str(e)}), 500 diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 3ad5392..f861884 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -78,16 +78,12 @@ # JSON/JSONB and array types tests make_request("GET", "/test/json-jsonb") make_request("GET", "/test/array-types") - - # Bug-exposing tests - kept for regression testing make_request("GET", "/test/cursor-set-result") - # Bug-exposing tests for parameter serialization issues # These tests expose hash mismatch bugs with Decimal and date/time types make_request("GET", "/test/decimal-types") make_request("GET", "/test/date-time-types") - # Bug-exposing tests for network and range type serialization issues # These tests expose serialization bugs with inet/cidr and range types make_request("GET", "/test/inet-cidr-types") make_request("GET", "/test/range-types") diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 9270f50..fa9e10a 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -3,9 +3,10 @@ import json import logging import weakref +from collections.abc import Iterator from contextlib import contextmanager from types import ModuleType -from typing import Any, Iterator +from typing import Any from opentelemetry import trace from opentelemetry.trace import SpanKind as OTelSpanKind @@ -52,10 +53,10 @@ def description(self): @property def rownumber(self): # In captured mode (after fetchall in _finalize_query_span), return tracked index - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: return self._tusk_index # In replay mode with mock data, return mock index - if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if hasattr(self, "_mock_rows") and self._mock_rows is not None: return self._mock_index # Otherwise, return real cursor's rownumber return super().rownumber @@ -63,7 +64,7 @@ def rownumber(self): @property def statusmessage(self): # In replay mode with mock data, return mock statusmessage - if hasattr(self, '_mock_statusmessage'): + if hasattr(self, "_mock_statusmessage"): return self._mock_statusmessage # Otherwise, return real cursor's statusmessage return super().statusmessage @@ -71,20 +72,20 @@ def statusmessage(self): def __iter__(self): # Support direct cursor iteration (for row in cursor) # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) - if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if hasattr(self, "_mock_rows") and self._mock_rows is not None: return self - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: return self return super().__iter__() def __next__(self): # In replay mode with mock data, iterate over mock rows - if hasattr(self, '_mock_rows') and self._mock_rows is not None: + if hasattr(self, "_mock_rows") and self._mock_rows is not None: if self._mock_index < len(self._mock_rows): row = self._mock_rows[self._mock_index] self._mock_index += 1 # Apply row transformation if fetchone is patched - if hasattr(self, 'fetchone') and callable(self.fetchone): + if hasattr(self, "fetchone") and callable(self.fetchone): # Reset index, get transformed row, restore index self._mock_index -= 1 result = self.fetchone() @@ -92,7 +93,7 @@ def __next__(self): return tuple(row) if isinstance(row, list) else row raise StopIteration # In record mode with captured data, iterate over stored rows - if hasattr(self, '_tusk_rows') and self._tusk_rows is not None: + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: if self._tusk_index < len(self._tusk_rows): row = self._tusk_rows[self._tusk_index] self._tusk_index += 1 @@ -193,15 +194,16 @@ def _patch_pipeline_class(self, module: ModuleType) -> None: instrumentation = self # Store originals for potential unpatch - self._original_pipeline_sync = getattr(Pipeline, 'sync', None) - self._original_pipeline_exit = getattr(Pipeline, '__exit__', None) + self._original_pipeline_sync = getattr(Pipeline, "sync", None) + self._original_pipeline_exit = getattr(Pipeline, "__exit__", None) if self._original_pipeline_sync: + def patched_sync(pipeline_self): """Patched Pipeline.sync that finalizes pending spans.""" result = instrumentation._original_pipeline_sync(pipeline_self) # _conn is the connection associated with the pipeline - conn = getattr(pipeline_self, '_conn', None) + conn = getattr(pipeline_self, "_conn", None) if conn: instrumentation._finalize_pending_pipeline_spans(conn) return result @@ -210,13 +212,12 @@ def patched_sync(pipeline_self): logger.debug("psycopg.Pipeline.sync instrumented") if self._original_pipeline_exit: + def patched_exit(pipeline_self, exc_type, exc_val, exc_tb): """Patched Pipeline.__exit__ that finalizes any remaining spans.""" - result = instrumentation._original_pipeline_exit( - pipeline_self, exc_type, exc_val, exc_tb - ) + result = instrumentation._original_pipeline_exit(pipeline_self, exc_type, exc_val, exc_tb) # Finalize any remaining pending spans (handles implicit sync on exit) - conn = getattr(pipeline_self, '_conn', None) + conn = getattr(pipeline_self, "_conn", None) if conn: instrumentation._finalize_pending_pipeline_spans(conn) return result @@ -335,9 +336,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -366,9 +365,9 @@ def _record_execute( # Reset cursor state from any previous execute() on this cursor. # Delete instance attribute overrides to expose original class methods. # This is safer than saving/restoring bound methods which can become stale. - if hasattr(cursor, '_tusk_patched'): + if hasattr(cursor, "_tusk_patched"): # Remove patched instance attributes to expose class methods - for attr in ('fetchone', 'fetchmany', 'fetchall', 'scroll'): + for attr in ("fetchone", "fetchmany", "fetchall", "scroll"): if attr in cursor.__dict__: delattr(cursor, attr) cursor._tusk_rows = None @@ -433,9 +432,7 @@ def _traced_executemany( if sdk.mode == TuskDriftMode.REPLAY: return handle_replay_mode( - replay_mode_handler=lambda: self._replay_executemany( - cursor, sdk, query_str, params_list, returning - ), + replay_mode_handler=lambda: self._replay_executemany(cursor, sdk, query_str, params_list, returning), no_op_request_handler=lambda: self._noop_execute(cursor), is_server_request=False, ) @@ -464,9 +461,7 @@ def _replay_executemany( if returning: params_for_mock["_returning"] = True - mock_result = self._try_get_mock( - sdk, query_str, params_for_mock, span_info.trace_id, span_info.span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params_for_mock, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -609,9 +604,7 @@ def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: An raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -677,9 +670,7 @@ def _finalize_stream_span( except Exception as e: logger.error(f"Error finalizing stream span: {e}") - def _traced_copy( - self, cursor: Any, original_copy: Any, sdk: TuskDrift, query: str, params=None, **kwargs - ) -> Any: + def _traced_copy(self, cursor: Any, original_copy: Any, sdk: TuskDrift, query: str, params=None, **kwargs) -> Any: """Traced cursor.copy method - returns a context manager.""" if sdk.mode == TuskDriftMode.DISABLED: return original_copy(query, params, **kwargs) @@ -751,9 +742,7 @@ def _replay_copy(self, cursor: Any, sdk: TuskDrift, query_str: str) -> Iterator[ raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_copy_mock( - sdk, query_str, span_info.trace_id, span_info.span_id - ) + mock_result = self._try_get_copy_mock(sdk, query_str, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -920,6 +909,7 @@ def _create_fetch_methods(self, cursor: Any, rows_attr: str, index_attr: str, tr Returns: Tuple of (fetchone, fetchmany, fetchall) functions """ + def fetchone(): rows = getattr(cursor, rows_attr) idx = getattr(cursor, index_attr) @@ -962,6 +952,7 @@ def _create_scroll_method(self, cursor: Any, rows_attr: str, index_attr: str): Returns: scroll function """ + def scroll(value: int, mode: str = "relative") -> None: rows = getattr(cursor, rows_attr) idx = getattr(cursor, index_attr) @@ -992,11 +983,11 @@ def _get_row_factory_from_cursor(self, cursor: Any): Returns: The row_factory or None if not found """ - row_factory = getattr(cursor, 'row_factory', None) + row_factory = getattr(cursor, "row_factory", None) if row_factory is None: - conn = getattr(cursor, 'connection', None) + conn = getattr(cursor, "connection", None) if conn: - row_factory = getattr(conn, 'row_factory', None) + row_factory = getattr(conn, "row_factory", None) return row_factory def _set_cursor_description(self, cursor: Any, description_data: list | None) -> None: @@ -1031,7 +1022,8 @@ def _create_row_transformer(self, row_factory_type: str, column_names: list | No RowClass = None if row_factory_type in ("namedtuple", "class") and column_names: from collections import namedtuple - RowClass = namedtuple('Row', column_names) + + RowClass = namedtuple("Row", column_names) def transform_row(row): """Transform raw row data according to row factory type.""" @@ -1041,7 +1033,7 @@ def transform_row(row): return row[0] if isinstance(row, list) and len(row) > 0 else row values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and column_names: - return dict(zip(column_names, values)) + return dict(zip(column_names, values, strict=False)) elif row_factory_type in ("namedtuple", "class") and RowClass is not None: return RowClass(*values) return values @@ -1088,20 +1080,20 @@ def _detect_row_factory_type(self, row_factory: Any) -> str: return "tuple" # Check by function/class name - factory_name = getattr(row_factory, '__name__', '') + factory_name = getattr(row_factory, "__name__", "") if not factory_name: factory_name = str(type(row_factory).__name__) factory_name_lower = factory_name.lower() - if 'dict' in factory_name_lower: + if "dict" in factory_name_lower: return "dict" - elif 'namedtuple' in factory_name_lower: + elif "namedtuple" in factory_name_lower: return "namedtuple" - elif 'kwargs' in factory_name_lower: + elif "kwargs" in factory_name_lower: return "kwargs" - elif 'scalar' in factory_name_lower: + elif "scalar" in factory_name_lower: return "scalar" - elif 'class' in factory_name_lower: + elif "class" in factory_name_lower: return "class" return "tuple" @@ -1112,20 +1104,20 @@ def _is_in_pipeline_mode(self, cursor: Any) -> bool: In psycopg3, when conn.pipeline() is active, connection._pipeline is set. """ try: - conn = getattr(cursor, 'connection', None) + conn = getattr(cursor, "connection", None) if conn is None: return False # MockConnection doesn't have real pipeline mode if isinstance(conn, MockConnection): return False - pipeline = getattr(conn, '_pipeline', None) + pipeline = getattr(conn, "_pipeline", None) return pipeline is not None except Exception: return False def _get_connection_from_cursor(self, cursor: Any) -> Any: """Get the connection object from a cursor.""" - return getattr(cursor, 'connection', None) + return getattr(cursor, "connection", None) def _add_pending_pipeline_span( self, @@ -1139,12 +1131,14 @@ def _add_pending_pipeline_span( if connection not in self._pending_pipeline_spans: self._pending_pipeline_spans[connection] = [] - self._pending_pipeline_spans[connection].append({ - 'span_info': span_info, - 'cursor': cursor, - 'query': query, - 'params': params, - }) + self._pending_pipeline_spans[connection].append( + { + "span_info": span_info, + "cursor": cursor, + "query": query, + "params": params, + } + ) logger.debug(f"[PIPELINE] Deferred span for query: {query[:50]}...") def _finalize_pending_pipeline_spans(self, connection: Any) -> None: @@ -1156,10 +1150,10 @@ def _finalize_pending_pipeline_spans(self, connection: Any) -> None: logger.debug(f"[PIPELINE] Finalizing {len(pending)} pending pipeline spans") for item in pending: - span_info = item['span_info'] - cursor = item['cursor'] - query = item['query'] - params = item['params'] + span_info = item["span_info"] + cursor = item["cursor"] + query = item["query"] + params = item["params"] try: span_finalized = self._finalize_query_span(span_info.span, cursor, query, params, error=None) @@ -1258,14 +1252,12 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] # Use helper methods to create fetch and scroll methods - fetchone, fetchmany, fetchall = self._create_fetch_methods( - cursor, '_mock_rows', '_mock_index', transform_row - ) + fetchone, fetchmany, fetchall = self._create_fetch_methods(cursor, "_mock_rows", "_mock_index", transform_row) cursor.fetchone = fetchone # pyright: ignore[reportAttributeAccessIssue] cursor.fetchmany = fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = fetchall # pyright: ignore[reportAttributeAccessIssue] - cursor.scroll = self._create_scroll_method(cursor, '_mock_rows', '_mock_index') # pyright: ignore[reportAttributeAccessIssue] + cursor.scroll = self._create_scroll_method(cursor, "_mock_rows", "_mock_index") # pyright: ignore[reportAttributeAccessIssue] # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor # and MockCursor classes, as Python looks up special methods on the type, not instance @@ -1327,7 +1319,7 @@ def transform_row(row, col_names, RowClass): return row values = tuple(row) if isinstance(row, list) else row if row_factory_type == "dict" and col_names: - return dict(zip(col_names, values)) + return dict(zip(col_names, values, strict=False)) elif row_factory_type == "namedtuple" and RowClass is not None: return RowClass(*values) return values @@ -1345,8 +1337,7 @@ def mock_results(): description_data = result_set.get("description") if description_data: desc = [ - (col["name"], col.get("type_code"), None, None, None, None, None) - for col in description_data + (col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data ] try: cursor._tusk_description = desc # pyright: ignore[reportAttributeAccessIssue] @@ -1428,6 +1419,7 @@ def fetchone(): cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] return transform_row(row, cn, RC) return None + return fetchone def make_fetchmany_replay(cn, RC): @@ -1441,13 +1433,15 @@ def fetchmany(size=cursor.arraysize): else: break return rows + return fetchmany def make_fetchall_replay(cn, RC): def fetchall(): - rows = cursor._mock_rows[cursor._mock_index:] # pyright: ignore[reportAttributeAccessIssue] + rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] return [transform_row(row, cn, RC) for row in rows] + return fetchall cursor.fetchone = make_fetchone_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] @@ -1496,9 +1490,7 @@ def patched_set_result(index: int): """Navigate to a specific result set by index (supports negative indices).""" num_results = len(cursor._mock_result_sets) # pyright: ignore[reportAttributeAccessIssue] if not -num_results <= index < num_results: - raise IndexError( - f"index {index} out of range: {num_results} result(s) available" - ) + raise IndexError(f"index {index} out of range: {num_results} result(s) available") if index < 0: index = num_results + index @@ -1567,7 +1559,6 @@ def _finalize_query_span( else: # Get query results and capture for replay try: - rows = [] description = None row_factory_type = "tuple" # default @@ -1584,11 +1575,11 @@ def _finalize_query_span( ] # Get row factory from cursor or connection - row_factory = getattr(cursor, 'row_factory', None) + row_factory = getattr(cursor, "row_factory", None) if row_factory is None: - conn = getattr(cursor, 'connection', None) + conn = getattr(cursor, "connection", None) if conn: - row_factory = getattr(conn, 'row_factory', None) + row_factory = getattr(conn, "row_factory", None) # Detect row factory type BEFORE processing rows row_factory_type = self._detect_row_factory_type(row_factory) @@ -1617,7 +1608,7 @@ def _finalize_query_span( } # Capture statusmessage for replay - if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: + if hasattr(cursor, "statusmessage") and cursor.statusmessage is not None: output_value["statusmessage"] = cursor.statusmessage except Exception as e: @@ -1646,13 +1637,11 @@ def _setup_lazy_capture(self, cursor: Any) -> None: # (not instance methods which might already be patched) cursor_class = type(cursor) original_fetchall = cursor_class.fetchall - original_fetchone = cursor_class.fetchone - original_fetchmany = cursor_class.fetchmany - original_scroll = cursor_class.scroll if hasattr(cursor_class, 'scroll') else None + original_scroll = cursor_class.scroll if hasattr(cursor_class, "scroll") else None def do_lazy_capture(): """Perform the actual capture - called on first fetch.""" - if hasattr(cursor, '_tusk_rows') and cursor._tusk_rows is not None: + if hasattr(cursor, "_tusk_rows") and cursor._tusk_rows is not None: return # Already captured try: @@ -1674,12 +1663,12 @@ def do_lazy_capture(): rows.append(row) elif row_factory_type == "scalar": rows.append([row]) - elif row_factory_type == "class" or hasattr(row, '__dataclass_fields__'): + elif row_factory_type == "class" or hasattr(row, "__dataclass_fields__"): # dataclass (from class_row): extract values by attribute name rows.append([getattr(row, col, None) for col in column_names]) elif isinstance(row, dict): rows.append([row.get(col) for col in column_names]) - elif hasattr(row, '_fields'): + elif hasattr(row, "_fields"): # namedtuple: extract values by field name rows.append([getattr(row, col, None) for col in column_names]) else: @@ -1704,7 +1693,7 @@ def do_lazy_capture(): serialized_rows = [[serialize_value(col) for col in row] for row in rows] output_value["rows"] = serialized_rows - if hasattr(cursor, 'statusmessage') and cursor.statusmessage is not None: + if hasattr(cursor, "statusmessage") and cursor.statusmessage is not None: output_value["statusmessage"] = cursor.statusmessage instrumentation._set_span_attributes(span, input_value, output_value) @@ -1718,7 +1707,7 @@ def do_lazy_capture(): logger.error(f"Error in lazy capture: {e}") # Try to end span even on error try: - span = cursor._tusk_lazy_span # pyright: ignore[reportAttributeAccessIssue] + span = cursor._tusk_lazy_span span.set_status(Status(OTelStatusCode.ERROR, str(e))) span.end() except Exception: @@ -1726,8 +1715,14 @@ def do_lazy_capture(): finally: # Clean up lazy capture attributes - for attr in ('_tusk_lazy_span', '_tusk_lazy_input_value', '_tusk_lazy_description', - '_tusk_lazy_row_factory_type', '_tusk_lazy_column_names', '_tusk_lazy_instrumentation'): + for attr in ( + "_tusk_lazy_span", + "_tusk_lazy_input_value", + "_tusk_lazy_description", + "_tusk_lazy_row_factory_type", + "_tusk_lazy_column_names", + "_tusk_lazy_instrumentation", + ): if hasattr(cursor, attr): try: delattr(cursor, attr) @@ -1797,7 +1792,6 @@ def _finalize_executemany_returning_span( be replayed with multiple result set iteration. """ try: - # Build input value input_value = { "query": query.strip(), @@ -1848,7 +1842,9 @@ def _finalize_executemany_returning_span( for row in raw_rows: if isinstance(row, dict): - rows.append([row.get(col) for col in column_names] if column_names else list(row.values())) + rows.append( + [row.get(col) for col in column_names] if column_names else list(row.values()) + ) elif hasattr(row, "_fields"): rows.append( [getattr(row, col, None) for col in column_names] if column_names else list(row) @@ -1927,6 +1923,7 @@ def patched_fetchone(): cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] return row return None + return patched_fetchone def make_patched_fetchmany_record(): @@ -1934,6 +1931,7 @@ def patched_fetchmany(size=cursor.arraysize): result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] return result + return patched_fetchmany def make_patched_fetchall_record(): @@ -1941,6 +1939,7 @@ def patched_fetchall(): result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] return result + return patched_fetchall cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] @@ -1970,9 +1969,7 @@ def patched_set_result_record(index: int): """Navigate to a specific result set by index (supports negative indices).""" num_results = len(cursor._tusk_result_sets) # pyright: ignore[reportAttributeAccessIssue] if not -num_results <= index < num_results: - raise IndexError( - f"index {index} out of range: {num_results} result(s) available" - ) + raise IndexError(f"index {index} out of range: {num_results} result(s) available") if index < 0: index = num_results + index diff --git a/drift/instrumentation/psycopg/mocks.py b/drift/instrumentation/psycopg/mocks.py index b5e672d..0709aed 100644 --- a/drift/instrumentation/psycopg/mocks.py +++ b/drift/instrumentation/psycopg/mocks.py @@ -7,8 +7,9 @@ from __future__ import annotations import logging +from collections.abc import Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Iterator +from typing import TYPE_CHECKING if TYPE_CHECKING: from ...core.drift_sdk import TuskDrift @@ -217,7 +218,7 @@ def rownumber(self): @property def statusmessage(self): """Return the mock status message if set, otherwise None.""" - return getattr(self, '_mock_statusmessage', None) + return getattr(self, "_mock_statusmessage", None) def execute(self, query, params=None, **kwargs): """Will be replaced by instrumentation.""" diff --git a/drift/instrumentation/psycopg/wrappers.py b/drift/instrumentation/psycopg/wrappers.py index 12efc1f..7152ab9 100644 --- a/drift/instrumentation/psycopg/wrappers.py +++ b/drift/instrumentation/psycopg/wrappers.py @@ -5,7 +5,8 @@ from __future__ import annotations -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any class TracedCopyWrapper: diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index 5da1aac..f62b0f6 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -447,9 +447,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready From e5101eda8164f11c54a3adede243124385a5c030 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 15:57:50 -0800 Subject: [PATCH 35/37] fix type errors --- .../psycopg/instrumentation.py | 211 +++++++++++------- drift/instrumentation/utils/psycopg_utils.py | 2 +- drift/instrumentation/utils/serialization.py | 4 +- 3 files changed, 133 insertions(+), 84 deletions(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index fa9e10a..acb344a 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -34,74 +34,6 @@ _instance: PsycopgInstrumentation | None = None -class _CursorInstrumentationMixin: - """Mixin providing common functionality for instrumented cursor classes. - - This mixin contains shared properties and methods used by both - InstrumentedCursor and InstrumentedServerCursor to avoid code duplication. - """ - - _tusk_description = None # Store mock description for replay mode - - @property - def description(self): - # In replay mode, return mock description if set; otherwise use base - if self._tusk_description is not None: - return self._tusk_description - return super().description - - @property - def rownumber(self): - # In captured mode (after fetchall in _finalize_query_span), return tracked index - if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: - return self._tusk_index - # In replay mode with mock data, return mock index - if hasattr(self, "_mock_rows") and self._mock_rows is not None: - return self._mock_index - # Otherwise, return real cursor's rownumber - return super().rownumber - - @property - def statusmessage(self): - # In replay mode with mock data, return mock statusmessage - if hasattr(self, "_mock_statusmessage"): - return self._mock_statusmessage - # Otherwise, return real cursor's statusmessage - return super().statusmessage - - def __iter__(self): - # Support direct cursor iteration (for row in cursor) - # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) - if hasattr(self, "_mock_rows") and self._mock_rows is not None: - return self - if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: - return self - return super().__iter__() - - def __next__(self): - # In replay mode with mock data, iterate over mock rows - if hasattr(self, "_mock_rows") and self._mock_rows is not None: - if self._mock_index < len(self._mock_rows): - row = self._mock_rows[self._mock_index] - self._mock_index += 1 - # Apply row transformation if fetchone is patched - if hasattr(self, "fetchone") and callable(self.fetchone): - # Reset index, get transformed row, restore index - self._mock_index -= 1 - result = self.fetchone() - return result - return tuple(row) if isinstance(row, list) else row - raise StopIteration - # In record mode with captured data, iterate over stored rows - if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: - if self._tusk_index < len(self._tusk_rows): - row = self._tusk_rows[self._tusk_index] - self._tusk_index += 1 - return row - raise StopIteration - return super().__next__() - - class PsycopgInstrumentation(InstrumentationBase): """Instrumentation for psycopg (psycopg3) PostgreSQL client library. @@ -198,10 +130,11 @@ def _patch_pipeline_class(self, module: ModuleType) -> None: self._original_pipeline_exit = getattr(Pipeline, "__exit__", None) if self._original_pipeline_sync: + original_sync = self._original_pipeline_sync def patched_sync(pipeline_self): """Patched Pipeline.sync that finalizes pending spans.""" - result = instrumentation._original_pipeline_sync(pipeline_self) + result = original_sync(pipeline_self) # _conn is the connection associated with the pipeline conn = getattr(pipeline_self, "_conn", None) if conn: @@ -212,10 +145,11 @@ def patched_sync(pipeline_self): logger.debug("psycopg.Pipeline.sync instrumented") if self._original_pipeline_exit: + original_exit = self._original_pipeline_exit def patched_exit(pipeline_self, exc_type, exc_val, exc_tb): """Patched Pipeline.__exit__ that finalizes any remaining spans.""" - result = instrumentation._original_pipeline_exit(pipeline_self, exc_type, exc_val, exc_tb) + result = original_exit(pipeline_self, exc_type, exc_val, exc_tb) # Finalize any remaining pending spans (handles implicit sync on exit) conn = getattr(pipeline_self, "_conn", None) if conn: @@ -243,12 +177,68 @@ def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): base = base_factory or BaseCursor - class InstrumentedCursor(_CursorInstrumentationMixin, base): # type: ignore - """Instrumented cursor with tracing support. - - Inherits common properties (description, rownumber, statusmessage) - and iteration methods (__iter__, __next__) from _CursorInstrumentationMixin. - """ + class InstrumentedCursor(base): # type: ignore + """Instrumented cursor with tracing support.""" + + _tusk_description = None # Store mock description for replay mode + + @property + def description(self): + # In replay mode, return mock description if set; otherwise use base + if self._tusk_description is not None: + return self._tusk_description + return super().description + + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, "_mock_statusmessage"): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, "fetchone") and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() def execute(self, query, params=None, **kwargs): return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -281,16 +271,73 @@ def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): base = base_factory or BaseServerCursor - class InstrumentedServerCursor(_CursorInstrumentationMixin, base): # type: ignore + class InstrumentedServerCursor(base): # type: ignore """Instrumented server cursor with tracing support. - Inherits common properties (description, rownumber, statusmessage) - and iteration methods (__iter__, __next__) from _CursorInstrumentationMixin. - Note: ServerCursor doesn't support executemany(). Note: ServerCursor has stream-like iteration via fetchmany/itersize. """ + _tusk_description = None # Store mock description for replay mode + + @property + def description(self): + # In replay mode, return mock description if set; otherwise use base + if self._tusk_description is not None: + return self._tusk_description + return super().description + + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, "_mock_statusmessage"): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, "fetchone") and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + def execute(self, query, params=None, **kwargs): # Note: ServerCursor.execute() doesn't support 'prepare' parameter return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) @@ -1712,6 +1759,8 @@ def do_lazy_capture(): span.end() except Exception: pass + # Re-raise the original exception so the user sees the actual database error + raise finally: # Clean up lazy capture attributes @@ -1847,7 +1896,7 @@ def _finalize_executemany_returning_span( ) elif hasattr(row, "_fields"): rows.append( - [getattr(row, col, None) for col in column_names] if column_names else list(row) + [getattr(row, str(col), None) for col in column_names] if column_names else list(row) ) else: rows.append(list(row)) diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index ea5e3ea..d14813d 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -10,7 +10,7 @@ # Try to import psycopg Range type for deserialization support try: - from psycopg.types.range import Range as PsycopgRange + from psycopg.types.range import Range as PsycopgRange # type: ignore[import-untyped] HAS_PSYCOPG_RANGE = True except ImportError: diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py index 9ea331f..e7fece0 100644 --- a/drift/instrumentation/utils/serialization.py +++ b/drift/instrumentation/utils/serialization.py @@ -11,7 +11,7 @@ # Try to import psycopg Range type for serialization support try: - from psycopg.types.range import Range as PsycopgRange + from psycopg.types.range import Range as PsycopgRange # type: ignore[import-untyped] HAS_PSYCOPG_RANGE = True except ImportError: @@ -61,7 +61,7 @@ def serialize_value(val: Any) -> Any: return {"__decimal__": str(val)} elif isinstance(val, uuid.UUID): return {"__uuid__": str(val)} - elif HAS_PSYCOPG_RANGE and isinstance(val, PsycopgRange): + elif HAS_PSYCOPG_RANGE and PsycopgRange is not None and isinstance(val, PsycopgRange): # Serialize psycopg Range objects to a deterministic dict format # This handles INT4RANGE, TSRANGE, and other PostgreSQL range types if val.isempty: From 70e04b7282c1f3d1b21195000a3b5696615b5565 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 16:00:18 -0800 Subject: [PATCH 36/37] format --- drift/instrumentation/psycopg/instrumentation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index acb344a..9906653 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -1896,7 +1896,9 @@ def _finalize_executemany_returning_span( ) elif hasattr(row, "_fields"): rows.append( - [getattr(row, str(col), None) for col in column_names] if column_names else list(row) + [getattr(row, str(col), None) for col in column_names] + if column_names + else list(row) ) else: rows.append(list(row)) From 26f92f7cf44fddc9ea8c20a40e8f929bd07fda61 Mon Sep 17 00:00:00 2001 From: Sohan Kshirsagar Date: Thu, 15 Jan 2026 16:07:29 -0800 Subject: [PATCH 37/37] try catch record mode --- .../psycopg/instrumentation.py | 130 +++++++++++------- 1 file changed, 80 insertions(+), 50 deletions(-) diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 9906653..cb30a0c 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -441,28 +441,36 @@ def _record_execute( error = e raise finally: - if error is not None: - # Always finalize immediately on error - self._finalize_query_span(span_info.span, cursor, query_str, params, error) - span_info.span.end() - elif in_pipeline_mode: - # Defer finalization until pipeline.sync() - connection = self._get_connection_from_cursor(cursor) - if connection: - self._add_pending_pipeline_span(connection, span_info, cursor, query_str, params) - # DON'T end span here - will be ended in _finalize_pending_pipeline_spans - else: - # Fallback: finalize immediately if we can't get connection - self._finalize_query_span(span_info.span, cursor, query_str, params, None) + try: + if error is not None: + # Always finalize immediately on error + self._finalize_query_span(span_info.span, cursor, query_str, params, error) span_info.span.end() - else: - # Normal mode: finalize immediately (unless lazy capture was set up) - span_finalized = self._finalize_query_span(span_info.span, cursor, query_str, params, None) - if span_finalized: - # Span was fully finalized, end it now + elif in_pipeline_mode: + # Defer finalization until pipeline.sync() + connection = self._get_connection_from_cursor(cursor) + if connection: + self._add_pending_pipeline_span(connection, span_info, cursor, query_str, params) + # DON'T end span here - will be ended in _finalize_pending_pipeline_spans + else: + # Fallback: finalize immediately if we can't get connection + self._finalize_query_span(span_info.span, cursor, query_str, params, None) + span_info.span.end() + else: + # Normal mode: finalize immediately (unless lazy capture was set up) + span_finalized = self._finalize_query_span(span_info.span, cursor, query_str, params, None) + if span_finalized: + # Span was fully finalized, end it now + span_info.span.end() + # If span_finalized is False, lazy capture was set up and span will be + # ended when user code calls a fetch method + except Exception as e: + logger.error(f"Error in span finalization: {e}") + # Ensure span is ended even if finalization fails + try: span_info.span.end() - # If span_finalized is False, lazy capture was set up and span will be - # ended when user code calls a fetch method + except Exception: + pass def _traced_executemany( self, cursor: Any, original_executemany: Any, sdk: TuskDrift, query: str, params_seq, **kwargs @@ -561,29 +569,37 @@ def _record_executemany( error = e raise finally: - if returning and error is None: - # Use specialized method for executemany with returning=True - self._finalize_executemany_returning_span( - span_info.span, - cursor, - query_str, - {"_batch": params_list, "_returning": True}, - error, - ) - span_info.span.end() - else: - # Existing behavior for executemany without returning - span_finalized = self._finalize_query_span( - span_info.span, - cursor, - query_str, - {"_batch": params_list}, - error, - ) - if span_finalized: + try: + if returning and error is None: + # Use specialized method for executemany with returning=True + self._finalize_executemany_returning_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list, "_returning": True}, + error, + ) + span_info.span.end() + else: + # Existing behavior for executemany without returning + span_finalized = self._finalize_query_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list}, + error, + ) + if span_finalized: + span_info.span.end() + # Note: executemany without returning typically has no results, + # so lazy capture is unlikely but we handle it for safety + except Exception as e: + logger.error(f"Error in span finalization: {e}") + # Ensure span is ended even if finalization fails + try: span_info.span.end() - # Note: executemany without returning typically has no results, - # so lazy capture is unlikely but we handle it for safety + except Exception: + pass def _traced_stream( self, cursor: Any, original_stream: Any, sdk: TuskDrift, query: str, params=None, **kwargs @@ -640,7 +656,14 @@ def _record_stream( error = e raise finally: - self._finalize_stream_span(span_info.span, cursor, query_str, params, rows_collected, error) + try: + self._finalize_stream_span(span_info.span, cursor, query_str, params, rows_collected, error) + except Exception as e: + logger.error(f"Error in stream span finalization: {e}") + try: + span_info.span.end() + except Exception: + pass span_info.span.end() def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any): @@ -772,13 +795,20 @@ def _record_copy( error = e raise finally: - self._finalize_copy_span( - span_info.span, - query_str, - data_collected, - error, - ) - span_info.span.end() + try: + self._finalize_copy_span( + span_info.span, + query_str, + data_collected, + error, + ) + span_info.span.end() + except Exception as e: + logger.error(f"Error in copy span finalization: {e}") + try: + span_info.span.end() + except Exception: + pass @contextmanager def _replay_copy(self, cursor: Any, sdk: TuskDrift, query_str: str) -> Iterator[MockCopy]: