|
4 | 4 |
|
5 | 5 | import warnings |
6 | 6 | from copy import deepcopy |
7 | | -from itertools import combinations |
| 7 | +from itertools import combinations, permutations |
8 | 8 | import numpy as np |
9 | 9 | from numpy.testing import (run_module_suite, dec, assert_allclose, assert_, |
10 | 10 | assert_equal, assert_raises, assert_array_equal, |
11 | 11 | assert_warns) |
12 | 12 |
|
13 | 13 | import pywt |
14 | 14 | from pywt._extensions._swt import swt_axis |
15 | | -from pywt._extensions._pywt import _check_dtype |
16 | 15 |
|
17 | 16 | # Check that float32 and complex64 are preserved. Other real types get |
18 | 17 | # converted to float64. |
@@ -387,6 +386,21 @@ def test_iswtn_errors(): |
387 | 386 | assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes) |
388 | 387 |
|
389 | 388 |
|
| 389 | +def test_swtn_iswtn_unique_shape_per_axis(): |
| 390 | + # test case for gh-460 |
| 391 | + _shape = (1, 48, 32) # unique shape per axis |
| 392 | + wav = 'sym2' |
| 393 | + max_level = 3 |
| 394 | + rstate = np.random.RandomState(0) |
| 395 | + for shape in permutations(_shape): |
| 396 | + # transform only along the non-singleton axes |
| 397 | + axes = [ax for ax, s in enumerate(shape) if s != 1] |
| 398 | + x = rstate.standard_normal(shape) |
| 399 | + c = pywt.swtn(x, wav, max_level, axes=axes) |
| 400 | + r = pywt.iswtn(c, wav, axes=axes) |
| 401 | + assert_allclose(x, r, rtol=1e-10, atol=1e-10) |
| 402 | + |
| 403 | + |
390 | 404 | def test_per_axis_wavelets(): |
391 | 405 | # tests seperate wavelet for each axis. |
392 | 406 | rstate = np.random.RandomState(1234) |
|
0 commit comments