3838 "nx" : 10 ,
3939 "nt" : 20 ,
4040 "kind" : "p" ,
41- "dtype" : "float32" ,
4241 "fftengine" : "numpy" ,
4342 "kwargs_fft" : {},
4443} # even, p, numpy
4746 "nx" : 11 ,
4847 "nt" : 21 ,
4948 "kind" : "p" ,
50- "dtype" : "float32" ,
5149 "fftengine" : "numpy" ,
5250 "kwargs_fft" : {},
5351} # odd, p, numpy
5654 "nx" : 10 ,
5755 "nt" : 20 ,
5856 "kind" : "p" ,
59- "dtype" : "float32" ,
6057 "fftengine" : "scipy" ,
6158 "kwargs_fft" : dict (workers = 4 ),
6259} # even, p, scipy
6562 "nx" : 10 ,
6663 "nt" : 20 ,
6764 "kind" : "p" ,
68- "dtype" : "float32" ,
6965 "fftengine" : "fft" ,
7066 "kwargs_fft" : {},
7167} # even, p, fftw
7470 "nx" : 10 ,
7571 "nt" : 20 ,
7672 "kind" : "vz" ,
77- "dtype" : "float32" ,
7873} # even, vz, numpy
7974par2v = {
8075 "ny" : 9 ,
8176 "nx" : 11 ,
8277 "nt" : 21 ,
8378 "kind" : "vz" ,
84- "dtype" : "float32" ,
8579} # odd, vz, numpy
8680
8781# deghosting params
@@ -116,7 +110,8 @@ def create_data2D(datakind):
116110
117111
118112@pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1s ), (par1w )])
119- def test_PhaseShift_2dsignal (par ):
113+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
114+ def test_PhaseShift_2dsignal (par , dtype ):
120115 """Dot-test for PhaseShift of 2d signal"""
121116 vel = 1500.0
122117 zprop = 200
@@ -131,32 +126,49 @@ def test_PhaseShift_2dsignal(par):
131126 freq ,
132127 kx ,
133128 fftengine = par ["fftengine" ] if backend == "numpy" else "numpy" ,
134- dtype = par [ " dtype" ] ,
129+ dtype = dtype ,
135130 ** kwargs_fft ,
136131 )
137132 assert dottest (
138- Pop , par ["nt" ] * par ["nx" ], par ["nt" ] * par ["nx" ], rtol = 1e-3 , backend = backend
133+ Pop ,
134+ par ["nt" ] * par ["nx" ],
135+ par ["nt" ] * par ["nx" ],
136+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
137+ backend = backend ,
139138 )
140139
140+ x = np .ones ((par ["nt" ], par ["nx" ]), dtype = dtype )
141+ y = Pop * x .ravel ()
142+ xadj = Pop .H * y
143+ assert y .dtype == dtype
144+ assert xadj .dtype == dtype
145+
141146
142147@pytest .mark .parametrize ("par" , [(par1 ), (par2 )])
143- def test_PhaseShift_3dsignal (par ):
148+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
149+ def test_PhaseShift_3dsignal (par , dtype ):
144150 """Dot-test for PhaseShift of 3d signal"""
145151 vel = 1500.0
146152 zprop = 200
147153 freq = np .fft .rfftfreq (par ["nt" ], 1.0 )
148154 kx = np .fft .fftshift (np .fft .fftfreq (par ["nx" ], 1.0 ))
149155 ky = np .fft .fftshift (np .fft .fftfreq (par ["ny" ], 1.0 ))
150156
151- Pop = PhaseShift (vel , zprop , par ["nt" ], freq , kx , ky , dtype = par [ " dtype" ] )
157+ Pop = PhaseShift (vel , zprop , par ["nt" ], freq , kx , ky , dtype = dtype )
152158 assert dottest (
153159 Pop ,
154160 par ["nt" ] * par ["nx" ] * par ["ny" ],
155161 par ["nt" ] * par ["nx" ] * par ["ny" ],
156- rtol = 1e-3 ,
162+ rtol = 1e-4 if dtype == np . float32 else 1e-6 ,
157163 backend = backend ,
158164 )
159165
166+ x = np .ones ((par ["nt" ], par ["nx" ], par ["ny" ]), dtype = dtype )
167+ y = Pop * x .ravel ()
168+ xadj = Pop .H * y
169+ assert y .dtype == dtype
170+ assert xadj .dtype == dtype
171+
160172
161173@pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1v ), (par2v )])
162174def test_Deghosting_2dsignal (par ):
@@ -176,7 +188,7 @@ def test_Deghosting_2dsignal(par):
176188 npad = 0 ,
177189 ntaper = 0 ,
178190 solver = lsqr ,
179- dtype = par [ "dtype" ] ,
191+ dtype = np . float32 ,
180192 ** dict (damp = 1e-10 , niter = 60 ),
181193 )
182194
0 commit comments