Skip to content

Commit a0eb304

Browse files
roydahandkropachev
authored andcommitted
asyncio: fix SSL connections by using native TLS transport
Python 3.8+ rejects ssl.SSLSocket in asyncio's sock_sendall/sock_recv with TypeError. This caused the driver to fail connecting to ScyllaDB clusters requiring TLS, manifesting as 'protocol version 21 not supported' errors (0x15 = TLS Alert byte misread as protocol version). Fix by using asyncio's native TLS transport (loop.create_connection with ssl= parameter) instead of wrapping sockets with ssl.SSLContext.wrap_socket(). This preserves shard-aware port binding done during _initiate_connection(). Add _AsyncioProtocol to bridge asyncio's transport/protocol API back to Connection.process_io_buffer() for SSL data reads. Non-SSL connections continue using the existing sock_recv path. Fixes #330
1 parent cf01c3f commit a0eb304

1 file changed

Lines changed: 168 additions & 31 deletions

File tree

cassandra/io/asyncioreactor.py

Lines changed: 168 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
asyncio.run_coroutine_threadsafe
2424
except AttributeError:
2525
raise ImportError(
26-
'Cannot use asyncioreactor without access to '
27-
'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
26+
"Cannot use asyncioreactor without access to "
27+
"asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)"
2828
)
2929

3030

@@ -38,12 +38,12 @@ class AsyncioTimer(object):
3838

3939
@property
4040
def end(self):
41-
raise NotImplementedError('{} is not compatible with TimerManager and '
42-
'does not implement .end()')
41+
raise NotImplementedError(
42+
"{} is not compatible with TimerManager and does not implement .end()"
43+
)
4344

4445
def __init__(self, timeout, callback, loop):
45-
delayed = self._call_delayed_coro(timeout=timeout,
46-
callback=callback)
46+
delayed = self._call_delayed_coro(timeout=timeout, callback=callback)
4747
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)
4848

4949
@staticmethod
@@ -63,17 +63,61 @@ def cancel(self):
6363
def finish(self):
6464
# connection.Timer method not implemented here because we can't inspect
6565
# the Handle returned from call_later
66-
raise NotImplementedError('{} is not compatible with TimerManager and '
67-
'does not implement .finish()')
66+
raise NotImplementedError(
67+
"{} is not compatible with TimerManager and does not implement .finish()"
68+
)
69+
70+
71+
class _AsyncioProtocol(asyncio.Protocol):
72+
"""
73+
Protocol adapter for asyncio SSL connections. Bridges asyncio's
74+
transport/protocol API back to AsyncioConnection's buffer processing.
75+
"""
76+
77+
def __init__(self, connection, loop_args=None):
78+
self._connection = connection
79+
self.transport = None
80+
self.write_ready = asyncio.Event(**(loop_args or {}))
81+
self.write_ready.set()
82+
83+
def connection_made(self, transport):
84+
self.transport = transport
85+
86+
def data_received(self, data):
87+
conn = self._connection
88+
conn._iobuf.write(data)
89+
if conn._iobuf.tell():
90+
conn.process_io_buffer()
91+
92+
def pause_writing(self):
93+
self.write_ready.clear()
94+
95+
def resume_writing(self):
96+
self.write_ready.set()
97+
98+
def connection_lost(self, exc):
99+
# Unblock any paused writer so shutdown does not hang
100+
self.write_ready.set()
101+
conn = self._connection
102+
if exc:
103+
log.debug("Connection %s lost: %s", conn, exc)
104+
conn.defunct(exc)
105+
else:
106+
log.debug("Connection %s closed by server", conn)
107+
conn.close()
108+
109+
def eof_received(self):
110+
return False
68111

69112

70113
class AsyncioConnection(Connection):
71114
"""
72-
An experimental implementation of :class:`.Connection` that uses the
73-
``asyncio`` module in the Python standard library for its event loop.
115+
An implementation of :class:`.Connection` that uses the ``asyncio``
116+
module in the Python standard library for its event loop.
74117
75-
Note that it requires ``asyncio`` features that were only introduced in the
76-
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
118+
Supports SSL connections via asyncio's native TLS transport, which
119+
avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
120+
low-level socket methods (``sock_sendall``, ``sock_recv``).
77121
"""
78122

