diff --git a/drift/instrumentation/django/e2e-tests/src/test_requests.py b/drift/instrumentation/django/e2e-tests/src/test_requests.py index 4e0301a..9d4354c 100644 --- a/drift/instrumentation/django/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/django/e2e-tests/src/test_requests.py @@ -1,26 +1,6 @@ """Execute test requests against the Django app.""" -import os -import time - -import requests - -PORT = os.getenv("PORT", "8000") -BASE_URL = f"http://localhost:{PORT}" - - -def make_request(method: str, endpoint: str, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Django test request sequence...\n") @@ -42,4 +22,4 @@ def make_request(method: str, endpoint: str, **kwargs): ) make_request("DELETE", "/api/post/1/delete") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/e2e_common/__init__.py b/drift/instrumentation/e2e_common/__init__.py index 2585a76..283be03 100644 --- a/drift/instrumentation/e2e_common/__init__.py +++ b/drift/instrumentation/e2e_common/__init__.py @@ -1,5 +1,12 @@ """Common utilities for Python SDK e2e tests.""" from .base_runner import Colors, E2ETestRunnerBase +from .test_utils import get_request_count, make_request, print_request_summary -__all__ = ["Colors", "E2ETestRunnerBase"] +__all__ = [ + "Colors", + "E2ETestRunnerBase", + "get_request_count", + "make_request", + "print_request_summary", +] diff --git a/drift/instrumentation/e2e_common/base_runner.py b/drift/instrumentation/e2e_common/base_runner.py index 17890aa..1b089e0 100644 --- a/drift/instrumentation/e2e_common/base_runner.py +++ b/drift/instrumentation/e2e_common/base_runner.py @@ -45,6 +45,7 @@ def __init__(self, app_port: int = 8000): self.app_port = app_port self.app_process: subprocess.Popen | None = None self.exit_code = 0 + self.expected_request_count: int | None = None # Register signal handlers for cleanup signal.signal(signal.SIGTERM, self._signal_handler) @@ -76,6 +77,17 @@ def run_command(self, cmd: list[str], env: dict | None = None, check: bool = Tru return result + def _parse_request_count(self, output: str): + """Parse the request count from test_requests.py output.""" + for line in output.split("\n"): + if line.startswith("TOTAL_REQUESTS_SENT:"): + try: + count = int(line.split(":")[1]) + self.expected_request_count = count + self.log(f"Captured request count: {count}", Colors.GREEN) + except (ValueError, IndexError): + self.log(f"Failed to parse request count from: {line}", Colors.YELLOW) + def wait_for_service(self, check_cmd: list[str], timeout: int = 30, interval: int = 1) -> bool: """Wait for a service to become ready.""" elapsed = 0 @@ -138,8 +150,8 @@ def record_traces(self) -> bool: self.app_process = subprocess.Popen( ["python", "src/app.py"], env={**os.environ, **env}, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, text=True, ) @@ -164,7 +176,13 @@ def record_traces(self) -> bool: # Execute test requests self.log("Executing test requests...", Colors.GREEN) try: - self.run_command(["python", "src/test_requests.py"]) + # Pass PYTHONPATH so test_requests.py can import from e2e_common + result = self.run_command( + ["python", "src/test_requests.py"], + env={"PYTHONPATH": "/sdk"}, + ) + # Parse request count from output + self._parse_request_count(result.stdout) except subprocess.CalledProcessError: self.log("Test requests failed", Colors.RED) self.exit_code = 1 @@ -257,6 +275,7 @@ def parse_test_results(self, output: str): idx += 1 all_passed = True + passed_count = 0 for result in results: test_id = result.get("test_id", "unknown") passed = result.get("passed", False) @@ -264,6 +283,7 @@ def parse_test_results(self, output: str): if passed: self.log(f"✓ Test ID: {test_id} (Duration: {duration}ms)", Colors.GREEN) + passed_count += 1 else: self.log(f"✗ Test ID: {test_id} (Duration: {duration}ms)", Colors.RED) all_passed = False @@ -276,6 +296,20 @@ def parse_test_results(self, output: str): self.log("Some tests failed!", Colors.RED) self.exit_code = 1 + # Validate request count matches passed tests + if self.expected_request_count is not None: + if passed_count < self.expected_request_count: + self.log( + f"✗ Request count mismatch: {passed_count} passed tests != {self.expected_request_count} requests sent", + Colors.RED, + ) + self.exit_code = 1 + else: + self.log( + f"✓ Request count validation: {passed_count} passed tests >= {self.expected_request_count} requests sent", + Colors.GREEN, + ) + except Exception as e: self.log(f"Failed to parse test results: {e}", Colors.RED) self.log(f"Raw output:\n{output}", Colors.YELLOW) diff --git a/drift/instrumentation/e2e_common/test_utils.py b/drift/instrumentation/e2e_common/test_utils.py new file mode 100644 index 0000000..951a774 --- /dev/null +++ b/drift/instrumentation/e2e_common/test_utils.py @@ -0,0 +1,37 @@ +""" +Shared test utilities for e2e tests. + +This module provides common functions used across all instrumentation e2e tests, +including request counting for validation. +""" + +import time + +import requests + +BASE_URL = "http://localhost:8000" +_request_count = 0 + + +def make_request(method, endpoint, **kwargs): + """Make HTTP request, log result, and track count.""" + global _request_count + _request_count += 1 + + url = f"{BASE_URL}{endpoint}" + print(f"→ {method} {endpoint}") + kwargs.setdefault("timeout", 30) + response = requests.request(method, url, **kwargs) + print(f" Status: {response.status_code}") + time.sleep(0.5) + return response + + +def get_request_count(): + """Return the current request count.""" + return _request_count + + +def print_request_summary(): + """Print the total request count in a parseable format.""" + print(f"\nTOTAL_REQUESTS_SENT:{_request_count}") diff --git a/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py b/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py index 743561a..e684b8b 100644 --- a/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/fastapi/e2e-tests/src/test_requests.py @@ -1,27 +1,9 @@ """Execute test requests against the FastAPI app.""" import json -import os -import time from pathlib import Path -import requests - -PORT = os.getenv("PORT", "8000") -BASE_URL = f"http://localhost:{PORT}" - - -def make_request(method: str, endpoint: str, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary def verify_stack_traces(): @@ -180,3 +162,5 @@ def verify_context_propagation(): print("\n" + "=" * 50) print("All requests completed successfully") print("=" * 50) + + print_request_summary() diff --git a/drift/instrumentation/flask/e2e-tests/src/test_requests.py b/drift/instrumentation/flask/e2e-tests/src/test_requests.py index 6ab2ff8..8f0861b 100644 --- a/drift/instrumentation/flask/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/flask/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence...\n") @@ -32,4 +14,4 @@ def make_request(method, endpoint, **kwargs): make_request("POST", "/api/post", json={"title": "Test Post", "body": "This is a test post", "userId": 1}) make_request("DELETE", "/api/post/1") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/httpx/e2e-tests/src/test_requests.py b/drift/instrumentation/httpx/e2e-tests/src/test_requests.py index ebcbebd..a5b7b53 100644 --- a/drift/instrumentation/httpx/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/httpx/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app to exercise the HTTPX instrumentation.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"-> {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence for HTTPX instrumentation...\n") @@ -124,4 +106,4 @@ def make_request(method, endpoint, **kwargs): make_request("POST", "/test/file-like-body") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/httpx/instrumentation.py b/drift/instrumentation/httpx/instrumentation.py index c758ead..c13911a 100644 --- a/drift/instrumentation/httpx/instrumentation.py +++ b/drift/instrumentation/httpx/instrumentation.py @@ -806,6 +806,7 @@ def _try_get_mock_from_request_sync( input_value=input_value, kind=SpanKind.CLIENT, input_schema_merges=input_schema_merges, + is_pre_app_start=not sdk.app_ready, ) if not mock_response_output or not mock_response_output.found: diff --git a/drift/instrumentation/psycopg/e2e-tests/requirements.txt b/drift/instrumentation/psycopg/e2e-tests/requirements.txt index 5495053..67be944 100644 --- a/drift/instrumentation/psycopg/e2e-tests/requirements.txt +++ b/drift/instrumentation/psycopg/e2e-tests/requirements.txt @@ -1,4 +1,5 @@ -e /sdk # Mount point for drift SDK Flask>=3.1.2 psycopg[binary]>=3.2.1 +psycopg-pool>=3.2.0 requests>=2.32.5 diff --git a/drift/instrumentation/psycopg/e2e-tests/src/app.py b/drift/instrumentation/psycopg/e2e-tests/src/app.py index b12e33e..3be2bcf 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/app.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/app.py @@ -1,6 +1,7 @@ """Flask app with Psycopg (v3) operations for e2e testing.""" import os +from datetime import UTC import psycopg from flask import Flask, jsonify, request @@ -144,6 +145,875 @@ def db_transaction(): return jsonify({"error": str(e)}), 500 +@app.route("/test/cursor-stream") +def test_cursor_stream(): + """Test cursor.stream() - generator-based result streaming. + + This tests whether the instrumentation captures streaming queries + that return results as a generator. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Stream results row-by-row instead of fetchall + results = [] + for row in cur.stream("SELECT id, name, email FROM users ORDER BY id LIMIT 5"): + results.append({"id": row[0], "name": row[1], "email": row[2]}) + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/server-cursor") +def test_server_cursor(): + """Test ServerCursor (named cursor) - server-side cursor. + + This tests whether the instrumentation captures server-side cursors + which use DECLARE CURSOR on the database server. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Named cursor creates a server-side cursor + with conn.cursor(name="test_server_cursor") as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 5") + rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] if cur.description else ["id", "name", "email"] + results = [dict(zip(columns, row, strict=False)) for row in rows] + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/copy-to") +def test_copy_to(): + """Test cursor.copy() with COPY TO - bulk data export. + + This tests whether the instrumentation captures COPY operations. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Use COPY to export data + output = [] + with cur.copy("COPY (SELECT id, name, email FROM users ORDER BY id LIMIT 5) TO STDOUT") as copy: + for row in copy: + # Handle both bytes and memoryview + if isinstance(row, memoryview): + row = bytes(row) + output.append(row.decode("utf-8").strip()) + return jsonify({"count": len(output), "data": output}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/multiple-queries") +def test_multiple_queries(): + """Test multiple queries in same connection. + + This tests whether multiple queries in the same connection + are all captured and replayed correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Query 1 + cur.execute("SELECT COUNT(*) FROM users") + count = cur.fetchone()[0] + + # Query 2 + cur.execute("SELECT MAX(id) FROM users") + max_id = cur.fetchone()[0] + + # Query 3 + cur.execute("SELECT MIN(id) FROM users") + min_id = cur.fetchone()[0] + + return jsonify({"count": count, "max_id": max_id, "min_id": min_id}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/pipeline-mode") +def test_pipeline_mode(): + """Test pipeline mode - batched operations. + + Pipeline mode allows sending multiple queries without waiting for results. + This tests whether the instrumentation handles pipeline mode correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Enter pipeline mode + with conn.pipeline() as p: + cur1 = conn.execute("SELECT id, name FROM users ORDER BY id LIMIT 3") + cur2 = conn.execute("SELECT COUNT(*) FROM users") + # Sync the pipeline to get results + p.sync() + + rows1 = cur1.fetchall() + count = cur2.fetchone()[0] + + return jsonify({"rows": [{"id": r[0], "name": r[1]} for r in rows1], "count": count}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/dict-row-factory") +def test_dict_row_factory(): + """Test dict_row row factory. + + Tests whether the instrumentation correctly handles dict row factories + which return dictionaries instead of tuples. + """ + try: + from psycopg.rows import dict_row + + with psycopg.connect(get_conn_string(), row_factory=dict_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify( + { + "count": len(rows), + "data": rows, # Already dictionaries + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/namedtuple-row-factory") +def test_namedtuple_row_factory(): + """Test namedtuple_row row factory. + + Tests whether the instrumentation correctly handles namedtuple row factories. + """ + try: + from psycopg.rows import namedtuple_row + + with psycopg.connect(get_conn_string(), row_factory=namedtuple_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + # Convert named tuples to dicts for JSON serialization + return jsonify({"count": len(rows), "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/cursor-iteration") +def test_cursor_iteration(): + """Test direct cursor iteration (for row in cursor). + + Tests whether iterating over cursor directly works correctly. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 5") + + # Iterate directly over cursor + results = [] + for row in cur: + results.append({"id": row[0], "name": row[1]}) + + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/executemany-returning") +def test_executemany_returning(): + """Test executemany with returning=True. + + Tests whether the instrumentation correctly handles executemany with returning=True. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table + cur.execute("CREATE TEMP TABLE batch_test (id SERIAL, name TEXT)") + + # Use executemany with returning + params = [("Batch User 1",), ("Batch User 2",), ("Batch User 3",)] + cur.executemany("INSERT INTO batch_test (name) VALUES (%s) RETURNING id, name", params, returning=True) + + # Fetch results from each batch + results = [] + for result in cur.results(): + row = result.fetchone() + if row: + results.append({"id": row[0], "name": row[1]}) + + conn.commit() + + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/rownumber") +def test_rownumber(): + """Test cursor.rownumber property. + + Tests whether the rownumber property is properly tracked during replay mode. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 5") + + positions = [] + # Record rownumber at each fetch + positions.append({"before": cur.rownumber}) + + cur.fetchone() + positions.append({"after_fetchone_1": cur.rownumber}) + + cur.fetchone() + positions.append({"after_fetchone_2": cur.rownumber}) + + cur.fetchmany(2) + positions.append({"after_fetchmany_2": cur.rownumber}) + + return jsonify({"positions": positions}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/statusmessage") +def test_statusmessage(): + """Test cursor.statusmessage property.""" + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # SELECT should return something like "SELECT 5" + cur.execute("SELECT id FROM users LIMIT 5") + select_status = cur.statusmessage + cur.fetchall() + + # INSERT should return something like "INSERT 0 1" + cur.execute( + "INSERT INTO users (name, email) VALUES (%s, %s) RETURNING id", ("StatusTest", "status@test.com") + ) + insert_status = cur.statusmessage + cur.fetchone() + + conn.rollback() # Don't actually insert + + return jsonify({"select_status": select_status, "insert_status": insert_status}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/nextset") +def test_nextset(): + """Test cursor.nextset() for multiple result sets. + + Tests whether the instrumentation correctly handles nextset() for multiple result sets. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table + cur.execute("CREATE TEMP TABLE nextset_test (id SERIAL, val TEXT)") + + # Insert multiple rows with returning + cur.executemany( + "INSERT INTO nextset_test (val) VALUES (%s) RETURNING id, val", + [("First",), ("Second",), ("Third",)], + returning=True, + ) + + # Use nextset to iterate through result sets + results = [] + while True: + row = cur.fetchone() + if row: + results.append({"id": row[0], "val": row[1]}) + if cur.nextset() is None: + break + + conn.commit() + + return jsonify({"count": len(results), "data": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/cursor-scroll") +def test_cursor_scroll(): + """Test cursor.scroll() method. + + Tests whether the instrumentation correctly handles scroll() for cursor position tracking. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + cur.execute("SELECT id, name FROM users ORDER BY id") + + # Fetch first row + first = cur.fetchone() + + # Scroll back to start + cur.scroll(0, mode="absolute") + + # Fetch first row again + first_again = cur.fetchone() + + return jsonify( + { + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/server-cursor-scroll") +def test_server_cursor_scroll(): + """Test ServerCursor.scroll() method. + + Tests whether the instrumentation correctly handles scroll() for server-side cursors. + """ + try: + with psycopg.connect(get_conn_string()) as conn: + # Named cursor with scrollable=True + with conn.cursor(name="scroll_test", scrollable=True) as cur: + cur.execute("SELECT id, name FROM users ORDER BY id") + + # Fetch first row + first = cur.fetchone() + + # Scroll back to start + cur.scroll(0, mode="absolute") + + # Fetch first row again + first_again = cur.fetchone() + + return jsonify( + { + "first": {"id": first[0], "name": first[1]} if first else None, + "first_again": {"id": first_again[0], "name": first_again[1]} if first_again else None, + "match": first == first_again, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/cursor-reuse") +def test_cursor_reuse(): + """Test reusing cursor for multiple queries. + + Tests whether the instrumentation correctly handles reusing a cursor for multiple execute() calls. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # First query + cur.execute("SELECT id, name FROM users WHERE id = 1") + row1 = cur.fetchone() + + # Second query on same cursor + cur.execute("SELECT id, name FROM users WHERE id = 2") + row2 = cur.fetchone() + + # Third query + cur.execute("SELECT COUNT(*) FROM users") + count = cur.fetchone()[0] + + return jsonify( + { + "row1": {"id": row1[0], "name": row1[1]} if row1 else None, + "row2": {"id": row2[0], "name": row2[1]} if row2 else None, + "count": count, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/sql-composed") +def test_sql_composed(): + """Test psycopg.sql.SQL() composed queries.""" + try: + from psycopg import sql + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + table = sql.Identifier("users") + columns = sql.SQL(", ").join([sql.Identifier("id"), sql.Identifier("name"), sql.Identifier("email")]) + + query = sql.SQL("SELECT {} FROM {} ORDER BY id LIMIT 3").format(columns, table) + cur.execute(query) + rows = cur.fetchall() + + return jsonify({"count": len(rows), "data": [{"id": r[0], "name": r[1], "email": r[2]} for r in rows]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/binary-uuid") +def test_binary_uuid(): + """Test binary UUID data type. + + Tests whether the instrumentation correctly handles binary UUID data types. + """ + try: + import uuid + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create a temp table with UUID column + cur.execute("CREATE TEMP TABLE uuid_test (id UUID PRIMARY KEY, name TEXT)") + + # Insert a UUID + test_uuid = uuid.uuid4() + cur.execute("INSERT INTO uuid_test (id, name) VALUES (%s, %s) RETURNING id, name", (test_uuid, "UUID Test")) + inserted = cur.fetchone() + + # Query it back + cur.execute("SELECT id, name FROM uuid_test WHERE id = %s", (test_uuid,)) + queried = cur.fetchone() + + conn.commit() + + return jsonify( + { + "inserted_uuid": str(inserted[0]) if inserted else None, + "queried_uuid": str(queried[0]) if queried else None, + "match": str(inserted[0]) == str(queried[0]) if inserted and queried else False, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/binary-bytea") +def test_binary_bytea(): + """Test binary bytea data type. + + Tests whether the instrumentation correctly handles binary bytea data types. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create a temp table with bytea column + cur.execute("CREATE TEMP TABLE bytea_test (id SERIAL, data BYTEA)") + + # Insert binary data + test_data = b"\x00\x01\x02\x03\xff\xfe\xfd" + cur.execute("INSERT INTO bytea_test (data) VALUES (%s) RETURNING id, data", (test_data,)) + inserted = cur.fetchone() + + conn.commit() + + # Convert bytes to hex for JSON serialization + return jsonify( + { + "inserted_id": inserted[0] if inserted else None, + "data_hex": inserted[1].hex() if inserted and inserted[1] else None, + "data_length": len(inserted[1]) if inserted and inserted[1] else 0, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/class-row-factory") +def test_class_row_factory(): + """Test class_row row factory. + + Tests whether the instrumentation correctly handles class_row factories + which return instances of a custom class. + """ + try: + from dataclasses import dataclass + + from psycopg.rows import class_row + + @dataclass + class User: + id: int + name: str + email: str + + with psycopg.connect(get_conn_string(), row_factory=class_row(User)) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify({"count": len(rows), "data": [{"id": r.id, "name": r.name, "email": r.email} for r in rows]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/kwargs-row-factory") +def test_kwargs_row_factory(): + """Test kwargs_row row factory. + + Tests whether the instrumentation correctly handles kwargs_row factories + which call a function with keyword arguments. + """ + try: + from psycopg.rows import kwargs_row + + def make_user_dict(**kwargs): + return {"user_data": kwargs, "processed": True} + + with psycopg.connect(get_conn_string(), row_factory=kwargs_row(make_user_dict)) as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, name, email FROM users ORDER BY id LIMIT 3") + rows = cur.fetchall() + + return jsonify( + { + "count": len(rows), + "data": rows, # Already processed dictionaries + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/scalar-row-factory") +def test_scalar_row_factory(): + """Test scalar_row row factory. + + Tests whether the instrumentation correctly handles scalar_row factories + which return just the first column value. + """ + try: + from psycopg.rows import scalar_row + + with psycopg.connect(get_conn_string(), row_factory=scalar_row) as conn: + with conn.cursor() as cur: + cur.execute("SELECT name FROM users ORDER BY id LIMIT 5") + rows = cur.fetchall() + + return jsonify( + { + "count": len(rows), + "data": rows, # Just names as scalars + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/binary-format") +def test_binary_format(): + """Test execute with binary=True parameter. + + Tests whether the instrumentation correctly handles binary format transfers. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Execute with binary=True + cur.execute("SELECT id, name FROM users ORDER BY id LIMIT 3", binary=True) + rows = cur.fetchall() + + return jsonify({"count": len(rows), "data": [{"id": r[0], "name": r[1]} for r in rows]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/null-values") +def test_null_values(): + """Test handling of NULL values in results.""" + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with nullable columns + cur.execute(""" + CREATE TEMP TABLE null_test ( + id INT, + nullable_text TEXT, + nullable_int INT, + nullable_bool BOOLEAN + ) + """) + + # Insert rows with NULL values + cur.execute(""" + INSERT INTO null_test VALUES + (1, 'has_value', 42, TRUE), + (2, NULL, NULL, NULL), + (3, 'another', NULL, FALSE) + """) + + # Query rows + cur.execute("SELECT * FROM null_test ORDER BY id") + rows = cur.fetchall() + conn.commit() + + return jsonify( + { + "count": len(rows), + "data": [ + {"id": r[0], "nullable_text": r[1], "nullable_int": r[2], "nullable_bool": r[3]} for r in rows + ], + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/transaction-context") +def test_transaction_context(): + """Test conn.transaction() context manager.""" + try: + results = [] + with psycopg.connect(get_conn_string()) as conn: + # Use explicit transaction + with conn.transaction(): + with conn.cursor() as cur: + cur.execute("CREATE TEMP TABLE tx_test (id INT, val TEXT)") + cur.execute("INSERT INTO tx_test VALUES (1, 'first')") + cur.execute("SELECT * FROM tx_test") + rows = cur.fetchall() + results.append({"phase": "inside_transaction", "rows": [list(r) for r in rows]}) + + # After transaction commit + with conn.cursor() as cur: + cur.execute("SELECT * FROM tx_test") + rows = cur.fetchall() + results.append({"phase": "after_commit", "rows": [list(r) for r in rows]}) + + return jsonify({"results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/json-jsonb") +def test_json_jsonb(): + """Test JSON and JSONB data types.""" + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with JSON columns + cur.execute(""" + CREATE TEMP TABLE json_test ( + id INT, + json_col JSON, + jsonb_col JSONB + ) + """) + + # Insert JSON data + import json + + test_json = {"name": "test", "values": [1, 2, 3], "nested": {"key": "value"}} + cur.execute("INSERT INTO json_test VALUES (%s, %s, %s)", (1, json.dumps(test_json), json.dumps(test_json))) + + # Query back + cur.execute("SELECT * FROM json_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify({"id": row[0], "json_col": row[1], "jsonb_col": row[2]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/array-types") +def test_array_types(): + """Test PostgreSQL array types.""" + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with array columns + cur.execute(""" + CREATE TEMP TABLE array_test ( + id INT, + int_array INTEGER[], + text_array TEXT[] + ) + """) + + # Insert array data + cur.execute("INSERT INTO array_test VALUES (%s, %s, %s)", (1, [10, 20, 30], ["a", "b", "c"])) + + # Query back + cur.execute("SELECT * FROM array_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify( + { + "id": row[0], + "int_array": list(row[1]) if row[1] else None, + "text_array": list(row[2]) if row[2] else None, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/cursor-set-result") +def test_cursor_set_result(): + """Test cursor.set_result() method. + + This tests whether the instrumentation correctly handles + set_result() for navigating between result sets. + """ + try: + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table and insert with returning + cur.execute("CREATE TEMP TABLE setresult_test (id SERIAL, val TEXT)") + + cur.executemany( + "INSERT INTO setresult_test (val) VALUES (%s) RETURNING id, val", + [("First",), ("Second",), ("Third",)], + returning=True, + ) + + # Use set_result to navigate to specific result sets + results = [] + + # Go to the last result set + cur.set_result(-1) + row = cur.fetchone() + results.append({"set": "last", "data": {"id": row[0], "val": row[1]} if row else None}) + + # Go to the first result set + cur.set_result(0) + row = cur.fetchone() + results.append({"set": "first", "data": {"id": row[0], "val": row[1]} if row else None}) + + conn.commit() + + return jsonify({"results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/decimal-types") +def test_decimal_types(): + """Test Decimal/numeric types.""" + try: + from decimal import Decimal + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with numeric columns + cur.execute(""" + CREATE TEMP TABLE decimal_test ( + id INT, + price DECIMAL(10, 2), + rate DECIMAL(18, 8) + ) + """) + + # Insert decimal data + cur.execute("INSERT INTO decimal_test VALUES (%s, %s, %s)", (1, Decimal("123.45"), Decimal("0.00000001"))) + + # Query back + cur.execute("SELECT * FROM decimal_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify( + {"id": row[0], "price": str(row[1]) if row[1] else None, "rate": str(row[2]) if row[2] else None} + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/date-time-types") +def test_date_time_types(): + """Test date/time types.""" + try: + from datetime import date, time, timedelta + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with date/time columns + cur.execute(""" + CREATE TEMP TABLE datetime_test ( + id INT, + birth_date DATE, + wake_time TIME, + duration INTERVAL + ) + """) + + # Insert date/time data + cur.execute( + "INSERT INTO datetime_test VALUES (%s, %s, %s, %s)", + (1, date(1990, 5, 15), time(8, 30, 0), timedelta(hours=2, minutes=30)), + ) + + # Query back + cur.execute("SELECT * FROM datetime_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify( + { + "id": row[0], + "birth_date": str(row[1]) if row[1] else None, + "wake_time": str(row[2]) if row[2] else None, + "duration": str(row[3]) if row[3] else None, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/inet-cidr-types") +def test_inet_cidr_types(): + """Test PostgreSQL inet/cidr network types.""" + try: + from ipaddress import IPv4Address, IPv4Network + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with network columns + cur.execute(""" + CREATE TEMP TABLE network_test ( + id INT, + ip_addr INET, + network CIDR + ) + """) + + # Insert network data + cur.execute("INSERT INTO network_test VALUES (%s, %s, %s)", (1, "192.168.1.100", "10.0.0.0/8")) + + # Query back + cur.execute("SELECT * FROM network_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify( + {"id": row[0], "ip_addr": str(row[1]) if row[1] else None, "network": str(row[2]) if row[2] else None} + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/range-types") +def test_range_types(): + """Test PostgreSQL range types.""" + try: + from psycopg.types.range import Range + + with psycopg.connect(get_conn_string()) as conn, conn.cursor() as cur: + # Create temp table with range columns + cur.execute(""" + CREATE TEMP TABLE range_test ( + id INT, + int_range INT4RANGE, + ts_range TSRANGE + ) + """) + + # Insert range data + from datetime import datetime + + int_range = Range(1, 10) + ts_range = Range(datetime(2024, 1, 1, 0, 0), datetime(2024, 12, 31, 23, 59)) + + cur.execute("INSERT INTO range_test VALUES (%s, %s, %s)", (1, int_range, ts_range)) + + # Query back + cur.execute("SELECT * FROM range_test WHERE id = 1") + row = cur.fetchone() + conn.commit() + + return jsonify( + {"id": row[0], "int_range": str(row[1]) if row[1] else None, "ts_range": str(row[2]) if row[2] else None} + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py index 629647c..f861884 100644 --- a/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Psycopg Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Psycopg test request sequence...\n") @@ -64,4 +46,46 @@ def make_request(method, endpoint, **kwargs): if user_id: make_request("DELETE", f"/db/delete/{user_id}") - print("\nAll requests completed successfully") + make_request("GET", "/test/cursor-stream") + make_request("GET", "/test/server-cursor") + make_request("GET", "/test/copy-to") + make_request("GET", "/test/multiple-queries") + make_request("GET", "/test/pipeline-mode") + make_request("GET", "/test/dict-row-factory") + make_request("GET", "/test/namedtuple-row-factory") + make_request("GET", "/test/cursor-iteration") + make_request("GET", "/test/executemany-returning") + make_request("GET", "/test/rownumber") + make_request("GET", "/test/statusmessage") + make_request("GET", "/test/nextset") + make_request("GET", "/test/server-cursor-scroll") + make_request("GET", "/test/cursor-scroll") + make_request("GET", "/test/cursor-reuse") + make_request("GET", "/test/sql-composed") + make_request("GET", "/test/binary-uuid") + make_request("GET", "/test/binary-bytea") + make_request("GET", "/test/class-row-factory") + make_request("GET", "/test/kwargs-row-factory") + make_request("GET", "/test/scalar-row-factory") + make_request("GET", "/test/binary-format") + + # Test: NULL values handling (integrated into E2E suite) + make_request("GET", "/test/null-values") + + # Test: Transaction context manager + make_request("GET", "/test/transaction-context") + + # JSON/JSONB and array types tests + make_request("GET", "/test/json-jsonb") + make_request("GET", "/test/array-types") + make_request("GET", "/test/cursor-set-result") + + # These tests expose hash mismatch bugs with Decimal and date/time types + make_request("GET", "/test/decimal-types") + make_request("GET", "/test/date-time-types") + + # These tests expose serialization bugs with inet/cidr and range types + make_request("GET", "/test/inet-cidr-types") + make_request("GET", "/test/range-types") + + print_request_summary() diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 08bb4d5..cb30a0c 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -2,7 +2,9 @@ import json import logging -import time +import weakref +from collections.abc import Iterator +from contextlib import contextmanager from types import ModuleType from typing import Any @@ -11,214 +13,27 @@ from opentelemetry.trace import Status from opentelemetry.trace import StatusCode as OTelStatusCode -from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper from ...core.mode_utils import handle_record_mode, handle_replay_mode from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( - CleanSpanData, - Duration, PackageType, SpanKind, - SpanStatus, - StatusCode, - Timestamp, TuskDriftMode, - replay_trace_id_context, ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value +from ..utils.serialization import serialize_value +from .mocks import MockConnection, MockCopy +from .wrappers import TracedCopyWrapper logger = logging.getLogger(__name__) _instance: PsycopgInstrumentation | None = None -class MockLoader: - """Mock loader for psycopg3.""" - - def __init__(self): - self.timezone = None # Django expects this attribute - - def __call__(self, data): - """No-op load function.""" - return data - - -class MockDumper: - """Mock dumper for psycopg3.""" - - def __call__(self, obj): - """No-op dump function.""" - return str(obj).encode("utf-8") - - -class MockAdapters: - """Mock adapters for psycopg3 connection.""" - - def get_loader(self, oid, format): - """Return a mock loader.""" - return MockLoader() - - def get_dumper(self, obj, format): - """Return a mock dumper.""" - return MockDumper() - - def register_loader(self, oid, loader): - """No-op register loader for Django compatibility.""" - pass - - def register_dumper(self, oid, dumper): - """No-op register dumper for Django compatibility.""" - pass - - -class MockConnection: - """Mock database connection for REPLAY mode when postgres is not available. - - Provides minimal interface for Django/Flask to work without a real database. - All queries are mocked at the cursor.execute() level. - """ - - def __init__(self, sdk: TuskDrift, instrumentation: PsycopgInstrumentation, cursor_factory): - self.sdk = sdk - self.instrumentation = instrumentation - self.cursor_factory = cursor_factory - self.closed = False - self.autocommit = False - - # Django/psycopg3 requires these for connection initialization - self.isolation_level = None - self.encoding = "UTF8" - self.adapters = MockAdapters() - self.pgconn = None # Mock pg connection object - - # Create a comprehensive mock info object for Django - class MockInfo: - vendor = "postgresql" - server_version = 150000 # PostgreSQL 15.0 as integer - encoding = "UTF8" - - def parameter_status(self, param): - """Return mock parameter status.""" - if param == "TimeZone": - return "UTC" - elif param == "server_version": - return "15.0" - return None - - self.info = MockInfo() - - logger.debug("[MOCK_CONNECTION] Created mock connection for REPLAY mode (psycopg3)") - - def cursor(self, name=None, cursor_factory=None): - """Create a cursor using the instrumented cursor factory.""" - # For mock connections, we create a MockCursor directly - cursor = MockCursor(self) - - # Wrap execute/executemany for mock cursor - instrumentation = self.instrumentation - sdk = self.sdk - - def mock_execute(query, params=None, **kwargs): - # For mock cursor, original_execute is just a no-op - def noop_execute(q, p, **kw): - return cursor - - return instrumentation._traced_execute(cursor, noop_execute, sdk, query, params, **kwargs) - - def mock_executemany(query, params_seq, **kwargs): - # For mock cursor, original_executemany is just a no-op - def noop_executemany(q, ps, **kw): - return cursor - - return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) - - # Monkey-patch mock functions onto cursor - cursor.execute = mock_execute # type: ignore[method-assign] - cursor.executemany = mock_executemany # type: ignore[method-assign] - - logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") - return cursor - - def commit(self): - """Mock commit - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] commit() called (no-op)") - pass - - def rollback(self): - """Mock rollback - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] rollback() called (no-op)") - pass - - def close(self): - """Mock close - no-op in REPLAY mode.""" - logger.debug("[MOCK_CONNECTION] close() called (no-op)") - self.closed = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None: - self.rollback() - else: - self.commit() - return False - - -class MockCursor: - """Mock cursor for when we can't create a real cursor from base class. - - This is a fallback when the connection is completely mocked. - """ - - def __init__(self, connection): - self.connection = connection - self.rowcount = -1 - self._tusk_description = None # Store mock description - self.arraysize = 1 - self._mock_rows = [] - self._mock_index = 0 - self.adapters = MockAdapters() # Django needs this - logger.debug("[MOCK_CURSOR] Created fallback mock cursor (psycopg3)") - - @property - def description(self): - return self._tusk_description - - def execute(self, query, params=None, **kwargs): - """Will be replaced by instrumentation.""" - logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") - return self - - def executemany(self, query, params_seq, **kwargs): - """Will be replaced by instrumentation.""" - logger.debug(f"[MOCK_CURSOR] executemany() called: {query[:100]}") - return self - - def fetchone(self): - return None - - def fetchmany(self, size=None): - return [] - - def fetchall(self): - return [] - - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - class PsycopgInstrumentation(InstrumentationBase): """Instrumentation for psycopg (psycopg3) PostgreSQL client library. @@ -234,6 +49,8 @@ def __init__(self, enabled: bool = True) -> None: enabled=enabled, ) self._original_connect = None + # Track pending pipeline spans per connection for deferred finalization + self._pending_pipeline_spans: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() _instance = self def patch(self, module: ModuleType) -> None: @@ -256,13 +73,22 @@ def patched_connect(*args, **kwargs): return original_connect(*args, **kwargs) user_cursor_factory = kwargs.pop("cursor_factory", None) + user_row_factory = kwargs.pop("row_factory", None) cursor_factory = instrumentation._create_cursor_factory(sdk, user_cursor_factory) + # Create server cursor factory for named cursors (conn.cursor(name="...")) + server_cursor_factory = instrumentation._create_server_cursor_factory(sdk) + # In REPLAY mode, try to connect but fall back to mock connection if DB is unavailable if sdk.mode == TuskDriftMode.REPLAY: try: kwargs["cursor_factory"] = cursor_factory + if user_row_factory is not None: + kwargs["row_factory"] = user_row_factory connection = original_connect(*args, **kwargs) + # Set server cursor factory on the connection for named cursors + if server_cursor_factory: + connection.server_cursor_factory = server_cursor_factory logger.info("[PATCHED_CONNECT] REPLAY mode: Successfully connected to database (psycopg3)") return connection except Exception as e: @@ -270,17 +96,69 @@ def patched_connect(*args, **kwargs): f"[PATCHED_CONNECT] REPLAY mode: Database connection failed ({e}), using mock connection (psycopg3)" ) # Return mock connection that doesn't require a real database - return MockConnection(sdk, instrumentation, cursor_factory) + return MockConnection(sdk, instrumentation, cursor_factory, row_factory=user_row_factory) # In RECORD mode, always require real connection kwargs["cursor_factory"] = cursor_factory + if user_row_factory is not None: + kwargs["row_factory"] = user_row_factory connection = original_connect(*args, **kwargs) + # Set server cursor factory on the connection for named cursors + if server_cursor_factory: + connection.server_cursor_factory = server_cursor_factory logger.debug("[PATCHED_CONNECT] RECORD mode: Connected to database (psycopg3)") return connection module.connect = patched_connect # type: ignore[attr-defined] logger.debug("psycopg.connect instrumented") + # Patch Pipeline class for pipeline mode support + self._patch_pipeline_class(module) + + def _patch_pipeline_class(self, module: ModuleType) -> None: + """Patch psycopg.Pipeline to finalize spans on sync/exit.""" + try: + from psycopg import Pipeline + except ImportError: + logger.debug("psycopg.Pipeline not available, skipping pipeline instrumentation") + return + + instrumentation = self + + # Store originals for potential unpatch + self._original_pipeline_sync = getattr(Pipeline, "sync", None) + self._original_pipeline_exit = getattr(Pipeline, "__exit__", None) + + if self._original_pipeline_sync: + original_sync = self._original_pipeline_sync + + def patched_sync(pipeline_self): + """Patched Pipeline.sync that finalizes pending spans.""" + result = original_sync(pipeline_self) + # _conn is the connection associated with the pipeline + conn = getattr(pipeline_self, "_conn", None) + if conn: + instrumentation._finalize_pending_pipeline_spans(conn) + return result + + Pipeline.sync = patched_sync + logger.debug("psycopg.Pipeline.sync instrumented") + + if self._original_pipeline_exit: + original_exit = self._original_pipeline_exit + + def patched_exit(pipeline_self, exc_type, exc_val, exc_tb): + """Patched Pipeline.__exit__ that finalizes any remaining spans.""" + result = original_exit(pipeline_self, exc_type, exc_val, exc_tb) + # Finalize any remaining pending spans (handles implicit sync on exit) + conn = getattr(pipeline_self, "_conn", None) + if conn: + instrumentation._finalize_pending_pipeline_spans(conn) + return result + + Pipeline.__exit__ = patched_exit + logger.debug("psycopg.Pipeline.__exit__ instrumented") + def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): """Create a cursor factory that wraps cursors with instrumentation. @@ -300,6 +178,8 @@ def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): base = base_factory or BaseCursor class InstrumentedCursor(base): # type: ignore + """Instrumented cursor with tracing support.""" + _tusk_description = None # Store mock description for replay mode @property @@ -309,14 +189,161 @@ def description(self): return self._tusk_description return super().description + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, "_mock_statusmessage"): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, "fetchone") and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + def execute(self, query, params=None, **kwargs): return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) def executemany(self, query, params_seq, **kwargs): return instrumentation._traced_executemany(self, super().executemany, sdk, query, params_seq, **kwargs) + def stream(self, query, params=None, **kwargs): + return instrumentation._traced_stream(self, super().stream, sdk, query, params, **kwargs) + + def copy(self, query, params=None, **kwargs): + return instrumentation._traced_copy(self, super().copy, sdk, query, params, **kwargs) + return InstrumentedCursor + def _create_server_cursor_factory(self, sdk: TuskDrift, base_factory=None): + """Create a server cursor factory that wraps ServerCursor with instrumentation. + + Returns a cursor CLASS (psycopg3 expects a class, not a function). + ServerCursor is used when conn.cursor(name="...") is called. + """ + instrumentation = self + logger.debug(f"[CURSOR_FACTORY] Creating server cursor factory, sdk.mode={sdk.mode}") + + try: + from psycopg import ServerCursor as BaseServerCursor + except ImportError: + logger.warning("[CURSOR_FACTORY] Could not import psycopg.ServerCursor") + return None + + base = base_factory or BaseServerCursor + + class InstrumentedServerCursor(base): # type: ignore + """Instrumented server cursor with tracing support. + + Note: ServerCursor doesn't support executemany(). + Note: ServerCursor has stream-like iteration via fetchmany/itersize. + """ + + _tusk_description = None # Store mock description for replay mode + + @property + def description(self): + # In replay mode, return mock description if set; otherwise use base + if self._tusk_description is not None: + return self._tusk_description + return super().description + + @property + def rownumber(self): + # In captured mode (after fetchall in _finalize_query_span), return tracked index + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self._tusk_index + # In replay mode with mock data, return mock index + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self._mock_index + # Otherwise, return real cursor's rownumber + return super().rownumber + + @property + def statusmessage(self): + # In replay mode with mock data, return mock statusmessage + if hasattr(self, "_mock_statusmessage"): + return self._mock_statusmessage + # Otherwise, return real cursor's statusmessage + return super().statusmessage + + def __iter__(self): + # Support direct cursor iteration (for row in cursor) + # In replay mode with mock data (_mock_rows) or record mode with captured data (_tusk_rows) + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + return self + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + return self + return super().__iter__() + + def __next__(self): + # In replay mode with mock data, iterate over mock rows + if hasattr(self, "_mock_rows") and self._mock_rows is not None: + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + # Apply row transformation if fetchone is patched + if hasattr(self, "fetchone") and callable(self.fetchone): + # Reset index, get transformed row, restore index + self._mock_index -= 1 + result = self.fetchone() + return result + return tuple(row) if isinstance(row, list) else row + raise StopIteration + # In record mode with captured data, iterate over stored rows + if hasattr(self, "_tusk_rows") and self._tusk_rows is not None: + if self._tusk_index < len(self._tusk_rows): + row = self._tusk_rows[self._tusk_index] + self._tusk_index += 1 + return row + raise StopIteration + return super().__next__() + + def execute(self, query, params=None, **kwargs): + # Note: ServerCursor.execute() doesn't support 'prepare' parameter + return instrumentation._traced_execute(self, super().execute, sdk, query, params, **kwargs) + + return InstrumentedServerCursor + def _traced_execute( self, cursor: Any, original_execute: Any, sdk: TuskDrift, query: str, params=None, **kwargs ) -> Any: @@ -350,29 +377,13 @@ def _noop_execute(self, cursor: Any) -> Any: def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any) -> Any: """Handle REPLAY mode for execute - fetch mock from CLI.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "query") if not span_info: raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -398,21 +409,19 @@ def _record_execute( kwargs: dict, ) -> Any: """Handle RECORD mode for execute - create span and execute query.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + # Reset cursor state from any previous execute() on this cursor. + # Delete instance attribute overrides to expose original class methods. + # This is safer than saving/restoring bound methods which can become stale. + if hasattr(cursor, "_tusk_patched"): + # Remove patched instance attributes to expose class methods + for attr in ("fetchone", "fetchmany", "fetchall", "scroll"): + if attr in cursor.__dict__: + delattr(cursor, attr) + cursor._tusk_rows = None + cursor._tusk_index = 0 + del cursor._tusk_patched + + span_info = self._create_query_span(sdk, "query", is_pre_app_start) if not span_info: # Fallback to original call if span creation fails @@ -421,6 +430,9 @@ def _record_execute( error = None result = None + # Check if we're in pipeline mode BEFORE executing + in_pipeline_mode = self._is_in_pipeline_mode(cursor) + with SpanUtils.with_span(span_info): try: result = original_execute(query, params, **kwargs) @@ -429,14 +441,36 @@ def _record_execute( error = e raise finally: - self._finalize_query_span( - span_info.span, - cursor, - query_str, - params, - error, - ) - span_info.span.end() + try: + if error is not None: + # Always finalize immediately on error + self._finalize_query_span(span_info.span, cursor, query_str, params, error) + span_info.span.end() + elif in_pipeline_mode: + # Defer finalization until pipeline.sync() + connection = self._get_connection_from_cursor(cursor) + if connection: + self._add_pending_pipeline_span(connection, span_info, cursor, query_str, params) + # DON'T end span here - will be ended in _finalize_pending_pipeline_spans + else: + # Fallback: finalize immediately if we can't get connection + self._finalize_query_span(span_info.span, cursor, query_str, params, None) + span_info.span.end() + else: + # Normal mode: finalize immediately (unless lazy capture was set up) + span_finalized = self._finalize_query_span(span_info.span, cursor, query_str, params, None) + if span_finalized: + # Span was fully finalized, end it now + span_info.span.end() + # If span_finalized is False, lazy capture was set up and span will be + # ended when user code calls a fetch method + except Exception as e: + logger.error(f"Error in span finalization: {e}") + # Ensure span is ended even if finalization fails + try: + span_info.span.end() + except Exception: + pass def _traced_executemany( self, cursor: Any, original_executemany: Any, sdk: TuskDrift, query: str, params_seq, **kwargs @@ -448,10 +482,12 @@ def _traced_executemany( query_str = self._query_to_string(query, cursor) # Convert to list BEFORE executing to avoid iterator exhaustion params_list = list(params_seq) + # Detect returning flag for executemany with RETURNING clause + returning = kwargs.get("returning", False) if sdk.mode == TuskDriftMode.REPLAY: return handle_replay_mode( - replay_mode_handler=lambda: self._replay_executemany(cursor, sdk, query_str, params_list), + replay_mode_handler=lambda: self._replay_executemany(cursor, sdk, query_str, params_list, returning), no_op_request_handler=lambda: self._noop_execute(cursor), is_server_request=False, ) @@ -460,36 +496,27 @@ def _traced_executemany( return handle_record_mode( original_function_call=lambda: original_executemany(query, params_list, **kwargs), record_mode_handler=lambda is_pre_app_start: self._record_executemany( - cursor, original_executemany, sdk, query, query_str, params_list, is_pre_app_start, kwargs + cursor, original_executemany, sdk, query, query_str, params_list, is_pre_app_start, kwargs, returning ), span_kind=OTelSpanKind.CLIENT, ) - def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, params_list: list) -> Any: + def _replay_executemany( + self, cursor: Any, sdk: TuskDrift, query_str: str, params_list: list, returning: bool = False + ) -> Any: """Handle REPLAY mode for executemany - fetch mock from CLI.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: not sdk.app_ready, - }, - is_pre_app_start=not sdk.app_ready, - ) - ) + span_info = self._create_query_span(sdk, "query") if not span_info: raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id, span_info.parent_span_id - ) + # Include returning flag in parameters for mock matching + params_for_mock = {"_batch": params_list} + if returning: + params_for_mock["_returning"] = True + + mock_result = self._try_get_mock(sdk, query_str, params_for_mock, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -502,7 +529,13 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param f"Query: {query_str[:100]}..." ) - self._mock_execute_with_data(cursor, mock_result) + # Check if this is executemany_returning format (multiple result sets) + if mock_result.get("executemany_returning"): + self._mock_executemany_returning_with_data(cursor, mock_result) + else: + # Backward compatible: use existing single result set handling + self._mock_execute_with_data(cursor, mock_result) + span_info.span.end() return cursor @@ -516,23 +549,10 @@ def _record_executemany( params_list: list, is_pre_app_start: bool, kwargs: dict, + returning: bool = False, ) -> Any: """Handle RECORD mode for executemany - create span and execute query.""" - span_info = SpanUtils.create_span( - CreateSpanOptions( - name="psycopg.query", - kind=OTelSpanKind.CLIENT, - attributes={ - TdSpanAttributes.NAME: "psycopg.query", - TdSpanAttributes.PACKAGE_NAME: "psycopg", - TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", - TdSpanAttributes.SUBMODULE_NAME: "query", - TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, - TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, - }, - is_pre_app_start=is_pre_app_start, - ) - ) + span_info = self._create_query_span(sdk, "query", is_pre_app_start) if not span_info: # Fallback to original call if span creation fails @@ -549,164 +569,1038 @@ def _record_executemany( error = e raise finally: - self._finalize_query_span( - span_info.span, - cursor, - query_str, - {"_batch": params_list}, - error, - ) - span_info.span.end() + try: + if returning and error is None: + # Use specialized method for executemany with returning=True + self._finalize_executemany_returning_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list, "_returning": True}, + error, + ) + span_info.span.end() + else: + # Existing behavior for executemany without returning + span_finalized = self._finalize_query_span( + span_info.span, + cursor, + query_str, + {"_batch": params_list}, + error, + ) + if span_finalized: + span_info.span.end() + # Note: executemany without returning typically has no results, + # so lazy capture is unlikely but we handle it for safety + except Exception as e: + logger.error(f"Error in span finalization: {e}") + # Ensure span is ended even if finalization fails + try: + span_info.span.end() + except Exception: + pass + + def _traced_stream( + self, cursor: Any, original_stream: Any, sdk: TuskDrift, query: str, params=None, **kwargs + ) -> Any: + """Traced cursor.stream method.""" + if sdk.mode == TuskDriftMode.DISABLED: + return original_stream(query, params, **kwargs) - def _query_to_string(self, query: Any, cursor: Any) -> str: - """Convert query to string.""" - try: - from psycopg.sql import Composed + query_str = self._query_to_string(query, cursor) - if isinstance(query, Composed): - return query.as_string(cursor) - except ImportError: - pass + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_stream(cursor, sdk, query_str, params), + no_op_request_handler=lambda: iter([]), # Empty iterator for background requests + is_server_request=False, + ) - return str(query) if not isinstance(query, str) else query + # RECORD mode + return handle_record_mode( + original_function_call=lambda: original_stream(query, params, **kwargs), + record_mode_handler=lambda is_pre_app_start: self._record_stream( + cursor, original_stream, sdk, query, query_str, params, is_pre_app_start, kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) - def _try_get_mock( + def _record_stream( self, + cursor: Any, + original_stream: Any, sdk: TuskDrift, query: str, + query_str: str, params: Any, - trace_id: str, - span_id: str, - parent_span_id: str | None, - ) -> dict[str, Any] | None: - """Try to get a mocked response from CLI. - - Returns: - Mocked response data if found, None otherwise - """ - try: - # Build input value - input_value = { - "query": query.strip(), - } - if params is not None: - input_value["parameters"] = params - - # Generate schema and hashes for CLI matching - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + is_pre_app_start: bool, + kwargs: dict, + ): + """Handle RECORD mode for stream - wrap generator with tracing.""" + span_info = self._create_query_span(sdk, "query", is_pre_app_start) - # Create mock span for matching - timestamp_ms = time.time() * 1000 - timestamp_seconds = int(timestamp_ms // 1000) - timestamp_nanos = int((timestamp_ms % 1000) * 1_000_000) + if not span_info: + yield from original_stream(query, params, **kwargs) + return - # Create mock span for matching - # NOTE: Schemas must be None to avoid betterproto map serialization issues - # The CLI only needs the hashes for matching anyway, not the full schemas - mock_span = CleanSpanData( - trace_id=trace_id, - span_id=span_id, - parent_span_id=parent_span_id or "", - name="psycopg.query", - package_name="psycopg", - package_type=PackageType.PG, - instrumentation_name="PsycopgInstrumentation", - submodule_name="query", - input_value=input_value, - output_value=None, - input_schema=None, # type: ignore - Must be None to avoid betterproto serialization issues - output_schema=None, # type: ignore - Must be None to avoid betterproto serialization issues - input_schema_hash=input_result.decoded_schema_hash, - output_schema_hash="", - input_value_hash=input_result.decoded_value_hash, - output_value_hash="", - stack_trace="", # Empty in REPLAY mode - kind=SpanKind.CLIENT, - status=SpanStatus(code=StatusCode.OK, message=""), - timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos), - duration=Duration(seconds=0, nanos=0), - is_root_span=False, - is_pre_app_start=not sdk.app_ready, - ) + rows_collected = [] + error = None - # Request mock from CLI - replay_trace_id = replay_trace_id_context.get() + try: + with SpanUtils.with_span(span_info): + for row in original_stream(query, params, **kwargs): + rows_collected.append(row) + yield row + except Exception as e: + error = e + raise + finally: + try: + self._finalize_stream_span(span_info.span, cursor, query_str, params, rows_collected, error) + except Exception as e: + logger.error(f"Error in stream span finalization: {e}") + try: + span_info.span.end() + except Exception: + pass + span_info.span.end() - mock_request = MockRequestInput( - test_id=replay_trace_id or "", - outbound_span=mock_span, - ) + def _replay_stream(self, cursor: Any, sdk: TuskDrift, query_str: str, params: Any): + """Handle REPLAY mode for stream - return mock generator.""" + span_info = self._create_query_span(sdk, "query") - logger.debug(f"Requesting mock from CLI for query: {query[:50]}...") - mock_response_output = sdk.request_mock_sync(mock_request) - logger.debug(f"CLI returned: found={mock_response_output.found}") + if not span_info: + raise RuntimeError("Error creating span in replay mode") - if not mock_response_output.found: - logger.debug(f"No mock found for psycopg query: {query[:100]}") - return None + with SpanUtils.with_span(span_info): + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) - return mock_response_output.response + if mock_result is None: + is_pre_app_start = not sdk.app_ready + raise RuntimeError( + f"[Tusk REPLAY] No mock found for psycopg stream query. " + f"This {'pre-app-start ' if is_pre_app_start else ''}query was not recorded. " + f"Query: {query_str[:100]}..." + ) - except Exception as e: - logger.error(f"Error getting mock for psycopg query: {e}") - return None + # Deserialize and yield rows from mock + rows = mock_result.get("rows", []) + for row in rows: + deserialized = deserialize_db_value(row) + yield tuple(deserialized) if isinstance(deserialized, list) else deserialized - def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> None: - """Mock cursor execute by setting internal state.""" - # The SDK communicator already extracts response.body from the CLI's MockInteraction structure - # So mock_data should already contain: {"rowcount": ..., "description": [...], "rows": [...]} - actual_data = mock_data - logger.debug(f"[MOCK_DATA] mock_data: {mock_data}") + span_info.span.end() - try: - cursor._rowcount = actual_data.get("rowcount", -1) - except AttributeError: - object.__setattr__(cursor, "rowcount", actual_data.get("rowcount", -1)) + def _finalize_stream_span( + self, + span: trace.Span, + cursor: Any, + query: str, + params: Any, + rows: list, + error: Exception | None, + ) -> None: + """Finalize span for stream operation with collected rows.""" + try: + # Build input value + input_value = { + "query": query.strip(), + } + if params is not None: + input_value["parameters"] = serialize_value(params) - description_data = actual_data.get("description") - if description_data: - desc = [(col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data] - # Set mock description - InstrumentedCursor has _tusk_description property - # MockCursor uses regular description attribute + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Use pre-collected rows (unlike _finalize_query_span which calls fetchall) + serialized_rows = [[serialize_value(col) for col in row] for row in rows] + + output_value = { + "rowcount": len(rows), + } + + if serialized_rows: + output_value["rows"] = serialized_rows + + self._set_span_attributes(span, input_value, output_value) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Stream span finalized successfully") + + except Exception as e: + logger.error(f"Error finalizing stream span: {e}") + + def _traced_copy(self, cursor: Any, original_copy: Any, sdk: TuskDrift, query: str, params=None, **kwargs) -> Any: + """Traced cursor.copy method - returns a context manager.""" + if sdk.mode == TuskDriftMode.DISABLED: + return original_copy(query, params, **kwargs) + + query_str = self._query_to_string(query, cursor) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_copy(cursor, sdk, query_str), + no_op_request_handler=lambda: self._noop_copy(), + is_server_request=False, + ) + + # RECORD mode - return a context manager that wraps the copy operation + return self._record_copy(cursor, original_copy, sdk, query, query_str, params, kwargs) + + @contextmanager + def _noop_copy(self) -> Iterator[MockCopy]: + """Handle background requests in REPLAY mode - return empty mock.""" + yield MockCopy(data=[]) + + @contextmanager + def _record_copy( + self, + cursor: Any, + original_copy: Any, + sdk: TuskDrift, + query: str, + query_str: str, + params: Any, + kwargs: dict, + ) -> Iterator[TracedCopyWrapper]: + """Handle RECORD mode for copy - wrap Copy object with tracing.""" + span_info = self._create_query_span(sdk, "copy") + + if not span_info: + # Fallback to original if span creation fails + with original_copy(query, params, **kwargs) as copy: + yield copy + return + + error = None + data_collected: list = [] + + try: + with SpanUtils.with_span(span_info): + with original_copy(query, params, **kwargs) as copy: + # Wrap the Copy object to capture data + wrapped_copy = TracedCopyWrapper(copy, data_collected) + yield wrapped_copy + except Exception as e: + error = e + raise + finally: + try: + self._finalize_copy_span( + span_info.span, + query_str, + data_collected, + error, + ) + span_info.span.end() + except Exception as e: + logger.error(f"Error in copy span finalization: {e}") + try: + span_info.span.end() + except Exception: + pass + + @contextmanager + def _replay_copy(self, cursor: Any, sdk: TuskDrift, query_str: str) -> Iterator[MockCopy]: + """Handle REPLAY mode for copy - return mock Copy object.""" + span_info = self._create_query_span(sdk, "copy") + + if not span_info: + raise RuntimeError("Error creating span in replay mode") + + with SpanUtils.with_span(span_info): + mock_result = self._try_get_copy_mock(sdk, query_str, span_info.trace_id, span_info.span_id) + + if mock_result is None: + is_pre_app_start = not sdk.app_ready + raise RuntimeError( + f"[Tusk REPLAY] No mock found for psycopg copy operation. " + f"This {'pre-app-start ' if is_pre_app_start else ''}copy was not recorded. " + f"Query: {query_str[:100]}..." + ) + + # Yield a mock copy object with recorded data + mock_copy = MockCopy(mock_result.get("data", [])) + yield mock_copy + + span_info.span.end() + + def _try_get_copy_mock( + self, + sdk: TuskDrift, + query: str, + trace_id: str, + span_id: str, + ) -> dict[str, Any] | None: + """Try to get a mocked response for copy operation from CLI.""" + try: + # Determine operation type from query + query_upper = query.upper() + is_copy_to = "TO" in query_upper and "STDOUT" in query_upper + is_copy_from = "FROM" in query_upper and "STDIN" in query_upper + + input_value = { + "query": query.strip(), + "operation": "COPY_TO" if is_copy_to else "COPY_FROM" if is_copy_from else "COPY", + } + + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync + + mock_response_output = find_mock_response_sync( + sdk=sdk, + trace_id=trace_id, + span_id=span_id, + name="psycopg.copy", + package_name="psycopg", + package_type=PackageType.PG, + instrumentation_name="PsycopgInstrumentation", + submodule_name="copy", + input_value=input_value, + kind=SpanKind.CLIENT, + is_pre_app_start=not sdk.app_ready, + ) + + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for psycopg copy: {query[:100]}") + return None + + return mock_response_output.response + + except Exception as e: + logger.error(f"Error getting mock for psycopg copy: {e}") + return None + + def _finalize_copy_span( + self, + span: trace.Span, + query: str, + data_collected: list, + error: Exception | None, + ) -> None: + """Finalize span for copy operation.""" + try: + # Determine operation type from query + query_upper = query.upper() + is_copy_to = "TO" in query_upper and "STDOUT" in query_upper + is_copy_from = "FROM" in query_upper and "STDIN" in query_upper + + # Build input value + input_value = { + "query": query.strip(), + "operation": "COPY_TO" if is_copy_to else "COPY_FROM" if is_copy_from else "COPY", + } + + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Serialize the captured data + serialized_data = [serialize_value(d) for d in data_collected] + output_value = { + "data": serialized_data, + "chunk_count": len(data_collected), + } + + self._set_span_attributes(span, input_value, output_value) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Copy span finalized successfully") + + except Exception as e: + logger.error(f"Error finalizing copy span: {e}") + + def _query_to_string(self, query: Any, cursor: Any) -> str: + """Convert query to string.""" + try: + from psycopg.sql import Composed + + if isinstance(query, Composed): + return query.as_string(cursor) + except ImportError: + pass + + return str(query) if not isinstance(query, str) else query + + def _create_query_span(self, sdk: TuskDrift, submodule: str = "query", is_pre_app_start: bool | None = None): + """Create a span for psycopg operations. + + This helper reduces code duplication across replay/record methods. + + Args: + sdk: The TuskDrift instance + submodule: The submodule name ("query" or "copy") + is_pre_app_start: Override for pre-app-start flag; if None, derived from sdk.app_ready + + Returns: + SpanInfo object or None if span creation fails + """ + if is_pre_app_start is None: + is_pre_app_start = not sdk.app_ready + span_name = f"psycopg.{submodule}" + return SpanUtils.create_span( + CreateSpanOptions( + name=span_name, + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: span_name, + TdSpanAttributes.PACKAGE_NAME: "psycopg", + TdSpanAttributes.INSTRUMENTATION_NAME: "PsycopgInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: submodule, + TdSpanAttributes.PACKAGE_TYPE: PackageType.PG.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + def _create_fetch_methods(self, cursor: Any, rows_attr: str, index_attr: str, transform_row=None): + """Create fetch method closures for cursor mocking. + + This helper reduces code duplication in mock/replay cursor setup. + + Args: + cursor: The cursor object to operate on + rows_attr: Attribute name for stored rows (e.g., '_mock_rows', '_tusk_rows') + index_attr: Attribute name for current index (e.g., '_mock_index', '_tusk_index') + transform_row: Optional function to transform each row before returning + + Returns: + Tuple of (fetchone, fetchmany, fetchall) functions + """ + + def fetchone(): + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + if idx < len(rows): + row = rows[idx] + setattr(cursor, index_attr, idx + 1) + return transform_row(row) if transform_row else row + return None + + def fetchmany(size=None): + if size is None: + size = cursor.arraysize + result = [] + for _ in range(size): + row = fetchone() + if row is None: + break + result.append(row) + return result + + def fetchall(): + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + remaining = rows[idx:] + setattr(cursor, index_attr, len(rows)) + if transform_row: + return [transform_row(row) for row in remaining] + return list(remaining) + + return fetchone, fetchmany, fetchall + + def _create_scroll_method(self, cursor: Any, rows_attr: str, index_attr: str): + """Create scroll method closure for cursor mocking. + + Args: + cursor: The cursor object to operate on + rows_attr: Attribute name for stored rows + index_attr: Attribute name for current index + + Returns: + scroll function + """ + + def scroll(value: int, mode: str = "relative") -> None: + rows = getattr(cursor, rows_attr) + idx = getattr(cursor, index_attr) + if mode == "relative": + newpos = idx + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(rows) + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + setattr(cursor, index_attr, newpos) + + return scroll + + def _get_row_factory_from_cursor(self, cursor: Any): + """Get row_factory from cursor or its connection. + + Args: + cursor: The cursor object + + Returns: + The row_factory or None if not found + """ + row_factory = getattr(cursor, "row_factory", None) + if row_factory is None: + conn = getattr(cursor, "connection", None) + if conn: + row_factory = getattr(conn, "row_factory", None) + return row_factory + + def _set_cursor_description(self, cursor: Any, description_data: list | None) -> None: + """Set description on cursor from description data. + + Args: + cursor: The cursor object + description_data: List of column description dicts with 'name' and 'type_code' keys + """ + if not description_data: + return + + desc = [(col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data] + try: + cursor._tusk_description = desc + except AttributeError: try: - cursor._tusk_description = desc + cursor.description = desc except AttributeError: - # For MockCursor, set description directly + pass + + def _create_row_transformer(self, row_factory_type: str, column_names: list | None): + """Create a row transformation function based on row factory type. + + Args: + row_factory_type: The detected row factory type ('dict', 'namedtuple', etc.) + column_names: List of column names for the result set + + Returns: + A function that transforms a raw row into the appropriate format + """ + RowClass = None + if row_factory_type in ("namedtuple", "class") and column_names: + from collections import namedtuple + + RowClass = namedtuple("Row", column_names) + + def transform_row(row): + """Transform raw row data according to row factory type.""" + if row_factory_type == "kwargs": + return row + if row_factory_type == "scalar": + return row[0] if isinstance(row, list) and len(row) > 0 else row + values = tuple(row) if isinstance(row, list) else row + if row_factory_type == "dict" and column_names: + return dict(zip(column_names, values, strict=False)) + elif row_factory_type in ("namedtuple", "class") and RowClass is not None: + return RowClass(*values) + return values + + return transform_row + + def _set_span_attributes( + self, + span: trace.Span, + input_value: dict, + output_value: dict, + ) -> None: + """Set span attributes for input/output values with schemas and hashes. + + This helper method centralizes the repeated pattern of: + 1. Generating schemas and hashes for input/output values + 2. Setting all span attributes (INPUT_VALUE, OUTPUT_VALUE, schemas, hashes) + + Args: + span: The OpenTelemetry span to set attributes on + input_value: The input data dictionary (query, parameters, etc.) + output_value: The output data dictionary (rows, rowcount, error, etc.) + """ + input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) + span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) + span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + + def _detect_row_factory_type(self, row_factory: Any) -> str: + """Detect the type of row factory for mock transformations. + + Returns: + "dict" for dict_row, "namedtuple" for namedtuple_row, + "class" for class_row, "tuple" otherwise + """ + if row_factory is None: + return "tuple" + + # Check by function/class name + factory_name = getattr(row_factory, "__name__", "") + if not factory_name: + factory_name = str(type(row_factory).__name__) + + factory_name_lower = factory_name.lower() + if "dict" in factory_name_lower: + return "dict" + elif "namedtuple" in factory_name_lower: + return "namedtuple" + elif "kwargs" in factory_name_lower: + return "kwargs" + elif "scalar" in factory_name_lower: + return "scalar" + elif "class" in factory_name_lower: + return "class" + + return "tuple" + + def _is_in_pipeline_mode(self, cursor: Any) -> bool: + """Check if the cursor's connection is currently in pipeline mode. + + In psycopg3, when conn.pipeline() is active, connection._pipeline is set. + """ + try: + conn = getattr(cursor, "connection", None) + if conn is None: + return False + # MockConnection doesn't have real pipeline mode + if isinstance(conn, MockConnection): + return False + pipeline = getattr(conn, "_pipeline", None) + return pipeline is not None + except Exception: + return False + + def _get_connection_from_cursor(self, cursor: Any) -> Any: + """Get the connection object from a cursor.""" + return getattr(cursor, "connection", None) + + def _add_pending_pipeline_span( + self, + connection: Any, + span_info: Any, + cursor: Any, + query: str, + params: Any, + ) -> None: + """Add a pending span to be finalized when pipeline syncs.""" + if connection not in self._pending_pipeline_spans: + self._pending_pipeline_spans[connection] = [] + + self._pending_pipeline_spans[connection].append( + { + "span_info": span_info, + "cursor": cursor, + "query": query, + "params": params, + } + ) + logger.debug(f"[PIPELINE] Deferred span for query: {query[:50]}...") + + def _finalize_pending_pipeline_spans(self, connection: Any) -> None: + """Finalize all pending spans for a connection after pipeline sync.""" + if connection not in self._pending_pipeline_spans: + return + + pending = self._pending_pipeline_spans.pop(connection, []) + logger.debug(f"[PIPELINE] Finalizing {len(pending)} pending pipeline spans") + + for item in pending: + span_info = item["span_info"] + cursor = item["cursor"] + query = item["query"] + params = item["params"] + + try: + span_finalized = self._finalize_query_span(span_info.span, cursor, query, params, error=None) + if span_finalized: + span_info.span.end() + # If lazy capture was set up, span will be ended when user fetches + except Exception as e: + logger.error(f"[PIPELINE] Error finalizing deferred span: {e}") try: - cursor.description = desc - except AttributeError: + span_info.span.end() + except Exception: pass + def _try_get_mock( + self, + sdk: TuskDrift, + query: str, + params: Any, + trace_id: str, + span_id: str, + ) -> dict[str, Any] | None: + """Try to get a mocked response from CLI. + + Returns: + Mocked response data if found, None otherwise + """ + try: + # Build input value + input_value = { + "query": query.strip(), + } + if params is not None: + # Serialize parameters to ensure consistent hashing with RECORD mode + input_value["parameters"] = serialize_value(params) + + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync + + mock_response_output = find_mock_response_sync( + sdk=sdk, + trace_id=trace_id, + span_id=span_id, + name="psycopg.query", + package_name="psycopg", + package_type=PackageType.PG, + instrumentation_name="PsycopgInstrumentation", + submodule_name="query", + input_value=input_value, + kind=SpanKind.CLIENT, + is_pre_app_start=not sdk.app_ready, + ) + + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for psycopg query: {query[:100]}") + return None + + return mock_response_output.response + + except Exception as e: + logger.error(f"Error getting mock for psycopg query: {e}") + return None + + def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> None: + """Mock cursor execute by setting internal state.""" + # The SDK communicator already extracts response.body from the CLI's MockInteraction structure + # So mock_data should already contain: {"rowcount": ..., "description": [...], "rows": [...]} + actual_data = mock_data + + try: + cursor._rowcount = actual_data.get("rowcount", -1) + except AttributeError: + object.__setattr__(cursor, "rowcount", actual_data.get("rowcount", -1)) + + description_data = actual_data.get("description") + self._set_cursor_description(cursor, description_data) + + # Set mock statusmessage for replay + statusmessage = actual_data.get("statusmessage") + if statusmessage is not None: + cursor._mock_statusmessage = statusmessage + + # Get row_factory and detect type for row transformation + row_factory = self._get_row_factory_from_cursor(cursor) + row_factory_type = self._detect_row_factory_type(row_factory) + + # Extract column names from description for row factory transformations + column_names = [col["name"] for col in description_data] if description_data else None + + # Create row transformer using helper + transform_row = self._create_row_transformer(row_factory_type, column_names) + mock_rows = actual_data.get("rows", []) # Deserialize datetime strings back to datetime objects for consistent Flask serialization mock_rows = [deserialize_db_value(row) for row in mock_rows] cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue] cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] - def mock_fetchone(): - if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] - row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] - cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] - return tuple(row) if isinstance(row, list) else row + # Use helper methods to create fetch and scroll methods + fetchone, fetchmany, fetchall = self._create_fetch_methods(cursor, "_mock_rows", "_mock_index", transform_row) + cursor.fetchone = fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = fetchall # pyright: ignore[reportAttributeAccessIssue] + + cursor.scroll = self._create_scroll_method(cursor, "_mock_rows", "_mock_index") # pyright: ignore[reportAttributeAccessIssue] + + # Note: __iter__ and __next__ are handled at the class level in InstrumentedCursor + # and MockCursor classes, as Python looks up special methods on the type, not instance + + def _mock_executemany_returning_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> None: + """Mock cursor for executemany with returning=True - supports multiple result sets. + + This method sets up the cursor to replay multiple result sets captured during + executemany with returning=True. It patches the cursor's results() method to + yield the cursor for each result set, allowing iteration. + """ + result_sets = mock_data.get("result_sets", []) + + if not result_sets: + # Fallback to empty result + cursor._mock_rows = [] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + return + + # Get row_factory and detect type using helpers + row_factory = self._get_row_factory_from_cursor(cursor) + row_factory_type = self._detect_row_factory_type(row_factory) + + # Store all result sets for iteration + cursor._mock_result_sets = [] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + for result_set in result_sets: + description_data = result_set.get("description") + column_names = None + if description_data: + column_names = [col["name"] for col in description_data] + + # Deserialize rows + mock_rows = result_set.get("rows", []) + mock_rows = [deserialize_db_value(row) for row in mock_rows] + + cursor._mock_result_sets.append( # pyright: ignore[reportAttributeAccessIssue] + { + "description": description_data, + "column_names": column_names, + "rows": mock_rows, + "rowcount": result_set.get("rowcount", -1), + } + ) + + # Create row transformation helper + def create_row_class(col_names): + if row_factory_type == "namedtuple" and col_names: + from collections import namedtuple + + return namedtuple("Row", col_names) return None - def mock_fetchmany(size=cursor.arraysize): - rows = [] - for _ in range(size): - row = mock_fetchone() - if row is None: - break - rows.append(row) - return rows + def transform_row(row, col_names, RowClass): + """Transform raw row data according to row factory type.""" + if row_factory_type == "kwargs": + # kwargs_row: return stored dict as-is + return row + values = tuple(row) if isinstance(row, list) else row + if row_factory_type == "dict" and col_names: + return dict(zip(col_names, values, strict=False)) + elif row_factory_type == "namedtuple" and RowClass is not None: + return RowClass(*values) + return values + + def mock_results(): + """Generator that yields cursor for each result set.""" + while cursor._mock_result_set_index < len(cursor._mock_result_sets): # pyright: ignore[reportAttributeAccessIssue] + result_set = cursor._mock_result_sets[cursor._mock_result_set_index] # pyright: ignore[reportAttributeAccessIssue] + + # Set up cursor state for this result set + cursor._mock_rows = result_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Set description + description_data = result_set.get("description") + if description_data: + desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) for col in description_data + ] + try: + cursor._tusk_description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + # Set rowcount + try: + cursor._rowcount = result_set.get("rowcount", -1) # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + column_names = result_set.get("column_names") + RowClass = create_row_class(column_names) + + # Create fetch methods for this result set with closures capturing current values + def make_fetchone(cn, RC): + def fetchone(): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return transform_row(row, cn, RC) + return None + + return fetchone + + def make_fetchmany(cn, RC): + def fetchmany(size=cursor.arraysize): + rows = [] + for _ in range(size): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + rows.append(transform_row(row, cn, RC)) + else: + break + return rows + + return fetchmany + + def make_fetchall(cn, RC): + def fetchall(): + rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] + return [transform_row(row, cn, RC) for row in rows] + + return fetchall + + cursor.fetchone = make_fetchone(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall(column_names, RowClass) # pyright: ignore[reportAttributeAccessIssue] + + cursor._mock_result_set_index += 1 # pyright: ignore[reportAttributeAccessIssue] + yield cursor + + # Patch results() method onto cursor + cursor.results = mock_results # pyright: ignore[reportAttributeAccessIssue] + + # Also set up initial state for the first result set (in case user calls fetch without results()) + if cursor._mock_result_sets: # pyright: ignore[reportAttributeAccessIssue] + first_set = cursor._mock_result_sets[0] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = first_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Set description for first result set using helper + self._set_cursor_description(cursor, first_set.get("description")) + + # Set up initial fetch methods for the first result set (for code that uses nextset() instead of results()) + first_column_names = first_set.get("column_names") + FirstRowClass = create_row_class(first_column_names) + + def make_fetchone_replay(cn, RC): + def fetchone(): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return transform_row(row, cn, RC) + return None + + return fetchone + + def make_fetchmany_replay(cn, RC): + def fetchmany(size=cursor.arraysize): + rows = [] + for _ in range(size): + if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue] + rows.append(transform_row(row, cn, RC)) + else: + break + return rows + + return fetchmany + + def make_fetchall_replay(cn, RC): + def fetchall(): + rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] + return [transform_row(row, cn, RC) for row in rows] + + return fetchall + + cursor.fetchone = make_fetchone_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Patch nextset() to work with _mock_result_sets + def patched_nextset(): + """Move to the next result set in _mock_result_sets.""" + next_index = cursor._mock_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue] + if next_index < len(cursor._mock_result_sets): # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue] + next_set = cursor._mock_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = next_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the new result set + next_column_names = next_set.get("column_names") + NextRowClass = create_row_class(next_column_names) + cursor.fetchone = make_fetchone_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Update description for next result set + next_description_data = next_set.get("description") + if next_description_data: + next_desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in next_description_data + ] + try: + cursor._tusk_description = next_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = next_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + + return True + return None + + cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] + + # Patch set_result() to work with _mock_result_sets + def patched_set_result(index: int): + """Navigate to a specific result set by index (supports negative indices).""" + num_results = len(cursor._mock_result_sets) # pyright: ignore[reportAttributeAccessIssue] + if not -num_results <= index < num_results: + raise IndexError(f"index {index} out of range: {num_results} result(s) available") + if index < 0: + index = num_results + index + + cursor._mock_result_set_index = index # pyright: ignore[reportAttributeAccessIssue] + target_set = cursor._mock_result_sets[index] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_rows = target_set["rows"] # pyright: ignore[reportAttributeAccessIssue] + cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the target result set + target_column_names = target_set.get("column_names") + TargetRowClass = create_row_class(target_column_names) + cursor.fetchone = make_fetchone_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_fetchmany_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_fetchall_replay(target_column_names, TargetRowClass) # pyright: ignore[reportAttributeAccessIssue] + + # Update description for target result set + target_description_data = target_set.get("description") + if target_description_data: + target_desc = [ + (col["name"], col.get("type_code"), None, None, None, None, None) + for col in target_description_data + ] + try: + cursor._tusk_description = target_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + try: + cursor.description = target_desc # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass - def mock_fetchall(): - rows = cursor._mock_rows[cursor._mock_index :] # pyright: ignore[reportAttributeAccessIssue] - cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue] - return [tuple(row) if isinstance(row, list) else row for row in rows] + return cursor - cursor.fetchone = mock_fetchone # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchmany = mock_fetchmany # pyright: ignore[reportAttributeAccessIssue] - cursor.fetchall = mock_fetchall # pyright: ignore[reportAttributeAccessIssue] + cursor.set_result = patched_set_result # pyright: ignore[reportAttributeAccessIssue] def _finalize_query_span( self, @@ -715,24 +1609,13 @@ def _finalize_query_span( query: str, params: Any, error: Exception | None, - ) -> None: - """Finalize span with query data.""" - try: - # Helper function to serialize non-JSON types - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val + ) -> bool: + """Finalize span with query data. + Returns True if span was fully finalized, False if lazy capture was set up + (meaning caller should NOT end the span - it will be ended by lazy fetch). + """ + try: # Build input value input_value = { "query": query.strip(), @@ -753,8 +1636,8 @@ def serialize_value(val): else: # Get query results and capture for replay try: - rows = [] description = None + row_factory_type = "tuple" # default # Try to fetch results if available if hasattr(cursor, "description") and cursor.description: @@ -768,19 +1651,320 @@ def serialize_value(val): for desc in cursor.description ] - # Fetch all rows for recording - # We need to capture these for replay mode + # Get row factory from cursor or connection + row_factory = getattr(cursor, "row_factory", None) + if row_factory is None: + conn = getattr(cursor, "connection", None) + if conn: + row_factory = getattr(conn, "row_factory", None) + + # Detect row factory type BEFORE processing rows + row_factory_type = self._detect_row_factory_type(row_factory) + column_names = [d["name"] for d in description] + + # Use LAZY CAPTURE to avoid hanging with binary=True and other edge cases. + # Instead of calling fetchall() immediately (which can hang), we set up + # wrappers that capture results when the user first calls a fetch method. + # Store context needed for lazy capture + cursor._tusk_lazy_span = span # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_input_value = input_value # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_description = description # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_row_factory_type = row_factory_type # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_column_names = column_names # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_lazy_instrumentation = self # pyright: ignore[reportAttributeAccessIssue] + + # Set up lazy capture wrappers + self._setup_lazy_capture(cursor) + + logger.debug("[PSYCOPG] Lazy capture set up, deferring span finalization") + return False # Signal caller NOT to end span + + # No description means no results expected (e.g., INSERT without RETURNING) + output_value = { + "rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1, + } + + # Capture statusmessage for replay + if hasattr(cursor, "statusmessage") and cursor.statusmessage is not None: + output_value["statusmessage"] = cursor.statusmessage + + except Exception as e: + logger.debug(f"Error getting query metadata: {e}") + + self._set_span_attributes(span, input_value, output_value) + + if not error: + span.set_status(Status(OTelStatusCode.OK)) + + logger.debug("[PSYCOPG] Span finalized successfully") + return True # Span fully finalized + + except Exception as e: + logger.error(f"Error creating query span: {e}") + return True # Return True to end span on error + + def _setup_lazy_capture(self, cursor: Any) -> None: + """Set up lazy capture wrappers on cursor fetch methods. + + These wrappers defer the actual fetchall() call until the user's code + requests results. This avoids issues with binary format and other cases + where calling fetchall() immediately after execute() can hang. + """ + # Get references to original fetch methods from the cursor's class + # (not instance methods which might already be patched) + cursor_class = type(cursor) + original_fetchall = cursor_class.fetchall + original_scroll = cursor_class.scroll if hasattr(cursor_class, "scroll") else None + + def do_lazy_capture(): + """Perform the actual capture - called on first fetch.""" + if hasattr(cursor, "_tusk_rows") and cursor._tusk_rows is not None: + return # Already captured + + try: + # Get the actual rows from psycopg + all_rows = original_fetchall(cursor) + + # Store for subsequent fetch calls + cursor._tusk_rows = all_rows # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Process rows for trace capture + description = cursor._tusk_lazy_description # pyright: ignore[reportAttributeAccessIssue] + row_factory_type = cursor._tusk_lazy_row_factory_type # pyright: ignore[reportAttributeAccessIssue] + column_names = cursor._tusk_lazy_column_names # pyright: ignore[reportAttributeAccessIssue] + + rows = [] + for row in all_rows: + if row_factory_type == "kwargs": + rows.append(row) + elif row_factory_type == "scalar": + rows.append([row]) + elif row_factory_type == "class" or hasattr(row, "__dataclass_fields__"): + # dataclass (from class_row): extract values by attribute name + rows.append([getattr(row, col, None) for col in column_names]) + elif isinstance(row, dict): + rows.append([row.get(col) for col in column_names]) + elif hasattr(row, "_fields"): + # namedtuple: extract values by field name + rows.append([getattr(row, col, None) for col in column_names]) + else: + rows.append(list(row)) + + # Finalize the span with captured data + span = cursor._tusk_lazy_span # pyright: ignore[reportAttributeAccessIssue] + input_value = cursor._tusk_lazy_input_value # pyright: ignore[reportAttributeAccessIssue] + instrumentation = cursor._tusk_lazy_instrumentation # pyright: ignore[reportAttributeAccessIssue] + + output_value = { + "rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1, + } + + if description: + output_value["description"] = description + + if rows: + if row_factory_type == "kwargs": + serialized_rows = [serialize_value(row) for row in rows] + else: + serialized_rows = [[serialize_value(col) for col in row] for row in rows] + output_value["rows"] = serialized_rows + + if hasattr(cursor, "statusmessage") and cursor.statusmessage is not None: + output_value["statusmessage"] = cursor.statusmessage + + instrumentation._set_span_attributes(span, input_value, output_value) + + span.set_status(Status(OTelStatusCode.OK)) + span.end() + + logger.debug("[PSYCOPG] Lazy capture completed, span finalized") + + except Exception as e: + logger.error(f"Error in lazy capture: {e}") + # Try to end span even on error + try: + span = cursor._tusk_lazy_span + span.set_status(Status(OTelStatusCode.ERROR, str(e))) + span.end() + except Exception: + pass + # Re-raise the original exception so the user sees the actual database error + raise + + finally: + # Clean up lazy capture attributes + for attr in ( + "_tusk_lazy_span", + "_tusk_lazy_input_value", + "_tusk_lazy_description", + "_tusk_lazy_row_factory_type", + "_tusk_lazy_column_names", + "_tusk_lazy_instrumentation", + ): + if hasattr(cursor, attr): try: - all_rows = cursor.fetchall() - # Convert tuples to lists for JSON serialization - rows = [list(row) for row in all_rows] + delattr(cursor, attr) + except AttributeError: + pass + + def lazy_fetchone(): + do_lazy_capture() + if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return row + return None - # CRITICAL: Re-populate cursor so user code can still fetch - # We'll store the rows and patch fetch methods - cursor._tusk_rows = all_rows # pyright: ignore[reportAttributeAccessIssue] + def lazy_fetchmany(size=None): + do_lazy_capture() + if size is None: + size = cursor.arraysize + result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] + return result + + def lazy_fetchall(): + do_lazy_capture() + result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + return result + + def lazy_scroll(value: int, mode: str = "relative") -> None: + do_lazy_capture() + if mode == "relative": + newpos = cursor._tusk_index + value # pyright: ignore[reportAttributeAccessIssue] + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("cursor position out of range") + elif newpos != 0: + raise IndexError("cursor position out of range") + + cursor._tusk_index = newpos # pyright: ignore[reportAttributeAccessIssue] + + # Patch the cursor with lazy wrappers + cursor.fetchone = lazy_fetchone # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = lazy_fetchmany # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = lazy_fetchall # pyright: ignore[reportAttributeAccessIssue] + if original_scroll: + cursor.scroll = lazy_scroll # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_patched = True # pyright: ignore[reportAttributeAccessIssue] + + def _finalize_executemany_returning_span( + self, + span: trace.Span, + cursor: Any, + query: str, + params: Any, + error: Exception | None, + ) -> None: + """Finalize span for executemany with returning=True - captures multiple result sets. + + This method iterates through cursor.results() to capture all result sets + from executemany with returning=True, storing them in a format that can + be replayed with multiple result set iteration. + """ + try: + # Build input value + input_value = { + "query": query.strip(), + } + if params is not None: + input_value["parameters"] = serialize_value(params) + + # Build output value + output_value = {} + + if error: + output_value = { + "errorName": type(error).__name__, + "errorMessage": str(error), + } + span.set_status(Status(OTelStatusCode.ERROR, str(error))) + else: + # Iterate through cursor.results() to capture all result sets + result_sets = [] + all_rows_collected = [] # For re-populating cursor + + try: + # cursor.results() yields the cursor itself for each result set + for result_cursor in cursor.results(): + result_set_data = {} + + # Capture description for this result set + if hasattr(result_cursor, "description") and result_cursor.description: + description = [ + { + "name": desc[0] if hasattr(desc, "__getitem__") else desc.name, + "type_code": desc[1] + if hasattr(desc, "__getitem__") and len(desc) > 1 + else getattr(desc, "type_code", None), + } + for desc in result_cursor.description + ] + result_set_data["description"] = description + column_names = [d["name"] for d in description] + else: + description = None + column_names = None + + # Fetch all rows for this result set + rows = [] + raw_rows = result_cursor.fetchall() + all_rows_collected.append(raw_rows) + + for row in raw_rows: + if isinstance(row, dict): + rows.append( + [row.get(col) for col in column_names] if column_names else list(row.values()) + ) + elif hasattr(row, "_fields"): + rows.append( + [getattr(row, str(col), None) for col in column_names] + if column_names + else list(row) + ) + else: + rows.append(list(row)) + + result_set_data["rowcount"] = ( + result_cursor.rowcount if hasattr(result_cursor, "rowcount") else len(rows) + ) + result_set_data["rows"] = [[serialize_value(col) for col in row] for row in rows] + + result_sets.append(result_set_data) + + except Exception as results_error: + logger.debug(f"Could not iterate results(): {results_error}") + # Fallback: treat as single result set + result_sets = [] + + if result_sets: + output_value = { + "executemany_returning": True, + "result_sets": result_sets, + } + + # Re-populate cursor for user code + # Store all collected rows for replay via results() + cursor._tusk_result_sets = all_rows_collected # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Patch results() method to iterate stored result sets + def patched_results(): + while cursor._tusk_result_set_index < len(cursor._tusk_result_sets): # pyright: ignore[reportAttributeAccessIssue] + rows = cursor._tusk_result_sets[cursor._tusk_result_set_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = rows # pyright: ignore[reportAttributeAccessIssue] cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index += 1 # pyright: ignore[reportAttributeAccessIssue] - # Replace with our versions that return stored rows + # Patch fetch methods for this result set def patched_fetchone(): if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] @@ -802,43 +1986,96 @@ def patched_fetchall(): cursor.fetchmany = patched_fetchmany # pyright: ignore[reportAttributeAccessIssue] cursor.fetchall = patched_fetchall # pyright: ignore[reportAttributeAccessIssue] - except Exception as fetch_error: - logger.debug(f"Could not fetch rows (query might not return rows): {fetch_error}") - rows = [] + yield cursor - output_value = { - "rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1, - } + cursor.results = patched_results # pyright: ignore[reportAttributeAccessIssue] - if description: - output_value["description"] = description + # Set up the first result set immediately for user code that uses nextset() instead of results() + if all_rows_collected: + cursor._tusk_rows = all_rows_collected[0] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue] - if rows: - # Convert rows to JSON-serializable format (handle datetime objects, etc.) - serialized_rows = [[serialize_value(col) for col in row] for row in rows] - output_value["rows"] = serialized_rows + # Create initial fetch methods for the first result set + def make_patched_fetchone_record(): + def patched_fetchone(): + if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue] + row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue] + return row + return None - except Exception as e: - logger.debug(f"Error getting query metadata: {e}") + return patched_fetchone + + def make_patched_fetchmany_record(): + def patched_fetchmany(size=cursor.arraysize): + result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue] + return result + + return patched_fetchmany + + def make_patched_fetchall_record(): + def patched_fetchall(): + result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue] + return result + + return patched_fetchall + + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + + # Patch nextset() to work with _tusk_result_sets + def patched_nextset(): + """Move to the next result set in _tusk_result_sets.""" + next_index = cursor._tusk_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue] + if next_index < len(cursor._tusk_result_sets): # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = cursor._tusk_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the new result set + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + return True + return None + + cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue] - # Generate schemas and hashes - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) - output_result = JsonSchemaHelper.generate_schema_and_hash(output_value, {}) + # Patch set_result() to work with _tusk_result_sets + def patched_set_result_record(index: int): + """Navigate to a specific result set by index (supports negative indices).""" + num_results = len(cursor._tusk_result_sets) # pyright: ignore[reportAttributeAccessIssue] + if not -num_results <= index < num_results: + raise IndexError(f"index {index} out of range: {num_results} result(s) available") + if index < 0: + index = num_results + index - # Set span attributes - span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(input_value)) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(output_value)) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA, json.dumps(input_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA, json.dumps(output_result.schema.to_primitive())) - span.set_attribute(TdSpanAttributes.INPUT_SCHEMA_HASH, input_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_SCHEMA_HASH, output_result.decoded_schema_hash) - span.set_attribute(TdSpanAttributes.INPUT_VALUE_HASH, input_result.decoded_value_hash) - span.set_attribute(TdSpanAttributes.OUTPUT_VALUE_HASH, output_result.decoded_value_hash) + cursor._tusk_result_set_index = index # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_rows = cursor._tusk_result_sets[index] # pyright: ignore[reportAttributeAccessIssue] + cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue] + + # Update fetch methods for the target result set + cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue] + cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue] + + return cursor + + cursor.set_result = patched_set_result_record # pyright: ignore[reportAttributeAccessIssue] + + else: + output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1} + + self._set_span_attributes(span, input_value, output_value) if not error: span.set_status(Status(OTelStatusCode.OK)) - logger.debug("[PSYCOPG] Span finalized successfully") + logger.debug("[PSYCOPG] Executemany returning span finalized successfully") except Exception as e: - logger.error(f"Error creating query span: {e}") + logger.error(f"Error finalizing executemany returning span: {e}") diff --git a/drift/instrumentation/psycopg/mocks.py b/drift/instrumentation/psycopg/mocks.py new file mode 100644 index 0000000..0709aed --- /dev/null +++ b/drift/instrumentation/psycopg/mocks.py @@ -0,0 +1,385 @@ +"""Mock classes for psycopg3 REPLAY mode. + +These mock classes provide a minimal interface for Django/Flask to work +without a real PostgreSQL database connection during replay. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...core.drift_sdk import TuskDrift + from .instrumentation import PsycopgInstrumentation + +logger = logging.getLogger(__name__) + + +class MockLoader: + """Mock loader for psycopg3.""" + + def __init__(self): + self.timezone = None # Django expects this attribute + + def __call__(self, data): + """No-op load function.""" + return data + + +class MockDumper: + """Mock dumper for psycopg3.""" + + def __call__(self, obj): + """No-op dump function.""" + return str(obj).encode("utf-8") + + +class MockAdapters: + """Mock adapters for psycopg3 connection.""" + + def get_loader(self, oid, format): + """Return a mock loader.""" + return MockLoader() + + def get_dumper(self, obj, format): + """Return a mock dumper.""" + return MockDumper() + + def register_loader(self, oid, loader): + """No-op register loader for Django compatibility.""" + pass + + def register_dumper(self, oid, dumper): + """No-op register dumper for Django compatibility.""" + pass + + +class MockConnection: + """Mock database connection for REPLAY mode when postgres is not available. + + Provides minimal interface for Django/Flask to work without a real database. + All queries are mocked at the cursor.execute() level. + """ + + def __init__( + self, + sdk: TuskDrift, + instrumentation: PsycopgInstrumentation, + cursor_factory, + row_factory=None, + ): + self.sdk = sdk + self.instrumentation = instrumentation + self.cursor_factory = cursor_factory + self.row_factory = row_factory # Store row_factory for cursor creation + self.closed = False + self.autocommit = False + + # Django/psycopg3 requires these for connection initialization + self.isolation_level = None + self.encoding = "UTF8" + self.adapters = MockAdapters() + self.pgconn = None # Mock pg connection object + + # Create a comprehensive mock info object for Django + class MockInfo: + vendor = "postgresql" + server_version = 150000 # PostgreSQL 15.0 as integer + encoding = "UTF8" + + def parameter_status(self, param): + """Return mock parameter status.""" + if param == "TimeZone": + return "UTC" + elif param == "server_version": + return "15.0" + return None + + self.info = MockInfo() + + logger.debug("[MOCK_CONNECTION] Created mock connection for REPLAY mode (psycopg3)") + + def cursor(self, name=None, *, cursor_factory=None, **kwargs): + """Create a cursor using the instrumented cursor factory. + + Accepts the same parameters as psycopg's Connection.cursor(), including + server cursor parameters like scrollable and withhold. + """ + # For mock connections, we create a MockCursor directly + # The name parameter is accepted but not used since mock cursors + # behave the same for both regular and server cursors + cursor = MockCursor(self) + + # Wrap execute/executemany for mock cursor + instrumentation = self.instrumentation + sdk = self.sdk + + def mock_execute(query, params=None, **kwargs): + # For mock cursor, original_execute is just a no-op + def noop_execute(q, p, **kw): + return cursor + + return instrumentation._traced_execute(cursor, noop_execute, sdk, query, params, **kwargs) + + def mock_executemany(query, params_seq, **kwargs): + # For mock cursor, original_executemany is just a no-op + def noop_executemany(q, ps, **kw): + return cursor + + return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) + + def mock_stream(query, params=None, **kwargs): + # For mock cursor, original_stream is just a no-op generator + def noop_stream(q, p, **kw): + return iter([]) + + return instrumentation._traced_stream(cursor, noop_stream, sdk, query, params, **kwargs) + + def mock_copy(query, params=None, **kwargs): + # For mock cursor, original_copy is a no-op context manager + @contextmanager + def noop_copy(q, p=None, **kw): + yield MockCopy([]) + + return instrumentation._traced_copy(cursor, noop_copy, sdk, query, params, **kwargs) + + # Monkey-patch mock functions onto cursor + cursor.execute = mock_execute # type: ignore[method-assign] + cursor.executemany = mock_executemany # type: ignore[method-assign] + cursor.stream = mock_stream # type: ignore[method-assign] + cursor.copy = mock_copy # type: ignore[method-assign] + + logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") + return cursor + + def commit(self): + """Mock commit - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] commit() called (no-op)") + pass + + def rollback(self): + """Mock rollback - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] rollback() called (no-op)") + pass + + def close(self): + """Mock close - no-op in REPLAY mode.""" + logger.debug("[MOCK_CONNECTION] close() called (no-op)") + self.closed = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.rollback() + else: + self.commit() + return False + + def pipeline(self): + """Return a mock pipeline context manager for REPLAY mode.""" + return MockPipeline(self) + + +class MockCursor: + """Mock cursor for when we can't create a real cursor from base class. + + This is a fallback when the connection is completely mocked. + """ + + def __init__(self, connection): + self.connection = connection + self.rowcount = -1 + self._tusk_description = None # Store mock description + self.arraysize = 1 + self._mock_rows = [] + self._mock_index = 0 + # Support for multiple result sets (executemany with returning=True) + self._mock_result_sets = [] + self._mock_result_set_index = 0 + self.adapters = MockAdapters() # Django needs this + logger.debug("[MOCK_CURSOR] Created fallback mock cursor (psycopg3)") + + @property + def description(self): + return self._tusk_description + + @property + def rownumber(self): + """Return the index of the next row to fetch, or None if no result.""" + if self._mock_rows: + return self._mock_index + return None + + @property + def statusmessage(self): + """Return the mock status message if set, otherwise None.""" + return getattr(self, "_mock_statusmessage", None) + + def execute(self, query, params=None, **kwargs): + """Will be replaced by instrumentation.""" + logger.debug(f"[MOCK_CURSOR] execute() called: {query[:100]}") + return self + + def executemany(self, query, params_seq, **kwargs): + """Will be replaced by instrumentation.""" + logger.debug(f"[MOCK_CURSOR] executemany() called: {query[:100]}") + return self + + def fetchone(self): + return None + + def fetchmany(self, size=None): + return [] + + def fetchall(self): + return [] + + def results(self): + """Iterate over result sets for executemany with returning=True. + + This method is patched by _mock_executemany_returning_with_data + when replaying executemany with returning=True. + Default implementation yields self once for backward compatibility. + """ + yield self + + def nextset(self): + """Move to the next result set. + + Returns True if there is another result set, None otherwise. + This method is patched during replay for executemany with returning=True. + """ + return None + + def stream(self, query, params=None, **kwargs): + """Will be replaced by instrumentation.""" + return iter([]) + + def __iter__(self): + """Support direct cursor iteration (for row in cursor).""" + return self + + def __next__(self): + """Return next row for iteration.""" + if self._mock_index < len(self._mock_rows): + row = self._mock_rows[self._mock_index] + self._mock_index += 1 + return tuple(row) if isinstance(row, list) else row + raise StopIteration + + def scroll(self, value: int, mode: str = "relative") -> None: + """Scroll the cursor to a new position in the result set.""" + if mode == "relative": + newpos = self._mock_index + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + + num_rows = len(self._mock_rows) + if num_rows > 0: + if not (0 <= newpos < num_rows): + raise IndexError("position out of bound") + elif newpos != 0: + raise IndexError("position out of bound") + + self._mock_index = newpos + + def close(self): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + +class MockCopy: + """Mock Copy object for REPLAY mode. + + Provides a minimal interface compatible with psycopg's Copy object + for COPY TO operations (iteration) and COPY FROM operations (write). + """ + + def __init__(self, data: list): + """Initialize MockCopy with recorded data. + + Args: + data: For COPY TO - list of data chunks (as strings from JSON, will be encoded to bytes) + """ + self._data = data + self._index = 0 + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data chunks.""" + for item in self._data: + # Data was stored as string in JSON, convert back to bytes + if isinstance(item, str): + yield item.encode("utf-8") + elif isinstance(item, bytes): + yield item + else: + yield str(item).encode("utf-8") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read next data chunk for COPY TO.""" + if self._index < len(self._data): + item = self._data[self._index] + self._index += 1 + if isinstance(item, str): + return item.encode("utf-8") + elif isinstance(item, bytes): + return item + return str(item).encode("utf-8") + return b"" + + def rows(self) -> Iterator[tuple]: + """Iterate over rows for COPY TO (parsed format).""" + for item in self._data: + yield tuple(item) if isinstance(item, list) else item + + def write(self, buffer) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def write_row(self, row) -> None: + """No-op for COPY FROM in replay mode.""" + pass + + def set_types(self, types) -> None: + """No-op for replay mode.""" + pass + + +class MockPipeline: + """Mock Pipeline for REPLAY mode. + + In REPLAY mode, pipeline operations are no-ops since queries + return mocked data immediately. + """ + + def __init__(self, connection: MockConnection): + self._conn = connection + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def sync(self): + """No-op sync for mock pipeline.""" + pass diff --git a/drift/instrumentation/psycopg/wrappers.py b/drift/instrumentation/psycopg/wrappers.py new file mode 100644 index 0000000..7152ab9 --- /dev/null +++ b/drift/instrumentation/psycopg/wrappers.py @@ -0,0 +1,84 @@ +"""Wrapper classes for psycopg3 instrumentation. + +These wrappers intercept operations to capture data for recording. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + + +class TracedCopyWrapper: + """Wrapper around psycopg's Copy object to capture data in RECORD mode. + + Intercepts all data operations to record them for replay. + """ + + def __init__(self, copy: Any, data_collected: list): + """Initialize wrapper. + + Args: + copy: The real psycopg Copy object + data_collected: List to append captured data chunks to + """ + self._copy = copy + self._data_collected = data_collected + + def __iter__(self) -> Iterator[bytes]: + """Iterate over COPY TO data, capturing each chunk.""" + for data in self._copy: + # Handle both bytes and memoryview + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + yield data + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def read(self) -> bytes: + """Read raw data from COPY TO, capturing it.""" + data = self._copy.read() + if data: + if isinstance(data, memoryview): + data = bytes(data) + self._data_collected.append(data) + return data + + def read_row(self): + """Read a parsed row from COPY TO.""" + row = self._copy.read_row() + if row is not None: + self._data_collected.append(row) + return row + + def rows(self): + """Iterate over parsed rows from COPY TO.""" + for row in self._copy.rows(): + self._data_collected.append(row) + yield row + + def write(self, buffer): + """Write raw data for COPY FROM.""" + # Convert memoryview to bytes to avoid mutation if buffer is reused + if isinstance(buffer, memoryview): + buffer = bytes(buffer) + self._data_collected.append(buffer) + return self._copy.write(buffer) + + def write_row(self, row): + """Write a row for COPY FROM.""" + self._data_collected.append(row) + return self._copy.write_row(row) + + def set_types(self, types): + """Proxy set_types to the underlying Copy object.""" + return self._copy.set_types(types) + + def __getattr__(self, name): + """Proxy any other attributes to the underlying copy object.""" + return getattr(self._copy, name) diff --git a/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py b/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py index 629647c..287afd3 100644 --- a/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/psycopg2/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Psycopg Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Psycopg test request sequence...\n") @@ -64,4 +46,4 @@ def make_request(method, endpoint, **kwargs): if user_id: make_request("DELETE", f"/db/delete/{user_id}") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index 12d2115..f62b0f6 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -4,7 +4,6 @@ import json import logging -import time from types import ModuleType from typing import TYPE_CHECKING, Any @@ -19,25 +18,19 @@ from opentelemetry.trace import Status from opentelemetry.trace import StatusCode as OTelStatusCode -from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper from ...core.mode_utils import handle_record_mode, handle_replay_mode from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( - CleanSpanData, - Duration, PackageType, SpanKind, - SpanStatus, - StatusCode, - Timestamp, TuskDriftMode, - replay_trace_id_context, ) from ..base import InstrumentationBase from ..utils.psycopg_utils import deserialize_db_value +from ..utils.serialization import serialize_value logger = logging.getLogger(__name__) @@ -454,9 +447,7 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): - mock_result = self._try_get_mock( - sdk, query_str, params, span_info.trace_id, span_info.span_id, span_info.parent_span_id - ) + mock_result = self._try_get_mock(sdk, query_str, params, span_info.trace_id, span_info.span_id) if mock_result is None: is_pre_app_start = not sdk.app_ready @@ -581,7 +572,7 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param with SpanUtils.with_span(span_info): mock_result = self._try_get_mock( - sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, query_str, {"_batch": params_list}, span_info.trace_id, span_info.span_id ) if mock_result is None: @@ -661,7 +652,6 @@ def _try_get_mock( params: Any, trace_id: str, span_id: str, - parent_span_id: str | None, ) -> dict[str, Any] | None: """Try to get a mocked response from CLI. @@ -677,108 +667,27 @@ def _try_get_mock( if params is not None: input_value["parameters"] = params - # Generate schema and hashes for CLI matching - input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync - # Create mock span for matching - timestamp_ms = time.time() * 1000 - timestamp_seconds = int(timestamp_ms // 1000) - timestamp_nanos = int((timestamp_ms % 1000) * 1_000_000) - - # Create mock span for matching - # NOTE: Schemas must be None to avoid betterproto map serialization issues - # The CLI only needs the hashes for matching anyway, not the full schemas - mock_span = CleanSpanData( + mock_response_output = find_mock_response_sync( + sdk=sdk, trace_id=trace_id, span_id=span_id, - parent_span_id=parent_span_id or "", name="psycopg2.query", package_name="psycopg2", package_type=PackageType.PG, instrumentation_name="Psycopg2Instrumentation", submodule_name="query", input_value=input_value, - output_value=None, - input_schema=None, # type: ignore[arg-type] - output_schema=None, # type: ignore[arg-type] - input_schema_hash=input_result.decoded_schema_hash, - output_schema_hash="", - input_value_hash=input_result.decoded_value_hash, - output_value_hash="", kind=SpanKind.CLIENT, - status=SpanStatus(code=StatusCode.OK, message=""), - timestamp=Timestamp(seconds=timestamp_seconds, nanos=timestamp_nanos), - duration=Duration(seconds=0, nanos=0), - is_root_span=False, is_pre_app_start=not sdk.app_ready, ) - # Request mock from CLI - replay_trace_id = replay_trace_id_context.get() - - mock_request = MockRequestInput( - test_id=replay_trace_id or "", - outbound_span=mock_span, - ) - - logger.info("[MOCK_REQUEST] Requesting mock from CLI:") - logger.info(f" replay_trace_id={replay_trace_id}") - logger.info(f" trace_id={trace_id}") - logger.info(f" span_id={span_id}") - logger.info(f" parent_span_id={parent_span_id or 'None'}") - logger.info(f" package_name={mock_span.package_name}") - logger.info(f" package_type={mock_span.package_type}") - logger.info(f" instrumentation_name={mock_span.instrumentation_name}") - logger.info(f" input_value_hash={mock_span.input_value_hash}") - logger.info(f" input_schema_hash={mock_span.input_schema_hash}") - logger.info(f" is_pre_app_start={mock_span.is_pre_app_start}") - logger.info(f" query={query_str[:100]}") - - # Check if communicator is connected before requesting mock - if not sdk.communicator or not sdk.communicator.is_connected: - logger.warning("[MOCK_REQUEST] CLI communicator is not connected yet!") - logger.warning(f"[MOCK_REQUEST] is_pre_app_start={mock_span.is_pre_app_start}") - - if mock_span.is_pre_app_start: - # For pre-app-start queries, return None (will trigger empty result fallback) - logger.warning( - "[MOCK_REQUEST] Pre-app-start query and CLI not ready - returning None to use empty result" - ) - return None - else: - # For in-request queries, this is an error but we'll return None to be safe - logger.error("[MOCK_REQUEST] In-request query but CLI not connected - returning None") - return None - - logger.debug("[MOCK_REQUEST] Calling sdk.request_mock_sync()...") - mock_response_output = sdk.request_mock_sync(mock_request) - logger.info( - f"[MOCK_RESPONSE] CLI returned: found={mock_response_output.found}, response={mock_response_output.response is not None}" - ) - - if mock_response_output.response: - logger.debug( - f"[MOCK_RESPONSE] Response keys: {mock_response_output.response.keys() if isinstance(mock_response_output.response, dict) else 'not a dict'}" - ) - - if not mock_response_output.found: - logger.error( - f"No mock found for psycopg2 query:\n" - f" replay_trace_id={replay_trace_id}\n" - f" span_trace_id={trace_id}\n" - f" span_id={span_id}\n" - f" package_name={mock_span.package_name}\n" - f" package_type={mock_span.package_type}\n" - f" name={mock_span.name}\n" - f" input_value_hash={mock_span.input_value_hash}\n" - f" input_schema_hash={mock_span.input_schema_hash}\n" - f" query={query_str[:100]}" - ) + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for psycopg2 query: {query_str[:100]}") return None - logger.info(f"[MOCK_FOUND] Found mock for psycopg2 query: {query_str[:100]}") - logger.info(f"[MOCK_FOUND] Mock response type: {type(mock_response_output.response)}") - logger.info(f"[MOCK_FOUND] Mock response: {mock_response_output.response}") return mock_response_output.response except Exception as e: @@ -904,21 +813,6 @@ def _finalize_query_span( ) -> None: """Finalize span with query data.""" try: - # Helper function to serialize non-JSON types - import datetime - - def serialize_value(val): - """Convert non-JSON-serializable values to JSON-compatible types.""" - if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): - return val.isoformat() - elif isinstance(val, bytes): - return val.decode("utf-8", errors="replace") - elif isinstance(val, (list, tuple)): - return [serialize_value(v) for v in val] - elif isinstance(val, dict): - return {k: serialize_value(v) for k, v in val.items()} - return val - # Build input value query_str = _query_to_str(query) input_value = { diff --git a/drift/instrumentation/redis/e2e-tests/src/test_requests.py b/drift/instrumentation/redis/e2e-tests/src/test_requests.py index 9cc893c..b923668 100644 --- a/drift/instrumentation/redis/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/redis/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Redis Flask app.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"→ {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting Redis test request sequence...\n") @@ -62,4 +44,4 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/transaction-watch") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/requests/e2e-tests/src/test_requests.py b/drift/instrumentation/requests/e2e-tests/src/test_requests.py index 5976ff4..67e9904 100644 --- a/drift/instrumentation/requests/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/requests/e2e-tests/src/test_requests.py @@ -1,24 +1,6 @@ """Execute test requests against the Flask app to exercise the requests instrumentation.""" -import time - -import requests - -BASE_URL = "http://localhost:8000" - - -def make_request(method, endpoint, **kwargs): - """Make HTTP request and log result.""" - url = f"{BASE_URL}{endpoint}" - print(f"-> {method} {endpoint}") - - # Set default timeout if not provided - kwargs.setdefault("timeout", 30) - response = requests.request(method, url, **kwargs) - print(f" Status: {response.status_code}") - time.sleep(0.5) # Small delay between requests - return response - +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary if __name__ == "__main__": print("Starting test request sequence for requests instrumentation...\n") @@ -76,4 +58,4 @@ def make_request(method, endpoint, **kwargs): make_request("GET", "/test/response-hooks") - print("\nAll requests completed successfully") + print_request_summary() diff --git a/drift/instrumentation/requests/instrumentation.py b/drift/instrumentation/requests/instrumentation.py index 88874eb..d5481d7 100644 --- a/drift/instrumentation/requests/instrumentation.py +++ b/drift/instrumentation/requests/instrumentation.py @@ -511,6 +511,7 @@ def _try_get_mock( input_value=input_value, kind=SpanKind.CLIENT, input_schema_merges=input_schema_merges, + is_pre_app_start=not sdk.app_ready, ) if not mock_response_output or not mock_response_output.found: diff --git a/drift/instrumentation/utils/psycopg_utils.py b/drift/instrumentation/utils/psycopg_utils.py index 1dc3859..d14813d 100644 --- a/drift/instrumentation/utils/psycopg_utils.py +++ b/drift/instrumentation/utils/psycopg_utils.py @@ -2,27 +2,77 @@ from __future__ import annotations +import base64 import datetime as dt +import uuid +from decimal import Decimal from typing import Any +# Try to import psycopg Range type for deserialization support +try: + from psycopg.types.range import Range as PsycopgRange # type: ignore[import-untyped] + + HAS_PSYCOPG_RANGE = True +except ImportError: + HAS_PSYCOPG_RANGE = False + PsycopgRange = None # type: ignore[misc, assignment] + def deserialize_db_value(val: Any) -> Any: - """Convert ISO datetime strings back to datetime objects for consistent serialization. + """Convert serialized values back to their original Python types. - During recording, datetime objects from the database are serialized to ISO format strings. - During replay, we need to convert them back to datetime objects so that Flask/Django - serializes them the same way (e.g., RFC 2822 vs ISO 8601 format). + During recording, database values are serialized for JSON storage: + - datetime objects -> ISO format strings + - bytes/memoryview -> {"__bytes__": ""} + - uuid.UUID -> {"__uuid__": ""} - Only parses strings that contain a time component (T or space separator with :) to avoid - incorrectly converting date-only strings or text that happens to look like dates. + During replay, we need to convert them back to their original types so that + application code (Flask/Django) handles them the same way. Args: val: A value from the mocked database rows. Can be a string, list, dict, or any other type. Returns: - The value with ISO datetime strings converted back to datetime objects. + The value with serialized types converted back to their original Python types. """ - if isinstance(val, str): + if isinstance(val, dict): + # Check for bytes tagged structure + if "__bytes__" in val and len(val) == 1: + # Decode base64 back to bytes + return base64.b64decode(val["__bytes__"]) + # Check for UUID tagged structure + if "__uuid__" in val and len(val) == 1: + return uuid.UUID(val["__uuid__"]) + # Check for Decimal tagged structure + if "__decimal__" in val and len(val) == 1: + return Decimal(val["__decimal__"]) + # Check for timedelta tagged structure + if "__timedelta__" in val and len(val) == 1: + return dt.timedelta(seconds=val["__timedelta__"]) + # Check for Range tagged structure (psycopg Range types) + if "__range__" in val and len(val) == 1: + range_data = val["__range__"] + if HAS_PSYCOPG_RANGE and PsycopgRange is not None: + if range_data.get("empty"): + return PsycopgRange(empty=True) + # Recursively deserialize the lower and upper bounds + # (they may contain datetime or other serialized types) + lower = deserialize_db_value(range_data.get("lower")) + upper = deserialize_db_value(range_data.get("upper")) + bounds = range_data.get("bounds", "[)") + # Convert floats back to ints if they represent whole numbers + # This is needed because JSON doesn't distinguish int/float + if isinstance(lower, float) and lower.is_integer(): + lower = int(lower) + if isinstance(upper, float) and upper.is_integer(): + upper = int(upper) + return PsycopgRange(lower, upper, bounds) + else: + # If psycopg is not available, return the dict as-is + return range_data + # Recursively deserialize dict values + return {k: deserialize_db_value(v) for k, v in val.items()} + elif isinstance(val, str): # Only parse strings that look like full datetime (must have time component) # This avoids converting date-only strings like "2024-01-15" or text columns # that happen to match date patterns @@ -35,6 +85,4 @@ def deserialize_db_value(val: Any) -> Any: pass elif isinstance(val, list): return [deserialize_db_value(v) for v in val] - elif isinstance(val, dict): - return {k: deserialize_db_value(v) for k, v in val.items()} return val diff --git a/drift/instrumentation/utils/serialization.py b/drift/instrumentation/utils/serialization.py new file mode 100644 index 0000000..e7fece0 --- /dev/null +++ b/drift/instrumentation/utils/serialization.py @@ -0,0 +1,99 @@ +"""Serialization utilities for instrumentation modules.""" + +from __future__ import annotations + +import base64 +import datetime +import ipaddress +import uuid +from decimal import Decimal +from typing import Any + +# Try to import psycopg Range type for serialization support +try: + from psycopg.types.range import Range as PsycopgRange # type: ignore[import-untyped] + + HAS_PSYCOPG_RANGE = True +except ImportError: + HAS_PSYCOPG_RANGE = False + PsycopgRange = None # type: ignore[misc, assignment] + + +def _serialize_bytes(val: bytes) -> Any: + """Serialize bytes to a JSON-compatible format. + + Attempts UTF-8 decode first for text data (like COPY output). + Falls back to base64 encoding with tagged structure for binary data + that contains invalid UTF-8 sequences (like bytea columns). + + Args: + val: The bytes value to serialize. + + Returns: + Either a string (if valid UTF-8) or a dict {"__bytes__": "base64_data"}. + """ + try: + # Try UTF-8 decode first - works for text data like COPY output + return val.decode("utf-8") + except UnicodeDecodeError: + # Fall back to base64 for binary data with invalid UTF-8 sequences + return {"__bytes__": base64.b64encode(val).decode("ascii")} + + +def serialize_value(val: Any) -> Any: + """Convert non-JSON-serializable values to JSON-compatible types. + + Handles datetime objects, bytes, Decimal, and nested structures (lists, tuples, dicts). + + Args: + val: The value to serialize. + + Returns: + A JSON-serializable version of the value. + """ + if isinstance(val, (datetime.datetime, datetime.date, datetime.time)): + return val.isoformat() + elif isinstance(val, datetime.timedelta): + # Serialize timedelta as total seconds for consistent hashing + return {"__timedelta__": val.total_seconds()} + elif isinstance(val, Decimal): + # Serialize Decimal as string to preserve precision and ensure consistent hashing + return {"__decimal__": str(val)} + elif isinstance(val, uuid.UUID): + return {"__uuid__": str(val)} + elif HAS_PSYCOPG_RANGE and PsycopgRange is not None and isinstance(val, PsycopgRange): + # Serialize psycopg Range objects to a deterministic dict format + # This handles INT4RANGE, TSRANGE, and other PostgreSQL range types + if val.isempty: + return {"__range__": {"empty": True}} + return { + "__range__": { + "lower": serialize_value(val.lower), + "upper": serialize_value(val.upper), + "bounds": val.bounds, + } + } + elif isinstance( + val, + ( + ipaddress.IPv4Address, + ipaddress.IPv6Address, + ipaddress.IPv4Interface, + ipaddress.IPv6Interface, + ipaddress.IPv4Network, + ipaddress.IPv6Network, + ), + ): + # Serialize ipaddress types to string for inet/cidr PostgreSQL columns + # These are returned by psycopg when querying inet and cidr columns + return str(val) + elif isinstance(val, memoryview): + # Convert memoryview to bytes first, then serialize + return _serialize_bytes(bytes(val)) + elif isinstance(val, bytes): + return _serialize_bytes(val) + elif isinstance(val, (list, tuple)): + return [serialize_value(v) for v in val] + elif isinstance(val, dict): + return {k: serialize_value(v) for k, v in val.items()} + return val