@@ -102,8 +102,7 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
102102 # memory-views do not support n-dimensional arrays, use np.ndarray instead
103103 cdef common.ArrayInfo data_info, output_info
104104 cdef np.ndarray cD, cA
105- # Explicit input_shape necessary to prevent memory leak
106- cdef size_t[::1 ] input_shape, output_shape
105+ cdef size_t[::1 ] output_shape
107106 cdef size_t end_level = start_level + level
108107 cdef int i, retval
109108
@@ -122,28 +121,23 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
122121 raise ValueError (msg)
123122
124123 data = data.astype(_check_dtype(data), copy = False )
125-
126- input_shape = < size_t [:data.ndim]> < size_t * > data.shape
127- output_shape = input_shape.copy()
128- output_shape[axis] = common.swt_buffer_length(data.shape[axis])
129- if output_shape[axis] != input_shape[axis]:
130- raise RuntimeError (" swt_axis assumes output_shape is the same as "
131- " input_shape" )
124+ # For SWT, the output matches the shape of the input
125+ output_shape = < size_t [:data.ndim]> < size_t * > data.shape
132126
133127 data_info.ndim = data.ndim
134128 data_info.strides = < pywt_index_t * > data.strides
135129 data_info.shape = < size_t * > data.shape
136130
137- cA = np.empty(output_shape, data.dtype)
138- output_info.ndim = cA.ndim
139- output_info.strides = < pywt_index_t * > cA.strides
140- output_info.shape = < size_t * > cA.shape
131+ output_info.ndim = data.ndim
141132
142133 ret = []
143134 for i in range (start_level+ 1 , end_level+ 1 ):
144-
135+ cA = np.empty(output_shape, dtype = data.dtype)
136+ cD = np.empty(output_shape, dtype = data.dtype)
137+ # strides won't match data_info.strides if data is not C-contiguous
138+ output_info.strides = < pywt_index_t * > cA.strides
139+ output_info.shape = < size_t * > cA.shape
145140 if data.dtype == np.float64:
146- cA = np.zeros(output_shape, dtype = np.float64)
147141 with nogil:
148142 retval = c_wt.double_downcoef_axis(
149143 < double * > data.data, data_info,
@@ -152,8 +146,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
152146 common.COEF_APPROX, common.MODE_PERIODIZATION,
153147 i, common.SWT_TRANSFORM)
154148 if retval:
155- raise RuntimeError (" C wavelet transform failed " )
156- cD = np.zeros(output_shape, dtype = np.float64 )
149+ raise RuntimeError (
150+ " C wavelet transform failed with error code %d " % retval )
157151 with nogil:
158152 retval = c_wt.double_downcoef_axis(
159153 < double * > data.data, data_info,
@@ -162,9 +156,9 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
162156 common.COEF_DETAIL, common.MODE_PERIODIZATION,
163157 i, common.SWT_TRANSFORM)
164158 if retval:
165- raise RuntimeError (" C wavelet transform failed" )
159+ raise RuntimeError (
160+ " C wavelet transform failed with error code %d " % retval)
166161 elif data.dtype == np.float32:
167- cA = np.zeros(output_shape, dtype = np.float32)
168162 with nogil:
169163 retval = c_wt.float_downcoef_axis(
170164 < float * > data.data, data_info,
@@ -173,8 +167,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
173167 common.COEF_APPROX, common.MODE_PERIODIZATION,
174168 i, common.SWT_TRANSFORM)
175169 if retval:
176- raise RuntimeError (" C wavelet transform failed " )
177- cD = np.zeros(output_shape, dtype = np.float32 )
170+ raise RuntimeError (
171+ " C wavelet transform failed with error code %d " % retval )
178172 with nogil:
179173 retval = c_wt.float_downcoef_axis(
180174 < float * > data.data, data_info,
@@ -183,15 +177,18 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
183177 common.COEF_DETAIL, common.MODE_PERIODIZATION,
184178 i, common.SWT_TRANSFORM)
185179 if retval:
186- raise RuntimeError (" C wavelet transform failed" )
180+ raise RuntimeError (
181+ " C wavelet transform failed with error code %d " % retval)
187182 else :
188183 raise TypeError (" Array must be floating point, not {}"
189184 .format(data.dtype))
190185 ret.append((cA, cD))
191186
192187 # previous approx coeffs are the data for the next level
193188 data = cA
194- data_info = output_info
189+ # update data_info to match the new data array
190+ data_info.strides = < pywt_index_t * > data.strides
191+ data_info.shape = < size_t * > data.shape
195192
196193 ret.reverse()
197194 return ret
0 commit comments