Skip to content

Commit 2f5b736

Browse files
committed
FIX: fix memory leak in swt_axis Cython routine.
previously some output_info.shape, output_info.strides pointers could potentially point to a garbage-collected object. Also changed cA, cD to initialize as empty rather than zeros for efficiency.
1 parent c417f05 commit 2f5b736

1 file changed

Lines changed: 20 additions & 23 deletions

File tree

pywt/_extensions/_swt.pyx

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
102102
# memory-views do not support n-dimensional arrays, use np.ndarray instead
103103
cdef common.ArrayInfo data_info, output_info
104104
cdef np.ndarray cD, cA
105-
# Explicit input_shape necessary to prevent memory leak
106-
cdef size_t[::1] input_shape, output_shape
105+
cdef size_t[::1] output_shape
107106
cdef size_t end_level = start_level + level
108107
cdef int i, retval
109108

@@ -122,28 +121,23 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
122121
raise ValueError(msg)
123122

124123
data = data.astype(_check_dtype(data), copy=False)
125-
126-
input_shape = <size_t [:data.ndim]> <size_t *> data.shape
127-
output_shape = input_shape.copy()
128-
output_shape[axis] = common.swt_buffer_length(data.shape[axis])
129-
if output_shape[axis] != input_shape[axis]:
130-
raise RuntimeError("swt_axis assumes output_shape is the same as "
131-
"input_shape")
124+
# For SWT, the output matches the shape of the input
125+
output_shape = <size_t [:data.ndim]> <size_t *> data.shape
132126

133127
data_info.ndim = data.ndim
134128
data_info.strides = <pywt_index_t *> data.strides
135129
data_info.shape = <size_t *> data.shape
136130

137-
cA = np.empty(output_shape, data.dtype)
138-
output_info.ndim = cA.ndim
139-
output_info.strides = <pywt_index_t *> cA.strides
140-
output_info.shape = <size_t *> cA.shape
131+
output_info.ndim = data.ndim
141132

142133
ret = []
143134
for i in range(start_level+1, end_level+1):
144-
135+
cA = np.empty(output_shape, dtype=data.dtype)
136+
cD = np.empty(output_shape, dtype=data.dtype)
137+
# strides won't match data_info.strides if data is not C-contiguous
138+
output_info.strides = <pywt_index_t *> cA.strides
139+
output_info.shape = <size_t *> cA.shape
145140
if data.dtype == np.float64:
146-
cA = np.zeros(output_shape, dtype=np.float64)
147141
with nogil:
148142
retval = c_wt.double_downcoef_axis(
149143
<double *> data.data, data_info,
@@ -152,8 +146,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
152146
common.COEF_APPROX, common.MODE_PERIODIZATION,
153147
i, common.SWT_TRANSFORM)
154148
if retval:
155-
raise RuntimeError("C wavelet transform failed")
156-
cD = np.zeros(output_shape, dtype=np.float64)
149+
raise RuntimeError(
150+
"C wavelet transform failed with error code %d" % retval)
157151
with nogil:
158152
retval = c_wt.double_downcoef_axis(
159153
<double *> data.data, data_info,
@@ -162,9 +156,9 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
162156
common.COEF_DETAIL, common.MODE_PERIODIZATION,
163157
i, common.SWT_TRANSFORM)
164158
if retval:
165-
raise RuntimeError("C wavelet transform failed")
159+
raise RuntimeError(
160+
"C wavelet transform failed with error code %d" % retval)
166161
elif data.dtype == np.float32:
167-
cA = np.zeros(output_shape, dtype=np.float32)
168162
with nogil:
169163
retval = c_wt.float_downcoef_axis(
170164
<float *> data.data, data_info,
@@ -173,8 +167,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
173167
common.COEF_APPROX, common.MODE_PERIODIZATION,
174168
i, common.SWT_TRANSFORM)
175169
if retval:
176-
raise RuntimeError("C wavelet transform failed")
177-
cD = np.zeros(output_shape, dtype=np.float32)
170+
raise RuntimeError(
171+
"C wavelet transform failed with error code %d" % retval)
178172
with nogil:
179173
retval = c_wt.float_downcoef_axis(
180174
<float *> data.data, data_info,
@@ -183,15 +177,18 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
183177
common.COEF_DETAIL, common.MODE_PERIODIZATION,
184178
i, common.SWT_TRANSFORM)
185179
if retval:
186-
raise RuntimeError("C wavelet transform failed")
180+
raise RuntimeError(
181+
"C wavelet transform failed with error code %d" % retval)
187182
else:
188183
raise TypeError("Array must be floating point, not {}"
189184
.format(data.dtype))
190185
ret.append((cA, cD))
191186

192187
# previous approx coeffs are the data for the next level
193188
data = cA
194-
data_info = output_info
189+
# update data_info to match the new data array
190+
data_info.strides = <pywt_index_t *> data.strides
191+
data_info.shape = <size_t *> data.shape
195192

196193
ret.reverse()
197194
return ret

0 commit comments

Comments
 (0)