Skip to content

Commit c3afe1c

Browse files
nik-localstackGitHub Copilot
andcommitted
fix(postgresql-proxy): prevent SSL COPY stalls by draining nonblocking reads
Co-authored-by: GitHub Copilot <copilot@github.com>
1 parent 37a1fee commit c3afe1c

2 files changed

Lines changed: 164 additions & 59 deletions

File tree

postgresql_proxy/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(self, sock, address, name, events, context):
1616
self.out_bytes = b''
1717
self.in_bytes = b''
1818
self.terminated = False
19+
self.ssl_handshake_done = False
20+
self.ssl_negotiation_pending = False
1921

2022
def parse_length(self, length_bytes):
2123
return int.from_bytes(length_bytes, 'big')

postgresql_proxy/proxy.py

Lines changed: 162 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,8 @@ def accept_wrapper(self, sock: socket.socket):
128128
:return:
129129
"""
130130

131-
# Accept the raw connection
131+
# Accept the raw connection and switch to non-blocking immediately.
132132
clientsocket, address = sock.accept()
133-
134-
# Check if SSL is enabled for this proxy
135-
if self.ssl_context:
136-
# Handle SSL negotiation - must happen before setblocking(False)
137-
clientsocket = self._handle_ssl_negotiation(clientsocket, self.ssl_context)
138-
139133
clientsocket.setblocking(False)
140134
self.num_clients += 1
141135
sock_name = f"{self.instance_config.listen.name}_{self.num_clients}"
@@ -156,6 +150,9 @@ def accept_wrapper(self, sock: socket.socket):
156150
events=events,
157151
context=context,
158152
)
153+
# SSL startup is handled in the selector loop to avoid blocking accept.
154+
conn.ssl_negotiation_pending = self.ssl_context is not None
155+
conn.ssl_handshake_done = self.ssl_context is None
159156

160157
pg_conn = self._create_pg_connection(address, context)
161158

@@ -180,41 +177,121 @@ def accept_wrapper(self, sock: socket.socket):
180177
self._register_conn(conn)
181178
self._register_conn(pg_conn)
182179

183-
def _handle_ssl_negotiation(
184-
self, client_socket: socket.socket, ssl_context: ssl.SSLContext
185-
) -> socket.socket:
180+
def _try_start_client_tls(self, conn: connection.Connection) -> bool:
181+
"""Resolve PostgreSQL SSLRequest in non-blocking mode.
182+
183+
Returns True when startup mode is decided (SSL or plain), False when more
184+
bytes are needed.
186185
"""
187-
Handle PostgreSQL SSL negotiation on an accepted socket.
188186

189-
PostgreSQL SSL flow:
190-
1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4)
191-
2. Server responds 'S' (SSL supported) or 'N' (not supported)
192-
3. If 'S', TLS handshake follows
193-
4. After TLS, normal PostgreSQL protocol begins
187+
if not conn.ssl_negotiation_pending:
188+
return True
194189

195-
Returns the SSL-wrapped socket if negotiation succeeds, or the original socket.
196-
"""
190+
sock = conn.sock
191+
try:
192+
data = sock.recv(8, socket.MSG_PEEK)
193+
except BlockingIOError:
194+
return False
195+
except OSError as exc:
196+
LOG.debug("%s SSL startup read failed %s: %s", conn.name, conn.address, exc)
197+
self._unregister_conn(conn)
198+
sock.close()
199+
return False
200+
201+
if len(data) == 0:
202+
self._unregister_conn(conn)
203+
sock.close()
204+
return False
205+
206+
if len(data) < 8:
207+
return False
208+
209+
length = int.from_bytes(data[:4], "big")
210+
code = int.from_bytes(data[4:8], "big")
211+
conn.ssl_negotiation_pending = False
212+
213+
if length != 8 or code != 80877103:
214+
# Plain startup packet.
215+
conn.ssl_handshake_done = True
216+
return True
217+
218+
try:
219+
# Consume SSLRequest and acknowledge TLS support.
220+
sock.recv(8)
221+
sock.sendall(b"S")
222+
ssl_sock = self.ssl_context.wrap_socket(
223+
sock,
224+
server_side=True,
225+
do_handshake_on_connect=False,
226+
)
227+
ssl_sock.setblocking(False)
197228

198-
# Peek at the first 8 bytes to check for SSLRequest
199-
# Using MSG_PEEK so we don't consume the data if it's not SSLRequest
200-
data = client_socket.recv(8, socket.MSG_PEEK)
229+
if self._debug:
230+
self._registered_conn.discard(f"{conn.name}-{conn.sock.fileno()}")
231+
self.selector.unregister(conn.sock)
232+
conn.sock = ssl_sock
233+
self._register_conn(conn)
234+
conn.ssl_handshake_done = False
235+
LOG.debug("SSL requested, deferring TLS handshake to selector loop")
236+
return True
237+
except OSError as exc:
238+
LOG.debug("%s SSL startup upgrade failed %s: %s", conn.name, conn.address, exc)
239+
self._unregister_conn(conn)
240+
sock.close()
241+
return False
242+
243+
def _set_write_interest(self, conn: connection.Connection, enabled: bool):
244+
"""Enable or disable EVENT_WRITE for a connection while preserving read interest."""
245+
try:
246+
selector_key = self.selector.get_key(conn.sock)
247+
except KeyError:
248+
return
201249

202-
if len(data) == 8:
203-
length = int.from_bytes(data[:4], "big")
204-
code = int.from_bytes(data[4:8], "big")
250+
current_events = selector_key.events
251+
if enabled:
252+
new_events = current_events | selectors.EVENT_WRITE
253+
else:
254+
new_events = current_events & ~selectors.EVENT_WRITE
205255

