Skip to content

Commit 0445e50

Browse files
authored
Fix socket timeouts during connection establishment (#1302)
* Prevent async socket from passing 0 as `ssl_handshake_timeout` leading to a `ValueError` being raised when the acquisition deadline has expired. Raise `ServiceUnavailable` instead. * Align async driver with sync driver: * (lazily) raise a `ValueError` when configured with a negative `connection_timeout`. * Treat `connection_timeout=0` as "no timeout". * Fix sync driver setting the socket timeout to `0` (non-blocking) on expired `connection_acquisition_timeout`. * Harden internal code against against `None` timeouts.
1 parent d504ea0 commit 0445e50

8 files changed

Lines changed: 437 additions & 64 deletions

File tree

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ isort>=5.11.5 # TODO: 6.0 - bump when support for Python 3.7 is dropped
1212
mypy>=1.4.1 # TODO: 6.0 - bump when support for Python 3.7 is dropped
1313
typing-extensions>=4.7.1
1414
types-pytz>=2023.3.1.1 # TODO: 6.0 - bump when support for Python 3.7 is dropped
15+
types-mock>=5.1.0.3 # TODO: 6.0 - bump when support for Python 3.7 is dropped
1516
ruff>=0.8.2
1617

1718
# needed for running tests

src/neo4j/_async/io/_bolt_socket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ async def connect(
318318
)
319319
async for resolved_address in resolved_addresses:
320320
deadline_timeout = deadline.to_timeout()
321-
if (
321+
if tcp_timeout is None or (
322322
deadline_timeout is not None
323323
and deadline_timeout <= tcp_timeout
324324
):

src/neo4j/_async/io/_pool.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,7 @@ async def _re_auth_connection(self, connection, auth, force, unprepared):
301301
)
302302
except Exception as exc:
303303
log.debug(
304-
"[#%04X] _: <POOL> check re_auth failed %r auth=%s "
305-
"force=%s",
304+
"[#%04X] _: <POOL> check re_auth failed %r auth=%s force=%s",
306305
connection.local_port,
307306
exc,
308307
log_auth,
@@ -879,8 +878,7 @@ async def fetch_routing_table(
879878
# No readers
880879
if num_readers == 0:
881880
log.debug(
882-
"[#0000] _: <POOL> no read servers returned from "
883-
"server %s",
881+
"[#0000] _: <POOL> no read servers returned from server %s",
884882
address,
885883
)
886884
return None

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
from ..._async.io import AsyncBolt
6464
from ..._sync.io import Bolt
6565

66+
_P = te.ParamSpec("_P")
67+
6668

6769
log = logging.getLogger("neo4j.io")
6870

@@ -76,6 +78,25 @@ def _sanitize_deadline(deadline):
7678
return deadline
7779

7880

81+
def _validate_timeout(timeout):
82+
if timeout is not None and timeout < 0:
83+
raise ValueError("Timeout value out of range")
84+
85+
86+
def _non_expired_timeout(
87+
deadline: Deadline | None,
88+
operation: str,
89+
) -> float | None:
90+
if deadline is None:
91+
return None
92+
timeout = deadline.to_timeout()
93+
if timeout is None:
94+
return None
95+
if timeout <= 0:
96+
raise SocketDeadlineExceededError(f"{operation} timed out")
97+
return timeout
98+
99+
79100
class AsyncBoltSocketBase(abc.ABC):
80101
Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment]
81102

@@ -86,32 +107,40 @@ def __init__(self, reader, protocol, writer) -> None:
86107
# 0 - non-blocking
87108
# None infinitely blocking
88109
# int - seconds to wait for data
89-
self._timeout = None
90-
self._deadline = None
91-
92-
async def _wait_for_io(self, io_async_fn, *args, **kwargs):
110+
self._timeout: float | None = None
111+
self._deadline: Deadline | None = None
112+
113+
async def _wait_for_io(
114+
self,
115+
io_async_fn: t.Callable[_P, t.Coroutine],
116+
*args: _P.args,
117+
**kwargs: _P.kwargs,
118+
) -> None:
93119
timeout = self._timeout
94-
to_raise = SocketTimeout
95-
if self._deadline is not None:
96-
deadline_timeout = self._deadline.to_timeout()
97-
if deadline_timeout <= 0:
98-
raise SocketDeadlineExceededError("timed out")
99-
if timeout is None or deadline_timeout <= timeout:
100-
timeout = deadline_timeout
101-
to_raise = SocketDeadlineExceededError
102-
103-
io_fut = io_async_fn(*args, **kwargs)
120+
to_raise: type[Exception] = SocketTimeout
121+
deadline_timeout = _non_expired_timeout(self._deadline, "IO operation")
122+
if deadline_timeout is not None and (
123+
timeout is None or deadline_timeout <= timeout
124+
):
125+
timeout = deadline_timeout
126+
to_raise = SocketDeadlineExceededError
127+
128+
io_fut: t.Awaitable
104129
if timeout is not None and timeout <= 0:
105130
# give the io-operation time for one loop cycle to do its thing
106-
io_fut = asyncio.create_task(io_fut)
131+
io_fut = asyncio.create_task(io_async_fn(*args, **kwargs))
107132
try:
108133
await asyncio.sleep(0)
109134
except asyncio.CancelledError:
110135
# This is emulating non-blocking. There is no cancelling this!
111136
# Still, we don't want to silently swallow the cancellation.
112137
# Hence, we flag this task as cancelled again, so that the next
113138
# `await` will raise the CancelledError.
114-
asyncio.current_task().cancel()
139+
current_task = asyncio.current_task()
140+
if current_task is not None:
141+
current_task.cancel()
142+
else:
143+
io_fut = io_async_fn(*args, **kwargs)
115144
try:
116145
return await wait_for(io_fut, timeout)
117146
except asyncio.TimeoutError as e:
@@ -197,6 +226,9 @@ async def _connect_secure(
197226
raise ValueError(f"Unsupported address {resolved_address!r}")
198227
s.setblocking(False) # asyncio + blocking = no-no!
199228
log.debug("[#0000] C: <OPEN> %s", resolved_address)
229+
_validate_timeout(timeout)
230+
if timeout == 0: # socket timeout of 0 => non-blocking
231+
timeout = None
200232
await wait_for(loop.sock_connect(s, resolved_address), timeout)
201233
local_port = s.getsockname()[1]
202234

@@ -208,18 +240,26 @@ async def _connect_secure(
208240
if ssl_context is not None:
209241
hostname = resolved_address._host_name or None
210242
sni_host = hostname if HAS_SNI and hostname else None
211-
ssl_kwargs.update(
212-
ssl=ssl_context,
213-
server_hostname=sni_host,
214-
ssl_handshake_timeout=deadline.to_timeout(),
215-
)
243+
ssl_kwargs.update(ssl=ssl_context, server_hostname=sni_host)
216244
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
217245

218246
reader = asyncio.StreamReader(
219247
limit=2**16, # 64 KiB,
220248
loop=loop,
221249
)
222250
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
251+
if ssl_context is not None:
252+
try:
253+
ssl_timeout = _non_expired_timeout(
254+
deadline, "SSL handshake"
255+
)
256+
except SocketDeadlineExceededError as error:
257+
raise BoltSecurityError(
258+
message="Failed to establish encrypted connection.",
259+
address=(hostname, local_port),
260+
) from error
261+
if ssl_timeout is not None:
262+
ssl_kwargs["ssl_handshake_timeout"] = ssl_timeout
223263
transport, _ = await loop.create_connection(
224264
lambda: protocol, sock=s, **ssl_kwargs
225265
)
@@ -262,6 +302,16 @@ async def _connect_secure(
262302
message="Failed to establish encrypted connection.",
263303
address=(resolved_address._host_name, local_port),
264304
) from error
305+
except BoltSecurityError as error:
306+
log.debug(
307+
"[#0000] S: <SECURE FAILURE> %s: %r",
308+
resolved_address,
309+
error,
310+
)
311+
if s:
312+
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
313+
cls._kill_raw_socket(s)
314+
raise
265315
except Exception as error:
266316
log.debug(
267317
"[#0000] S: <ERROR> %s %s",
@@ -315,14 +365,14 @@ class BoltSocketBase:
315365

316366
def __init__(self, socket_: socket):
317367
self._socket = socket_
318-
self._deadline = None
368+
self._deadline: Deadline | None = None
319369

320370
@property
321-
def _socket(self):
371+
def _socket(self) -> socket | SSLSocket:
322372
return self.__socket
323373

324374
@_socket.setter
325-
def _socket(self, socket_: socket | SSLSocket):
375+
def _socket(self, socket_: socket | SSLSocket) -> None:
326376
self.__socket = socket_
327377
self.getsockname = socket_.getsockname
328378
self.getpeername = socket_.getpeername
@@ -339,14 +389,17 @@ def _socket(self, socket_: socket | SSLSocket):
339389
gettimeout: t.Callable = None # type: ignore
340390
settimeout: t.Callable = None # type: ignore
341391

342-
def _wait_for_io(self, func, *args, **kwargs):
343-
if self._deadline is None:
344-
return func(*args, **kwargs)
345-
timeout = self._socket.gettimeout()
346-
deadline_timeout = self._deadline.to_timeout()
347-
if deadline_timeout <= 0:
348-
raise SocketDeadlineExceededError("timed out")
349-
if timeout is None or deadline_timeout <= timeout:
392+
def _wait_for_io(
393+
self,
394+
func: t.Callable[_P, t.Any],
395+
*args: _P.args,
396+
**kwargs: _P.kwargs,
397+
) -> None:
398+
timeout: float | None = self._socket.gettimeout()
399+
deadline_timeout = _non_expired_timeout(self._deadline, "IO operation")
400+
if deadline_timeout is not None and (
401+
timeout is None or deadline_timeout <= timeout
402+
):
350403
self._socket.settimeout(deadline_timeout)
351404
try:
352405
return func(*args, **kwargs)
@@ -443,11 +496,19 @@ def _connect_secure(
443496
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
444497
try:
445498
t = s.gettimeout()
446-
if timeout:
447-
s.settimeout(deadline.to_timeout())
499+
ssl_timeout = _non_expired_timeout(
500+
deadline, "SSL handshake"
501+
)
502+
if ssl_timeout is not None:
503+
s.settimeout(ssl_timeout)
448504
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
449505
s.settimeout(t)
450-
except (OSError, SSLError, CertificateError) as cause:
506+
except (
507+
OSError,
508+
SSLError,
509+
CertificateError,
510+
SocketDeadlineExceededError,
511+
) as cause:
451512
raise BoltSecurityError(
452513
message="Failed to establish encrypted connection.",
453514
address=(hostname, local_port),

src/neo4j/_deadline.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,76 +14,97 @@
1414
# limitations under the License.
1515

1616

17+
from __future__ import annotations
18+
19+
import typing as t
1720
from contextlib import contextmanager
1821
from time import monotonic
1922

2023

24+
if t.TYPE_CHECKING:
25+
import typing_extensions as te
26+
27+
from ._async.io import AsyncBolt
28+
from ._sync.io import Bolt
29+
30+
2131
class Deadline:
22-
def __init__(self, timeout):
32+
def __init__(self, timeout: float | None) -> None:
2333
if timeout is None or timeout == float("inf"):
2434
self._deadline = float("inf")
2535
else:
2636
self._deadline = monotonic() + timeout
2737
self._original_timeout = timeout
2838

2939
@property
30-
def original_timeout(self):
40+
def original_timeout(self) -> float | None:
3141
return self._original_timeout
3242

33-
def expired(self):
43+
def expired(self) -> bool:
3444
return self.to_timeout() == 0
3545

36-
def to_timeout(self):
46+
def to_timeout(self) -> float | None:
3747
if self._deadline == float("inf"):
3848
return None
3949
timeout = self._deadline - monotonic()
4050
return max(0, timeout)
4151

42-
def __eq__(self, other):
52+
def __eq__(self, other) -> bool:
4353
if isinstance(other, Deadline):
4454
return self._deadline == other._deadline
4555
return NotImplemented
4656

47-
def __gt__(self, other):
57+
def __gt__(self, other) -> bool:
4858
if isinstance(other, Deadline):
4959
return self._deadline > other._deadline
5060
return NotImplemented
5161

52-
def __ge__(self, other):
62+
def __ge__(self, other) -> bool:
5363
if isinstance(other, Deadline):
5464
return self._deadline >= other._deadline
5565
return NotImplemented
5666

57-
def __lt__(self, other):
67+
def __lt__(self, other) -> bool:
5868
if isinstance(other, Deadline):
5969
return self._deadline < other._deadline
6070
return NotImplemented
6171

62-
def __le__(self, other):
72+
def __le__(self, other) -> bool:
6373
if isinstance(other, Deadline):
6474
return self._deadline <= other._deadline
6575
return NotImplemented
6676

6777
@classmethod
68-
def from_timeout_or_deadline(cls, timeout):
69-
if isinstance(timeout, cls):
70-
return timeout
71-
return cls(timeout)
72-
73-
def __str__(self):
78+
def from_timeout_or_deadline(
79+
cls, timeout: te.Self | float | None
80+
) -> te.Self:
81+
if isinstance(timeout, (float, int)) or timeout is None:
82+
return cls(timeout)
83+
return timeout
84+
85+
def __str__(self) -> str:
7486
return f"Deadline(timeout={self._original_timeout})"
7587

88+
def __bool__(self) -> bool:
89+
"""Whether a deadline is set (:data:`True`) or not (:data:`False`)."""
90+
return self._deadline != float("inf")
91+
7692

7793
merge_deadlines = min
7894

7995

80-
def merge_deadlines_and_timeouts(*deadline):
96+
def merge_deadlines_and_timeouts(
97+
*deadline: Deadline | None,
98+
) -> Deadline:
8199
deadlines = map(Deadline.from_timeout_or_deadline, deadline)
82100
return merge_deadlines(deadlines)
83101

84102

85103
@contextmanager
86-
def connection_deadline(connection, deadline):
104+
def connection_deadline(
105+
connection: AsyncBolt | Bolt,
106+
deadline: Deadline | None,
107+
) -> t.Generator[None, None, None]:
87108
original_deadline = connection.socket.get_deadline()
88109
if deadline is None and original_deadline is not None:
89110
# nothing to do here

src/neo4j/_sync/io/_bolt_socket.py

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/io/_pool.py

Lines changed: 2 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)