Skip to content

Commit d3adacd

Browse files
PR comments
1 parent a143e8b commit d3adacd

2 files changed

Lines changed: 74 additions & 41 deletions

File tree

postgresql_proxy/interceptors.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,14 @@ def intercept(self, packet_type, data):
5555
# Query, ends with b'\x00'
5656
data = self._intercept_query(data, ic_queries)
5757
elif packet_type == b"P":
58-
# Parse packet body:
59-
# statement_name\x00 + query\x00 + int16(param_count) + uint32[]
60-
# Keep the binary suffix untouched (count + OID array).
61-
statement_end = data.find(b"\x00")
62-
if statement_end != -1:
63-
query_start = statement_end + 1
64-
query_end = data.find(b"\x00", query_start)
65-
if query_end != -1:
66-
statement = data[:query_start]
67-
query = self._intercept_query(
68-
data[query_start : query_end + 1], ic_queries
69-
)
70-
params = data[query_end + 1 :]
71-
data = statement + query + params
58+
# Parse message (Extended Query Protocol). Body format:
59+
# statement_name\x00 + query\x00 + int16(param_count) + uint32[] OIDs
60+
# maxsplit=2 keeps the binary suffix intact even when OID values contain \x00 bytes.
61+
parts = data.split(b"\x00", 2)
62+
if len(parts) == 3:
63+
statement, raw_query, params = parts
64+
query = self._intercept_query(raw_query, ic_queries)
65+
data = statement + b"\x00" + query + params
7266

7367
if packet_type == b"":
7468
# Connection request / context. Ignore the first 4 bytes, keep it
@@ -106,8 +100,8 @@ def _intercept_context_data(self, data):
106100

107101
def _intercept_query(self, query, interceptors):
108102
logging.getLogger("intercept").debug("intercepting query\n%s", query)
109-
# Remove zero byte at the end
110-
query = query[:-1].decode("utf-8")
103+
# Remove null terminator
104+
query = query.rstrip(b"\x00").decode("utf-8")
111105
for interceptor in interceptors:
112106
func = self._get_plugin_interceptor_function(interceptor)
113107
query = func(query, self.context)

tests/test_proxy.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Generator
12
import contextlib
23
import os
34
import shutil
@@ -16,6 +17,17 @@
1617
from postgresql_proxy.proxy import Proxy
1718

1819

20+
class _QuerySpy:
21+
"""Minimal query interceptor that records every query it sees."""
22+
23+
def __init__(self):
24+
self.captured: list[str] = []
25+
26+
def capture(self, query: str, context) -> str:
27+
self.captured.append(query)
28+
return query
29+
30+
1931
def _get_free_tcp_port() -> int:
2032
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
2133
sock.bind(("127.0.0.1", 0))
@@ -126,8 +138,18 @@ def _temporary_server_cert_pair():
126138

127139

128140
@contextlib.contextmanager
129-
def _run_proxy(postgres_settings, ssl_context: ssl.SSLContext | None = None):
141+
def _run_proxy(
142+
postgres_settings,
143+
ssl_context: ssl.SSLContext | None = None,
144+
*,
145+
plugins: dict | None = None,
146+
query_interceptors: list[dict] | None = None,
147+
) -> Generator[int, None, None]:
148+
"""Start a proxy in a background thread and yield its listening port."""
130149
proxy_port = _get_free_tcp_port()
150+
commands_config: dict = {}
151+
if query_interceptors:
152+
commands_config["queries"] = query_interceptors
131153
instance = cfg.InstanceSettings(
132154
{
133155
"listen": {"name": "proxy", "host": "127.0.0.1", "port": proxy_port},
@@ -136,14 +158,13 @@ def _run_proxy(postgres_settings, ssl_context: ssl.SSLContext | None = None):
136158
"host": postgres_settings["host"],
137159
"port": postgres_settings["port"],
138160
},
139-
# Keep interceptors active with default no-op behavior.
140-
"intercept": {"commands": {}, "responses": {}},
161+
"intercept": {"commands": commands_config, "responses": {}},
141162
}
142163
)
143164
if not hasattr(instance.intercept.responses, "parameter_status"):
144165
instance.intercept.responses.parameter_status = []
145166

