Skip to content

Commit e875d8d

Browse files
nik-localstackGitHub Copilot
andauthored
Fix: Extended Query Protocol Parse packets corrupted by interceptor (#18)
* fix: correctly parse Extended Query Protocol Parse packets with binary OID suffix The Parse packet body (type 'P') has the format: statement_name\x00 + query\x00 + int16(param_count) + uint32[] OIDs The old handler used data[1:-2] / data[-2:] which treated only the last 2 bytes as the param-count field, leaking binary OID bytes into the query slice fed to _intercept_query(). When a parameter type OID contains a byte >= 0x80 (e.g. jsonb OID 3802 = 0x00000EDA), the subsequent UTF-8 decode raised UnicodeDecodeError, causing the connection to drop or hang. Fix: use find(b'\x00') to locate the true boundaries of statement name and query text, so the binary suffix is never touched by the text decoder. Adds a regression test using psycopg v3 (Extended Query Protocol) with a jsonb+text parameterized INSERT through the proxy. Co-authored-by: GitHub Copilot <copilot@github.com> * PR comments --------- Co-authored-by: GitHub Copilot <copilot@github.com>
1 parent 7a5f6e9 commit e875d8d

3 files changed

Lines changed: 124 additions & 14 deletions

File tree

postgresql_proxy/interceptors.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ def intercept(self, packet_type, data):
5555
# Query, ends with b'\x00'
5656
data = self._intercept_query(data, ic_queries)
5757
elif packet_type == b"P":
58-
# Statement that needs parsing.
59-
# First byte of the body is some Statement flag. Ignore, don't lose
60-
# Next is the query, same as above, ends with an b'\x00'
61-
# Last 2 bytes are the number of parameters. Ignore, don't lose
62-
statement = data[0:1]
63-
query = self._intercept_query(data[1:-2], ic_queries)
64-
params = data[-2:]
65-
data = statement + query + params
58+
# Parse message (Extended Query Protocol). Body format:
59+
# statement_name\x00 + query\x00 + int16(param_count) + uint32[] OIDs
60+
# maxsplit=2 keeps the binary suffix intact even when OID values contain \x00 bytes.
61+
parts = data.split(b"\x00", 2)
62+
if len(parts) == 3:
63+
statement, raw_query, params = parts
64+
query = self._intercept_query(raw_query, ic_queries)
65+
data = statement + b"\x00" + query + params
6666

6767
if packet_type == b"":
6868
# Connection request / context. Ignore the first 4 bytes, keep it
@@ -100,8 +100,8 @@ def _intercept_context_data(self, data):
100100

101101
def _intercept_query(self, query, interceptors):
102102
logging.getLogger("intercept").debug("intercepting query\n%s", query)
103-
# Remove zero byte at the end
104-
query = query[:-1].decode("utf-8")
103+
# Remove null terminator
104+
query = query.rstrip(b"\x00").decode("utf-8")
105105
for interceptor in interceptors:
106106
func = self._get_plugin_interceptor_function(interceptor)
107107
query = func(query, self.context)

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pytest==9.0.3
22
pytest-timeout==2.4.0
3+
psycopg[binary]==3.3.4

tests/test_proxy.py

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Generator
12
import contextlib
23
import os
34
import shutil
@@ -8,13 +9,25 @@
89
import threading
910
import time
1011

12+
import psycopg
1113
import psycopg2
1214
import pytest
1315

1416
from postgresql_proxy import config_schema as cfg
1517
from postgresql_proxy.proxy import Proxy
1618

1719

20+
class _QuerySpy:
21+
"""Minimal query interceptor that records every query it sees."""
22+
23+
def __init__(self):
24+
self.captured: list[str] = []
25+
26+
def capture(self, query: str, context) -> str:
27+
self.captured.append(query)
28+
return query
29+
30+
1831
def _get_free_tcp_port() -> int:
1932
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
2033
sock.bind(("127.0.0.1", 0))
@@ -125,8 +138,18 @@ def _temporary_server_cert_pair():
125138

126139

127140
@contextlib.contextmanager
128-
def _run_proxy(postgres_settings, ssl_context: ssl.SSLContext | None = None):
141+
def _run_proxy(
142+
postgres_settings,
143+
ssl_context: ssl.SSLContext | None = None,
144+
*,
145+
plugins: dict | None = None,
146+
query_interceptors: list[dict] | None = None,
147+
) -> Generator[int, None, None]:
148+
"""Start a proxy in a background thread and yield its listening port."""
129149
proxy_port = _get_free_tcp_port()
150+
commands_config: dict = {}
151+
if query_interceptors:
152+
commands_config["queries"] = query_interceptors
130153
instance = cfg.InstanceSettings(
131154
{
132155
"listen": {"name": "proxy", "host": "127.0.0.1", "port": proxy_port},
@@ -135,14 +158,13 @@ def _run_proxy(postgres_settings, ssl_context: ssl.SSLContext | None = None):
135158
"host": postgres_settings["host"],
136159
"port": postgres_settings["port"],
137160
},
138-
# Keep interceptors active with default no-op behavior.
139-
"intercept": {"commands": {}, "responses": {}},
161+
"intercept": {"commands": commands_config, "responses": {}},
140162
}
141163
)
142164
if not hasattr(instance.intercept.responses, "parameter_status"):
143165
instance.intercept.responses.parameter_status = []
144166

