77
88import sys
99import warnings
10+ import multiprocessing
1011import numpy as np
1112from functools import partial
1213from numpy .testing import dec , run_module_suite , assert_array_equal
1819 import futures
1920 else :
2021 from concurrent import futures
22+ max_workers = multiprocessing .cpu_count ()
2123 futures_available = True
2224except ImportError :
2325 futures_available = False
@@ -54,7 +56,7 @@ def test_concurrent_swt():
5456 transform = partial (swt_func , wavelet = 'haar' , level = 1 )
5557 for _ in range (10 ):
5658 arrs = [x .copy () for _ in range (100 )]
57- with futures .ThreadPoolExecutor () as ex :
59+ with futures .ThreadPoolExecutor (max_workers = max_workers ) as ex :
5860 results = list (ex .map (transform , arrs ))
5961
6062 # validate result from one of the concurrent runs
@@ -65,13 +67,13 @@ def test_concurrent_swt():
6567@dec .skipif (not futures_available )
6668def test_concurrent_wavedec ():
6769 # wavedec on 1D data calls the Cython dwt_single
68- # other cases call dwt_axes
70+ # other cases call dwt_axis
6971 for wavedec_func , x in zip ([pywt .wavedec , pywt .wavedec2 , pywt .wavedecn ],
7072 [np .ones (8 ), np .eye (16 ), np .eye (16 )]):
7173 transform = partial (wavedec_func , wavelet = 'haar' , level = 1 )
7274 for _ in range (10 ):
7375 arrs = [x .copy () for _ in range (100 )]
74- with futures .ThreadPoolExecutor () as ex :
76+ with futures .ThreadPoolExecutor (max_workers = max_workers ) as ex :
7577 results = list (ex .map (transform , arrs ))
7678
7779 # validate result from one of the concurrent runs
@@ -82,13 +84,13 @@ def test_concurrent_wavedec():
8284@dec .skipif (not futures_available )
8385def test_concurrent_dwt ():
8486 # dwt on 1D data calls the Cython dwt_single
85- # other cases call dwt_axes
87+ # other cases call dwt_axis
8688 for dwt_func , x in zip ([pywt .dwt , pywt .dwt2 , pywt .dwtn ],
8789 [np .ones (8 ), np .eye (16 ), np .eye (16 )]):
8890 transform = partial (dwt_func , wavelet = 'haar' )
8991 for _ in range (10 ):
9092 arrs = [x .copy () for _ in range (100 )]
91- with futures .ThreadPoolExecutor () as ex :
93+ with futures .ThreadPoolExecutor (max_workers = max_workers ) as ex :
9294 results = list (ex .map (transform , arrs ))
9395
9496 # validate result from one of the concurrent runs
@@ -104,7 +106,7 @@ def test_concurrent_cwt():
104106 sampling_period = dt )
105107 for _ in range (10 ):
106108 arrs = [sst .copy () for _ in range (50 )]
107- with futures .ThreadPoolExecutor () as ex :
109+ with futures .ThreadPoolExecutor (max_workers = max_workers ) as ex :
108110 results = list (ex .map (transform , arrs ))
109111
110112 # validate result from one of the concurrent runs
0 commit comments