Skip to content

Commit 96b73ce

Browse files
authored
Merge pull request #296 from grlee77/052_tmp2
backporting bug fixes from master to 0.5.2
2 parents df39f6d + 3ee021c commit 96b73ce

14 files changed

Lines changed: 249 additions & 58 deletions

.travis.yml

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,37 @@ matrix:
1616
- PYFLAKES=1
1717
- PEP8=1
1818
- NUMPYSPEC=numpy
19+
- MPLSPEC=matplotlib
1920
before_install:
2021
- pip install pep8==1.5.1
2122
- pip install pyflakes
2223
script:
2324
- PYFLAKES_NODOCTEST=1 pyflakes pywt demo | grep -E -v 'unable to detect undefined names|assigned to but never used|imported but unused|redefinition of unused' > test.out; cat test.out; test \! -s test.out
2425
- pep8 pywt demo
25-
2626
- python: 3.5
2727
env:
2828
- NUMPYSPEC=numpy
29-
- python: 3.4
29+
- MPLSPEC=matplotlib
30+
- USE_WHEEL=1
31+
- os: linux
32+
python: 3.4
3033
env:
3134
- NUMPYSPEC=numpy
32-
- python: 2.6
35+
- MPLSPEC=matplotlib
36+
- USE_SDIST=1
37+
- os: linux
38+
python: 2.6
3339
env:
3440
- NUMPYSPEC="numpy==1.9.3"
41+
- MPLSPEC="matplotlib<2"
3542
- python: 2.7
3643
env:
3744
- NUMPYSPEC=numpy
45+
- MPLSPEC=matplotlib
3846
- python: 3.5
3947
env:
4048
- NUMPYSPEC=numpy
49+
- MPLSPEC=matplotlib
4150
- REFGUIDE_CHECK=1 # run doctests only
4251

4352
cache: pip
@@ -52,8 +61,9 @@ before_install:
5261
- pip install --upgrade wheel
5362
# Set numpy version first, other packages link against it
5463
- pip install $NUMPYSPEC
55-
- pip install Cython matplotlib nose coverage codecov
64+
- pip install Cython $MPLSPEC nose coverage codecov futures
5665
- set -o pipefail
66+
- if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi
5767
- |
5868
if [ "${REFGUIDE_CHECK}" == "1" ]; then
5969
pip install sphinx numpydoc
@@ -62,7 +72,21 @@ before_install:
6272
script:
6373
# Define a fixed build dir so next step works
6474
- |
65-
if [ "${REFGUIDE_CHECK}" == "1" ]; then
75+
if [ "${USE_WHEEL}" == "1" ]; then
76+
# Need verbose output or TravisCI will terminate after 10 minutes
77+
pip wheel . -v
78+
pip install PyWavelets*.whl -v
79+
pushd demo
80+
nosetests pywt
81+
popd
82+
elif [ "${USE_SDIST}" == "1" ]; then
83+
python setup.py sdist
84+
# Move out of source directory to avoid finding local pywt
85+
pushd dist
86+
pip install PyWavelets* -v
87+
nosetests pywt
88+
popd
89+
elif [ "${REFGUIDE_CHECK}" == "1" ]; then
6690
pip install -e . -v
6791
python util/refguide_check.py --doctests
6892
else

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ install:
2323
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
2424
numpy --cache-dir c:\\tmp\\pip-cache"
2525
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
26-
Cython nose coverage matplotlib --cache-dir c:\\tmp\\pip-cache"
26+
Cython nose coverage matplotlib futures --cache-dir c:\\tmp\\pip-cache"
2727

2828
test_script:
2929
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe setup.py build --build-lib build\\lib\\"

