@@ -128,14 +128,8 @@ def accept_wrapper(self, sock: socket.socket):
128128 :return:
129129 """
130130
131- # Accept the raw connection
131+ # Accept the raw connection and switch to non-blocking immediately.
132132 clientsocket , address = sock .accept ()
133-
134- # Check if SSL is enabled for this proxy
135- if self .ssl_context :
136- # Handle SSL negotiation - must happen before setblocking(False)
137- clientsocket = self ._handle_ssl_negotiation (clientsocket , self .ssl_context )
138-
139133 clientsocket .setblocking (False )
140134 self .num_clients += 1
141135 sock_name = f"{ self .instance_config .listen .name } _{ self .num_clients } "
@@ -156,6 +150,9 @@ def accept_wrapper(self, sock: socket.socket):
156150 events = events ,
157151 context = context ,
158152 )
153+ # SSL startup is handled in the selector loop to avoid blocking accept.
154+ conn .ssl_negotiation_pending = self .ssl_context is not None
155+ conn .ssl_handshake_done = self .ssl_context is None
159156
160157 pg_conn = self ._create_pg_connection (address , context )
161158
@@ -180,41 +177,121 @@ def accept_wrapper(self, sock: socket.socket):
180177 self ._register_conn (conn )
181178 self ._register_conn (pg_conn )
182179
183- def _handle_ssl_negotiation (
184- self , client_socket : socket .socket , ssl_context : ssl .SSLContext
185- ) -> socket .socket :
180+ def _try_start_client_tls (self , conn : connection .Connection ) -> bool :
181+ """Resolve PostgreSQL SSLRequest in non-blocking mode.
182+
183+ Returns True when startup mode is decided (SSL or plain), False when more
184+ bytes are needed.
186185 """
187- Handle PostgreSQL SSL negotiation on an accepted socket.
188186
189- PostgreSQL SSL flow:
190- 1. Client sends SSLRequest (8 bytes): length (4) + code 80877103 (4)
191- 2. Server responds 'S' (SSL supported) or 'N' (not supported)
192- 3. If 'S', TLS handshake follows
193- 4. After TLS, normal PostgreSQL protocol begins
187+ if not conn .ssl_negotiation_pending :
188+ return True
194189
195- Returns the SSL-wrapped socket if negotiation succeeds, or the original socket.
196- """
190+ sock = conn .sock
191+ try :
192+ data = sock .recv (8 , socket .MSG_PEEK )
193+ except BlockingIOError :
194+ return False
195+ except OSError as exc :
196+ LOG .debug ("%s SSL startup read failed %s: %s" , conn .name , conn .address , exc )
197+ self ._unregister_conn (conn )
198+ sock .close ()
199+ return False
200+
201+ if len (data ) == 0 :
202+ self ._unregister_conn (conn )
203+ sock .close ()
204+ return False
205+
206+ if len (data ) < 8 :
207+ return False
208+
209+ length = int .from_bytes (data [:4 ], "big" )
210+ code = int .from_bytes (data [4 :8 ], "big" )
211+ conn .ssl_negotiation_pending = False
212+
213+ if length != 8 or code != 80877103 :
214+ # Plain startup packet.
215+ conn .ssl_handshake_done = True
216+ return True
217+
218+ try :
219+ # Consume SSLRequest and acknowledge TLS support.
220+ sock .recv (8 )
221+ sock .sendall (b"S" )
222+ ssl_sock = self .ssl_context .wrap_socket (
223+ sock ,
224+ server_side = True ,
225+ do_handshake_on_connect = False ,
226+ )
227+ ssl_sock .setblocking (False )
197228
198- # Peek at the first 8 bytes to check for SSLRequest
199- # Using MSG_PEEK so we don't consume the data if it's not SSLRequest
200- data = client_socket .recv (8 , socket .MSG_PEEK )
229+ if self ._debug :
230+ self ._registered_conn .discard (f"{ conn .name } -{ conn .sock .fileno ()} " )
231+ self .selector .unregister (conn .sock )
232+ conn .sock = ssl_sock
233+ self ._register_conn (conn )
234+ conn .ssl_handshake_done = False
235+ LOG .debug ("SSL requested, deferring TLS handshake to selector loop" )
236+ return True
237+ except OSError as exc :
238+ LOG .debug ("%s SSL startup upgrade failed %s: %s" , conn .name , conn .address , exc )
239+ self ._unregister_conn (conn )
240+ sock .close ()
241+ return False
242+
243+ def _set_write_interest (self , conn : connection .Connection , enabled : bool ):
244+ """Enable or disable EVENT_WRITE for a connection while preserving read interest."""
245+ try :
246+ selector_key = self .selector .get_key (conn .sock )
247+ except KeyError :
248+ return
201249
202- if len (data ) == 8 :
203- length = int .from_bytes (data [:4 ], "big" )
204- code = int .from_bytes (data [4 :8 ], "big" )
250+ current_events = selector_key .events
251+ if enabled :
252+ new_events = current_events | selectors .EVENT_WRITE
253+ else :
254+ new_events = current_events & ~ selectors .EVENT_WRITE
205255
206- if length == 8 and code == 80877103 : # SSLRequest code
207- # Consume the SSLRequest
208- client_socket .recv (8 )
209- # Send 'S' to indicate SSL is supported
210- client_socket .send (b"S" )
211- # Wrap socket with SSL
212- ssl_socket = ssl_context .wrap_socket (client_socket , server_side = True )
213- LOG .debug ("SSL handshake completed for PostgreSQL connection" )
214- return ssl_socket
256+ if new_events != current_events :
257+ self .selector .modify (conn .sock , new_events , data = conn )
258+ conn .events = new_events
259+
260+ def _flush_outgoing (
261+ self ,
262+ source_conn : connection .Connection ,
263+ source_sock : socket .socket ,
264+ target_conn : connection .Connection | None ,
265+ ):
266+ if not target_conn or not target_conn .out_bytes :
267+ return
215268
216- # Not an SSLRequest, return original socket
217- return client_socket
269+ try :
270+ while target_conn .out_bytes :
271+ LOG .debug ('sending to %s:\n %s' , target_conn .name , target_conn .out_bytes )
272+ sent = target_conn .sock .send (target_conn .out_bytes )
273+ if sent == 0 :
274+ # send() returned 0: socket closed or buffer full. Enable write interest
275+ # so the next writable event will retry the send.
276+ self ._set_write_interest (target_conn , True )
277+ return
278+ target_conn .sent (sent )
279+ except (BlockingIOError , ssl .SSLWantWriteError ):
280+ self ._set_write_interest (target_conn , True )
281+ return
282+ except ssl .SSLWantReadError :
283+ self ._set_write_interest (target_conn , False )
284+ return
285+ except OSError :
286+ # If one side is closed, close the other one
287+ # this can happen in the case where the client disconnects, and postgres still return a response
288+ # we then read the response then close the PG side of the socket.
289+ LOG .debug ('error sending to %s: connection closed' , target_conn .name )
290+ self ._unregister_conn (source_conn )
291+ source_sock .close ()
292+ return
293+
294+ self ._set_write_interest (target_conn , bool (target_conn .out_bytes ))
218295
219296 def service_connection (self , key : SelectorKeyProxy , mask ):
220297 """
@@ -227,37 +304,63 @@ def service_connection(self, key: SelectorKeyProxy, mask):
227304 """
228305 sock = key .fileobj
229306 conn = key .data
307+
308+ if conn .ssl_negotiation_pending :
309+ if not self ._try_start_client_tls (conn ):
310+ return
311+ sock = conn .sock
312+
313+ # Drive TLS handshake in non-blocking mode so one slow client cannot block others.
314+ if isinstance (sock , ssl .SSLSocket ) and not conn .ssl_handshake_done :
315+ try :
316+ sock .do_handshake ()
317+ conn .ssl_handshake_done = True
318+ self ._set_write_interest (conn , bool (conn .out_bytes ))
319+ except ssl .SSLWantReadError :
320+ return
321+ except ssl .SSLWantWriteError :
322+ self ._set_write_interest (conn , True )
323+ return
324+ except OSError as e :
325+ LOG .debug ('%s SSL handshake failed %s: %s' , conn .name , conn .address , e )
326+ self ._unregister_conn (conn )
327+ sock .close ()
328+ return
329+
230330 if mask & selectors .EVENT_READ :
231331 LOG .debug ('%s can receive' , conn .name )
232- try :
233- recv_data = sock .recv (4096 ) # Should be ready to read
234- if recv_data :
235- LOG .debug ('%s received data:\n %s' , conn .name , recv_data )
236- conn .received (recv_data )
237- else :
332+ while True :
333+ try :
334+ if recv_data := sock .recv (4096 ):
335+ LOG .debug ('%s received data:\n %s' , conn .name , recv_data )
336+ conn .received (recv_data )
337+ # Keep draining bytes in the same readiness cycle until recv indicates no immediate data.
338+ continue
339+
238340 self ._unregister_conn (conn )
239341 LOG .debug ('%s connection closing %s' , conn .name , conn .address )
240342 # A file object shall be unregistered prior to being closed.
241343 sock .close ()
242- except OSError as e :
243- # it means the socket was closed by peer
244- LOG .debug ('%s connection closed by peer %s: %s' , conn .name , conn .address , e )
245- self ._unregister_conn (conn )
344+ return
345+ except ssl .SSLWantReadError :
346+ break
347+ except ssl .SSLWantWriteError :
348+ self ._set_write_interest (conn , True )
349+ break
350+ except BlockingIOError :
351+ break
352+ except OSError as e :
353+ # it means the socket was closed by peer
354+ LOG .debug ('%s connection closed by peer %s: %s' , conn .name , conn .address , e )
355+ self ._unregister_conn (conn )
356+ return
246357
247- next_conn = conn .redirect_conn
248- if next_conn and next_conn .out_bytes :
249- try :
250- while next_conn .out_bytes :
251- LOG .debug ('sending to %s:\n %s' , next_conn .name , next_conn .out_bytes )
252- sent = next_conn .sock .send (next_conn .out_bytes )
253- next_conn .sent (sent )
254- except OSError :
255- # If one side is closed, close the other one
256- # this can happen in the case where the client disconnects, and postgres still return a response
257- # we then read the response then close the PG side of the socket.
258- LOG .debug ('error sending to %s: connection closed' , next_conn .name )
259- self ._unregister_conn (conn )
260- sock .close ()
358+ next_conn = conn .redirect_conn
359+ if next_conn and next_conn .out_bytes :
360+ self ._flush_outgoing (conn , sock , next_conn )
361+
362+ if mask & selectors .EVENT_WRITE :
363+ self ._flush_outgoing (conn , sock , conn )
261364
262365 def listen (self , max_connections : int = 8 ):
263366 """
0 commit comments