Skip to content

Commit a17fedf

Browse files
committed
TST: add tests for round-trip swtn/iswtn with non-uniform shape
1 parent 00a5d1c commit a17fedf

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

pywt/tests/test_swt.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
import warnings
66
from copy import deepcopy
7-
from itertools import combinations
7+
from itertools import combinations, permutations
88
import numpy as np
99
from numpy.testing import (run_module_suite, dec, assert_allclose, assert_,
1010
assert_equal, assert_raises, assert_array_equal,
1111
assert_warns)
1212

1313
import pywt
1414
from pywt._extensions._swt import swt_axis
15-
from pywt._extensions._pywt import _check_dtype
1615

1716
# Check that float32 and complex64 are preserved. Other real types get
1817
# converted to float64.
@@ -387,6 +386,21 @@ def test_iswtn_errors():
387386
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
388387

389388

389+
def test_swtn_iswtn_unique_shape_per_axis():
390+
# test case for gh-460
391+
_shape = (1, 48, 32) # unique shape per axis
392+
wav = 'sym2'
393+
max_level = 3
394+
rstate = np.random.RandomState(0)
395+
for shape in permutations(_shape):
396+
# transform only along the non-singleton axes
397+
axes = [ax for ax, s in enumerate(shape) if s != 1]
398+
x = rstate.standard_normal(shape)
399+
c = pywt.swtn(x, wav, max_level, axes=axes)
400+
r = pywt.iswtn(c, wav, axes=axes)
401+
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
402+
403+
390404
def test_per_axis_wavelets():
391405
# tests seperate wavelet for each axis.
392406
rstate = np.random.RandomState(1234)

0 commit comments

Comments
 (0)