Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from math import sqrt, log2, floor, ceil

import numpy as np

from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency


__all__ = ["cwt"]


def cwt(data, scales, wavelet, sampling_period=1.):
def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
"""
cwt(data, scales, wavelet)

Expand All @@ -29,6 +32,16 @@ def cwt(data, scales, wavelet, sampling_period=1.):
The values computed for ``coefs`` are independent of the choice of
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
period).
method : {'conv', 'fft', 'auto'}, optional
The method used to compute the CWT. Can be any of:
- ``conv`` uses ``numpy.convolve``.
- ``fft`` uses frequency domain convolution via ``numpy.fft.fft``.
- ``auto`` uses automatic selection based on an estimate of the
computational complexity at each scale.
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
The ``fft`` method is ``O(N * log2(N))`` with
``N = len(scale) + len(data) - 1``. It is well suited for large size
signals but slower than ``conv`` on small ones.

Returns
-------
Expand Down Expand Up @@ -74,34 +87,71 @@ def cwt(data, scales, wavelet, sampling_period=1.):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])
dt_out = None # TODO: fix in/out dtype consistency in a subsequent PR
if data.ndim == 1:
if wavelet.complex_cwt:
out = np.zeros((np.size(scales), data.size), dtype=complex)
else:
out = np.zeros((np.size(scales), data.size))
dt_out = complex
out = np.empty((np.size(scales), data.size), dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
for i in np.arange(np.size(scales)):

if method in ('auto', 'fft'):
# - to be as large as the sum of data length and and maximum
# wavelet support to avoid circular convolution effects
# - additional padding to reach a power of 2 for CPU-optimal FFT
size_pad = lambda s: 2**ceil(log2(s[0] + s[1]))
Comment thread
alsauve marked this conversation as resolved.
Outdated
size_scale0 = size_pad((data.size,
scales[0] * ((x[-1] - x[0]) + 1)))
fft_data = None
elif not method == 'conv':
raise ValueError("method must be in: 'conv', 'fft' or 'auto'")

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.floor(
np.arange(scales[i] * (x[-1] - x[0]) + 1) / (scales[i] * step))
if np.max(j) >= np.size(int_psi):
j = np.delete(j, np.where((j >= np.size(int_psi)))[0])
coef = - np.sqrt(scales[i]) * np.diff(
np.convolve(data, int_psi[j.astype(np.int)][::-1]))
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]

if method == 'conv':
conv = np.convolve(data, int_psi_scale)
else:
size_scale = size_pad((data.size, int_psi_scale.size))
if size_scale != size_scale0:
# the fft of data changes when padding size changes thus
# it has to be recomputed
fft_data = None
size_scale0 = size_scale
if method == 'auto':
nops_conv = len(data) * len(int_psi_scale)
nops_fft = ((2 + (fft_data is None)) *
size_scale * log2(size_scale))
if (method == 'fft') or (
(method == 'auto') and (nops_fft < nops_conv)):
Comment thread
alsauve marked this conversation as resolved.
Outdated
if fft_data is None:
fft_data = np.fft.fft(data, size_scale)
fft_wav = np.fft.fft(int_psi_scale, size_scale)
conv = np.fft.ifft(fft_wav * fft_data)
conv = conv[:data.size + int_psi_scale.size - 1]
else:
conv = np.convolve(data, int_psi_scale)

coef = - sqrt(scale) * np.diff(conv)
if not np.iscomplexobj(out):
coef = np.real(coef)
d = (coef.size - data.size) / 2.
if d > 0:
out[i, :] = coef[int(np.floor(d)):int(-np.ceil(d))]
out[i, :] = coef[floor(d):-ceil(d)]
elif d == 0.:
out[i, :] = coef
else:
raise ValueError(
"Selected scale of {} too small.".format(scales[i]))
"Selected scale of {} too small.".format(scale))
frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
for i in np.arange(len(frequencies)):
frequencies[i] /= sampling_period
frequencies /= sampling_period
return out, frequencies
else:
raise ValueError("Only dim == 1 supported")
50 changes: 49 additions & 1 deletion pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import division, print_function, absolute_import

from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises)
assert_raises, assert_equal)
import numpy as np
import pywt

Expand Down Expand Up @@ -371,3 +371,51 @@ def test_cwt_small_scales():

# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')


def test_cwt_method_fft():
Comment thread
alsauve marked this conversation as resolved.
data = np.zeros(32, dtype=np.float32)
data[15] = 1.
scales1 = 1
wavelet = 'cmor1.5-1.0'

# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales1, wavelet, method='conv')

# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales1, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)


def test_cwt_method_auto():
np.random.seed(1)
Comment thread
alsauve marked this conversation as resolved.
Outdated
data = np.random.randn(50)
scales = [1, 5, 25, 125]
wavelet = 'cmor1.5-1.0'

# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')

# 'fft' method switches for scale 2 with len(data)==50
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='auto')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)


def test_cwt_dtype():
Comment thread
alsauve marked this conversation as resolved.
Outdated
"""Currently output dtype precision is fixed in version 1.0.2"""
wavelet = 'mexh'
scales = 1
dtype_expected = {
np.float16: np.float64,
np.float32: np.float64,
np.float64: np.float64,
np.float128: np.float64,
np.complex64: np.float64,
np.complex128: np.float64,
np.complex256: np.float64,
}
for dtype_in, dtype_out in dtype_expected.items():
data = np.zeros(2, dtype=dtype_in)
cfs, f = pywt.cwt(data, scales, wavelet)
assert_equal(dtype_out, cfs.dtype)