Skip to content

Commit b60baef

Browse files
committed
Cache sockname and peername
Async ----- Instead of relying on asyncio's caching of these properties, we do it ourselves. The advantage is that asyncio populates the cache on a best-effort basis. This can lead to the values being `None` if retrieving them causes an `OSError`. This is a misalignment between the async and sync driver that this PR aims to remedy. Sync ---- While in then async driver we're introducing (custom) caching where implicit caching was already in place, we also introduce caching of these fields in the sync driver for parity.
1 parent 01d77c4 commit b60baef

2 files changed

Lines changed: 31 additions & 25 deletions

File tree

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _non_expired_timeout(
100100
class AsyncBoltSocketBase(abc.ABC):
101101
Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment]
102102

103-
def __init__(self, reader, protocol, writer) -> None:
103+
def __init__(self, reader, protocol, writer, sockname, peername) -> None:
104104
self._reader = reader # type: asyncio.StreamReader
105105
self._protocol = protocol # type: asyncio.StreamReaderProtocol
106106
self._writer = writer # type: asyncio.StreamWriter
@@ -109,6 +109,8 @@ def __init__(self, reader, protocol, writer) -> None:
109109
# int - seconds to wait for data
110110
self._timeout: float | None = None
111111
self._deadline: Deadline | None = None
112+
self._sockname = sockname
113+
self._peername = peername
112114

113115
async def _wait_for_io(
114116
self,
@@ -157,10 +159,10 @@ def _socket(self) -> socket:
157159
return self._writer.transport.get_extra_info("socket")
158160

159161
def getsockname(self):
160-
return self._writer.transport.get_extra_info("sockname")
162+
return self._sockname
161163

162164
def getpeername(self):
163-
return self._writer.transport.get_extra_info("peername")
165+
return self._peername
164166

165167
def getpeercert(self, *args, **kwargs):
166168
return self._writer.transport.get_extra_info("ssl_object").getpeercert(
@@ -230,7 +232,10 @@ async def _connect_secure(
230232
if timeout == 0: # socket timeout of 0 => non-blocking
231233
timeout = None
232234
await wait_for(loop.sock_connect(s, resolved_address), timeout)
233-
local_port = s.getsockname()[1]
235+
236+
sockname = s.getsockname()
237+
peername = s.getpeername()
238+
local_port = sockname[1]
234239

235240
keep_alive = 1 if keep_alive else 0
236241
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
@@ -271,14 +276,13 @@ async def _connect_secure(
271276
"ssl_object"
272277
).getpeercert(binary_form=True)
273278
if der_encoded_server_certificate is None:
274-
local_port = s.getsockname()[1]
275279
raise BoltProtocolError(
276280
"When using an encrypted socket, the server should "
277281
"always provide a certificate",
278282
address=(resolved_address._host_name, local_port),
279283
)
280284

281-
return cls(reader, protocol, writer)
285+
return cls(reader, protocol, writer, sockname, peername)
282286

283287
except asyncio.TimeoutError:
284288
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
@@ -363,8 +367,10 @@ def _kill_raw_socket(cls, socket_):
363367
class BoltSocketBase:
364368
Bolt: te.Final[type[Bolt]] = None # type: ignore[assignment]
365369

366-
def __init__(self, socket_: socket):
370+
def __init__(self, socket_: socket, sockname, peername):
367371
self._socket = socket_
372+
self._sockname = sockname
373+
self._peername = peername
368374
self._deadline: Deadline | None = None
369375

370376
@property
@@ -374,21 +380,24 @@ def _socket(self) -> socket | SSLSocket:
374380
@_socket.setter
375381
def _socket(self, socket_: socket | SSLSocket) -> None:
376382
self.__socket = socket_
377-
self.getsockname = socket_.getsockname
378-
self.getpeername = socket_.getpeername
379383
if hasattr(socket, "getpeercert"):
380384
self.getpeercert = t.cast(SSLSocket, socket_).getpeercert
381385
elif "getpeercert" in self.__dict__:
382386
del self.__dict__["getpeercert"]
387+
socket_.getsockname()
383388
self.gettimeout = socket_.gettimeout
384389
self.settimeout = socket_.settimeout
385390

386-
getsockname: t.Callable = None # type: ignore
387-
getpeername: t.Callable = None # type: ignore
388391
getpeercert: t.Callable = None # type: ignore
389392
gettimeout: t.Callable = None # type: ignore
390393
settimeout: t.Callable = None # type: ignore
391394

395+
def getsockname(self):
396+
return self._sockname
397+
398+
def getpeername(self):
399+
return self._peername
400+
392401
def _wait_for_io(
393402
self,
394403
func: t.Callable[_P, t.Any],
@@ -488,7 +497,9 @@ def _connect_secure(
488497
) from error
489498
raise
490499

491-
local_port = s.getsockname()[1]
500+
sockname = s.getsockname()
501+
peername = s.getpeername()
502+
local_port = sockname[1]
492503
# Secure the connection if an SSL context has been provided
493504
if ssl_context:
494505
hostname = resolved_address._host_name or None
@@ -529,7 +540,7 @@ def _connect_secure(
529540
cls._kill_raw_socket(s)
530541
raise
531542

532-
return cls(s)
543+
return cls(s, sockname, peername)
533544

534545
@abc.abstractmethod
535546
def _handshake(self, resolved_address, deadline): ...

tests/unit/fixtures/socket.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,18 @@ async def drain():
7171
bytes_written.extend(write_buffer)
7272
write_buffer.clear()
7373

74-
def transport_get_extra(key):
75-
if key == "sockname":
76-
return "localhost", 0x1234
77-
if key == "peername":
78-
return "peer_name"
79-
raise KeyError(f"not mocked: {key}")
80-
8174
reader = mocker.Mock(spec=asyncio.StreamReader)
8275
writer = mocker.Mock(spec=asyncio.StreamWriter)
8376
protocol = mocker.Mock(spec=asyncio.StreamReaderProtocol)
8477

8578
reader.read.side_effect = read
8679
writer.write.side_effect = write
8780
writer.drain.side_effect = drain
88-
writer.transport.get_extra_info.side_effect = transport_get_extra
8981

90-
return AsyncBoltSocket(reader, protocol, writer)
82+
sockname = "localhost", 0x1234
83+
peername = "peer_name"
84+
85+
return AsyncBoltSocket(reader, protocol, writer, sockname, peername)
9186

9287
return factory
9388

@@ -120,10 +115,10 @@ def send_all(b):
120115
socket_mock.recv.side_effect = recv
121116
socket_mock.recv_into.side_effect = recv_into
122117
socket_mock.sendall.side_effect = send_all
123-
socket_mock.getsockname.return_value = ("localhost", 0x1234)
124-
socket_mock.getpeername.return_value = "peer_name"
118+
sockname = "localhost", 0x1234
119+
peername = "peer_name"
125120
socket_mock.gettimeout.return_value = None
126121

127-
return BoltSocket(socket_mock)
122+
return BoltSocket(socket_mock, sockname, peername)
128123

129124
return factory

0 commit comments

Comments
 (0)