diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 4230e3ae6..8b8017cd7 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -331,7 +331,7 @@ async def ping(cls, address, *, deadline=None, pool_config=None): except (ServiceUnavailable, SessionExpired, BoltHandshakeError): return None else: - await AsyncBoltSocket.close_socket(s) + AsyncBoltSocket.close_socket(s) return protocol_version @staticmethod @@ -377,7 +377,7 @@ async def open( bolt_cls = protocol_handlers.get(protocol_version) if bolt_cls is None: log.debug("[#%04X] C: ", s.getsockname()[1]) - await AsyncBoltSocket.close_socket(s) + AsyncBoltSocket.close_socket(s) raise UnsupportedServerProduct( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " @@ -386,21 +386,13 @@ async def open( try: auth = await AsyncUtil.callback(auth_manager.get_auth) - except asyncio.CancelledError as e: - log.debug( - "[#%04X] C: open auth manager failed: %r", - s.getsockname()[1], - e, - ) - s.kill() - raise - except Exception as e: + except (Exception, asyncio.CancelledError) as e: log.debug( "[#%04X] C: open auth manager failed: %r", s.getsockname()[1], e, ) - await s.close() + s.close() raise connection = bolt_cls( @@ -998,17 +990,20 @@ async def close(self): self.goodbye() try: await self._send_all() - except (OSError, BoltError, DriverError) as exc: + except ( + OSError, + BoltError, + DriverError, + SocketDeadlineExceededError, + ) as exc: log.debug( - "[#%04X] _: ignoring failed close %r", + "[#%04X] _: ignoring failed final flush %r", self.local_port, exc, ) log.debug("[#%04X] C: ", self.local_port) try: - await self.socket.close() - except OSError: - pass + self.socket.close() finally: self._closed = True @@ -1019,13 +1014,7 @@ def kill(self): log.debug("[#%04X] C: ", self.local_port) self._closing = True try: - self.socket.kill() - except OSError as exc: - log.debug( - "[#%04X] _: ignoring failed kill %r", - self.local_port, - exc, - ) + self.socket.close() finally: self._closed = True diff --git a/src/neo4j/_async/io/_bolt_socket.py b/src/neo4j/_async/io/_bolt_socket.py index 1bb4306e3..84016ec82 100644 --- a/src/neo4j/_async/io/_bolt_socket.py +++ b/src/neo4j/_async/io/_bolt_socket.py @@ -177,7 +177,7 @@ async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: # If no data is returned after a successful select # response, the server has closed the connection log.debug("[#%04X] S: ", ctx.local_port) - await self.close() + self.close() raise ServiceUnavailable( f"Connection to {ctx.resolved_address} closed with incomplete " f"handshake response" @@ -185,7 +185,7 @@ async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes: if data_size != n: # Some garbled data has been received log.debug("[#%04X] S: @*#!", ctx.local_port) - await self.close() + self.close() raise BoltProtocolError( f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received " f"{response!r} instead (so far {ctx.full_response!r}); " @@ -264,7 +264,7 @@ async def _handshake( if response == b"HTTP": log.debug("[#%04X] C: (received b'HTTP')", local_port) - await self.close() + self.close() raise ServiceUnavailable( f"Cannot to connect to Bolt service on {resolved_address!r} " "(looks like HTTP)" @@ -350,7 +350,7 @@ async def connect( err_str, ) if s: - await cls.close_socket(s) + cls.close_socket(s) errors.append(error) failed_addresses.append(resolved_address) except asyncio.CancelledError: @@ -362,12 +362,11 @@ async def connect( "[#%04X] C: %s", local_port, resolved_address ) if s: - with suppress(OSError): - s.kill() + s.close() raise except Exception: if s: - await cls.close_socket(s) + cls.close_socket(s) raise address_strs = tuple(map(str, failed_addresses)) # TODO: 7.0 - when Python 3.11+ is the minimum, use exception groups diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 5d8a94df9..af67108a7 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -60,6 +60,7 @@ from ..._sync.io import Bolt _P = t.ParamSpec("_P") + _R = t.TypeVar("_R") log = logging.getLogger("neo4j.io") @@ -141,10 +142,10 @@ async def _wait_for_io( name: str, timeout: float | None, deadline: Deadline | None, - io_async_fn: t.Callable[_P, t.Coroutine], + io_async_fn: t.Callable[_P, t.Coroutine[t.Any, t.Any, _R]], *args: _P.args, **kwargs: _P.kwargs, - ) -> None: + ) -> _R: to_raise: type[Exception] = TimeoutError deadline_timeout = _non_expired_timeout(deadline, name) if deadline_timeout is not None and ( @@ -226,12 +227,11 @@ async def sendall(self, data): self._writer.write(data) return await self._wait_for_write(self._writer.drain) - async def close(self): - self._writer.close() - await self._writer.wait_closed() - - def kill(self): - self._writer.close() + def close(self) -> None: + # Simulate `SO_LINGER` off: + # flush data and close socket in the background, don't block + with suppress(OSError): + self._writer.close() @classmethod async def _connect_secure( @@ -383,15 +383,14 @@ async def connect( ) -> tuple[t.Self, BoltProtocolVersion]: ... @classmethod - async def close_socket(cls, socket_): + def close_socket(cls, socket_: t.Self | socket) -> None: if isinstance(socket_, AsyncBoltSocketBase): - with suppress(OSError): - await socket_.close() + socket_.close() else: cls._kill_raw_socket(socket_) @classmethod - def _kill_raw_socket(cls, socket_): + def _kill_raw_socket(cls, socket_: socket) -> None: with suppress(OSError): socket_.shutdown(SHUT_RDWR) with suppress(OSError): @@ -431,6 +430,7 @@ def _wait_for_read(self, func, *args, **kwargs): "read", self._read_timeout, self._read_deadline, + _non_expired_timeout, func, *args, **kwargs, @@ -441,6 +441,7 @@ def _wait_for_write(self, func, *args, **kwargs): "write", self._write_timeout, self._write_deadline, + _non_expired_timeout, func, *args, **kwargs, @@ -451,12 +452,13 @@ def _wait_for_io( name: str, timeout: float | None, deadline: Deadline | None, - func: t.Callable[_P, t.Any], + deadline_conversion: t.Callable[[Deadline | None, str], float | None], + func: t.Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs, - ) -> None: + ) -> _R: rewrite_error = False - deadline_timeout = _non_expired_timeout(deadline, name) + deadline_timeout = deadline_conversion(deadline, name) if deadline_timeout is not None and ( timeout is None or deadline_timeout <= timeout ): @@ -504,12 +506,9 @@ def recv_into(self, buffer, nbytes): def sendall(self, data): return self._wait_for_write(self._socket.sendall, data) - def close(self): + def close(self) -> None: self.close_socket(self._socket) - def kill(self): - self._socket.close() - @classmethod def _connect_secure( cls, @@ -643,13 +642,14 @@ def connect( ) -> tuple[t.Self, BoltProtocolVersion]: ... @classmethod - def close_socket(cls, socket_): + def close_socket(cls, socket_: t.Self | socket) -> None: if isinstance(socket_, BoltSocketBase): - socket_ = socket_._socket - cls._kill_raw_socket(socket_) + cls._kill_raw_socket(socket_._socket) + else: + cls._kill_raw_socket(socket_) @classmethod - def _kill_raw_socket(cls, socket_): + def _kill_raw_socket(cls, socket_: socket) -> None: with suppress(OSError): socket_.shutdown(SHUT_RDWR) with suppress(OSError): diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 080f0a298..028adabfb 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -386,15 +386,7 @@ def open( try: auth = Util.callback(auth_manager.get_auth) - except asyncio.CancelledError as e: - log.debug( - "[#%04X] C: open auth manager failed: %r", - s.getsockname()[1], - e, - ) - s.kill() - raise - except Exception as e: + except (Exception, asyncio.CancelledError) as e: log.debug( "[#%04X] C: open auth manager failed: %r", s.getsockname()[1], @@ -998,17 +990,20 @@ def close(self): self.goodbye() try: self._send_all() - except (OSError, BoltError, DriverError) as exc: + except ( + OSError, + BoltError, + DriverError, + SocketDeadlineExceededError, + ) as exc: log.debug( - "[#%04X] _: ignoring failed close %r", + "[#%04X] _: ignoring failed final flush %r", self.local_port, exc, ) log.debug("[#%04X] C: ", self.local_port) try: self.socket.close() - except OSError: - pass finally: self._closed = True @@ -1019,13 +1014,7 @@ def kill(self): log.debug("[#%04X] C: ", self.local_port) self._closing = True try: - self.socket.kill() - except OSError as exc: - log.debug( - "[#%04X] _: ignoring failed kill %r", - self.local_port, - exc, - ) + self.socket.close() finally: self._closed = True diff --git a/src/neo4j/_sync/io/_bolt_socket.py b/src/neo4j/_sync/io/_bolt_socket.py index ba501dd40..3f3da35c9 100644 --- a/src/neo4j/_sync/io/_bolt_socket.py +++ b/src/neo4j/_sync/io/_bolt_socket.py @@ -362,8 +362,7 @@ def connect( "[#%04X] C: %s", local_port, resolved_address ) if s: - with suppress(OSError): - s.kill() + s.close() raise except Exception: if s: diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 90a84ae36..db20b8e0a 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -92,10 +92,7 @@ async def sendall(self, data): if callable(self.on_send): self.on_send(data) - async def close(self): - return - - def kill(self): + def close(self): return def inject(self, data): diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 629d9eaa0..c98de899f 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -273,7 +273,7 @@ async def test_cancel_auth_manager_in_open(mocker): with pytest.raises(asyncio.CancelledError): await AsyncBolt.open(address, auth_manager=auth_manager) - socket_mock.kill.assert_called_once_with() + socket_mock.close.assert_called_once_with() @AsyncTestDecorators.mark_async_only_test @@ -324,11 +324,7 @@ async def test_error_handler_bubbling( await handler(error) assert exc.value is error - if isinstance(error, asyncio.CancelledError): - connection.socket.kill.assert_called_once() - connection.socket.close.assert_not_called() - else: - connection.socket.close.assert_awaited_once() + connection.socket.close.assert_called_once() assert connection.closed() assert connection.defunct() @@ -368,7 +364,7 @@ async def test_error_handler_rewritten( with pytest.raises(expected_error) as exc: await handler(error) assert exc.value.__cause__ is error - connection.socket.close.assert_awaited_once() + connection.socket.close.assert_called_once() assert connection.closed() assert connection.defunct() diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 7b2a73df7..70033dd1b 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -314,7 +314,7 @@ async def break_connection(): if break_on_close: cx1.close.assert_called() else: - cx1.close.assert_called_once() + cx1.close.assert_awaited_once() assert cx2 is not cx1 assert cx2.unresolved_address == cx1.unresolved_address assert cx1 not in pool.connections[cx1.unresolved_address] @@ -344,7 +344,7 @@ async def test_does_not_close_stale_connections_in_use(opener): cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx3) - cx1.close.assert_called_once() + cx1.close.assert_awaited_once() assert cx2 is cx3 assert cx3.unresolved_address == cx1.unresolved_address assert cx1 not in pool.connections[cx1.unresolved_address] diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py index 34c65dbed..2ad1c2b9c 100644 --- a/tests/unit/sync/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -95,9 +95,6 @@ def sendall(self, data): def close(self): return - def kill(self): - return - def inject(self, data): self.recv_buffer += data diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index cd58ba50c..b857a46b2 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -273,7 +273,7 @@ def test_cancel_auth_manager_in_open(mocker): with pytest.raises(asyncio.CancelledError): Bolt.open(address, auth_manager=auth_manager) - socket_mock.kill.assert_called_once_with() + socket_mock.close.assert_called_once_with() @TestDecorators.mark_async_only_test @@ -324,11 +324,7 @@ def test_error_handler_bubbling( handler(error) assert exc.value is error - if isinstance(error, asyncio.CancelledError): - connection.socket.kill.assert_called_once() - connection.socket.close.assert_not_called() - else: - connection.socket.close.assert_called_once() + connection.socket.close.assert_called_once() assert connection.closed() assert connection.defunct()