206-
if length == 8 and code == 80877103: # SSLRequest code
207-
# Consume the SSLRequest
208-
client_socket.recv(8)
209-
# Send 'S' to indicate SSL is supported
210-
client_socket.send(b"S")
211-
# Wrap socket with SSL
212-
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
213-
LOG.debug("SSL handshake completed for PostgreSQL connection")
214-
return ssl_socket
256+
if new_events != current_events:
257+
self.selector.modify(conn.sock, new_events, data=conn)
258+
conn.events = new_events
259+
260+
def _flush_outgoing(
261+
self,
262+
source_conn: connection.Connection,
263+
source_sock: socket.socket,
264+
target_conn: connection.Connection | None,
265+
):
266+
if not target_conn or not target_conn.out_bytes:
267+
return
215268

216-
# Not an SSLRequest, return original socket
217-
return client_socket
269+
try:
270+
while target_conn.out_bytes:
271+
LOG.debug('sending to %s:\n%s', target_conn.name, target_conn.out_bytes)
272+
sent = target_conn.sock.send(target_conn.out_bytes)
273+
if sent == 0:
274+
# send() returned 0: socket closed or buffer full. Enable write interest
275+
# so the next writable event will retry the send.
276+
self._set_write_interest(target_conn, True)
277+
return
278+
target_conn.sent(sent)
279+
except (BlockingIOError, ssl.SSLWantWriteError):
280+
self._set_write_interest(target_conn, True)
281+
return
282+
except ssl.SSLWantReadError:
283+
self._set_write_interest(target_conn, False)
284+
return
285+
except OSError:
286+
# If one side is closed, close the other one
287+
# this can happen in the case where the client disconnects, and postgres still return a response
288+
# we then read the response then close the PG side of the socket.
289+
LOG.debug('error sending to %s: connection closed', target_conn.name)
290+
self._unregister_conn(source_conn)
291+
source_sock.close()
292+
return
293+
294+
self._set_write_interest(target_conn, bool(target_conn.out_bytes))
218295

219296
def service_connection(self, key: SelectorKeyProxy, mask):
220297
"""
@@ -227,37 +304,63 @@ def service_connection(self, key: SelectorKeyProxy, mask):
227304
"""
228305
sock = key.fileobj
229306
conn = key.data
307+
308+
if conn.ssl_negotiation_pending:
309+
if not self._try_start_client_tls(conn):
310+
return
311+
sock = conn.sock
312+
313+
# Drive TLS handshake in non-blocking mode so one slow client cannot block others.
314+
if isinstance(sock, ssl.SSLSocket) and not conn.ssl_handshake_done:
315+
try:
316+
sock.do_handshake()
317+
conn.ssl_handshake_done = True
318+
self._set_write_interest(conn, bool(conn.out_bytes))
319+
except ssl.SSLWantReadError:
320+
return
321+
except ssl.SSLWantWriteError:
322+
self._set_write_interest(conn, True)
323+
return
324+
except OSError as e:
325+
LOG.debug('%s SSL handshake failed %s: %s', conn.name, conn.address, e)
326+
self._unregister_conn(conn)
327+
sock.close()
328+
return
329+
230330
if mask & selectors.EVENT_READ:
231331
LOG.debug('%s can receive', conn.name)
232-
try:
233-
recv_data = sock.recv(4096) # Should be ready to read
234-
if recv_data:
235-
LOG.debug('%s received data:\n%s', conn.name, recv_data)
236-
conn.received(recv_data)
237-
else:
332+
while True:
333+
try:
334+
if recv_data := sock.recv(4096):
335+
LOG.debug('%s received data:\n%s', conn.name, recv_data)
336+
conn.received(recv_data)
337+
# Keep draining bytes in the same readiness cycle until recv indicates no immediate data.
338+
continue
339+
238340
self._unregister_conn(conn)
239341
LOG.debug('%s connection closing %s', conn.name, conn.address)
240342
# A file object shall be unregistered prior to being closed.
241343
sock.close()
242-
except OSError as e:
243-
# it means the socket was closed by peer
244-
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
245-
self._unregister_conn(conn)
344+
return
345+
except ssl.SSLWantReadError:
346+
break
347+
except ssl.SSLWantWriteError:
348+
self._set_write_interest(conn, True)
349+
break
350+
except BlockingIOError:
351+
break
352+
except OSError as e:
353+
# it means the socket was closed by peer
354+
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
355+
self._unregister_conn(conn)
356+
return
246357

247-
next_conn = conn.redirect_conn
248-
if next_conn and next_conn.out_bytes:
249-
try:
250-
while next_conn.out_bytes:
251-
LOG.debug('sending to %s:\n%s', next_conn.name, next_conn.out_bytes)
252-
sent = next_conn.sock.send(next_conn.out_bytes)
253-
next_conn.sent(sent)
254-
except OSError:
255-
# If one side is closed, close the other one
256-
# this can happen in the case where the client disconnects, and postgres still return a response
257-
# we then read the response then close the PG side of the socket.
258-
LOG.debug('error sending to %s: connection closed', next_conn.name)
259-
self._unregister_conn(conn)
260-
sock.close()
358+
next_conn = conn.redirect_conn
359+
if next_conn and next_conn.out_bytes:
360+
self._flush_outgoing(conn, sock, next_conn)
361+
362+
if mask & selectors.EVENT_WRITE:
363+
self._flush_outgoing(conn, sock, conn)
261364

262365
def listen(self, max_connections: int = 8):
263366
"""

0 commit comments

Comments
 (0)