@@ -100,7 +100,7 @@ def _non_expired_timeout(
100100class AsyncBoltSocketBase (abc .ABC ):
101101 Bolt : te .Final [type [AsyncBolt ]] = None # type: ignore[assignment]
102102
103- def __init__ (self , reader , protocol , writer ) -> None :
103+ def __init__ (self , reader , protocol , writer , sockname , peername ) -> None :
104104 self ._reader = reader # type: asyncio.StreamReader
105105 self ._protocol = protocol # type: asyncio.StreamReaderProtocol
106106 self ._writer = writer # type: asyncio.StreamWriter
@@ -109,6 +109,8 @@ def __init__(self, reader, protocol, writer) -> None:
109109 # int - seconds to wait for data
110110 self ._timeout : float | None = None
111111 self ._deadline : Deadline | None = None
112+ self ._sockname = sockname
113+ self ._peername = peername
112114
113115 async def _wait_for_io (
114116 self ,
@@ -157,10 +159,10 @@ def _socket(self) -> socket:
157159 return self ._writer .transport .get_extra_info ("socket" )
158160
159161 def getsockname (self ):
160- return self ._writer . transport . get_extra_info ( "sockname" )
162+ return self ._sockname
161163
162164 def getpeername (self ):
163- return self ._writer . transport . get_extra_info ( "peername" )
165+ return self ._peername
164166
165167 def getpeercert (self , * args , ** kwargs ):
166168 return self ._writer .transport .get_extra_info ("ssl_object" ).getpeercert (
@@ -230,7 +232,10 @@ async def _connect_secure(
230232 if timeout == 0 : # socket timeout of 0 => non-blocking
231233 timeout = None
232234 await wait_for (loop .sock_connect (s , resolved_address ), timeout )
233- local_port = s .getsockname ()[1 ]
235+
236+ sockname = s .getsockname ()
237+ peername = s .getpeername ()
238+ local_port = sockname [1 ]
234239
235240 keep_alive = 1 if keep_alive else 0
236241 s .setsockopt (SOL_SOCKET , SO_KEEPALIVE , keep_alive )
@@ -271,14 +276,13 @@ async def _connect_secure(
271276 "ssl_object"
272277 ).getpeercert (binary_form = True )
273278 if der_encoded_server_certificate is None :
274- local_port = s .getsockname ()[1 ]
275279 raise BoltProtocolError (
276280 "When using an encrypted socket, the server should "
277281 "always provide a certificate" ,
278282 address = (resolved_address ._host_name , local_port ),
279283 )
280284
281- return cls (reader , protocol , writer )
285+ return cls (reader , protocol , writer , sockname , peername )
282286
283287 except asyncio .TimeoutError :
284288 log .debug ("[#0000] S: <TIMEOUT> %s" , resolved_address )
@@ -363,8 +367,10 @@ def _kill_raw_socket(cls, socket_):
363367class BoltSocketBase :
364368 Bolt : te .Final [type [Bolt ]] = None # type: ignore[assignment]
365369
366- def __init__ (self , socket_ : socket ):
370+ def __init__ (self , socket_ : socket , sockname , peername ):
367371 self ._socket = socket_
372+ self ._sockname = sockname
373+ self ._peername = peername
368374 self ._deadline : Deadline | None = None
369375
370376 @property
@@ -374,21 +380,24 @@ def _socket(self) -> socket | SSLSocket:
374380 @_socket .setter
375381 def _socket (self , socket_ : socket | SSLSocket ) -> None :
376382 self .__socket = socket_
377- self .getsockname = socket_ .getsockname
378- self .getpeername = socket_ .getpeername
379383 if hasattr (socket , "getpeercert" ):
380384 self .getpeercert = t .cast (SSLSocket , socket_ ).getpeercert
381385 elif "getpeercert" in self .__dict__ :
382386 del self .__dict__ ["getpeercert" ]
387+ socket_ .getsockname ()
383388 self .gettimeout = socket_ .gettimeout
384389 self .settimeout = socket_ .settimeout
385390
386- getsockname : t .Callable = None # type: ignore
387- getpeername : t .Callable = None # type: ignore
388391 getpeercert : t .Callable = None # type: ignore
389392 gettimeout : t .Callable = None # type: ignore
390393 settimeout : t .Callable = None # type: ignore
391394
395+ def getsockname (self ):
396+ return self ._sockname
397+
398+ def getpeername (self ):
399+ return self ._peername
400+
392401 def _wait_for_io (
393402 self ,
394403 func : t .Callable [_P , t .Any ],
@@ -488,7 +497,9 @@ def _connect_secure(
488497 ) from error
489498 raise
490499
491- local_port = s .getsockname ()[1 ]
500+ sockname = s .getsockname ()
501+ peername = s .getpeername ()
502+ local_port = sockname [1 ]
492503 # Secure the connection if an SSL context has been provided
493504 if ssl_context :
494505 hostname = resolved_address ._host_name or None
@@ -529,7 +540,7 @@ def _connect_secure(
529540 cls ._kill_raw_socket (s )
530541 raise
531542
532- return cls (s )
543+ return cls (s , sockname , peername )
533544
534545 @abc .abstractmethod
535546 def _handshake (self , resolved_address , deadline ): ...
0 commit comments