Skip to content

Commit c250082

Browse files
Fix query encoding
1 parent cafe4d2 commit c250082

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

postgresql_proxy/interceptors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def _intercept_context_data(self, data):
101101
def _intercept_query(self, query, interceptors):
102102
logging.getLogger("intercept").debug("intercepting query\n%s", query)
103103
# Remove null terminator
104-
query = query.rstrip(b"\x00").decode("utf-8")
104+
codec = self.get_codec()
105+
query = query.rstrip(b"\x00").decode(codec)
105106
for interceptor in interceptors:
106107
func = self._get_plugin_interceptor_function(interceptor)
107108
query = func(query, self.context)
@@ -113,7 +114,7 @@ def _intercept_query(self, query, interceptors):
113114
)
114115

115116
# Append the zero byte at the end
116-
return query.encode("utf-8") + b"\x00"
117+
return query.encode(codec) + b"\x00"
117118

118119

119120
class ResponseInterceptor(Interceptor):

tests/test_proxy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,26 @@ def test_extended_query_protocol_named_prepared_statement_passes_through_proxy(
441441
)
442442
prepared_count = cur.fetchone()
443443
assert prepared_count is not None and prepared_count[0] >= 1
444+
445+
446+
def test_non_utf8_client_encoding_query_text_through_proxy(
447+
postgres_settings, plain_proxy_port
448+
):
449+
"""Proxy query interception should honor client_encoding when decoding query text."""
450+
with psycopg2.connect(
451+
host="127.0.0.1",
452+
port=plain_proxy_port,
453+
user=postgres_settings["user"],
454+
password=postgres_settings["password"],
455+
dbname=postgres_settings["dbname"],
456+
sslmode="disable",
457+
connect_timeout=3,
458+
client_encoding="LATIN1",
459+
) as conn:
460+
conn.autocommit = True
461+
with conn.cursor() as cur:
462+
cur.execute("SHOW client_encoding")
463+
assert cur.fetchone() == ("LATIN1",)
464+
465+
cur.execute("SELECT 'olá'::text")
466+
assert cur.fetchone() == ("olá",)

0 commit comments

Comments
 (0)