Skip to content

Commit 99e7245

Browse files
committed
asyncio: fix SSL connections by using native TLS transport
Python 3.8+ rejects ssl.SSLSocket in asyncio's sock_sendall/sock_recv with TypeError. This caused the driver to fail connecting to ScyllaDB clusters requiring TLS, manifesting as 'protocol version 21 not supported' errors (0x15 = TLS Alert byte misread as protocol version). Fix by using asyncio's native TLS transport (loop.create_connection with ssl= parameter) instead of wrapping sockets with ssl.SSLContext.wrap_socket(). This preserves shard-aware port binding done during _initiate_connection(). Add _AsyncioProtocol to bridge asyncio's transport/protocol API back to Connection.process_io_buffer() for SSL data reads. Non-SSL connections continue using the existing sock_recv path. Fixes #330
1 parent c56fbf8 commit 99e7245

1 file changed

Lines changed: 148 additions & 31 deletions

File tree

cassandra/io/asyncioreactor.py

Lines changed: 148 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
asyncio.run_coroutine_threadsafe
2424
except AttributeError:
2525
raise ImportError(
26-
'Cannot use asyncioreactor without access to '
27-
'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)'
26+
"Cannot use asyncioreactor without access to "
27+
"asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)"
2828
)
2929

3030

@@ -38,12 +38,12 @@ class AsyncioTimer(object):
3838

3939
@property
4040
def end(self):
41-
raise NotImplementedError('{} is not compatible with TimerManager and '
42-
'does not implement .end()')
41+
raise NotImplementedError(
42+
"{} is not compatible with TimerManager and does not implement .end()"
43+
)
4344

4445
def __init__(self, timeout, callback, loop):
45-
delayed = self._call_delayed_coro(timeout=timeout,
46-
callback=callback)
46+
delayed = self._call_delayed_coro(timeout=timeout, callback=callback)
4747
self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop)
4848

4949
@staticmethod
@@ -63,17 +63,51 @@ def cancel(self):
6363
def finish(self):
6464
# connection.Timer method not implemented here because we can't inspect
6565
# the Handle returned from call_later
66-
raise NotImplementedError('{} is not compatible with TimerManager and '
67-
'does not implement .finish()')
66+
raise NotImplementedError(
67+
"{} is not compatible with TimerManager and does not implement .finish()"
68+
)
69+
70+
71+
class _AsyncioProtocol(asyncio.Protocol):
72+
"""
73+
Protocol adapter for asyncio SSL connections. Bridges asyncio's
74+
transport/protocol API back to AsyncioConnection's buffer processing.
75+
"""
76+
77+
def __init__(self, connection):
78+
self._connection = connection
79+
self.transport = None
80+
81+
def connection_made(self, transport):
82+
self.transport = transport
83+
84+
def data_received(self, data):
85+
conn = self._connection
86+
conn._iobuf.write(data)
87+
if conn._iobuf.tell():
88+
conn.process_io_buffer()
89+
90+
def connection_lost(self, exc):
91+
conn = self._connection
92+
if exc:
93+
log.debug("Connection %s lost: %s", conn, exc)
94+
conn.defunct(exc)
95+
else:
96+
log.debug("Connection %s closed by server", conn)
97+
conn.close()
98+
99+
def eof_received(self):
100+
return False
68101

69102

70103
class AsyncioConnection(Connection):
71104
"""
72-
An experimental implementation of :class:`.Connection` that uses the
73-
``asyncio`` module in the Python standard library for its event loop.
105+
An implementation of :class:`.Connection` that uses the ``asyncio``
106+
module in the Python standard library for its event loop.
74107
75-
Note that it requires ``asyncio`` features that were only introduced in the
76-
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
108+
Supports SSL connections via asyncio's native TLS transport, which
109+
avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
110+
low-level socket methods (``sock_sendall``, ``sock_recv``).
77111
"""
78112

79113
_loop = None
@@ -88,26 +122,102 @@ class AsyncioConnection(Connection):
88122
def __init__(self, *args, **kwargs):
89123
Connection.__init__(self, *args, **kwargs)
90124
self._background_tasks = set()
125+
self._transport = None
126+
self._protocol = _AsyncioProtocol(self) if self.ssl_context else None
127+
self._using_ssl = bool(self.ssl_context)
128+
self._ssl_ready = asyncio.Event() if self.ssl_context else None
91129

92130
self._connect_socket()
93131
self._socket.setblocking(0)
94132
loop_args = dict()
95133
if sys.version_info[0] == 3 and sys.version_info[1] < 10:
96-
loop_args['loop'] = self._loop
134+
loop_args["loop"] = self._loop
97135
self._write_queue = asyncio.Queue(**loop_args)
98136
self._write_queue_lock = asyncio.Lock(**loop_args)
99137

100138
# see initialize_reactor -- loop is running in a separate thread, so we
101139
# have to use a threadsafe call
102-
self._read_watcher = asyncio.run_coroutine_threadsafe(
103-
self.handle_read(), loop=self._loop
104-
)
140+
if self._using_ssl:
141+
# For SSL: set up asyncio transport/protocol, then start writer
142+
self._read_watcher = asyncio.run_coroutine_threadsafe(
143+
self._setup_ssl_and_run(), loop=self._loop
144+
)
145+
else:
146+
# For non-SSL: use low-level sock_sendall/sock_recv as before
147+
self._read_watcher = asyncio.run_coroutine_threadsafe(
148+
self.handle_read(), loop=self._loop
149+
)
105150
self._write_watcher = asyncio.run_coroutine_threadsafe(
106151
self.handle_write(), loop=self._loop
107152
)
108153
self._send_options_message()
109154

