1717from __future__ import annotations
1818
1919import asyncio
20- import gc
2120import sys
2221import threading
2322import time
2625
2726from test .asynchronous import AsyncUnitTest , unittest
2827
29- from pymongo import periodic_executor
30- from pymongo .periodic_executor import (
31- AsyncPeriodicExecutor ,
32- _register_executor ,
33- _shutdown_executors ,
34- )
28+ from pymongo .periodic_executor import AsyncPeriodicExecutor
3529
3630_IS_SYNC = False
3731
@@ -49,11 +43,12 @@ async def target():
4943
5044class AsyncPeriodicExecutorTestBase (AsyncUnitTest ):
5145 async def asyncSetUp (self ):
52- self .executor = _make_executor ()
46+ self .executor = None
5347
5448 async def asyncTearDown (self ):
55- self .executor .close ()
56- await self .executor .join (timeout = 2 )
49+ if self .executor is not None :
50+ self .executor .close ()
51+ await self .executor .join (timeout = 2 )
5752
5853
5954class TestAsyncPeriodicExecutor (AsyncPeriodicExecutorTestBase ):
@@ -64,6 +59,7 @@ async def test_repr_contains_class_and_name(self):
6459 self .assertIn ("exec" , executor_repr )
6560
6661 async def test_join_without_open_is_safe (self ):
62+ self .executor = _make_executor ()
6763 try :
6864 await self .executor .join (timeout = 0.01 )
6965 except Exception as e :
@@ -109,18 +105,17 @@ def target():
109105 self .assertEqual (len (captured_exc ), 1 )
110106 self .assertIsInstance (captured_exc [0 ], RuntimeError )
111107 else :
112- ran = asyncio . Event ()
108+ call_count = 0
113109
114110 async def target ():
115- ran .set ()
111+ nonlocal call_count
112+ call_count += 1
116113 raise RuntimeError ("error" )
117114
118115 self .executor = _make_executor (target = target )
119116 self .executor .open ()
120117 await self .executor .join (timeout = 2 )
121- self .assertTrue (ran .is_set (), "target never ran" )
122- if self .executor ._task is not None and self .executor ._task .done ():
123- self .executor ._task .exception ()
118+ self .assertEqual (call_count , 1 , "target should stop after exception" )
124119
125120 async def test_skip_sleep_flag_skips_interval (self ):
126121 call_times = []
@@ -139,19 +134,18 @@ async def target():
139134 self .assertLess (call_times [1 ] - call_times [0 ], 5.0 )
140135
141136 async def test_wake_causes_early_run (self ):
142- call_count = [ 0 ]
137+ call_count = 0
143138 if _IS_SYNC :
144139 woken = threading .Event ()
145140 else :
146141 woken = asyncio .Event ()
147142
148143 async def target ():
149- call_count [0 ] += 1
150- if call_count [0 ] == 1 :
144+ nonlocal call_count
145+ call_count += 1
146+ if call_count == 1 :
151147 woken .set ()
152- if call_count [0 ] >= 2 :
153- return False
154- return True
148+ return call_count < 2
155149
156150 self .executor = _make_executor (interval = 30.0 , min_interval = 0.01 , target = target )
157151 self .executor .open ()
@@ -161,73 +155,22 @@ async def target():
161155 await asyncio .wait_for (woken .wait (), timeout = 2 )
162156 self .executor .wake ()
163157 await self .executor .join (timeout = 3 )
164- self .assertGreaterEqual (call_count [ 0 ] , 2 )
158+ self .assertGreaterEqual (call_count , 2 )
165159
166160 async def test_open_after_target_returns_false (self ):
167- called = [ 0 ]
161+ called = 0
168162
169163 async def target ():
170- called [0 ] += 1
164+ nonlocal called
165+ called += 1
171166 return False
172167
173168 self .executor = _make_executor (target = target )
174169 self .executor .open ()
175170 await self .executor .join (timeout = 2 )
176171 self .executor .open ()
177172 await self .executor .join (timeout = 2 )
178- self .assertGreaterEqual (called [0 ], 2 )
179-
180-
181- class TestShouldStop (AsyncUnitTest ):
182- if _IS_SYNC :
183-
184- def test_returns_false_when_not_stopped (self ):
185- executor = _make_executor ()
186- self .assertFalse (executor ._should_stop ())
187- self .assertFalse (executor ._thread_will_exit )
188-
189- def test_returns_true_and_sets_thread_will_exit (self ):
190- executor = _make_executor ()
191- executor ._stopped = True
192- self .assertTrue (executor ._should_stop ())
193- self .assertTrue (executor ._thread_will_exit )
194-
195-
196- class TestRegisterExecutor (AsyncUnitTest ):
197- if _IS_SYNC :
198-
199- def setUp (self ):
200- self ._orig = set (periodic_executor ._EXECUTORS )
201-
202- def tearDown (self ):
203- periodic_executor ._EXECUTORS .clear ()
204- periodic_executor ._EXECUTORS .update (self ._orig )
205-
206- def test_register_adds_weakref (self ):
207- executor = _make_executor ()
208- before = len (periodic_executor ._EXECUTORS )
209- _register_executor (executor )
210- self .assertEqual (len (periodic_executor ._EXECUTORS ), before + 1 )
211- ref = next (r for r in periodic_executor ._EXECUTORS if r () is executor )
212- del executor
213- gc .collect ()
214- self .assertNotIn (ref , periodic_executor ._EXECUTORS )
215-
216- def test_shutdown_executors_stops_running_executors (self ):
217- ran = threading .Event ()
218-
219- def target ():
220- ran .set ()
221- return True
222-
223- executor = _make_executor (target = target )
224- executor .open ()
225- self .assertTrue (ran .wait (timeout = 2 ), "target never ran" )
226- _shutdown_executors ()
227-
228- def test_shutdown_executors_safe_when_empty (self ):
229- periodic_executor ._EXECUTORS .clear ()
230- _shutdown_executors ()
173+ self .assertGreaterEqual (called , 2 )
231174
232175
233176if __name__ == "__main__" :
0 commit comments