79123
_loop = None
@@ -88,26 +132,109 @@ class AsyncioConnection(Connection):
88132
def __init__(self, *args, **kwargs):
89133
Connection.__init__(self, *args, **kwargs)
90134
self._background_tasks = set()
135+
self._transport = None
136+
self._using_ssl = bool(self.ssl_context)
91137

92138
self._connect_socket()
93139
self._socket.setblocking(0)
94140
loop_args = dict()
95141
if sys.version_info[0] == 3 and sys.version_info[1] < 10:
96-
loop_args['loop'] = self._loop
142+
loop_args["loop"] = self._loop
143+
self._protocol = _AsyncioProtocol(self, loop_args) if self._using_ssl else None
144+
self._ssl_ready = asyncio.Event(**loop_args) if self._using_ssl else None
97145
self._write_queue = asyncio.Queue(**loop_args)
98146
self._write_queue_lock = asyncio.Lock(**loop_args)
99147

100148
# see initialize_reactor -- loop is running in a separate thread, so we
101149
# have to use a threadsafe call
102-
self._read_watcher = asyncio.run_coroutine_threadsafe(
103-
self.handle_read(), loop=self._loop
104-
)
150+
if self._using_ssl:
151+
# For SSL: set up asyncio transport/protocol, then start writer
152+
self._read_watcher = asyncio.run_coroutine_threadsafe(
153+
self._setup_ssl_and_run(), loop=self._loop
154+
)
155+
else:
156+
# For non-SSL: use low-level sock_sendall/sock_recv as before
157+
self._read_watcher = asyncio.run_coroutine_threadsafe(
158+
self.handle_read(), loop=self._loop
159+
)
105160
self._write_watcher = asyncio.run_coroutine_threadsafe(
106161
self.handle_write(), loop=self._loop
107162
)
108163
self._send_options_message()
109164

165+
def _connect_socket(self):
166+
"""
167+
Override base class to skip SSL wrapping of the socket.
168+
For SSL connections, the plain TCP socket is connected here, and TLS
169+
is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
170+
"""
171+
sockerr = None
172+
addresses = self._get_socket_addresses()
173+
for af, socktype, proto, _, sockaddr in addresses:
174+
try:
175+
self._socket = self._socket_impl.socket(af, socktype, proto)
176+
# Do NOT wrap with ssl_context here -- asyncio will handle TLS
177+
self._socket.settimeout(self.connect_timeout)
178+
self._initiate_connection(sockaddr)
179+
self._socket.settimeout(None)
180+
181+
local_addr = self._socket.getsockname()
182+
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
183+
sockerr = None
184+
break
185+
except socket.error as err:
186+
if self._socket:
187+
self._socket.close()
188+
self._socket = None
189+
sockerr = err
190+
191+
if sockerr:
192+
raise socket.error(
193+
sockerr.errno,
194+
"Tried connecting to %s. Last error: %s"
195+
% ([a[4] for a in addresses], sockerr.strerror or sockerr),
196+
)
197+
198+
if self.sockopts:
199+
for args in self.sockopts:
200+
self._socket.setsockopt(*args)
201+
202+
async def _setup_ssl_and_run(self):
203+
"""
204+
Upgrade the plain TCP connection to TLS using asyncio's native SSL
205+
transport, then continuously read data via the protocol callbacks.
206+
"""
207+
try:
208+
ssl_context = self.ssl_context
209+
server_hostname = None
210+
if self.ssl_options:
211+
server_hostname = self.ssl_options.get("server_hostname", None)
212+
if server_hostname is None:
213+
# asyncio's create_connection requires server_hostname when
214+
# ssl= is set. Use endpoint address for SNI/verification when
215+
# check_hostname is enabled; otherwise pass "" to suppress SNI.
216+
server_hostname = (
217+
self.endpoint.address if ssl_context.check_hostname else ""
218+
)
219+
220+
transport, protocol = await self._loop.create_connection(
221+
lambda: self._protocol,
222+
sock=self._socket,
223+
ssl=ssl_context,
224+
server_hostname=server_hostname,
225+
)
226+
self._transport = transport
227+
228+
if self._check_hostname:
229+
self._validate_hostname()
110230

