Skip to content

Commit 1658317

Browse files
committed
TST: specify max_workers explicitly for concurrent tests
1 parent 18b20f5 commit 1658317

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

pywt/tests/test_concurrent.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import sys
99
import warnings
10+
import multiprocessing
1011
import numpy as np
1112
from functools import partial
1213
from numpy.testing import dec, run_module_suite, assert_array_equal
@@ -18,6 +19,7 @@
1819
import futures
1920
else:
2021
from concurrent import futures
22+
max_workers = multiprocessing.cpu_count()
2123
futures_available = True
2224
except ImportError:
2325
futures_available = False
@@ -54,7 +56,7 @@ def test_concurrent_swt():
5456
transform = partial(swt_func, wavelet='haar', level=1)
5557
for _ in range(10):
5658
arrs = [x.copy() for _ in range(100)]
57-
with futures.ThreadPoolExecutor() as ex:
59+
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
5860
results = list(ex.map(transform, arrs))
5961

6062
# validate result from one of the concurrent runs
@@ -65,13 +67,13 @@ def test_concurrent_swt():
6567
@dec.skipif(not futures_available)
6668
def test_concurrent_wavedec():
6769
# wavedec on 1D data calls the Cython dwt_single
68-
# other cases call dwt_axes
70+
# other cases call dwt_axis
6971
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
7072
[np.ones(8), np.eye(16), np.eye(16)]):
7173
transform = partial(wavedec_func, wavelet='haar', level=1)
7274
for _ in range(10):
7375
arrs = [x.copy() for _ in range(100)]
74-
with futures.ThreadPoolExecutor() as ex:
76+
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
7577
results = list(ex.map(transform, arrs))
7678

7779
# validate result from one of the concurrent runs
@@ -82,13 +84,13 @@ def test_concurrent_wavedec():
8284
@dec.skipif(not futures_available)
8385
def test_concurrent_dwt():
8486
# dwt on 1D data calls the Cython dwt_single
85-
# other cases call dwt_axes
87+
# other cases call dwt_axis
8688
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
8789
[np.ones(8), np.eye(16), np.eye(16)]):
8890
transform = partial(dwt_func, wavelet='haar')
8991
for _ in range(10):
9092
arrs = [x.copy() for _ in range(100)]
91-
with futures.ThreadPoolExecutor() as ex:
93+
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
9294
results = list(ex.map(transform, arrs))
9395

9496
# validate result from one of the concurrent runs
@@ -104,7 +106,7 @@ def test_concurrent_cwt():
104106
sampling_period=dt)
105107
for _ in range(10):
106108
arrs = [sst.copy() for _ in range(50)]
107-
with futures.ThreadPoolExecutor() as ex:
109+
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
108110
results = list(ex.map(transform, arrs))
109111

110112
# validate result from one of the concurrent runs

0 commit comments

Comments
 (0)