demo/wp_scalogram.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
# Show spectrogram and wavelet packet coefficients
4141
fig2 = plt.figure()
4242
ax2 = fig2.add_subplot(211)
43-
ax2.specgram(data, NFFT=64, noverlap=32, cmap=cmap)
43+
ax2.specgram(data, NFFT=64, noverlap=32, Fs=2, cmap=cmap,
44+
interpolation='bilinear')
4445
ax2.set_title("Spectrogram of signal")
4546
ax3 = fig2.add_subplot(212)
4647
ax3.imshow(values, origin='upper', extent=[-1, 1, -1, 1],

pywt/_extensions/_swt.pyx

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
102102
# memory-views do not support n-dimensional arrays, use np.ndarray instead
103103
cdef common.ArrayInfo data_info, output_info
104104
cdef np.ndarray cD, cA
105-
# Explicit input_shape necessary to prevent memory leak
106-
cdef size_t[::1] input_shape, output_shape
105+
cdef size_t[::1] output_shape
107106
cdef size_t end_level = start_level + level
108107
cdef int i, retval
109108

@@ -122,28 +121,23 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
122121
raise ValueError(msg)
123122

124123
data = data.astype(_check_dtype(data), copy=False)
125-
126-
input_shape = <size_t [:data.ndim]> <size_t *> data.shape
127-
output_shape = input_shape.copy()
128-
output_shape[axis] = common.swt_buffer_length(data.shape[axis])
129-
if output_shape[axis] != input_shape[axis]:
130-
raise RuntimeError("swt_axis assumes output_shape is the same as "
131-
"input_shape")
124+
# For SWT, the output matches the shape of the input
125+
output_shape = <size_t [:data.ndim]> <size_t *> data.shape
132126

133127
data_info.ndim = data.ndim
134128
data_info.strides = <pywt_index_t *> data.strides
135129
data_info.shape = <size_t *> data.shape
136130

137-
cA = np.empty(output_shape, data.dtype)
138-
output_info.ndim = cA.ndim
139-
output_info.strides = <pywt_index_t *> cA.strides
140-
output_info.shape = <size_t *> cA.shape
131+
output_info.ndim = data.ndim
141132

142133
ret = []
143134
for i in range(start_level+1, end_level+1):
144-
135+
cA = np.empty(output_shape, dtype=data.dtype)
136+
cD = np.empty(output_shape, dtype=data.dtype)
137+
# strides won't match data_info.strides if data is not C-contiguous
138+
output_info.strides = <pywt_index_t *> cA.strides
139+
output_info.shape = <size_t *> cA.shape
145140
if data.dtype == np.float64:
146-
cA = np.zeros(output_shape, dtype=np.float64)
147141
with nogil:
148142
retval = c_wt.double_downcoef_axis(
149143
<double *> data.data, data_info,
@@ -152,8 +146,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
152146
common.COEF_APPROX, common.MODE_PERIODIZATION,
153147
i, common.SWT_TRANSFORM)
154148
if retval:
155-
raise RuntimeError("C wavelet transform failed")
156-
cD = np.zeros(output_shape, dtype=np.float64)
149+
raise RuntimeError(
150+
"C wavelet transform failed with error code %d" % retval)
157151
with nogil:
158152
retval = c_wt.double_downcoef_axis(
159153
<double *> data.data, data_info,
@@ -162,9 +156,9 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
162156
common.COEF_DETAIL, common.MODE_PERIODIZATION,
163157
i, common.SWT_TRANSFORM)
164158
if retval:
165-
raise RuntimeError("C wavelet transform failed")
159+
raise RuntimeError(
160+
"C wavelet transform failed with error code %d" % retval)
166161
elif data.dtype == np.float32:
167-
cA = np.zeros(output_shape, dtype=np.float32)
168162
with nogil:
169163
retval = c_wt.float_downcoef_axis(
170164
<float *> data.data, data_info,
@@ -173,8 +167,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
173167
common.COEF_APPROX, common.MODE_PERIODIZATION,
174168
i, common.SWT_TRANSFORM)
175169
if retval:
176-
raise RuntimeError("C wavelet transform failed")
177-
cD = np.zeros(output_shape, dtype=np.float32)
170+
raise RuntimeError(
171+
"C wavelet transform failed with error code %d" % retval)
178172
with nogil:
179173
retval = c_wt.float_downcoef_axis(
180174
<float *> data.data, data_info,
@@ -183,15 +177,18 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
183177
common.COEF_DETAIL, common.MODE_PERIODIZATION,
184178
i, common.SWT_TRANSFORM)
185179
if retval:
186-
raise RuntimeError("C wavelet transform failed")
180+
raise RuntimeError(
181+
"C wavelet transform failed with error code %d" % retval)
187182
else:
188183
raise TypeError("Array must be floating point, not {}"
189184
.format(data.dtype))
190185
ret.append((cA, cD))
191186

192187
# previous approx coeffs are the data for the next level
193188
data = cA
194-
data_info = output_info
189+
# update data_info to match the new data array
190+
data_info.strides = <pywt_index_t *> data.strides
191+
data_info.shape = <size_t *> data.shape
195192

196193
ret.reverse()
197194
return ret

pywt/_extensions/c/wt.template.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,25 @@ int CAT(TYPE, _downcoef_axis)(const TYPE * const restrict input, const ArrayInfo
3636
if (input_info.ndim != output_info.ndim)
3737
return 1;
3838
if (axis >= input_info.ndim)
39-
return 1;
39+
return 2;
4040

4141
for (i = 0; i < input_info.ndim; ++i){
4242
if (i == axis){
4343
switch (transform) {
4444
case DWT_TRANSFORM:
4545
if (dwt_buffer_length(input_info.shape[i], wavelet->dec_len,
4646
dwt_mode) != output_info.shape[i])
47-
return 1;
47+
return 3;
4848
break;
4949
case SWT_TRANSFORM:
5050
if (swt_buffer_length(input_info.shape[i])
5151
!= output_info.shape[i])
52-
return 1;
52+
return 4;
5353
break;
5454
}
5555
} else {
5656
if (input_info.shape[i] != output_info.shape[i])
57-
return 1;
57+
return 5;
5858
}
5959
}
6060

@@ -160,7 +160,7 @@ int CAT(TYPE, _downcoef_axis)(const TYPE * const restrict input, const ArrayInfo
160160
cleanup:
161161
free(temp_input);
162162
free(temp_output);
163-
return 2;
163+
return 6;
164164
}
165165

166166

pywt/_multidim.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
7979
----------
8080
coeffs : tuple
8181
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
82-
details coefficients 2D arrays like from `dwt2()`
82+
details coefficients 2D arrays like from `dwt2`. If any of these
83+
components are set to ``None``, it will be treated as zeros.
8384
wavelet : Wavelet object or name string
8485
Wavelet to use
8586
mode : str, optional
@@ -106,10 +107,6 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
106107
raise ValueError("Expected 2 axes")
107108

108109
coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
109-
110-
# drop the keys corresponding to value = None
111-
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
112-
113110
return idwtn(coeffs, wavelet, mode, axes)
114111

115112

@@ -215,8 +212,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
215212
Parameters
216213
----------
217214
coeffs: dict
218-
Dictionary as in output of `dwtn`. Missing or None items
219-
will be treated as zeroes.
215+
Dictionary as in output of `dwtn`. Missing or ``None`` items
216+
will be treated as zeros.
220217
wavelet : Wavelet object or name string
221218
Wavelet to use
222219
mode : str, optional
@@ -240,6 +237,9 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
240237
wavelet = Wavelet(wavelet)
241238
mode = Modes.from_object(mode)
242239

240+
# drop the keys corresponding to value = None
241+
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
242+
243243
# Raise error for invalid key combinations
244244
coeffs = _fix_coeffs(coeffs)
245245

pywt/_multilevel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,14 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
142142
a, ds = coeffs[0], coeffs[1:]
143143

144144
for d in ds:
145-
if (a is not None) and (d is not None) and (len(a) == len(d) + 1):
146-
a = a[:-1]
145+
if (a is not None) and (d is not None):
146+
try:
147+
if a.shape[axis] == d.shape[axis] + 1:
148+
a = a[[slice(s) for s in d.shape]]
149+
elif a.shape[axis] != d.shape[axis]:
150+
raise ValueError("coefficient shape mismatch")
151+
except IndexError:
152+
raise ValueError("Axis greater than coefficient dimensions")
147153
a = idwt(a, d, wavelet, mode, axis)
148154

149155
return a

pywt/data/_readers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def nino():
174174
sst_csv = np.load(fname)['sst_csv']
175175
# sst_csv = pd.read_csv("http://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii", sep=' ', skipinitialspace=True)
176176
# take only full years
177-
n = np.floor(sst_csv.shape[0]/12.)*12.
177+
n = int(np.floor(sst_csv.shape[0]/12.)*12.)
178178
# Building the mean of three mounth
179179
# the 4. column is nino 3
180-
sst = np.mean(np.reshape(np.array(sst_csv)[:n,4],(n/3,-1)),axis=1)
180+
sst = np.mean(np.reshape(np.array(sst_csv)[:n, 4], (n//3, -1)), axis=1)
181181
sst = (sst - np.mean(sst)) / np.std(sst, ddof=1)
182182

183183
dt = 0.25

0 commit comments

Comments
 (0)