6363 from ..._async .io import AsyncBolt
6464 from ..._sync .io import Bolt
6565
66+ _P = te .ParamSpec ("_P" )
67+
6668
6769log = logging .getLogger ("neo4j.io" )
6870
@@ -76,6 +78,25 @@ def _sanitize_deadline(deadline):
7678 return deadline
7779
7880
81+ def _validate_timeout (timeout ):
82+ if timeout is not None and timeout < 0 :
83+ raise ValueError ("Timeout value out of range" )
84+
85+
86+ def _non_expired_timeout (
87+ deadline : Deadline | None ,
88+ operation : str ,
89+ ) -> float | None :
90+ if deadline is None :
91+ return None
92+ timeout = deadline .to_timeout ()
93+ if timeout is None :
94+ return None
95+ if timeout <= 0 :
96+ raise SocketDeadlineExceededError (f"{ operation } timed out" )
97+ return timeout
98+
99+
79100class AsyncBoltSocketBase (abc .ABC ):
80101 Bolt : te .Final [type [AsyncBolt ]] = None # type: ignore[assignment]
81102
@@ -86,32 +107,40 @@ def __init__(self, reader, protocol, writer) -> None:
86107 # 0 - non-blocking
87108 # None infinitely blocking
88109 # int - seconds to wait for data
89- self ._timeout = None
90- self ._deadline = None
91-
92- async def _wait_for_io (self , io_async_fn , * args , ** kwargs ):
110+ self ._timeout : float | None = None
111+ self ._deadline : Deadline | None = None
112+
113+ async def _wait_for_io (
114+ self ,
115+ io_async_fn : t .Callable [_P , t .Coroutine ],
116+ * args : _P .args ,
117+ ** kwargs : _P .kwargs ,
118+ ) -> None :
93119 timeout = self ._timeout
94- to_raise = SocketTimeout
95- if self ._deadline is not None :
96- deadline_timeout = self ._deadline .to_timeout ()
97- if deadline_timeout <= 0 :
98- raise SocketDeadlineExceededError ("timed out" )
99- if timeout is None or deadline_timeout <= timeout :
100- timeout = deadline_timeout
101- to_raise = SocketDeadlineExceededError
102-
103- io_fut = io_async_fn (* args , ** kwargs )
120+ to_raise : type [Exception ] = SocketTimeout
121+ deadline_timeout = _non_expired_timeout (self ._deadline , "IO operation" )
122+ if deadline_timeout is not None and (
123+ timeout is None or deadline_timeout <= timeout
124+ ):
125+ timeout = deadline_timeout
126+ to_raise = SocketDeadlineExceededError
127+
128+ io_fut : t .Awaitable
104129 if timeout is not None and timeout <= 0 :
105130 # give the io-operation time for one loop cycle to do its thing
106- io_fut = asyncio .create_task (io_fut )
131+ io_fut = asyncio .create_task (io_async_fn ( * args , ** kwargs ) )
107132 try :
108133 await asyncio .sleep (0 )
109134 except asyncio .CancelledError :
110135 # This is emulating non-blocking. There is no cancelling this!
111136 # Still, we don't want to silently swallow the cancellation.
112137 # Hence, we flag this task as cancelled again, so that the next
113138 # `await` will raise the CancelledError.
114- asyncio .current_task ().cancel ()
139+ current_task = asyncio .current_task ()
140+ if current_task is not None :
141+ current_task .cancel ()
142+ else :
143+ io_fut = io_async_fn (* args , ** kwargs )
115144 try :
116145 return await wait_for (io_fut , timeout )
117146 except asyncio .TimeoutError as e :
@@ -197,6 +226,9 @@ async def _connect_secure(
197226 raise ValueError (f"Unsupported address { resolved_address !r} " )
198227 s .setblocking (False ) # asyncio + blocking = no-no!
199228 log .debug ("[#0000] C: <OPEN> %s" , resolved_address )
229+ _validate_timeout (timeout )
230+ if timeout == 0 : # socket timeout of 0 => non-blocking
231+ timeout = None
200232 await wait_for (loop .sock_connect (s , resolved_address ), timeout )
201233 local_port = s .getsockname ()[1 ]
202234
@@ -208,18 +240,26 @@ async def _connect_secure(
208240 if ssl_context is not None :
209241 hostname = resolved_address ._host_name or None
210242 sni_host = hostname if HAS_SNI and hostname else None
211- ssl_kwargs .update (
212- ssl = ssl_context ,
213- server_hostname = sni_host ,
214- ssl_handshake_timeout = deadline .to_timeout (),
215- )
243+ ssl_kwargs .update (ssl = ssl_context , server_hostname = sni_host )
216244 log .debug ("[#%04X] C: <SECURE> %s" , local_port , hostname )
217245
218246 reader = asyncio .StreamReader (
219247 limit = 2 ** 16 , # 64 KiB,
220248 loop = loop ,
221249 )
222250 protocol = asyncio .StreamReaderProtocol (reader , loop = loop )
251+ if ssl_context is not None :
252+ try :
253+ ssl_timeout = _non_expired_timeout (
254+ deadline , "SSL handshake"
255+ )
256+ except SocketDeadlineExceededError as error :
257+ raise BoltSecurityError (
258+ message = "Failed to establish encrypted connection." ,
259+ address = (hostname , local_port ),
260+ ) from error
261+ if ssl_timeout is not None :
262+ ssl_kwargs ["ssl_handshake_timeout" ] = ssl_timeout
223263 transport , _ = await loop .create_connection (
224264 lambda : protocol , sock = s , ** ssl_kwargs
225265 )
@@ -262,6 +302,16 @@ async def _connect_secure(
262302 message = "Failed to establish encrypted connection." ,
263303 address = (resolved_address ._host_name , local_port ),
264304 ) from error
305+ except BoltSecurityError as error :
306+ log .debug (
307+ "[#0000] S: <SECURE FAILURE> %s: %r" ,
308+ resolved_address ,
309+ error ,
310+ )
311+ if s :
312+ log .debug ("[#0000] C: <CLOSE> %s" , resolved_address )
313+ cls ._kill_raw_socket (s )
314+ raise
265315 except Exception as error :
266316 log .debug (
267317 "[#0000] S: <ERROR> %s %s" ,
@@ -315,14 +365,14 @@ class BoltSocketBase:
315365
316366 def __init__ (self , socket_ : socket ):
317367 self ._socket = socket_
318- self ._deadline = None
368+ self ._deadline : Deadline | None = None
319369
320370 @property
321- def _socket (self ):
371+ def _socket (self ) -> socket | SSLSocket :
322372 return self .__socket
323373
324374 @_socket .setter
325- def _socket (self , socket_ : socket | SSLSocket ):
375+ def _socket (self , socket_ : socket | SSLSocket ) -> None :
326376 self .__socket = socket_
327377 self .getsockname = socket_ .getsockname
328378 self .getpeername = socket_ .getpeername
@@ -339,14 +389,17 @@ def _socket(self, socket_: socket | SSLSocket):
339389 gettimeout : t .Callable = None # type: ignore
340390 settimeout : t .Callable = None # type: ignore
341391
342- def _wait_for_io (self , func , * args , ** kwargs ):
343- if self ._deadline is None :
344- return func (* args , ** kwargs )
345- timeout = self ._socket .gettimeout ()
346- deadline_timeout = self ._deadline .to_timeout ()
347- if deadline_timeout <= 0 :
348- raise SocketDeadlineExceededError ("timed out" )
349- if timeout is None or deadline_timeout <= timeout :
392+ def _wait_for_io (
393+ self ,
394+ func : t .Callable [_P , t .Any ],
395+ * args : _P .args ,
396+ ** kwargs : _P .kwargs ,
397+ ) -> None :
398+ timeout : float | None = self ._socket .gettimeout ()
399+ deadline_timeout = _non_expired_timeout (self ._deadline , "IO operation" )
400+ if deadline_timeout is not None and (
401+ timeout is None or deadline_timeout <= timeout
402+ ):
350403 self ._socket .settimeout (deadline_timeout )
351404 try :
352405 return func (* args , ** kwargs )
@@ -443,11 +496,19 @@ def _connect_secure(
443496 log .debug ("[#%04X] C: <SECURE> %s" , local_port , hostname )
444497 try :
445498 t = s .gettimeout ()
446- if timeout :
447- s .settimeout (deadline .to_timeout ())
499+ ssl_timeout = _non_expired_timeout (
500+ deadline , "SSL handshake"
501+ )
502+ if ssl_timeout is not None :
503+ s .settimeout (ssl_timeout )
448504 s = ssl_context .wrap_socket (s , server_hostname = sni_host )
449505 s .settimeout (t )
450- except (OSError , SSLError , CertificateError ) as cause :
506+ except (
507+ OSError ,
508+ SSLError ,
509+ CertificateError ,
510+ SocketDeadlineExceededError ,
511+ ) as cause :
451512 raise BoltSecurityError (
452513 message = "Failed to establish encrypted connection." ,
453514 address = (hostname , local_port ),
0 commit comments