Skip to content

Commit 645c3f7

Browse files
committed
Use the robust map() implementation from standard library.
1 parent 2af0894 commit 645c3f7

2 files changed

Lines changed: 92 additions & 17 deletions

File tree

src/qasync/__init__.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,17 @@ def wait(self):
170170
super().wait()
171171

172172

173+
def _result_or_cancel(fut, timeout=None):
174+
try:
175+
try:
176+
return fut.result(timeout)
177+
finally:
178+
fut.cancel()
179+
finally:
180+
# Break a reference cycle with the exception in self._exception
181+
del fut
182+
183+
173184
class QThreadExecutorBase:
174185
def __init__(self):
175186
self._been_shutdown = False
@@ -180,14 +191,29 @@ def submit(self, callback, *args, **kwargs):
180191

181192
def map(self, func, *iterables, timeout=None, chunksize=1):
182193
"""Map the function to the iterables in a blocking way."""
183-
# iterables are consumed immediately
184-
start = time.monotonic()
185-
futures = list(map(lambda *args: self.submit(func, *args), *iterables))
186-
for future in futures:
187-
if timeout is not None:
188-
yield future.result(timeout=time.monotonic() - start)
189-
else:
190-
yield future.result()
194+
# based on standard python implementation for BaseExecutor.map
195+
end_time = time.monotonic() + timeout if timeout is not None else None
196+
futures = [self.submit(func, *args) for args in zip(*iterables)]
197+
198+
# the generator must be an inner function so that map() and the submit
199+
# occurs immediately.
200+
def generator():
201+
# reverse and pop to not keep future references around
202+
# (for reference cycles in exceptions)
203+
try:
204+
futures.reverse()
205+
while futures:
206+
if end_time is not None:
207+
yield _result_or_cancel(
208+
futures.pop(), timeout=end_time - time.monotonic()
209+
)
210+
else:
211+
yield _result_or_cancel(futures.pop())
212+
finally:
213+
for future in futures:
214+
future.cancel()
215+
216+
return generator()
191217

192218
def shutdown(self, wait=True, *, cancel_futures=False):
193219
if self._been_shutdown:

tests/test_qthreadexec.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88
import weakref
99
from concurrent.futures import TimeoutError
10+
from itertools import islice
1011

1112
import pytest
1213

@@ -140,9 +141,11 @@ def test_map(executor):
140141
results = list(executor.map(lambda x: x + 1, range(10)))
141142
assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
142143

144+
results = list(executor.map(lambda x, y: x + y, range(10), range(9)))
145+
assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16]
143146

144-
@pytest.mark.parametrize("cancel", [True, False])
145-
def test_map_timeout(executor, cancel):
147+
148+
def test_map_timeout(executor):
146149
"""Test that map with timeout raises TimeoutError and cancels futures"""
147150
results = []
148151

@@ -158,15 +161,61 @@ def func(x):
158161
duration = time.monotonic() - start
159162
assert duration < 0.05
160163

164+
executor.shutdown(wait=True)
165+
# only about half of the tasks should have completed
166+
# because the max number of workers is 5 and the rest of
167+
# the tasks were not started at the time of the cancel.
168+
assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
169+
170+
171+
def test_map_error(executor):
172+
"""Test that map with an exception will raise, and remaining tasks are cancelled"""
173+
results = []
174+
175+
def func(x):
176+
nonlocal results
177+
time.sleep(0.05)
178+
if len(results) == 5:
179+
raise ValueError("Test error")
180+
results.append(x)
181+
return x
182+
183+
with pytest.raises(ValueError):
184+
list(executor.map(func, range(15)))
185+
186+
executor.shutdown(wait=True, cancel_futures=False)
187+
assert len(results) <= 10, "Final 5 at least should have been cancelled"
188+
189+
190+
@pytest.mark.parametrize("cancel", [True, False])
191+
def test_map_shutdown(executor, cancel):
192+
results = []
193+
194+
def func(x):
195+
nonlocal results
196+
time.sleep(0.05)
197+
results.append(x)
198+
return x
199+
200+
# Get the first few results.
201+
# Keep the iterator alive so that it isn't closed when its reference is dropped.
202+
m = executor.map(func, range(15))
203+
values = list(islice(m, 5))
204+
assert values == [0, 1, 2, 3, 4]
205+
161206
executor.shutdown(wait=True, cancel_futures=cancel)
162-
if not cancel:
163-
# they were not cancelled
164-
assert set(results) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
207+
if cancel:
208+
assert len(results) < 15, "Some tasks should have been cancelled"
165209
else:
166-
# only about half of the tasks should have completed
167-
# because the max number of workers is 5 and the rest of
168-
# the tasks were not started at the time of the cancel.
169-
assert set(results) != {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
210+
assert len(results) == 15, "All tasks should have been completed"
211+
212+
213+
def test_map_start(executor):
214+
"""Test that map starts tasks immediately, before iterating"""
215+
e = threading.Event()
216+
m = executor.map(lambda x: (e.set(), x), range(1))
217+
e.wait(timeout=0.1)
218+
assert list(m) == [(None, 0)]
170219

171220

172221
def test_context(executor):

0 commit comments

Comments
 (0)