145-
proxy = Proxy(instance, plugins={}, debug=True, ssl_context=ssl_context)
167+
proxy = Proxy(instance, plugins=plugins or {}, debug=True, ssl_context=ssl_context)
146168
thread = threading.Thread(
147169
target=proxy.listen, kwargs={"max_connections": 32}, daemon=True
148170
)
@@ -332,3 +354,90 @@ def test_psql_ssl_file_batch_stress_no_hang(postgres_settings, ssl_proxy_port):
332354
"psql -f batch succeeded but expected marker missing "
333355
f"(run={run_idx + 1}, {elapsed=:.2f}s) stdout_tail={out_tail}"
334356
)
357+
358+
359+
def test_extended_query_protocol_parse_packet_with_high_oid_params_passes_through_proxy(
360+
postgres_settings,
361+
):
362+
"""Regression: proxy must not corrupt Extended Query Protocol Parse packets.
363+
364+
psycopg v3 sends Parse → Bind → Execute for parameterized queries. The Parse body
365+
ends with binary uint32 OIDs; jsonb OID 3802 (0x00000EDA) contains 0xDA which is
366+
not valid UTF-8. The old interceptor sliced the body incorrectly and crashed on
367+
decode, causing the connection to hang or drop.
368+
369+
A _QuerySpy is wired into the proxy to verify the interceptor receives the correct
370+
SQL text — not corrupted bytes from the binary OID suffix.
371+
"""
372+
spy = _QuerySpy()
373+
with _run_proxy(
374+
postgres_settings,
375+
plugins={"spy": spy},
376+
query_interceptors=[{"plugin": "spy", "function": "capture"}],
377+
) as proxy_port:
378+
with psycopg.connect(
379+
host="127.0.0.1",
380+
port=proxy_port,
381+
user=postgres_settings["user"],
382+
password=postgres_settings["password"],
383+
dbname=postgres_settings["dbname"],
384+
sslmode="disable",
385+
) as conn:
386+
with conn.cursor() as cur:
387+
cur.execute(
388+
"DROP TABLE IF EXISTS _test_jsonb_proxy_params;"
389+
"CREATE TABLE _test_jsonb_proxy_params "
390+
"(id serial PRIMARY KEY, data jsonb, label text);"
391+
)
392+
393+
cur.execute(
394+
"INSERT INTO _test_jsonb_proxy_params (data, label) "
395+
"VALUES (%s, %s) RETURNING id",
396+
(psycopg.types.json.Jsonb({"key": "value"}), "hello"),
397+
)
398+
row = cur.fetchone()
399+
400+
assert row is not None and row[0] >= 1
401+
402+
# Verify the interceptor received clean SQL — no binary OID bytes leaked in.
403+
insert_queries = [
404+
q for q in spy.captured if "INSERT INTO _test_jsonb_proxy_params" in q
405+
]
406+
assert insert_queries
407+
assert all("\x00" not in q for q in insert_queries), (
408+
"null byte leaked into intercepted query"
409+
)
410+
411+
412+
def test_extended_query_protocol_named_prepared_statement_passes_through_proxy(
413+
postgres_settings, plain_proxy_port
414+
):
415+
"""Parse packets with a non-empty statement name must also be relayed correctly.
416+
417+
The statement_name field precedes the query text in the Parse body. The fix uses
418+
find(b'\\x00') to locate boundaries, so named statements work the same as anonymous
419+
ones (empty name).
420+
"""
421+
with psycopg.connect(
422+
host="127.0.0.1",
423+
port=plain_proxy_port,
424+
user=postgres_settings["user"],
425+
password=postgres_settings["password"],
426+
dbname=postgres_settings["dbname"],
427+
sslmode="disable",
428+
# Prepare after the first execution of the same query (i.e. on 2nd run).
429+
prepare_threshold=1,
430+
) as conn:
431+
with conn.cursor() as cur:
432+
# Execute twice so psycopg can promote the query to a named statement.
433+
for val in (1, 2):
434+
cur.execute("SELECT %s::int + 1", (val,))
435+
result = cur.fetchone()
436+
assert result == (val + 1,)
437+
438+
# Verify psycopg created a named prepared statement in this session.
439+
cur.execute(
440+
"SELECT count(*) FROM pg_prepared_statements WHERE name LIKE '_pg3_%'"
441+
)
442+
prepared_count = cur.fetchone()
443+
assert prepared_count is not None and prepared_count[0] >= 1

0 commit comments

Comments
 (0)