|
| 1 | +"""Regression tests for process worker pickle safety (issue #161 follow-up). |
| 2 | +
|
| 3 | +Two layers of defence: |
| 4 | +
|
| 5 | +1. `CursorBasedAdapter.execute_query` sanitizes rows before returning, so |
| 6 | + `memoryview` (from psycopg2 bytea) becomes `bytes`. This is what PR #171 |
| 7 | + landed. We pin that wiring here with a mock cursor, and check the |
| 8 | + sanitized output is picklable. |
| 9 | +
|
| 10 | +2. `_WorkerState.send` surfaces a pickle failure as an error message |
| 11 | + instead of silently dropping it — otherwise the client's `recv()` |
| 12 | + waits forever. Without this, any non-picklable cell type (not just |
| 13 | + bytea) would still hang the TUI. |
| 14 | +""" |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import multiprocessing |
| 19 | +import pickle |
| 20 | +from multiprocessing.connection import Connection |
| 21 | +from typing import Any |
| 22 | + |
| 23 | +from sqlit.domains.connections.providers.adapters.base import ( |
| 24 | + CursorBasedAdapter, |
| 25 | + _sanitize_row, |
| 26 | +) |
| 27 | +from sqlit.domains.process_worker.app.process_worker import _WorkerState |
| 28 | + |
| 29 | + |
| 30 | +class _FakeCursor: |
| 31 | + """Minimal stand-in for a DB-API cursor.""" |
| 32 | + |
| 33 | + def __init__(self, columns: list[str], rows: list[tuple]) -> None: |
| 34 | + self.description = [(name,) for name in columns] |
| 35 | + self._rows = list(rows) |
| 36 | + |
| 37 | + def execute(self, sql: str) -> None: # noqa: ARG002 |
| 38 | + pass |
| 39 | + |
| 40 | + def fetchall(self) -> list[tuple]: |
| 41 | + return self._rows |
| 42 | + |
| 43 | + def fetchmany(self, size: int) -> list[tuple]: |
| 44 | + head, self._rows = self._rows[:size], self._rows[size:] |
| 45 | + return head |
| 46 | + |
| 47 | + |
| 48 | +class _FakeConn: |
| 49 | + def __init__(self, cursor: _FakeCursor) -> None: |
| 50 | + self._cursor = cursor |
| 51 | + |
| 52 | + def cursor(self) -> _FakeCursor: |
| 53 | + return self._cursor |
| 54 | + |
| 55 | + |
| 56 | +def test_execute_query_sanitizes_memoryview_in_returned_rows() -> None: |
| 57 | + """Pins the `_sanitize_row` call in CursorBasedAdapter.execute_query. |
| 58 | +
|
| 59 | + Without the call site wiring, this test fails even if _sanitize_row |
| 60 | + itself is correct. |
| 61 | + """ |
| 62 | + # CursorBasedAdapter is abstract but execute_query doesn't touch self, |
| 63 | + # so call it unbound. |
| 64 | + cursor = _FakeCursor( |
| 65 | + columns=["id", "blob"], |
| 66 | + rows=[(1, memoryview(b"\xde\xad\xbe\xef"))], |
| 67 | + ) |
| 68 | + columns, rows, truncated = CursorBasedAdapter.execute_query( |
| 69 | + None, # type: ignore[arg-type] |
| 70 | + _FakeConn(cursor), |
| 71 | + "SELECT 1", |
| 72 | + ) |
| 73 | + |
| 74 | + assert columns == ["id", "blob"] |
| 75 | + assert truncated is False |
| 76 | + assert rows == [(1, b"\xde\xad\xbe\xef")] |
| 77 | + assert isinstance(rows[0][1], bytes) |
| 78 | + |
| 79 | + |
| 80 | +def test_sanitized_rows_are_picklable() -> None: |
| 81 | + """The actual failure mode in #161 was pickle failing on memoryview. |
| 82 | +
|
| 83 | + Pickle round-trip is the closest cheap stand-in for `Pipe.send()`, |
| 84 | + which is what hung the worker. |
| 85 | + """ |
| 86 | + raw = [(1, "row1", memoryview(b"\xca\xfe\xba\xbe"))] |
| 87 | + sanitized = [_sanitize_row(r) for r in raw] |
| 88 | + |
| 89 | + data = pickle.dumps(sanitized) |
| 90 | + assert pickle.loads(data) == [(1, "row1", b"\xca\xfe\xba\xbe")] |
| 91 | + |
| 92 | + |
| 93 | +def _make_state_with_pipe() -> tuple[_WorkerState, Connection]: |
| 94 | + """Build a _WorkerState attached to a real in-process pipe.""" |
| 95 | + ctx = multiprocessing.get_context("spawn") |
| 96 | + parent, child = ctx.Pipe(duplex=True) |
| 97 | + state = _WorkerState(conn=child) |
| 98 | + return state, parent |
| 99 | + |
| 100 | + |
| 101 | +def test_worker_send_non_picklable_payload_emits_error() -> None: |
| 102 | + """Defence-in-depth: if a future driver returns something non-picklable, |
| 103 | + the client should receive an error message, not hang on recv(). |
| 104 | + """ |
| 105 | + state, parent = _make_state_with_pipe() |
| 106 | + try: |
| 107 | + # memoryview is not picklable — simulates any unexpected non-picklable cell. |
| 108 | + payload: dict[str, Any] = { |
| 109 | + "type": "result", |
| 110 | + "id": 42, |
| 111 | + "kind": "query", |
| 112 | + "result": memoryview(b"not picklable"), |
| 113 | + } |
| 114 | + state.send(payload) |
| 115 | + |
| 116 | + assert parent.poll(timeout=2.0), "client would hang; no error message was sent" |
| 117 | + message = parent.recv() |
| 118 | + assert message["type"] == "error" |
| 119 | + assert message["id"] == 42 |
| 120 | + assert "could not be serialized" in message["message"].lower() |
| 121 | + finally: |
| 122 | + parent.close() |
| 123 | + state.conn.close() |
| 124 | + |
| 125 | + |
| 126 | +def test_worker_send_picklable_payload_passes_through() -> None: |
| 127 | + """Confirm the fallback path doesn't interfere with normal sends.""" |
| 128 | + state, parent = _make_state_with_pipe() |
| 129 | + try: |
| 130 | + payload = {"type": "result", "id": 1, "kind": "query", "result": [1, 2, 3]} |
| 131 | + state.send(payload) |
| 132 | + |
| 133 | + assert parent.poll(timeout=2.0) |
| 134 | + message = parent.recv() |
| 135 | + assert message == payload |
| 136 | + finally: |
| 137 | + parent.close() |
| 138 | + state.conn.close() |
0 commit comments