Skip to content

Commit 3eb0779

Browse files
committed
Fixed expand dims bug
1 parent 3711ec5 commit 3eb0779

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

pytorch_wavelets/dtcwt/lowlevel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def prep_filt(h, c, transpose=False):
5959
""" Prepares an array to be of the correct format for pytorch.
6060
Can also specify whether to make it a row filter (set tranpose=True)"""
6161
h = _as_col_vector(h)[::-1]
62-
#h = np.reshape(h, [1, 1, *h.shape])
63-
h = np.expand_dims(h, (0,1))
62+
h = h[None, None, :]
6463
h = np.repeat(h, repeats=c, axis=0)
6564
if transpose:
6665
h = h.transpose((0,1,3,2))

0 commit comments

Comments
 (0)