Skip to content

Commit ba00fea

Browse files
authored
Merge pull request #471 from grlee77/unravel_clarification
user-friendly error messages about multilevel DWT format
2 parents 9ce883e + f2a0790 commit ba00fea

2 files changed

Lines changed: 44 additions & 8 deletions

File tree

pywt/_multilevel.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
156156
a, ds = coeffs[0], coeffs[1:]
157157

158158
for d in ds:
159+
if d is not None and not isinstance(d, np.ndarray):
160+
raise ValueError((
161+
"Unexpected detail coefficient type: {}. Detail coefficients "
162+
"must be arrays as returned by wavedec. If you are using "
163+
"pywt.array_to_coeffs or pywt.unravel_coeffs, please specify "
164+
"output_format='wavedec'").format(type(d)))
159165
if (a is not None) and (d is not None):
160166
try:
161167
if a.shape[axis] == d.shape[axis] + 1:
@@ -164,10 +170,6 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
164170
raise ValueError("coefficient shape mismatch")
165171
except IndexError:
166172
raise ValueError("Axis greater than coefficient dimensions")
167-
except AttributeError:
168-
raise AttributeError(
169-
"Wrong coefficient format, if using 'array_to_coeffs' "
170-
"please specify the 'output_format' parameter")
171173
a = idwt(a, d, wavelet, mode, axis)
172174

173175
return a
@@ -310,6 +312,12 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
310312
a = np.asarray(a)
311313

312314
for d in ds:
315+
if not isinstance(d, (list, tuple)) or len(d) != 3:
316+
raise ValueError((
317+
"Unexpected detail coefficient type: {}. Detail coefficients "
318+
"must be a 3-tuple of arrays as returned by wavedec2. If you "
319+
"are using pywt.array_to_coeffs or pywt.unravel_coeffs, "
320+
"please specify output_format='wavedec2'").format(type(d)))
313321
d = tuple(np.asarray(coeff) if coeff is not None else None
314322
for coeff in d)
315323
d_shapes = (coeff.shape for coeff in d if coeff is not None)
@@ -511,6 +519,14 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None):
511519

512520
a, ds = coeffs[0], coeffs[1:]
513521

522+
# this dictionary check must be prior to the call to _fix_coeffs
523+
if len(ds) > 0 and not all([isinstance(d, dict) for d in ds]):
524+
raise ValueError((
525+
"Unexpected detail coefficient type: {}. Detail coefficients "
526+
"must be a dicionary of arrays as returned by wavedecn. If "
527+
"you are using pywt.array_to_coeffs or pywt.unravel_coeffs, "
528+
"please specify output_format='wavedecn'").format(type(ds[0])))
529+
514530
# Raise error for invalid key combinations
515531
ds = list(map(_fix_coeffs, ds))
516532

@@ -827,7 +843,8 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'):
827843
>>> cam = pywt.data.camera()
828844
>>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3)
829845
>>> arr, coeff_slices = pywt.coeffs_to_array(coeffs)
830-
>>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)
846+
>>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices,
847+
... output_format='wavedecn')
831848
>>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2')
832849
>>> assert_array_almost_equal(cam, cam_recon)
833850
@@ -1121,7 +1138,8 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'):
11211138
>>> cam = pywt.data.camera()
11221139
>>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3)
11231140
>>> arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs)
1124-
>>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes)
1141+
>>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes,
1142+
... output_format='wavedecn')
11251143
>>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2')
11261144
>>> assert_array_almost_equal(cam, cam_recon)
11271145

pywt/tests/test_multilevel.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def test_waverec_invalid_inputs():
8080
coeffs = pywt.wavedec(x, 'db1')
8181
arr, coeff_slices = pywt.coeffs_to_array(coeffs)
8282
coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)
83-
message = "Wrong coefficient format, if using 'array_to_coeffs' please specify the 'output_format' parameter"
84-
assert_raises_regex(AttributeError, message, pywt.waverec, coeffs_from_arr, 'haar')
83+
message = "Unexpected detail coefficient type"
84+
assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr,
85+
'haar')
8586

8687

8788
def test_waverec_accuracies():
@@ -208,6 +209,13 @@ def test_waverec2_invalid_inputs():
208209
# input list cannot be empty
209210
assert_raises(ValueError, pywt.waverec2, [], 'haar')
210211

212+
# coefficients from a difference decomposition used as input
213+
for dec_func in [pywt.wavedec, pywt.wavedecn]:
214+
coeffs = dec_func(np.ones((8, 8)), 'haar')
215+
message = "Unexpected detail coefficient type"
216+
assert_raises_regex(ValueError, message, pywt.waverec2, coeffs,
217+
'haar')
218+
211219

212220
def test_waverec2_coeff_shape_mismatch():
213221
x = np.ones((8, 8))
@@ -285,6 +293,16 @@ def test_waverecn_invalid_coeffs():
285293
assert_raises(ValueError, pywt.waverecn, [], 'haar')
286294

287295

296+
def test_waverecn_invalid_inputs():
297+
298+
# coefficients from a difference decomposition used as input
299+
for dec_func in [pywt.wavedec, pywt.wavedec2]:
300+
coeffs = dec_func(np.ones((8, 8)), 'haar')
301+
message = "Unexpected detail coefficient type"
302+
assert_raises_regex(ValueError, message, pywt.waverecn, coeffs,
303+
'haar')
304+
305+
288306
def test_waverecn_lists():
289307
# support coefficient arrays specified as lists instead of arrays
290308
coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}]

0 commit comments

Comments
 (0)