2323 asyncio .run_coroutine_threadsafe
2424except AttributeError :
2525 raise ImportError (
26- ' Cannot use asyncioreactor without access to '
27- ' asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
26+ " Cannot use asyncioreactor without access to "
27+ " asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)"
2828 )
2929
3030
@@ -38,12 +38,12 @@ class AsyncioTimer(object):
3838
3939 @property
4040 def end (self ):
41- raise NotImplementedError ('{} is not compatible with TimerManager and '
42- 'does not implement .end()' )
41+ raise NotImplementedError (
42+ "{} is not compatible with TimerManager and does not implement .end()"
43+ )
4344
4445 def __init__ (self , timeout , callback , loop ):
45- delayed = self ._call_delayed_coro (timeout = timeout ,
46- callback = callback )
46+ delayed = self ._call_delayed_coro (timeout = timeout , callback = callback )
4747 self ._handle = asyncio .run_coroutine_threadsafe (delayed , loop = loop )
4848
4949 @staticmethod
@@ -63,17 +63,51 @@ def cancel(self):
6363 def finish (self ):
6464 # connection.Timer method not implemented here because we can't inspect
6565 # the Handle returned from call_later
66- raise NotImplementedError ('{} is not compatible with TimerManager and '
67- 'does not implement .finish()' )
66+ raise NotImplementedError (
67+ "{} is not compatible with TimerManager and does not implement .finish()"
68+ )
69+
70+
71+ class _AsyncioProtocol (asyncio .Protocol ):
72+ """
73+ Protocol adapter for asyncio SSL connections. Bridges asyncio's
74+ transport/protocol API back to AsyncioConnection's buffer processing.
75+ """
76+
77+ def __init__ (self , connection ):
78+ self ._connection = connection
79+ self .transport = None
80+
81+ def connection_made (self , transport ):
82+ self .transport = transport
83+
84+ def data_received (self , data ):
85+ conn = self ._connection
86+ conn ._iobuf .write (data )
87+ if conn ._iobuf .tell ():
88+ conn .process_io_buffer ()
89+
90+ def connection_lost (self , exc ):
91+ conn = self ._connection
92+ if exc :
93+ log .debug ("Connection %s lost: %s" , conn , exc )
94+ conn .defunct (exc )
95+ else :
96+ log .debug ("Connection %s closed by server" , conn )
97+ conn .close ()
98+
99+ def eof_received (self ):
100+ return False
68101
69102
70103class AsyncioConnection (Connection ):
71104 """
72- An experimental implementation of :class:`.Connection` that uses the
73- ``asyncio`` module in the Python standard library for its event loop.
105+ An implementation of :class:`.Connection` that uses the ``asyncio``
106+ module in the Python standard library for its event loop.
74107
75- Note that it requires ``asyncio`` features that were only introduced in the
76- 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
108+ Supports SSL connections via asyncio's native TLS transport, which
109+ avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
110+ low-level socket methods (``sock_sendall``, ``sock_recv``).
77111 """
78112
79113 _loop = None
@@ -88,26 +122,102 @@ class AsyncioConnection(Connection):
88122 def __init__ (self , * args , ** kwargs ):
89123 Connection .__init__ (self , * args , ** kwargs )
90124 self ._background_tasks = set ()
125+ self ._transport = None
126+ self ._protocol = _AsyncioProtocol (self ) if self .ssl_context else None
127+ self ._using_ssl = bool (self .ssl_context )
128+ self ._ssl_ready = asyncio .Event () if self .ssl_context else None
91129
92130 self ._connect_socket ()
93131 self ._socket .setblocking (0 )
94132 loop_args = dict ()
95133 if sys .version_info [0 ] == 3 and sys .version_info [1 ] < 10 :
96- loop_args [' loop' ] = self ._loop
134+ loop_args [" loop" ] = self ._loop
97135 self ._write_queue = asyncio .Queue (** loop_args )
98136 self ._write_queue_lock = asyncio .Lock (** loop_args )
99137
100138 # see initialize_reactor -- loop is running in a separate thread, so we
101139 # have to use a threadsafe call
102- self ._read_watcher = asyncio .run_coroutine_threadsafe (
103- self .handle_read (), loop = self ._loop
104- )
140+ if self ._using_ssl :
141+ # For SSL: set up asyncio transport/protocol, then start writer
142+ self ._read_watcher = asyncio .run_coroutine_threadsafe (
143+ self ._setup_ssl_and_run (), loop = self ._loop
144+ )
145+ else :
146+ # For non-SSL: use low-level sock_sendall/sock_recv as before
147+ self ._read_watcher = asyncio .run_coroutine_threadsafe (
148+ self .handle_read (), loop = self ._loop
149+ )
105150 self ._write_watcher = asyncio .run_coroutine_threadsafe (
106151 self .handle_write (), loop = self ._loop
107152 )
108153 self ._send_options_message ()
109154
155+ def _connect_socket (self ):
156+ """
157+ Override base class to skip SSL wrapping of the socket.
158+ For SSL connections, the plain TCP socket is connected here, and TLS
159+ is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
160+ """
161+ sockerr = None
162+ addresses = self ._get_socket_addresses ()
163+ for af , socktype , proto , _ , sockaddr in addresses :
164+ try :
165+ self ._socket = self ._socket_impl .socket (af , socktype , proto )
166+ # Do NOT wrap with ssl_context here -- asyncio will handle TLS
167+ self ._socket .settimeout (self .connect_timeout )
168+ self ._initiate_connection (sockaddr )
169+ self ._socket .settimeout (None )
170+
171+ local_addr = self ._socket .getsockname ()
172+ log .debug ("Connection %s: '%s' -> '%s'" , id (self ), local_addr , sockaddr )
173+ sockerr = None
174+ break
175+ except socket .error as err :
176+ if self ._socket :
177+ self ._socket .close ()
178+ self ._socket = None
179+ sockerr = err
180+
181+ if sockerr :
182+ raise socket .error (
183+ sockerr .errno ,
184+ "Tried connecting to %s. Last error: %s"
185+ % ([a [4 ] for a in addresses ], sockerr .strerror or sockerr ),
186+ )
110187
188+ async def _setup_ssl_and_run (self ):
189+ """
190+ Upgrade the plain TCP connection to TLS using asyncio's native SSL
191+ transport, then continuously read data via the protocol callbacks.
192+ """
193+ try :
194+ ssl_context = self .ssl_context
195+ server_hostname = None
196+ if self .ssl_options :
197+ server_hostname = self .ssl_options .get ("server_hostname" , None )
198+ if not server_hostname :
199+ # asyncio's create_connection requires server_hostname when
200+ # ssl= is set, even if check_hostname is False
201+ server_hostname = self .endpoint .address
202+
203+ transport , protocol = await self ._loop .create_connection (
204+ lambda : self ._protocol ,
205+ sock = self ._socket ,
206+ ssl = ssl_context ,
207+ server_hostname = server_hostname ,
208+ )
209+ self ._transport = transport
210+
211+ if self ._check_hostname :
212+ self ._validate_hostname ()
213+
214+ self ._ssl_ready .set ()
215+ except Exception as exc :
216+ log .debug ("SSL setup failed for %s: %s" , self , exc )
217+ self .defunct (exc )
218+ # Unblock handle_write so it can observe the defunct state and exit
219+ self ._ssl_ready .set ()
220+ return
111221
112222 @classmethod
113223 def initialize_reactor (cls ):
@@ -126,8 +236,9 @@ def initialize_reactor(cls):
126236 cls ._loop = asyncio .new_event_loop ()
127237 # daemonize so the loop will be shut down on interpreter
128238 # shutdown
129- cls ._loop_thread = Thread (target = cls ._loop .run_forever ,
130- daemon = True , name = "asyncio_thread" )
239+ cls ._loop_thread = Thread (
240+ target = cls ._loop .run_forever , daemon = True , name = "asyncio_thread"
241+ )
131242 cls ._loop_thread .start ()
132243
133244 @classmethod
@@ -142,17 +253,18 @@ def close(self):
142253
143254 # close from the loop thread to avoid races when removing file
144255 # descriptors
145- asyncio .run_coroutine_threadsafe (
146- self ._close (), loop = self ._loop
147- )
256+ asyncio .run_coroutine_threadsafe (self ._close (), loop = self ._loop )
148257
149258 async def _close (self ):
150259 log .debug ("Closing connection (%s) to %s" % (id (self ), self .endpoint ))
151260 if self ._write_watcher :
152261 self ._write_watcher .cancel ()
153262 if self ._read_watcher :
154263 self ._read_watcher .cancel ()
155- if self ._socket :
264+ if self ._transport :
265+ self ._transport .close ()
266+ self ._transport = None
267+ elif self ._socket :
156268 self ._loop .remove_writer (self ._socket .fileno ())
157269 self ._loop .remove_reader (self ._socket .fileno ())
158270 self ._socket .close ()
@@ -172,15 +284,12 @@ def push(self, data):
172284 if len (data ) > buff_size :
173285 chunks = []
174286 for i in range (0 , len (data ), buff_size ):
175- chunks .append (data [i : i + buff_size ])
287+ chunks .append (data [i : i + buff_size ])
176288 else :
177289 chunks = [data ]
178290
179291 if self ._loop_thread != threading .current_thread ():
180- asyncio .run_coroutine_threadsafe (
181- self ._push_msg (chunks ),
182- loop = self ._loop
183- )
292+ asyncio .run_coroutine_threadsafe (self ._push_msg (chunks ), loop = self ._loop )
184293 else :
185294 # avoid races/hangs by just scheduling this, not using threadsafe
186295 task = self ._loop .create_task (self ._push_msg (chunks ))
@@ -194,13 +303,22 @@ async def _push_msg(self, chunks):
194303 for chunk in chunks :
195304 self ._write_queue .put_nowait (chunk )
196305
197-
198306 async def handle_write (self ):
307+ # For SSL connections, wait until the TLS handshake completes
308+ if self ._ssl_ready :
309+ await self ._ssl_ready .wait ()
310+ if self .is_defunct :
311+ return
199312 while True :
200313 try :
201314 next_msg = await self ._write_queue .get ()
202315 if next_msg :
203- await self ._loop .sock_sendall (self ._socket , next_msg )
316+ if self ._transport :
317+ # SSL: use asyncio transport (handles TLS transparently)
318+ self ._transport .write (next_msg )
319+ else :
320+ # Non-SSL: use low-level socket API
321+ await self ._loop .sock_sendall (self ._socket , next_msg )
204322 except socket .error as err :
205323 log .debug ("Exception in send for %s: %s" , self , err )
206324 self .defunct (err )
@@ -223,8 +341,7 @@ async def handle_read(self):
223341 await asyncio .sleep (0 )
224342 continue
225343 except socket .error as err :
226- log .debug ("Exception during socket recv for %s: %s" ,
227- self , err )
344+ log .debug ("Exception during socket recv for %s: %s" , self , err )
228345 self .defunct (err )
229346 return # leave the read loop
230347 except asyncio .CancelledError :
0 commit comments