22from __future__ import division , print_function , absolute_import
33
44from numpy .testing import (assert_allclose , assert_warns , assert_almost_equal ,
5- assert_raises )
5+ assert_raises , assert_equal )
66import numpy as np
77import pywt
88
@@ -345,20 +345,28 @@ def test_cwt_parameters_in_names():
345345
346346
347347def test_cwt_complex ():
348- for dtype in [np .float32 , np .float64 ]:
348+ for dtype , tol in [( np .float32 , 1e-5 ), ( np .float64 , 1e-13 ) ]:
349349 time , sst = pywt .data .nino ()
350350 sst = np .asarray (sst , dtype = dtype )
351351 dt = time [1 ] - time [0 ]
352352 wavelet = 'cmor1.5-1.0'
353353 scales = np .arange (1 , 32 )
354354
355- # real-valued tranfsorm
356- [cfs , f ] = pywt .cwt (sst , scales , wavelet , dt )
355+ for method in ['conv' , 'fft' ]:
356+ # real-valued tranfsorm as a reference
357+ [cfs , f ] = pywt .cwt (sst , scales , wavelet , dt , method = method )
357358
358- # complex-valued tranfsorm equals sum of the transforms of the real and
359- # imaginary components
360- [cfs_complex , f ] = pywt .cwt (sst + 1j * sst , scales , wavelet , dt )
361- assert_almost_equal (cfs + 1j * cfs , cfs_complex )
359+ # verify same precision
360+ assert_equal (cfs .real .dtype , sst .dtype )
361+
362+ # complex-valued transform equals sum of the transforms of the real
363+ # and imaginary components
364+ sst_complex = sst + 1j * sst
365+ [cfs_complex , f ] = pywt .cwt (sst_complex , scales , wavelet , dt ,
366+ method = method )
367+ assert_allclose (cfs + 1j * cfs , cfs_complex , atol = tol , rtol = tol )
368+ # verify dtype is preserved
369+ assert_equal (cfs_complex .dtype , sst_complex .dtype )
362370
363371
364372def test_cwt_small_scales ():
@@ -377,12 +385,12 @@ def test_cwt_method_fft():
377385 rstate = np .random .RandomState (1 )
378386 data = rstate .randn (50 )
379387 data [15 ] = 1.
380- scales = np .arange (1 , 64 )
381- wavelet = 'cmor1.5-1.0'
388+ scales = np .arange (1 , 64 )
389+ wavelet = 'cmor1.5-1.0'
382390
383391 # build a reference cwt with the legacy np.conv() method
384392 cfs_conv , _ = pywt .cwt (data , scales , wavelet , method = 'conv' )
385393
386394 # compare with the fft based convolution
387- cfs_fft , _ = pywt .cwt (data , scales , wavelet , method = 'fft' )
395+ cfs_fft , _ = pywt .cwt (data , scales , wavelet , method = 'fft' )
388396 assert_allclose (cfs_conv , cfs_fft , rtol = 0 , atol = 1e-13 )
0 commit comments