Skip to content

Commit 68e8001

Browse files
authored
Merge pull request #465 from grlee77/backport_pr462
Backport gh-462 (iswtn axis fix)
2 parents f6b3eef + a17fedf commit 68e8001

2 files changed

Lines changed: 21 additions & 4 deletions

File tree

pywt/_swt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,16 +516,19 @@ def iswtn(coeffs, wavelet, axes=None):
516516
[dt, ] + [v.dtype for v in details.values()]))
517517
if output.dtype != common_dtype:
518518
output = output.astype(common_dtype)
519+
519520
# We assume all coefficient arrays are of equal size
520521
shapes = [v.shape for k, v in details.items()]
521-
dshape = shapes[0]
522522
if len(set(shapes)) != 1:
523523
raise RuntimeError(
524524
"Mismatch in shape of intermediate coefficient arrays")
525525

526+
# shape of a single coefficient array, excluding non-transformed axes
527+
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])
528+
526529
# nested loop over all combinations of axis offsets at this level
527530
for firsts in product(*([range(last_index), ]*ndim_transform)):
528-
for first, sh, ax in zip(firsts, dshape, axes):
531+
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
529532
indices[ax] = slice(first, sh, step_size)
530533
even_indices[ax] = slice(first, sh, 2*step_size)
531534
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)

pywt/tests/test_swt.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
import warnings
66
from copy import deepcopy
7-
from itertools import combinations
7+
from itertools import combinations, permutations
88
import numpy as np
99
from numpy.testing import (run_module_suite, dec, assert_allclose, assert_,
1010
assert_equal, assert_raises, assert_array_equal,
1111
assert_warns)
1212

1313
import pywt
1414
from pywt._extensions._swt import swt_axis
15-
from pywt._extensions._pywt import _check_dtype
1615

1716
# Check that float32 and complex64 are preserved. Other real types get
1817
# converted to float64.
@@ -387,6 +386,21 @@ def test_iswtn_errors():
387386
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
388387

389388

389+
def test_swtn_iswtn_unique_shape_per_axis():
390+
# test case for gh-460
391+
_shape = (1, 48, 32) # unique shape per axis
392+
wav = 'sym2'
393+
max_level = 3
394+
rstate = np.random.RandomState(0)
395+
for shape in permutations(_shape):
396+
# transform only along the non-singleton axes
397+
axes = [ax for ax, s in enumerate(shape) if s != 1]
398+
x = rstate.standard_normal(shape)
399+
c = pywt.swtn(x, wav, max_level, axes=axes)
400+
r = pywt.iswtn(c, wav, axes=axes)
401+
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
402+
403+
390404
def test_per_axis_wavelets():
391405
# tests seperate wavelet for each axis.
392406
rstate = np.random.RandomState(1234)

0 commit comments

Comments
 (0)