Skip to content

Commit 18b20f5

Browse files
committed
TST: add tests for concurrent operation
tests the fix for the swt2 issue raised in gh-288. Also adds similar tests for DWT and CWT routines.
1 parent f4f9986 commit 18b20f5

3 files changed

Lines changed: 119 additions & 2 deletions

File tree

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ before_install:
5252
- pip install --upgrade wheel
5353
# Set numpy version first, other packages link against it
5454
- pip install $NUMPYSPEC
55-
- pip install Cython matplotlib nose coverage codecov
55+
- pip install Cython matplotlib nose coverage codecov futures
5656
- set -o pipefail
5757
- |
5858
if [ "${REFGUIDE_CHECK}" == "1" ]; then

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ install:
2323
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
2424
numpy --cache-dir c:\\tmp\\pip-cache"
2525
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
26-
Cython nose coverage matplotlib --cache-dir c:\\tmp\\pip-cache"
26+
Cython nose coverage matplotlib futures --cache-dir c:\\tmp\\pip-cache"
2727

2828
test_script:
2929
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe setup.py build --build-lib build\\lib\\"

pywt/tests/test_concurrent.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Tests used to verify running PyWavelets transforms in parallel via
3+
concurrent.futures.ThreadPoolExecutor does not raise errors.
4+
"""
5+
6+
from __future__ import division, print_function, absolute_import
7+
8+
import sys
9+
import warnings
10+
import numpy as np
11+
from functools import partial
12+
from numpy.testing import dec, run_module_suite, assert_array_equal
13+
14+
import pywt
15+
16+
try:
17+
if sys.version_info[0] == 2:
18+
import futures
19+
else:
20+
from concurrent import futures
21+
futures_available = True
22+
except ImportError:
23+
futures_available = False
24+
25+
26+
def _assert_all_coeffs_equal(coefs1, coefs2):
27+
# return True only if all coefficients of SWT or DWT match over all levels
28+
if len(coefs1) != len(coefs2):
29+
return False
30+
for (c1, c2) in zip(coefs1, coefs2):
31+
if isinstance(c1, tuple):
32+
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
33+
for a1, a2 in zip(c1, c2):
34+
assert_array_equal(a1, a2)
35+
elif isinstance(c1, dict):
36+
# for swtn, dwtn, wavedecn
37+
for k, v in c1.items():
38+
assert_array_equal(v, c2[k])
39+
else:
40+
return False
41+
return True
42+
43+
44+
@dec.skipif(not futures_available)
45+
def test_concurrent_swt():
46+
# tests error-free concurrent operation (see gh-288)
47+
# swt on 1D data calls the Cython swt
48+
# other cases call swt_axes
49+
with warnings.catch_warnings():
50+
# can remove catch_warnings once the swt2 FutureWarning is removed
51+
warnings.simplefilter('ignore', FutureWarning)
52+
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
53+
[np.ones(8), np.eye(16), np.eye(16)]):
54+
transform = partial(swt_func, wavelet='haar', level=1)
55+
for _ in range(10):
56+
arrs = [x.copy() for _ in range(100)]
57+
with futures.ThreadPoolExecutor() as ex:
58+
results = list(ex.map(transform, arrs))
59+
60+
# validate result from one of the concurrent runs
61+
expected_result = transform(x)
62+
_assert_all_coeffs_equal(expected_result, results[-1])
63+
64+
65+
@dec.skipif(not futures_available)
66+
def test_concurrent_wavedec():
67+
# wavedec on 1D data calls the Cython dwt_single
68+
# other cases call dwt_axes
69+
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
70+
[np.ones(8), np.eye(16), np.eye(16)]):
71+
transform = partial(wavedec_func, wavelet='haar', level=1)
72+
for _ in range(10):
73+
arrs = [x.copy() for _ in range(100)]
74+
with futures.ThreadPoolExecutor() as ex:
75+
results = list(ex.map(transform, arrs))
76+
77+
# validate result from one of the concurrent runs
78+
expected_result = transform(x)
79+
_assert_all_coeffs_equal(expected_result, results[-1])
80+
81+
82+
@dec.skipif(not futures_available)
83+
def test_concurrent_dwt():
84+
# dwt on 1D data calls the Cython dwt_single
85+
# other cases call dwt_axes
86+
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
87+
[np.ones(8), np.eye(16), np.eye(16)]):
88+
transform = partial(dwt_func, wavelet='haar')
89+
for _ in range(10):
90+
arrs = [x.copy() for _ in range(100)]
91+
with futures.ThreadPoolExecutor() as ex:
92+
results = list(ex.map(transform, arrs))
93+
94+
# validate result from one of the concurrent runs
95+
expected_result = transform(x)
96+
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
97+
98+
99+
@dec.skipif(not futures_available)
100+
def test_concurrent_cwt():
101+
time, sst = pywt.data.nino()
102+
dt = time[1]-time[0]
103+
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor',
104+
sampling_period=dt)
105+
for _ in range(10):
106+
arrs = [sst.copy() for _ in range(50)]
107+
with futures.ThreadPoolExecutor() as ex:
108+
results = list(ex.map(transform, arrs))
109+
110+
# validate result from one of the concurrent runs
111+
expected_result = transform(sst)
112+
for a1, a2 in zip(expected_result, results[-1]):
113+
assert_array_equal(a1, a2)
114+
115+
116+
if __name__ == '__main__':
117+
run_module_suite()

0 commit comments

Comments
 (0)