155+
def _connect_socket(self):
156+
"""
157+
Override base class to skip SSL wrapping of the socket.
158+
For SSL connections, the plain TCP socket is connected here, and TLS
159+
is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
160+
"""
161+
sockerr = None
162+
addresses = self._get_socket_addresses()
163+
for af, socktype, proto, _, sockaddr in addresses:
164+
try:
165+
self._socket = self._socket_impl.socket(af, socktype, proto)
166+
# Do NOT wrap with ssl_context here -- asyncio will handle TLS
167+
self._socket.settimeout(self.connect_timeout)
168+
self._initiate_connection(sockaddr)
169+
self._socket.settimeout(None)
170+
171+
local_addr = self._socket.getsockname()
172+
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
173+
sockerr = None
174+
break
175+
except socket.error as err:
176+
if self._socket:
177+
self._socket.close()
178+
self._socket = None
179+
sockerr = err
180+
181+
if sockerr:
182+
raise socket.error(
183+
sockerr.errno,
184+
"Tried connecting to %s. Last error: %s"
185+
% ([a[4] for a in addresses], sockerr.strerror or sockerr),
186+
)
110187

188+
async def _setup_ssl_and_run(self):
189+
"""
190+
Upgrade the plain TCP connection to TLS using asyncio's native SSL
191+
transport, then continuously read data via the protocol callbacks.
192+
"""
193+
try:
194+
ssl_context = self.ssl_context
195+
server_hostname = None
196+
if self.ssl_options:
197+
server_hostname = self.ssl_options.get("server_hostname", None)
198+
if not server_hostname:
199+
# asyncio's create_connection requires server_hostname when
200+
# ssl= is set, even if check_hostname is False
201+
server_hostname = self.endpoint.address
202+
203+
transport, protocol = await self._loop.create_connection(
204+
lambda: self._protocol,
205+
sock=self._socket,
206+
ssl=ssl_context,
207+
server_hostname=server_hostname,
208+
)
209+
self._transport = transport
210+
211+
if self._check_hostname:
212+
self._validate_hostname()
213+
214+
self._ssl_ready.set()
215+
except Exception as exc:
216+
log.debug("SSL setup failed for %s: %s", self, exc)
217+
self.defunct(exc)
218+
# Unblock handle_write so it can observe the defunct state and exit
219+
self._ssl_ready.set()
220+
return
111221

112222
@classmethod
113223
def initialize_reactor(cls):
@@ -126,8 +236,9 @@ def initialize_reactor(cls):
126236
cls._loop = asyncio.new_event_loop()
127237
# daemonize so the loop will be shut down on interpreter
128238
# shutdown
129-
cls._loop_thread = Thread(target=cls._loop.run_forever,
130-
daemon=True, name="asyncio_thread")
239+
cls._loop_thread = Thread(
240+
target=cls._loop.run_forever, daemon=True, name="asyncio_thread"
241+
)
131242
cls._loop_thread.start()
132243

133244
@classmethod
@@ -142,17 +253,18 @@ def close(self):
142253

143254
# close from the loop thread to avoid races when removing file
144255
# descriptors
145-
asyncio.run_coroutine_threadsafe(
146-
self._close(), loop=self._loop
147-
)
256+
asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop)
148257

149258
async def _close(self):
150259
log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint))
151260
if self._write_watcher:
152261
self._write_watcher.cancel()
153262
if self._read_watcher:
154263
self._read_watcher.cancel()
155-
if self._socket:
264+
if self._transport:
265+
self._transport.close()
266+
self._transport = None
267+
elif self._socket:
156268
self._loop.remove_writer(self._socket.fileno())
157269
self._loop.remove_reader(self._socket.fileno())
158270
self._socket.close()
@@ -172,15 +284,12 @@ def push(self, data):
172284
if len(data) > buff_size:
173285
chunks = []
174286
for i in range(0, len(data), buff_size):
175-
chunks.append(data[i:i + buff_size])
287+
chunks.append(data[i : i + buff_size])
176288
else:
177289
chunks = [data]
178290

179291
if self._loop_thread != threading.current_thread():
180-
asyncio.run_coroutine_threadsafe(
181-
self._push_msg(chunks),
182-
loop=self._loop
183-
)
292+
asyncio.run_coroutine_threadsafe(self._push_msg(chunks), loop=self._loop)
184293
else:
185294
# avoid races/hangs by just scheduling this, not using threadsafe
186295
task = self._loop.create_task(self._push_msg(chunks))
@@ -194,13 +303,22 @@ async def _push_msg(self, chunks):
194303
for chunk in chunks:
195304
self._write_queue.put_nowait(chunk)
196305

197-
198306
async def handle_write(self):
307+
# For SSL connections, wait until the TLS handshake completes
308+
if self._ssl_ready:
309+
await self._ssl_ready.wait()
310+
if self.is_defunct:
311+
return
199312
while True:
200313
try:
201314
next_msg = await self._write_queue.get()
202315
if next_msg:
203-
await self._loop.sock_sendall(self._socket, next_msg)
316+
if self._transport:
317+
# SSL: use asyncio transport (handles TLS transparently)
318+
self._transport.write(next_msg)
319+
else:
320+
# Non-SSL: use low-level socket API
321+
await self._loop.sock_sendall(self._socket, next_msg)
204322
except socket.error as err:
205323
log.debug("Exception in send for %s: %s", self, err)
206324
self.defunct(err)
@@ -223,8 +341,7 @@ async def handle_read(self):
223341
await asyncio.sleep(0)
224342
continue
225343
except socket.error as err:
226-
log.debug("Exception during socket recv for %s: %s",
227-
self, err)
344+
log.debug("Exception during socket recv for %s: %s", self, err)
228345
self.defunct(err)
229346
return # leave the read loop
230347
except asyncio.CancelledError:

0 commit comments

Comments
 (0)