146-
proxy = Proxy(instance, plugins={}, debug=True, ssl_context=ssl_context)
167+
proxy = Proxy(instance, plugins=plugins or {}, debug=True, ssl_context=ssl_context)
147168
thread = threading.Thread(
148169
target=proxy.listen, kwargs={"max_connections": 32}, daemon=True
149170
)
@@ -336,39 +357,57 @@ def test_psql_ssl_file_batch_stress_no_hang(postgres_settings, ssl_proxy_port):
336357

337358

338359
def test_extended_query_protocol_parse_packet_with_high_oid_params_passes_through_proxy(
339-
postgres_settings, plain_proxy_port
360+
postgres_settings,
340361
):
341362
"""Regression: proxy must not corrupt Extended Query Protocol Parse packets.
342363
343364
psycopg v3 sends Parse → Bind → Execute for parameterized queries. The Parse body
344365
ends with binary uint32 OIDs; jsonb OID 3802 (0x00000EDA) contains 0xDA which is
345366
not valid UTF-8. The old interceptor sliced the body incorrectly and crashed on
346367
decode, causing the connection to hang or drop.
368+
369+
A _QuerySpy is wired into the proxy to verify the interceptor receives the correct
370+
SQL text — not corrupted bytes from the binary OID suffix.
347371
"""
348-
with psycopg.connect(
349-
host="127.0.0.1",
350-
port=plain_proxy_port,
351-
user=postgres_settings["user"],
352-
password=postgres_settings["password"],
353-
dbname=postgres_settings["dbname"],
354-
sslmode="disable",
355-
) as conn:
356-
with conn.cursor() as cur:
357-
cur.execute(
358-
"DROP TABLE IF EXISTS _test_jsonb_proxy_params;"
359-
"CREATE TABLE _test_jsonb_proxy_params "
360-
"(id serial PRIMARY KEY, data jsonb, label text);"
361-
)
372+
spy = _QuerySpy()
373+
with _run_proxy(
374+
postgres_settings,
375+
plugins={"spy": spy},
376+
query_interceptors=[{"plugin": "spy", "function": "capture"}],
377+
) as proxy_port:
378+
with psycopg.connect(
379+
host="127.0.0.1",
380+
port=proxy_port,
381+
user=postgres_settings["user"],
382+
password=postgres_settings["password"],
383+
dbname=postgres_settings["dbname"],
384+
sslmode="disable",
385+
) as conn:
386+
with conn.cursor() as cur:
387+
cur.execute(
388+
"DROP TABLE IF EXISTS _test_jsonb_proxy_params;"
389+
"CREATE TABLE _test_jsonb_proxy_params "
390+
"(id serial PRIMARY KEY, data jsonb, label text);"
391+
)
362392

363-
cur.execute(
364-
"INSERT INTO _test_jsonb_proxy_params (data, label) "
365-
"VALUES (%s, %s) RETURNING id",
366-
(psycopg.types.json.Jsonb({"key": "value"}), "hello"),
367-
)
368-
row = cur.fetchone()
393+
cur.execute(
394+
"INSERT INTO _test_jsonb_proxy_params (data, label) "
395+
"VALUES (%s, %s) RETURNING id",
396+
(psycopg.types.json.Jsonb({"key": "value"}), "hello"),
397+
)
398+
row = cur.fetchone()
369399

370400
assert row is not None and row[0] >= 1
371401

402+
# Verify the interceptor received clean SQL — no binary OID bytes leaked in.
403+
insert_queries = [
404+
q for q in spy.captured if "INSERT INTO _test_jsonb_proxy_params" in q
405+
]
406+
assert insert_queries
407+
assert all("\x00" not in q for q in insert_queries), (
408+
"null byte leaked into intercepted query"
409+
)
410+
372411

373412
def test_extended_query_protocol_named_prepared_statement_passes_through_proxy(
374413
postgres_settings, plain_proxy_port

0 commit comments

Comments
 (0)