Skip to content

Commit 1904642

Browse files
committed
FIX: shape adjustment in waverec should not assume a transform along axis 0
test case added. closes gh-293
1 parent b626ee4 commit 1904642

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

pywt/_multilevel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,11 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
142142
a, ds = coeffs[0], coeffs[1:]
143143

144144
for d in ds:
145-
if (a is not None) and (d is not None) and (len(a) == len(d) + 1):
146-
a = a[:-1]
145+
if (a is not None) and (d is not None):
146+
if a.shape[axis] == d.shape[axis] + 1:
147+
a = a[[slice(s) for s in d.shape]]
148+
elif a.shape[axis] != d.shape[axis]:
149+
raise RuntimeError("coefficient shape mismatch")
147150
a = idwt(a, d, wavelet, mode, axis)
148151

149152
return a

pywt/tests/test_multilevel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,16 @@ def test_waverec_axes_subsets():
486486
assert_allclose(rec, data, atol=1e-14)
487487

488488

489+
def test_waverec_axis_db2():
490+
"""test for fix to issue gh-293"""
491+
rstate = np.random.RandomState(0)
492+
data = rstate.standard_normal((16, 16))
493+
for axis in [0, 1]:
494+
coefs = pywt.wavedec(data, 'db2', axis=axis)
495+
rec = pywt.waverec(coefs, 'db2', axis=axis)
496+
assert_allclose(rec, data, atol=1e-14)
497+
498+
489499
def test_waverec2_axes_subsets():
490500
rstate = np.random.RandomState(0)
491501
data = rstate.standard_normal((8, 8, 8))

0 commit comments

Comments
 (0)