Skip to content

Commit f6b3eef

Browse files
authored
Merge pull request #464 from grlee77/backport_pr448
backport of #448 (fix coefficient shape mismatch in WaveletPacket reconstruction)
2 parents 4abf379 + c8d1204 commit f6b3eef

3 files changed

Lines changed: 29 additions & 4 deletions

File tree

pywt/_wavelet_packets.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def __init__(self, parent, data, node_name):
6868

6969
# data - signal on level 0, coeffs on higher levels
7070
self.data = data
71+
# Need to retain original data size/shape so we can trim any excess
72+
# boundary coefficients from the inverse transform.
73+
if self.data is None:
74+
self._data_shape = None
75+
else:
76+
self._data_shape = np.asarray(data).shape
7177

7278
self._init_subnodes()
7379

@@ -436,6 +442,9 @@ def _reconstruct(self, update):
436442
" from subnodes.")
437443
else:
438444
rec = idwt(data_a, data_d, self.wavelet, self.mode)
445+
if self._data_shape is not None and (
446+
rec.shape != self._data_shape):
447+
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
439448
if update:
440449
self.data = rec
441450
return rec
@@ -504,6 +513,9 @@ def _reconstruct(self, update):
504513
else:
505514
coeffs = data_ll, (data_hl, data_lh, data_hh)
506515
rec = idwt2(coeffs, self.wavelet, self.mode)
516+
if self._data_shape is not None and (
517+
rec.shape != self._data_shape):
518+
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
507519
if update:
508520
self.data = rec
509521
return rec
@@ -568,8 +580,6 @@ def reconstruct(self, update=True):
568580
"""
569581
if self.has_any_subnode:
570582
data = super(WaveletPacket, self).reconstruct(update)
571-
if self.data_size is not None and len(data) > self.data_size:
572-
data = data[:self.data_size]
573583
if update:
574584
self.data = data
575585
return data
@@ -669,8 +679,6 @@ def reconstruct(self, update=True):
669679
"""
670680
if self.has_any_subnode:
671681
data = super(WaveletPacket2D, self).reconstruct(update)
672-
if self.data_size is not None and (data.shape != self.data_size):
673-
data = data[:self.data_size[0], :self.data_size[1]]
674682
if update:
675683
self.data = data
676684
return data

pywt/tests/test_wp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,5 +189,13 @@ def test_wavelet_packet_dtypes():
189189
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
190190

191191

192+
def test_db3_roundtrip():
193+
original = np.arange(512)
194+
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
195+
maxlevel=3)
196+
r = wp.reconstruct()
197+
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
198+
199+
192200
if __name__ == '__main__':
193201
run_module_suite()

pywt/tests/test_wp2d.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,5 +168,14 @@ def test_wavelet_packet_dtypes():
168168
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
169169

170170

171+
def test_2d_roundtrip():
172+
# test case corresponding to PyWavelets issue 447
173+
original = pywt.data.camera()
174+
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
175+
maxlevel=3)
176+
r = wp.reconstruct()
177+
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
178+
179+
171180
if __name__ == '__main__':
172181
run_module_suite()

0 commit comments

Comments
 (0)