Skip to content

Commit 92f6bd0

Browse files
nik-localstackGitHub Copilot
andcommitted
fix(postgresql-proxy): prevent SSL COPY stalls by draining nonblocking reads
Co-authored-by: GitHub Copilot <copilot@github.com>
1 parent 37a1fee commit 92f6bd0

1 file changed

Lines changed: 103 additions & 27 deletions

File tree

postgresql_proxy/proxy.py

Lines changed: 103 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def accept_wrapper(self, sock: socket.socket):
156156
events=events,
157157
context=context,
158158
)
159+
conn.ssl_handshake_done = not isinstance(clientsocket, ssl.SSLSocket)
159160

160161
pg_conn = self._create_pg_connection(address, context)
161162

@@ -208,14 +209,68 @@ def _handle_ssl_negotiation(
208209
client_socket.recv(8)
209210
# Send 'S' to indicate SSL is supported
210211
client_socket.send(b"S")
211-
# Wrap socket with SSL
212-
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
213-
LOG.debug("SSL handshake completed for PostgreSQL connection")
212+
# Wrap without immediate handshake so a stalled client can't block accept loop.
213+
ssl_socket = ssl_context.wrap_socket(
214+
client_socket,
215+
server_side=True,
216+
do_handshake_on_connect=False,
217+
)
218+
LOG.debug("SSL requested, deferring TLS handshake to selector loop")
214219
return ssl_socket
215220

216221
# Not an SSLRequest, return original socket
217222
return client_socket
218223

224+
def _set_write_interest(self, conn: connection.Connection, enabled: bool):
225+
"""Enable or disable EVENT_WRITE for a connection while preserving read interest."""
226+
try:
227+
selector_key = self.selector.get_key(conn.sock)
228+
except KeyError:
229+
return
230+
231+
current_events = selector_key.events
232+
if enabled:
233+
new_events = current_events | selectors.EVENT_WRITE
234+
else:
235+
new_events = current_events & ~selectors.EVENT_WRITE
236+
237+
if new_events != current_events:
238+
self.selector.modify(conn.sock, new_events, data=conn)
239+
conn.events = new_events
240+
241+
def _flush_outgoing(
242+
self,
243+
source_conn: connection.Connection,
244+
source_sock: socket.socket,
245+
target_conn: connection.Connection | None,
246+
):
247+
if not target_conn or not target_conn.out_bytes:
248+
return
249+
250+
try:
251+
while target_conn.out_bytes:
252+
LOG.debug('sending to %s:\n%s', target_conn.name, target_conn.out_bytes)
253+
sent = target_conn.sock.send(target_conn.out_bytes)
254+
if sent == 0:
255+
break
256+
target_conn.sent(sent)
257+
except ssl.SSLWantWriteError:
258+
self._set_write_interest(target_conn, True)
259+
return
260+
except ssl.SSLWantReadError:
261+
self._set_write_interest(target_conn, True)
262+
return
263+
except OSError:
264+
# If one side is closed, close the other one
265+
# this can happen in the case where the client disconnects, and postgres still return a response
266+
# we then read the response then close the PG side of the socket.
267+
LOG.debug('error sending to %s: connection closed', target_conn.name)
268+
self._unregister_conn(source_conn)
269+
source_sock.close()
270+
return
271+
272+
self._set_write_interest(target_conn, bool(target_conn.out_bytes))
273+
219274
def service_connection(self, key: SelectorKeyProxy, mask):
220275
"""
221276
This method proxies the messages between socket. It will use properties of the Connection object to
@@ -227,37 +282,58 @@ def service_connection(self, key: SelectorKeyProxy, mask):
227282
"""
228283
sock = key.fileobj
229284
conn = key.data
285+
286+
# Drive TLS handshake in non-blocking mode so one slow client cannot block others.
287+
if isinstance(sock, ssl.SSLSocket) and not getattr(conn, "ssl_handshake_done", False):
288+
try:
289+
sock.do_handshake()
290+
conn.ssl_handshake_done = True
291+
self._set_write_interest(conn, bool(conn.out_bytes))
292+
except ssl.SSLWantReadError:
293+
return
294+
except ssl.SSLWantWriteError:
295+
self._set_write_interest(conn, True)
296+
return
297+
except OSError as e:
298+
LOG.debug('%s SSL handshake failed %s: %s', conn.name, conn.address, e)
299+
self._unregister_conn(conn)
300+
sock.close()
301+
return
302+
230303
if mask & selectors.EVENT_READ:
231304
LOG.debug('%s can receive', conn.name)
232-
try:
233-
recv_data = sock.recv(4096) # Should be ready to read
234-
if recv_data:
235-
LOG.debug('%s received data:\n%s', conn.name, recv_data)
236-
conn.received(recv_data)
237-
else:
305+
while True:
306+
try:
307+
if recv_data := sock.recv(4096):
308+
LOG.debug('%s received data:\n%s', conn.name, recv_data)
309+
conn.received(recv_data)
310+
# Keep draining bytes in the same readiness cycle until recv indicates no immediate data.
311+
continue
312+
238313
self._unregister_conn(conn)
239314
LOG.debug('%s connection closing %s', conn.name, conn.address)
240315
# A file object shall be unregistered prior to being closed.
241316
sock.close()
242-
except OSError as e:
243-
# it means the socket was closed by peer
244-
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
245-
self._unregister_conn(conn)
317+
return
318+
except ssl.SSLWantReadError:
319+
break
320+
except ssl.SSLWantWriteError:
321+
self._set_write_interest(conn, True)
322+
break
323+
except BlockingIOError:
324+
break
325+
except OSError as e:
326+
# it means the socket was closed by peer
327+
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
328+
self._unregister_conn(conn)
329+
return
246330

247-
next_conn = conn.redirect_conn
248-
if next_conn and next_conn.out_bytes:
249-
try:
250-
while next_conn.out_bytes:
251-
LOG.debug('sending to %s:\n%s', next_conn.name, next_conn.out_bytes)
252-
sent = next_conn.sock.send(next_conn.out_bytes)
253-
next_conn.sent(sent)
254-
except OSError:
255-
# If one side is closed, close the other one
256-
# this can happen in the case where the client disconnects, and postgres still return a response
257-
# we then read the response then close the PG side of the socket.
258-
LOG.debug('error sending to %s: connection closed', next_conn.name)
259-
self._unregister_conn(conn)
260-
sock.close()
331+
next_conn = conn.redirect_conn
332+
if next_conn and next_conn.out_bytes:
333+
self._flush_outgoing(conn, sock, next_conn)
334+
335+
if mask & selectors.EVENT_WRITE:
336+
self._flush_outgoing(conn, sock, conn)
261337

262338
def listen(self, max_connections: int = 8):
263339
"""

0 commit comments

Comments
 (0)