Skip to content

Commit a858672

Browse files
committed
Full asyncio test suite
1 parent 4ab9f66 commit a858672

3 files changed

Lines changed: 348 additions & 3 deletions

File tree

cassandra/io/asyncioreactor.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from cassandra.connection import Connection, ConnectionShutdown
22

3+
import atexit
34
import asyncio
45
import logging
56
import os
@@ -11,6 +12,32 @@
1112
log = logging.getLogger(__name__)
1213

1314

15+
def _cleanup():
16+
"""
17+
Module-level cleanup called at interpreter shutdown via atexit.
18+
Stops the asyncio event loop and joins the loop thread.
19+
"""
20+
loop = AsyncioConnection._loop
21+
thread = AsyncioConnection._loop_thread
22+
if loop is not None:
23+
try:
24+
loop.call_soon_threadsafe(loop.stop)
25+
except RuntimeError:
26+
# loop may already be closed post-fork or during shutdown
27+
pass
28+
if thread is not None:
29+
thread.join(timeout=1.0)
30+
if thread.is_alive():
31+
log.warning(
32+
"Event loop thread could not be joined, so shutdown may not be clean. "
33+
"Please call Cluster.shutdown() to avoid this.")
34+
else:
35+
log.debug("Event loop thread was joined")
36+
37+
38+
atexit.register(_cleanup)
39+
40+
1441
class AsyncioTimer(object):
1542
"""
1643
An ``asyncioreactor``-specific Timer. Similar to :class:`.connection.Timer,
@@ -92,17 +119,33 @@ def __init__(self, *args, **kwargs):
92119
def initialize_reactor(cls):
93120
with cls._lock:
94121
if cls._pid != os.getpid():
95-
cls._loop = None
122+
log.debug("Detected fork, clearing and reinitializing reactor state")
123+
cls.handle_fork()
96124
if cls._loop is None:
97125
cls._loop = asyncio.new_event_loop()
98126

99-
if not cls._loop_thread:
127+
if not cls._loop_thread or not cls._loop_thread.is_alive():
100128
# daemonize so the loop will be shut down on interpreter
101129
# shutdown
102130
cls._loop_thread = Thread(target=cls._loop.run_forever,
103131
daemon=True, name="asyncio_thread")
104132
cls._loop_thread.start()
105133

134+
@classmethod
135+
def handle_fork(cls):
136+
"""
137+
Called after a fork. Cleans up any reactor state from the parent
138+
process so that a fresh event loop can be started in the child.
139+
"""
140+
if cls._loop is not None:
141+
try:
142+
cls._loop.call_soon_threadsafe(cls._loop.stop)
143+
except RuntimeError:
144+
pass
145+
cls._loop = None
146+
cls._loop_thread = None
147+
cls._pid = os.getpid()
148+
106149
@classmethod
107150
def create_timer(cls, timeout, callback):
108151
return AsyncioTimer(timeout, callback, loop=cls._loop)

tests/unit/io/test_asyncioreactor.py

Lines changed: 292 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
from tests import is_monkey_patched, connection_class
1010
from tests.unit.io.utils import TimerCallback, TimerTestMixin, submit_and_wait_for_completion
1111

12-
from unittest.mock import patch, MagicMock
12+
from unittest.mock import patch, MagicMock, Mock, AsyncMock
1313

14+
import asyncio
15+
import socket as stdlib_socket
1416
import unittest
1517
import time
18+
import threading
1619

1720
skip_me = (is_monkey_patched() or
1821
(not ASYNCIO_AVAILABLE) or
@@ -99,3 +102,291 @@ def test_timer_cancellation(self):
99102
time.sleep(.2)
100103
# Assert that the cancellation was honored
101104
self.assertFalse(callback.was_invoked())
105+
106+
107+
@unittest.skipIf(is_monkey_patched(), 'runtime is monkey patched for another reactor')
108+
@unittest.skipIf(connection_class is not AsyncioConnection,
109+
'not running asyncio tests; current connection_class is {}'.format(connection_class))
110+
@unittest.skipUnless(ASYNCIO_AVAILABLE, "asyncio is not available for this runtime")
111+
class AsyncioConnectionTest(unittest.TestCase):
112+
"""
113+
Tests for AsyncioConnection covering write, read, close, and error
114+
handling at the reactor level. Unlike the ReactorTestMixin used by
115+
asyncore/libev, these tests exercise the public interface (push/close)
116+
because handle_read/handle_write are async coroutines running inside
117+
the event loop thread.
118+
"""
119+
120+
@classmethod
121+
def setUpClass(cls):
122+
if skip_me:
123+
return
124+
# Force a fresh reactor so we aren't affected by a previous test
125+
# class that may have stopped the shared event loop.
126+
if AsyncioConnection._loop is not None:
127+
try:
128+
AsyncioConnection._loop.call_soon_threadsafe(
129+
AsyncioConnection._loop.stop)
130+
except RuntimeError:
131+
pass
132+
if AsyncioConnection._loop_thread:
133+
AsyncioConnection._loop_thread.join(timeout=1.0)
134+
AsyncioConnection._loop = None
135+
AsyncioConnection._loop_thread = None
136+
AsyncioConnection.initialize_reactor()
137+
cls._loop = AsyncioConnection._loop
138+
# Save original loop methods so we can restore after each test
139+
cls._orig_sock_recv = cls._loop.sock_recv
140+
cls._orig_sock_sendall = cls._loop.sock_sendall
141+
142+
@classmethod
143+
def tearDownClass(cls):
144+
if skip_me:
145+
return
146+
cls._loop.sock_recv = cls._orig_sock_recv
147+
cls._loop.sock_sendall = cls._orig_sock_sendall
148+
149+
def _make_connection(self):
150+
"""
151+
Create an AsyncioConnection with mocked socket and _connect_socket.
152+
Loop socket methods are pre-mocked so that the handle_read/handle_write
153+
coroutines started in __init__ don't hit real I/O.
154+
"""
155+
mock_socket = MagicMock(spec=stdlib_socket.socket)
156+
mock_socket.fileno.return_value = 99
157+
mock_socket.setblocking = MagicMock()
158+
mock_socket.connect.return_value = None
159+
mock_socket.getsockopt.return_value = 0
160+
mock_socket.send.side_effect = lambda x: len(x)
161+
162+
def fake_connect_socket(self_inner):
163+
self_inner._socket = mock_socket
164+
165+
with patch.object(AsyncioConnection, '_connect_socket', fake_connect_socket):
166+
conn = AsyncioConnection(
167+
host='127.0.0.1',
168+
cql_version='3.0.1',
169+
connect_timeout=5,
170+
)
171+
return conn
172+
173+
def setUp(self):
174+
if skip_me:
175+
return
176+
177+
loop = self._loop
178+
179+
# Pre-mock sock_recv to block indefinitely (read loop won't spin)
180+
self._recv_unblock = threading.Event()
181+
182+
async def blocking_recv(sock, bufsize):
183+
while not self._recv_unblock.is_set():
184+
await asyncio.sleep(0.01)
185+
raise asyncio.CancelledError()
186+
187+
# Pre-mock sock_sendall to silently consume data (options message, etc.)
188+
self._sent_data = []
189+
190+
async def capturing_sendall(sock, data):
191+
self._sent_data.append(bytes(data))
192+
193+
loop.sock_recv = blocking_recv
194+
loop.sock_sendall = capturing_sendall
195+
196+
self.conn = self._make_connection()
197+
# Give the loop a moment to process __init__ tasks (options message)
198+
time.sleep(0.1)
199+
# Clear any data sent during init (options message)
200+
self._sent_data.clear()
201+
202+
def tearDown(self):
203+
if skip_me:
204+
return
205+
# Unblock the recv so the read loop can exit
206+
self._recv_unblock.set()
207+
try:
208+
self.conn.close()
209+
except Exception:
210+
pass
211+
time.sleep(0.05)
212+
# Restore default mocks for next test
213+
self._loop.sock_recv = self._orig_sock_recv
214+
self._loop.sock_sendall = self._orig_sock_sendall
215+
216+
def test_push_sends_data(self):
217+
"""
218+
Verify that push() enqueues data and the write loop sends it
219+
via sock_sendall on the event loop.
220+
"""
221+
test_data = b'hello world'
222+
self.conn.push(test_data)
223+
224+
# Wait for the event loop to drain the write queue
225+
time.sleep(0.2)
226+
227+
self.assertTrue(len(self._sent_data) > 0)
228+
self.assertEqual(b''.join(self._sent_data), test_data)
229+
230+
def test_push_chunking(self):
231+
"""
232+
Verify that data larger than out_buffer_size is chunked
233+
into multiple pieces before being sent.
234+
"""
235+
buf_size = self.conn.out_buffer_size
236+
# Send data that is 2.5x the buffer size
237+
test_data = b'x' * int(buf_size * 2.5)
238+
self.conn.push(test_data)
239+
240+
time.sleep(0.2)
241+
242+
# Should have been broken into at least 3 chunks
243+
self.assertGreaterEqual(len(self._sent_data), 3)
244+
self.assertEqual(b''.join(self._sent_data), test_data)
245+
246+
def test_write_error_defuncts_connection(self):
247+
"""
248+
Verify that a socket error during write causes the
249+
connection to become defunct.
250+
"""
251+
loop = self._loop
252+
253+
async def error_sendall(sock, data):
254+
raise stdlib_socket.error(32, "Broken pipe")
255+
256+
loop.sock_sendall = error_sendall
257+
258+
self.conn.push(b'trigger error')
259+
time.sleep(0.2)
260+
261+
self.assertTrue(self.conn.is_defunct)
262+
self.assertIsInstance(self.conn.last_error, stdlib_socket.error)
263+
264+
def test_read_eof_closes_connection(self):
265+
"""
266+
Verify that receiving an empty buffer (EOF / server close)
267+
causes the connection to close.
268+
"""
269+
loop = self._loop
270+
271+
# Cancel the existing read watcher so we can start a new one
272+
if self.conn._read_watcher:
273+
self.conn._read_watcher.cancel()
274+
time.sleep(0.05)
275+
276+
call_count = 0
277+
async def eof_recv(sock, bufsize):
278+
nonlocal call_count
279+
call_count += 1
280+
if call_count == 1:
281+
return b'' # EOF
282+
raise asyncio.CancelledError()
283+
284+
loop.sock_recv = eof_recv
285+
286+
self.conn._read_watcher = asyncio.run_coroutine_threadsafe(
287+
self.conn.handle_read(), loop=loop
288+
)
289+
290+
time.sleep(0.2)
291+
self.assertTrue(self.conn.is_closed)
292+
293+
def test_read_error_defuncts_connection(self):
294+
"""
295+
Verify that a socket error during read causes the
296+
connection to become defunct.
297+
"""
298+
loop = self._loop
299+
300+
if self.conn._read_watcher:
301+
self.conn._read_watcher.cancel()
302+
time.sleep(0.05)
303+
304+
async def error_recv(sock, bufsize):
305+
raise stdlib_socket.error(104, "Connection reset by peer")
306+
307+
loop.sock_recv = error_recv
308+
309+
self.conn._read_watcher = asyncio.run_coroutine_threadsafe(
310+
self.conn.handle_read(), loop=loop
311+
)
312+
313+
time.sleep(0.2)
314+
self.assertTrue(self.conn.is_defunct)
315+
self.assertIsInstance(self.conn.last_error, stdlib_socket.error)
316+
317+
def test_read_processes_data(self):
318+
"""
319+
Verify that data received via sock_recv is written to the
320+
IO buffer and process_io_buffer is called.
321+
"""
322+
loop = self._loop
323+
324+
if self.conn._read_watcher:
325+
self.conn._read_watcher.cancel()
326+
time.sleep(0.05)
327+
328+
call_count = 0
329+
async def data_then_eof_recv(sock, bufsize):
330+
nonlocal call_count
331+
call_count += 1
332+
if call_count == 1:
333+
return b'some data from server'
334+
return b''
335+
336+
loop.sock_recv = data_then_eof_recv
337+
338+
with patch.object(self.conn, 'process_io_buffer') as mock_process:
339+
self.conn._read_watcher = asyncio.run_coroutine_threadsafe(
340+
self.conn.handle_read(), loop=loop
341+
)
342+
time.sleep(0.2)
343+
mock_process.assert_called()
344+
345+
def test_close_cancels_watchers(self):
346+
"""
347+
Verify that closing the connection cancels both the
348+
read and write watchers.
349+
"""
350+
read_watcher = self.conn._read_watcher
351+
write_watcher = self.conn._write_watcher
352+
353+
self.conn.close()
354+
time.sleep(0.2)
355+
356+
self.assertTrue(self.conn.is_closed)
357+
# The watchers should have been cancelled
358+
if read_watcher:
359+
self.assertTrue(read_watcher.cancelled() or read_watcher.done())
360+
if write_watcher:
361+
self.assertTrue(write_watcher.cancelled() or write_watcher.done())
362+
363+
364+
@unittest.skipIf(is_monkey_patched(), 'runtime is monkey patched for another reactor')
365+
@unittest.skipUnless(ASYNCIO_AVAILABLE, "asyncio is not available for this runtime")
366+
class AsyncioForkTest(unittest.TestCase):
367+
"""
368+
Test that handle_fork() properly resets reactor state.
369+
"""
370+
371+
def test_handle_fork_resets_state(self):
372+
"""
373+
Verify handle_fork() clears loop, thread, and updates pid.
374+
"""
375+
AsyncioConnection.initialize_reactor()
376+
self.assertIsNotNone(AsyncioConnection._loop)
377+
self.assertIsNotNone(AsyncioConnection._loop_thread)
378+
379+
old_loop = AsyncioConnection._loop
380+
old_thread = AsyncioConnection._loop_thread
381+
382+
AsyncioConnection.handle_fork()
383+
384+
self.assertIsNone(AsyncioConnection._loop)
385+
self.assertIsNone(AsyncioConnection._loop_thread)
386+
387+
# Re-initialize for other tests
388+
AsyncioConnection.initialize_reactor()
389+
self.assertIsNotNone(AsyncioConnection._loop)
390+
self.assertIsNotNone(AsyncioConnection._loop_thread)
391+
# Should be a new loop and thread
392+
self.assertIsNot(AsyncioConnection._loop, old_loop)

tox.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,14 @@ setenv = LIBEV_EMBED=0
4444
changedir = {envtmpdir}
4545
commands =
4646
pytest -v {toxinidir}/tests/unit/io/test_eventletreactor.py
47+
48+
49+
[testenv:asyncio_loop]
50+
deps = {[base]deps}
51+
52+
setenv = LIBEV_EMBED=0
53+
CARES_EMBED=0
54+
EVENT_LOOP_MANAGER=asyncio
55+
changedir = {envtmpdir}
56+
commands =
57+
pytest -v {toxinidir}/tests/unit/io/test_asyncioreactor.py

0 commit comments

Comments
 (0)