231+
self._ssl_ready.set()
232+
except Exception as exc:
233+
log.debug("SSL setup failed for %s: %s", self, exc)
234+
self.defunct(exc)
235+
# Unblock handle_write so it can observe the defunct state and exit
236+
self._ssl_ready.set()
237+
return
111238

112239
@classmethod
113240
def initialize_reactor(cls):
@@ -126,8 +253,9 @@ def initialize_reactor(cls):
126253
cls._loop = asyncio.new_event_loop()
127254
# daemonize so the loop will be shut down on interpreter
128255
# shutdown
129-
cls._loop_thread = Thread(target=cls._loop.run_forever,
130-
daemon=True, name="asyncio_thread")
256+
cls._loop_thread = Thread(
257+
target=cls._loop.run_forever, daemon=True, name="asyncio_thread"
258+
)
131259
cls._loop_thread.start()
132260

133261
@classmethod
@@ -142,17 +270,18 @@ def close(self):
142270

143271
# close from the loop thread to avoid races when removing file
144272
# descriptors
145-
asyncio.run_coroutine_threadsafe(
146-
self._close(), loop=self._loop
147-
)
273+
asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop)
148274

149275
async def _close(self):
150276
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
151277
if self._write_watcher:
152278
self._write_watcher.cancel()
153279
if self._read_watcher:
154280
self._read_watcher.cancel()
155-
if self._socket:
281+
if self._transport:
282+
self._transport.close()
283+
self._transport = None
284+
elif self._socket:
156285
self._loop.remove_writer(self._socket.fileno())
157286
self._loop.remove_reader(self._socket.fileno())
158287
self._socket.close()
@@ -172,15 +301,12 @@ def push(self, data):
172301
if len(data) > buff_size:
173302
chunks = []
174303
for i in range(0, len(data), buff_size):
175-
chunks.append(data[i:i + buff_size])
304+
chunks.append(data[i : i + buff_size])
176305
else:
177306
chunks = [data]
178307

179308
if self._loop_thread != threading.current_thread():
180-
asyncio.run_coroutine_threadsafe(
181-
self._push_msg(chunks),
182-
loop=self._loop
183-
)
309+
asyncio.run_coroutine_threadsafe(self._push_msg(chunks), loop=self._loop)
184310
else:
185311
# avoid races/hangs by just scheduling this, not using threadsafe
186312
task = self._loop.create_task(self._push_msg(chunks))
@@ -194,13 +320,25 @@ async def _push_msg(self, chunks):
194320
for chunk in chunks:
195321
self._write_queue.put_nowait(chunk)
196322

197-
198323
async def handle_write(self):
324+
# For SSL connections, wait until the TLS handshake completes
325+
if self._ssl_ready:
326+
await self._ssl_ready.wait()
327+
if self.is_defunct:
328+
return
199329
while True:
200330
try:
201331
next_msg = await self._write_queue.get()
202332
if next_msg:
203-
await self._loop.sock_sendall(self._socket, next_msg)
333+
if self._transport:
334+
# SSL: use asyncio transport (handles TLS transparently)
335+
await self._protocol.write_ready.wait()
336+
if self.is_closed or self.is_defunct or not self._transport:
337+
return
338+
self._transport.write(next_msg)
339+
else:
340+
# Non-SSL: use low-level socket API
341+
await self._loop.sock_sendall(self._socket, next_msg)
204342
except socket.error as err:
205343
log.debug("Exception in send for %s: %s", self, err)
206344
self.defunct(err)
@@ -223,8 +361,7 @@ async def handle_read(self):
223361
await asyncio.sleep(0)
224362
continue
225363
except socket.error as err:
226-
log.debug("Exception during socket recv for %s: %s",
227-
self, err)
364+
log.debug("Exception during socket recv for %s: %s", self, err)
228365
self.defunct(err)
229366
return # leave the read loop
230367
except asyncio.CancelledError:

0 commit comments

Comments
 (0)