|
| 1 | +""" |
| 2 | +Tests used to verify running PyWavelets transforms in parallel via |
| 3 | +concurrent.futures.ThreadPoolExecutor does not raise errors. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import division, print_function, absolute_import |
| 7 | + |
| 8 | +import sys |
| 9 | +import warnings |
| 10 | +import numpy as np |
| 11 | +from functools import partial |
| 12 | +from numpy.testing import dec, run_module_suite, assert_array_equal |
| 13 | + |
| 14 | +import pywt |
| 15 | + |
| 16 | +try: |
| 17 | + if sys.version_info[0] == 2: |
| 18 | + import futures |
| 19 | + else: |
| 20 | + from concurrent import futures |
| 21 | + futures_available = True |
| 22 | +except ImportError: |
| 23 | + futures_available = False |
| 24 | + |
| 25 | + |
| 26 | +def _assert_all_coeffs_equal(coefs1, coefs2): |
| 27 | + # return True only if all coefficients of SWT or DWT match over all levels |
| 28 | + if len(coefs1) != len(coefs2): |
| 29 | + return False |
| 30 | + for (c1, c2) in zip(coefs1, coefs2): |
| 31 | + if isinstance(c1, tuple): |
| 32 | + # for swt, swt2, dwt, dwt2, wavedec, wavedec2 |
| 33 | + for a1, a2 in zip(c1, c2): |
| 34 | + assert_array_equal(a1, a2) |
| 35 | + elif isinstance(c1, dict): |
| 36 | + # for swtn, dwtn, wavedecn |
| 37 | + for k, v in c1.items(): |
| 38 | + assert_array_equal(v, c2[k]) |
| 39 | + else: |
| 40 | + return False |
| 41 | + return True |
| 42 | + |
| 43 | + |
| 44 | +@dec.skipif(not futures_available) |
| 45 | +def test_concurrent_swt(): |
| 46 | + # tests error-free concurrent operation (see gh-288) |
| 47 | + # swt on 1D data calls the Cython swt |
| 48 | + # other cases call swt_axes |
| 49 | + with warnings.catch_warnings(): |
| 50 | + # can remove catch_warnings once the swt2 FutureWarning is removed |
| 51 | + warnings.simplefilter('ignore', FutureWarning) |
| 52 | + for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn], |
| 53 | + [np.ones(8), np.eye(16), np.eye(16)]): |
| 54 | + transform = partial(swt_func, wavelet='haar', level=1) |
| 55 | + for _ in range(10): |
| 56 | + arrs = [x.copy() for _ in range(100)] |
| 57 | + with futures.ThreadPoolExecutor() as ex: |
| 58 | + results = list(ex.map(transform, arrs)) |
| 59 | + |
| 60 | + # validate result from one of the concurrent runs |
| 61 | + expected_result = transform(x) |
| 62 | + _assert_all_coeffs_equal(expected_result, results[-1]) |
| 63 | + |
| 64 | + |
| 65 | +@dec.skipif(not futures_available) |
| 66 | +def test_concurrent_wavedec(): |
| 67 | + # wavedec on 1D data calls the Cython dwt_single |
| 68 | + # other cases call dwt_axes |
| 69 | + for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn], |
| 70 | + [np.ones(8), np.eye(16), np.eye(16)]): |
| 71 | + transform = partial(wavedec_func, wavelet='haar', level=1) |
| 72 | + for _ in range(10): |
| 73 | + arrs = [x.copy() for _ in range(100)] |
| 74 | + with futures.ThreadPoolExecutor() as ex: |
| 75 | + results = list(ex.map(transform, arrs)) |
| 76 | + |
| 77 | + # validate result from one of the concurrent runs |
| 78 | + expected_result = transform(x) |
| 79 | + _assert_all_coeffs_equal(expected_result, results[-1]) |
| 80 | + |
| 81 | + |
| 82 | +@dec.skipif(not futures_available) |
| 83 | +def test_concurrent_dwt(): |
| 84 | + # dwt on 1D data calls the Cython dwt_single |
| 85 | + # other cases call dwt_axes |
| 86 | + for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn], |
| 87 | + [np.ones(8), np.eye(16), np.eye(16)]): |
| 88 | + transform = partial(dwt_func, wavelet='haar') |
| 89 | + for _ in range(10): |
| 90 | + arrs = [x.copy() for _ in range(100)] |
| 91 | + with futures.ThreadPoolExecutor() as ex: |
| 92 | + results = list(ex.map(transform, arrs)) |
| 93 | + |
| 94 | + # validate result from one of the concurrent runs |
| 95 | + expected_result = transform(x) |
| 96 | + _assert_all_coeffs_equal([expected_result, ], [results[-1], ]) |
| 97 | + |
| 98 | + |
| 99 | +@dec.skipif(not futures_available) |
| 100 | +def test_concurrent_cwt(): |
| 101 | + time, sst = pywt.data.nino() |
| 102 | + dt = time[1]-time[0] |
| 103 | + transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor', |
| 104 | + sampling_period=dt) |
| 105 | + for _ in range(10): |
| 106 | + arrs = [sst.copy() for _ in range(50)] |
| 107 | + with futures.ThreadPoolExecutor() as ex: |
| 108 | + results = list(ex.map(transform, arrs)) |
| 109 | + |
| 110 | + # validate result from one of the concurrent runs |
| 111 | + expected_result = transform(sst) |
| 112 | + for a1, a2 in zip(expected_result, results[-1]): |
| 113 | + assert_array_equal(a1, a2) |
| 114 | + |
| 115 | + |
| 116 | +if __name__ == '__main__': |
| 117 | + run_module_suite() |
0 commit comments