Skip to content
37 changes: 13 additions & 24 deletions src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ async def ping(cls, address, *, deadline=None, pool_config=None):
except (ServiceUnavailable, SessionExpired, BoltHandshakeError):
return None
else:
await AsyncBoltSocket.close_socket(s)
AsyncBoltSocket.close_socket(s)
return protocol_version

@staticmethod
Expand Down Expand Up @@ -377,7 +377,7 @@ async def open(
bolt_cls = protocol_handlers.get(protocol_version)
if bolt_cls is None:
log.debug("[#%04X] C: <CLOSE>", s.getsockname()[1])
await AsyncBoltSocket.close_socket(s)
AsyncBoltSocket.close_socket(s)
raise UnsupportedServerProduct(
"The neo4j server does not support communication with this "
"driver. This driver has support for Bolt protocols "
Expand All @@ -386,21 +386,13 @@ async def open(

try:
auth = await AsyncUtil.callback(auth_manager.get_auth)
except asyncio.CancelledError as e:
log.debug(
"[#%04X] C: <KILL> open auth manager failed: %r",
s.getsockname()[1],
e,
)
s.kill()
raise
except Exception as e:
except (Exception, asyncio.CancelledError) as e:
log.debug(
"[#%04X] C: <CLOSE> open auth manager failed: %r",
s.getsockname()[1],
e,
)
await s.close()
s.close()
raise

connection = bolt_cls(
Expand Down Expand Up @@ -998,17 +990,20 @@ async def close(self):
self.goodbye()
try:
await self._send_all()
except (OSError, BoltError, DriverError) as exc:
except (
OSError,
BoltError,
DriverError,
SocketDeadlineExceededError,
) as exc:
log.debug(
"[#%04X] _: <CONNECTION> ignoring failed close %r",
"[#%04X] _: <CONNECTION> ignoring failed final flush %r",
self.local_port,
exc,
)
log.debug("[#%04X] C: <CLOSE>", self.local_port)
try:
await self.socket.close()
except OSError:
pass
self.socket.close()
finally:
self._closed = True

Expand All @@ -1019,13 +1014,7 @@ def kill(self):
log.debug("[#%04X] C: <KILL>", self.local_port)
self._closing = True
try:
self.socket.kill()
except OSError as exc:
log.debug(
"[#%04X] _: <CONNECTION> ignoring failed kill %r",
self.local_port,
exc,
)
self.socket.close()
finally:
self._closed = True

Expand Down
13 changes: 6 additions & 7 deletions src/neo4j/_async/io/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes:
# If no data is returned after a successful select
# response, the server has closed the connection
log.debug("[#%04X] S: <CLOSE>", ctx.local_port)
await self.close()
self.close()
raise ServiceUnavailable(
f"Connection to {ctx.resolved_address} closed with incomplete "
f"handshake response"
)
if data_size != n:
# Some garbled data has been received
log.debug("[#%04X] S: @*#!", ctx.local_port)
await self.close()
self.close()
raise BoltProtocolError(
f"Expected {ctx.ctx} from {ctx.resolved_address!r}, received "
f"{response!r} instead (so far {ctx.full_response!r}); "
Expand Down Expand Up @@ -264,7 +264,7 @@ async def _handshake(

if response == b"HTTP":
log.debug("[#%04X] C: <CLOSE> (received b'HTTP')", local_port)
await self.close()
self.close()
raise ServiceUnavailable(
f"Cannot to connect to Bolt service on {resolved_address!r} "
"(looks like HTTP)"
Expand Down Expand Up @@ -350,7 +350,7 @@ async def connect(
err_str,
)
if s:
await cls.close_socket(s)
cls.close_socket(s)
errors.append(error)
failed_addresses.append(resolved_address)
except asyncio.CancelledError:
Expand All @@ -362,12 +362,11 @@ async def connect(
"[#%04X] C: <CANCELED> %s", local_port, resolved_address
)
if s:
with suppress(OSError):
s.kill()
s.close()
raise
except Exception:
if s:
await cls.close_socket(s)
cls.close_socket(s)
raise
address_strs = tuple(map(str, failed_addresses))
# TODO: 7.0 - when Python 3.11+ is the minimum, use exception groups
Expand Down
46 changes: 23 additions & 23 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from ..._sync.io import Bolt

_P = t.ParamSpec("_P")
_R = t.TypeVar("_R")


log = logging.getLogger("neo4j.io")
Expand Down Expand Up @@ -141,10 +142,10 @@ async def _wait_for_io(
name: str,
timeout: float | None,
deadline: Deadline | None,
io_async_fn: t.Callable[_P, t.Coroutine],
io_async_fn: t.Callable[_P, t.Coroutine[t.Any, t.Any, _R]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
) -> _R:
to_raise: type[Exception] = TimeoutError
deadline_timeout = _non_expired_timeout(deadline, name)
if deadline_timeout is not None and (
Expand Down Expand Up @@ -226,12 +227,11 @@ async def sendall(self, data):
self._writer.write(data)
return await self._wait_for_write(self._writer.drain)

async def close(self):
self._writer.close()
await self._writer.wait_closed()

def kill(self):
self._writer.close()
def close(self) -> None:
# Simulate `SO_LINGER` off:
# flush data and close socket in the background, don't block
with suppress(OSError):
self._writer.close()

@classmethod
async def _connect_secure(
Expand Down Expand Up @@ -383,15 +383,14 @@ async def connect(
) -> tuple[t.Self, BoltProtocolVersion]: ...

@classmethod
async def close_socket(cls, socket_):
def close_socket(cls, socket_: t.Self | socket) -> None:
if isinstance(socket_, AsyncBoltSocketBase):
with suppress(OSError):
await socket_.close()
socket_.close()
else:
cls._kill_raw_socket(socket_)

@classmethod
def _kill_raw_socket(cls, socket_):
def _kill_raw_socket(cls, socket_: socket) -> None:
with suppress(OSError):
socket_.shutdown(SHUT_RDWR)
with suppress(OSError):
Expand Down Expand Up @@ -431,6 +430,7 @@ def _wait_for_read(self, func, *args, **kwargs):
"read",
self._read_timeout,
self._read_deadline,
_non_expired_timeout,
func,
*args,
**kwargs,
Expand All @@ -441,6 +441,7 @@ def _wait_for_write(self, func, *args, **kwargs):
"write",
self._write_timeout,
self._write_deadline,
_non_expired_timeout,
func,
*args,
**kwargs,
Expand All @@ -451,12 +452,13 @@ def _wait_for_io(
name: str,
timeout: float | None,
deadline: Deadline | None,
func: t.Callable[_P, t.Any],
deadline_conversion: t.Callable[[Deadline | None, str], float | None],
func: t.Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> None:
) -> _R:
rewrite_error = False
deadline_timeout = _non_expired_timeout(deadline, name)
deadline_timeout = deadline_conversion(deadline, name)
if deadline_timeout is not None and (
timeout is None or deadline_timeout <= timeout
):
Expand Down Expand Up @@ -504,12 +506,9 @@ def recv_into(self, buffer, nbytes):
def sendall(self, data):
return self._wait_for_write(self._socket.sendall, data)

def close(self):
def close(self) -> None:
self.close_socket(self._socket)

def kill(self):
self._socket.close()

@classmethod
def _connect_secure(
cls,
Expand Down Expand Up @@ -643,13 +642,14 @@ def connect(
) -> tuple[t.Self, BoltProtocolVersion]: ...

@classmethod
def close_socket(cls, socket_):
def close_socket(cls, socket_: t.Self | socket) -> None:
if isinstance(socket_, BoltSocketBase):
socket_ = socket_._socket
cls._kill_raw_socket(socket_)
cls._kill_raw_socket(socket_._socket)
else:
cls._kill_raw_socket(socket_)

@classmethod
def _kill_raw_socket(cls, socket_):
def _kill_raw_socket(cls, socket_: socket) -> None:
with suppress(OSError):
socket_.shutdown(SHUT_RDWR)
with suppress(OSError):
Expand Down
29 changes: 9 additions & 20 deletions src/neo4j/_sync/io/_bolt.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions src/neo4j/_sync/io/_bolt_socket.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions tests/unit/async_/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ async def sendall(self, data):
if callable(self.on_send):
self.on_send(data)

async def close(self):
return

def kill(self):
def close(self):
return

def inject(self, data):
Expand Down
10 changes: 3 additions & 7 deletions tests/unit/async_/io/test_class_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def test_cancel_auth_manager_in_open(mocker):
with pytest.raises(asyncio.CancelledError):
await AsyncBolt.open(address, auth_manager=auth_manager)

socket_mock.kill.assert_called_once_with()
socket_mock.close.assert_called_once_with()


@AsyncTestDecorators.mark_async_only_test
Expand Down Expand Up @@ -324,11 +324,7 @@ async def test_error_handler_bubbling(
await handler(error)
assert exc.value is error

if isinstance(error, asyncio.CancelledError):
connection.socket.kill.assert_called_once()
connection.socket.close.assert_not_called()
else:
connection.socket.close.assert_awaited_once()
connection.socket.close.assert_called_once()

assert connection.closed()
assert connection.defunct()
Expand Down Expand Up @@ -368,7 +364,7 @@ async def test_error_handler_rewritten(
with pytest.raises(expected_error) as exc:
await handler(error)
assert exc.value.__cause__ is error
connection.socket.close.assert_awaited_once()
connection.socket.close.assert_called_once()

assert connection.closed()
assert connection.defunct()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/async_/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def break_connection():
if break_on_close:
cx1.close.assert_called()
else:
cx1.close.assert_called_once()
cx1.close.assert_awaited_once()
assert cx2 is not cx1
assert cx2.unresolved_address == cx1.unresolved_address
assert cx1 not in pool.connections[cx1.unresolved_address]
Expand Down Expand Up @@ -344,7 +344,7 @@ async def test_does_not_close_stale_connections_in_use(opener):

cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None)
await pool.release(cx3)
cx1.close.assert_called_once()
cx1.close.assert_awaited_once()
assert cx2 is cx3
assert cx3.unresolved_address == cx1.unresolved_address
assert cx1 not in pool.connections[cx1.unresolved_address]
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/sync/io/conftest.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading