2424 "ny" : 11 ,
2525 "imag" : 0 ,
2626 "dtype" : "float64" ,
27- } # square real
27+ } # square real (fp64)
2828par2 = {
2929 "nt" : 41 ,
3030 "nx" : 21 ,
3131 "ny" : 11 ,
3232 "imag" : 0 ,
3333 "dtype" : "float64" ,
34- } # overdetermined real
34+ } # overdetermined real (fp64)
35+ par1s = {
36+ "nt" : 41 ,
37+ "nx" : 41 ,
38+ "ny" : 11 ,
39+ "imag" : 0 ,
40+ "dtype" : "float32" ,
41+ } # square real (fp32)
42+ par2s = {
43+ "nt" : 41 ,
44+ "nx" : 21 ,
45+ "ny" : 11 ,
46+ "imag" : 0 ,
47+ "dtype" : "float32" ,
48+ } # overdetermined real (fp32)
3549par1j = {
3650 "nt" : 41 ,
3751 "nx" : 41 ,
@@ -59,14 +73,16 @@ def test_unknown_engine(par):
5973 )
6074
6175
62- @pytest .mark .parametrize ("par" , [(par1 ), (par1j )])
76+ @pytest .mark .parametrize ("par" , [(par1 ), (par1s ), ( par1j )])
6377def test_Shift1D (par ):
64- """Dot-test and inversion for Shift operator on 1d data"""
78+ """Dot-test and forward/adjoint/ inversion for Shift operator on 1d data"""
6579 np .random .seed (0 )
80+ dtype = np .empty (0 , dtype = par ["dtype" ]).real .dtype
81+
6682 shift = 5.5
6783 x = np .asarray (
68- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
69- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
84+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
85+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
7086 )
7187
7288 Sop = Shift (
@@ -77,35 +93,43 @@ def test_Shift1D(par):
7793 par ["nt" ],
7894 par ["nt" ],
7995 complexflag = 0 if par ["imag" ] == 0 else 3 ,
96+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
8097 backend = backend ,
8198 )
8299
100+ y = Sop * x
101+ xadj = Sop .H * y
83102 xlsqr = lsqr (
84103 Sop ,
85- Sop * x ,
104+ y ,
86105 x0 = np .zeros_like (x ),
87106 damp = 1e-20 ,
88107 niter = 200 ,
89108 atol = 1e-8 ,
90109 btol = 1e-8 ,
91110 show = 0 ,
92111 )[0 ]
93- assert_array_almost_equal (x , xlsqr , decimal = 1 )
112+
113+ assert y .dtype == par ["dtype" ]
114+ assert xadj .dtype == par ["dtype" ]
115+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
94116
95117
96118@pytest .mark .skipif (
97119 int (os .environ .get ("TEST_CUPY_PYLOPS" , 0 )) == 1 ,
98120 reason = "SciPy engine not compatible with CuPy" ,
99121)
100- @pytest .mark .parametrize ("par" , [(par1 ), (par1j )])
122+ @pytest .mark .parametrize ("par" , [(par1 ), (par1s ), ( par1j )])
101123def test_Shift1D_scipy (par ):
102- """Dot-test and inversion for Shift operator on 1d data
124+ """Dot-test and forward/adjoint/ inversion for Shift operator on 1d data
103125 with scipy engine and workers"""
104126 np .random .seed (0 )
127+ dtype = np .empty (0 , dtype = par ["dtype" ]).real .dtype
128+
105129 shift = 5.5
106130 x = np .asarray (
107- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
108- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
131+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
132+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
109133 )
110134
111135 Sop = Shift (
@@ -121,34 +145,42 @@ def test_Shift1D_scipy(par):
121145 par ["nt" ],
122146 par ["nt" ],
123147 complexflag = 0 if par ["imag" ] == 0 else 3 ,
148+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
124149 backend = backend ,
125150 )
126151
152+ y = Sop * x
153+ xadj = Sop .H * y
127154 xlsqr = lsqr (
128155 Sop ,
129- Sop * x ,
156+ y ,
130157 x0 = np .zeros_like (x ),
131158 damp = 1e-20 ,
132159 niter = 200 ,
133160 atol = 1e-8 ,
134161 btol = 1e-8 ,
135162 show = 0 ,
136163 )[0 ]
137- assert_array_almost_equal (x , xlsqr , decimal = 1 )
164+
165+ assert y .dtype == par ["dtype" ]
166+ assert xadj .dtype == par ["dtype" ]
167+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
138168
139169
140- @pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1j ), (par2j )])
170+ @pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1s ), ( par2s ), ( par1j ), (par2j )])
141171def test_Shift2D (par ):
142- """Dot-test and inversion for Shift operator on 2d data"""
172+ """Dot-test and forward/adjoint/ inversion for Shift operator on 2d data"""
143173 np .random .seed (0 )
174+ dtype = np .empty (0 , dtype = par ["dtype" ]).real .dtype
175+
144176 shift = 5.5
145177
146178 # 1st axis
147179 x = np .asarray (
148- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
149- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
180+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
181+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
150182 )
151- x = np .outer (x , np .ones (par ["nx" ]))
183+ x = np .outer (x , np .ones (par ["nx" ], dtype = dtype ))
152184 Sop = Shift (
153185 (par ["nt" ], par ["nx" ]),
154186 shift ,
@@ -161,26 +193,33 @@ def test_Shift2D(par):
161193 par ["nt" ] * par ["nx" ],
162194 par ["nt" ] * par ["nx" ],
163195 complexflag = 0 if par ["imag" ] == 0 else 3 ,
196+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
164197 backend = backend ,
165198 )
199+
200+ y = Sop * x .ravel ()
201+ xadj = Sop .H * y
166202 xlsqr = lsqr (
167203 Sop ,
168- Sop * x . ravel () ,
204+ y ,
169205 x0 = np .zeros_like (x ),
170206 damp = 1e-20 ,
171207 niter = 200 ,
172208 atol = 1e-8 ,
173209 btol = 1e-8 ,
174210 show = 0 ,
175211 )[0 ]
176- assert_array_almost_equal (x , xlsqr , decimal = 1 )
212+
213+ assert y .dtype == par ["dtype" ]
214+ assert xadj .dtype == par ["dtype" ]
215+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
177216
178217 # 2nd axis
179218 x = np .asarray (
180- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
181- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
219+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
220+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
182221 )
183- x = np .outer (x , np .ones (par ["nx" ])).T
222+ x = np .outer (x , np .ones (par ["nx" ], dtype = dtype )).T
184223 Sop = Shift (
185224 (par ["nx" ], par ["nt" ]),
186225 shift ,
@@ -193,33 +232,42 @@ def test_Shift2D(par):
193232 par ["nt" ] * par ["nx" ],
194233 par ["nt" ] * par ["nx" ],
195234 complexflag = 0 if par ["imag" ] == 0 else 3 ,
235+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
196236 backend = backend ,
197237 )
238+
239+ y = Sop * x .ravel ()
240+ xadj = Sop .H * y
198241 xlsqr = lsqr (
199242 Sop ,
200- Sop * x . ravel () ,
243+ y ,
201244 x0 = np .zeros_like (x ),
202245 damp = 1e-20 ,
203246 niter = 200 ,
204247 atol = 1e-8 ,
205248 btol = 1e-8 ,
206249 show = 0 ,
207250 )[0 ]
208- assert_array_almost_equal (x , xlsqr , decimal = 1 )
251+
252+ assert y .dtype == par ["dtype" ]
253+ assert xadj .dtype == par ["dtype" ]
254+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
209255
210256
211- @pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1j ), (par2j )])
257+ @pytest .mark .parametrize ("par" , [(par1 ), (par2 ), (par1s ), ( par2s ), ( par1j ), (par2j )])
212258def test_Shift2Dvariable (par ):
213- """Dot-test and inversion for Shift operator on 2d data with variable shift"""
259+ """Dot-test and forward/adjoint/ inversion for Shift operator on 2d data with variable shift"""
214260 np .random .seed (0 )
261+ dtype = np .empty (0 , dtype = par ["dtype" ]).real .dtype
262+
215263 shift = npp .arange (par ["nx" ])
216264
217265 # 1st axis
218266 x = np .asarray (
219- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
220- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
267+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
268+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
221269 )
222- x = np .outer (x , np .ones (par ["nx" ]))
270+ x = np .outer (x , np .ones (par ["nx" ], dtype = dtype ))
223271 Sop = Shift (
224272 (par ["nt" ], par ["nx" ]),
225273 shift ,
@@ -232,26 +280,33 @@ def test_Shift2Dvariable(par):
232280 par ["nt" ] * par ["nx" ],
233281 par ["nt" ] * par ["nx" ],
234282 complexflag = 0 if par ["imag" ] == 0 else 3 ,
283+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
235284 backend = backend ,
236285 )
286+
287+ y = Sop * x .ravel ()
288+ xadj = Sop .H * y
237289 xlsqr = lsqr (
238290 Sop ,
239- Sop * x . ravel () ,
291+ y ,
240292 x0 = np .zeros_like (x ),
241293 damp = 1e-20 ,
242294 niter = 200 ,
243295 atol = 1e-8 ,
244296 btol = 1e-8 ,
245297 show = 0 ,
246298 )[0 ]
247- assert_array_almost_equal (x , xlsqr , decimal = 1 )
299+
300+ assert y .dtype == par ["dtype" ]
301+ assert xadj .dtype == par ["dtype" ]
302+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
248303
249304 # 2nd axis
250305 x = np .asarray (
251- gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
252- + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]
306+ gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
307+ + par ["imag" ] * gaussian (np .arange (par ["nt" ] // 2 + 1 ), 2.0 )[0 ]. astype ( dtype )
253308 )
254- x = np .outer (x , np .ones (par ["nx" ])).T
309+ x = np .outer (x , np .ones (par ["nx" ], dtype = dtype )).T
255310 Sop = Shift (
256311 (par ["nx" ], par ["nt" ]),
257312 shift ,
@@ -264,8 +319,12 @@ def test_Shift2Dvariable(par):
264319 par ["nt" ] * par ["nx" ],
265320 par ["nt" ] * par ["nx" ],
266321 complexflag = 0 if par ["imag" ] == 0 else 3 ,
322+ rtol = 1e-4 if dtype == np .float32 else 1e-6 ,
267323 backend = backend ,
268324 )
325+
326+ y = Sop * x .ravel ()
327+ xadj = Sop .H * y
269328 xlsqr = lsqr (
270329 Sop ,
271330 Sop * x .ravel (),
@@ -276,4 +335,7 @@ def test_Shift2Dvariable(par):
276335 btol = 1e-8 ,
277336 show = 0 ,
278337 )[0 ]
279- assert_array_almost_equal (x , xlsqr , decimal = 1 )
338+
339+ assert y .dtype == par ["dtype" ]
340+ assert xadj .dtype == par ["dtype" ]
341+ assert_array_almost_equal (x , xlsqr , decimal = 2 if dtype == np .float32 else 4 )
0 commit comments