Skip to content

Commit 4d14737

Browse files
committed
test: added dtype checks for shift
1 parent 621c3bc commit 4d14737

1 file changed

Lines changed: 99 additions & 37 deletions

File tree

pytests/test_shift.py

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,28 @@
2424
"ny": 11,
2525
"imag": 0,
2626
"dtype": "float64",
27-
} # square real
27+
} # square real (fp64)
2828
par2 = {
2929
"nt": 41,
3030
"nx": 21,
3131
"ny": 11,
3232
"imag": 0,
3333
"dtype": "float64",
34-
} # overdetermined real
34+
} # overdetermined real (fp64)
35+
par1s = {
36+
"nt": 41,
37+
"nx": 41,
38+
"ny": 11,
39+
"imag": 0,
40+
"dtype": "float32",
41+
} # square real (fp32)
42+
par2s = {
43+
"nt": 41,
44+
"nx": 21,
45+
"ny": 11,
46+
"imag": 0,
47+
"dtype": "float32",
48+
} # overdetermined real (fp32)
3549
par1j = {
3650
"nt": 41,
3751
"nx": 41,
@@ -59,14 +73,16 @@ def test_unknown_engine(par):
5973
)
6074

6175

62-
@pytest.mark.parametrize("par", [(par1), (par1j)])
76+
@pytest.mark.parametrize("par", [(par1), (par1s), (par1j)])
6377
def test_Shift1D(par):
64-
"""Dot-test and inversion for Shift operator on 1d data"""
78+
"""Dot-test and forward/adjoint/inversion for Shift operator on 1d data"""
6579
np.random.seed(0)
80+
dtype = np.empty(0, dtype=par["dtype"]).real.dtype
81+
6682
shift = 5.5
6783
x = np.asarray(
68-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
69-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
84+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
85+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
7086
)
7187

7288
Sop = Shift(
@@ -77,35 +93,43 @@ def test_Shift1D(par):
7793
par["nt"],
7894
par["nt"],
7995
complexflag=0 if par["imag"] == 0 else 3,
96+
rtol=1e-4 if dtype == np.float32 else 1e-6,
8097
backend=backend,
8198
)
8299

100+
y = Sop * x
101+
xadj = Sop.H * y
83102
xlsqr = lsqr(
84103
Sop,
85-
Sop * x,
104+
y,
86105
x0=np.zeros_like(x),
87106
damp=1e-20,
88107
niter=200,
89108
atol=1e-8,
90109
btol=1e-8,
91110
show=0,
92111
)[0]
93-
assert_array_almost_equal(x, xlsqr, decimal=1)
112+
113+
assert y.dtype == par["dtype"]
114+
assert xadj.dtype == par["dtype"]
115+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)
94116

95117

96118
@pytest.mark.skipif(
97119
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1,
98120
reason="SciPy engine not compatible with CuPy",
99121
)
100-
@pytest.mark.parametrize("par", [(par1), (par1j)])
122+
@pytest.mark.parametrize("par", [(par1), (par1s), (par1j)])
101123
def test_Shift1D_scipy(par):
102-
"""Dot-test and inversion for Shift operator on 1d data
124+
"""Dot-test and forward/adjoint/inversion for Shift operator on 1d data
103125
with scipy engine and workers"""
104126
np.random.seed(0)
127+
dtype = np.empty(0, dtype=par["dtype"]).real.dtype
128+
105129
shift = 5.5
106130
x = np.asarray(
107-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
108-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
131+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
132+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
109133
)
110134

111135
Sop = Shift(
@@ -121,34 +145,42 @@ def test_Shift1D_scipy(par):
121145
par["nt"],
122146
par["nt"],
123147
complexflag=0 if par["imag"] == 0 else 3,
148+
rtol=1e-4 if dtype == np.float32 else 1e-6,
124149
backend=backend,
125150
)
126151

152+
y = Sop * x
153+
xadj = Sop.H * y
127154
xlsqr = lsqr(
128155
Sop,
129-
Sop * x,
156+
y,
130157
x0=np.zeros_like(x),
131158
damp=1e-20,
132159
niter=200,
133160
atol=1e-8,
134161
btol=1e-8,
135162
show=0,
136163
)[0]
137-
assert_array_almost_equal(x, xlsqr, decimal=1)
164+
165+
assert y.dtype == par["dtype"]
166+
assert xadj.dtype == par["dtype"]
167+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)
138168

139169

140-
@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
170+
@pytest.mark.parametrize("par", [(par1), (par2), (par1s), (par2s), (par1j), (par2j)])
141171
def test_Shift2D(par):
142-
"""Dot-test and inversion for Shift operator on 2d data"""
172+
"""Dot-test and forward/adjoint/inversion for Shift operator on 2d data"""
143173
np.random.seed(0)
174+
dtype = np.empty(0, dtype=par["dtype"]).real.dtype
175+
144176
shift = 5.5
145177

146178
# 1st axis
147179
x = np.asarray(
148-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
149-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
180+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
181+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
150182
)
151-
x = np.outer(x, np.ones(par["nx"]))
183+
x = np.outer(x, np.ones(par["nx"], dtype=dtype))
152184
Sop = Shift(
153185
(par["nt"], par["nx"]),
154186
shift,
@@ -161,26 +193,33 @@ def test_Shift2D(par):
161193
par["nt"] * par["nx"],
162194
par["nt"] * par["nx"],
163195
complexflag=0 if par["imag"] == 0 else 3,
196+
rtol=1e-4 if dtype == np.float32 else 1e-6,
164197
backend=backend,
165198
)
199+
200+
y = Sop * x.ravel()
201+
xadj = Sop.H * y
166202
xlsqr = lsqr(
167203
Sop,
168-
Sop * x.ravel(),
204+
y,
169205
x0=np.zeros_like(x),
170206
damp=1e-20,
171207
niter=200,
172208
atol=1e-8,
173209
btol=1e-8,
174210
show=0,
175211
)[0]
176-
assert_array_almost_equal(x, xlsqr, decimal=1)
212+
213+
assert y.dtype == par["dtype"]
214+
assert xadj.dtype == par["dtype"]
215+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)
177216

178217
# 2nd axis
179218
x = np.asarray(
180-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
181-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
219+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
220+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
182221
)
183-
x = np.outer(x, np.ones(par["nx"])).T
222+
x = np.outer(x, np.ones(par["nx"], dtype=dtype)).T
184223
Sop = Shift(
185224
(par["nx"], par["nt"]),
186225
shift,
@@ -193,33 +232,42 @@ def test_Shift2D(par):
193232
par["nt"] * par["nx"],
194233
par["nt"] * par["nx"],
195234
complexflag=0 if par["imag"] == 0 else 3,
235+
rtol=1e-4 if dtype == np.float32 else 1e-6,
196236
backend=backend,
197237
)
238+
239+
y = Sop * x.ravel()
240+
xadj = Sop.H * y
198241
xlsqr = lsqr(
199242
Sop,
200-
Sop * x.ravel(),
243+
y,
201244
x0=np.zeros_like(x),
202245
damp=1e-20,
203246
niter=200,
204247
atol=1e-8,
205248
btol=1e-8,
206249
show=0,
207250
)[0]
208-
assert_array_almost_equal(x, xlsqr, decimal=1)
251+
252+
assert y.dtype == par["dtype"]
253+
assert xadj.dtype == par["dtype"]
254+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)
209255

210256

211-
@pytest.mark.parametrize("par", [(par1), (par2), (par1j), (par2j)])
257+
@pytest.mark.parametrize("par", [(par1), (par2), (par1s), (par2s), (par1j), (par2j)])
212258
def test_Shift2Dvariable(par):
213-
"""Dot-test and inversion for Shift operator on 2d data with variable shift"""
259+
"""Dot-test and forward/adjoint/inversion for Shift operator on 2d data with variable shift"""
214260
np.random.seed(0)
261+
dtype = np.empty(0, dtype=par["dtype"]).real.dtype
262+
215263
shift = npp.arange(par["nx"])
216264

217265
# 1st axis
218266
x = np.asarray(
219-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
220-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
267+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
268+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
221269
)
222-
x = np.outer(x, np.ones(par["nx"]))
270+
x = np.outer(x, np.ones(par["nx"], dtype=dtype))
223271
Sop = Shift(
224272
(par["nt"], par["nx"]),
225273
shift,
@@ -232,26 +280,33 @@ def test_Shift2Dvariable(par):
232280
par["nt"] * par["nx"],
233281
par["nt"] * par["nx"],
234282
complexflag=0 if par["imag"] == 0 else 3,
283+
rtol=1e-4 if dtype == np.float32 else 1e-6,
235284
backend=backend,
236285
)
286+
287+
y = Sop * x.ravel()
288+
xadj = Sop.H * y
237289
xlsqr = lsqr(
238290
Sop,
239-
Sop * x.ravel(),
291+
y,
240292
x0=np.zeros_like(x),
241293
damp=1e-20,
242294
niter=200,
243295
atol=1e-8,
244296
btol=1e-8,
245297
show=0,
246298
)[0]
247-
assert_array_almost_equal(x, xlsqr, decimal=1)
299+
300+
assert y.dtype == par["dtype"]
301+
assert xadj.dtype == par["dtype"]
302+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)
248303

249304
# 2nd axis
250305
x = np.asarray(
251-
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
252-
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0]
306+
gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
307+
+ par["imag"] * gaussian(np.arange(par["nt"] // 2 + 1), 2.0)[0].astype(dtype)
253308
)
254-
x = np.outer(x, np.ones(par["nx"])).T
309+
x = np.outer(x, np.ones(par["nx"], dtype=dtype)).T
255310
Sop = Shift(
256311
(par["nx"], par["nt"]),
257312
shift,
@@ -264,8 +319,12 @@ def test_Shift2Dvariable(par):
264319
par["nt"] * par["nx"],
265320
par["nt"] * par["nx"],
266321
complexflag=0 if par["imag"] == 0 else 3,
322+
rtol=1e-4 if dtype == np.float32 else 1e-6,
267323
backend=backend,
268324
)
325+
326+
y = Sop * x.ravel()
327+
xadj = Sop.H * y
269328
xlsqr = lsqr(
270329
Sop,
271330
Sop * x.ravel(),
@@ -276,4 +335,7 @@ def test_Shift2Dvariable(par):
276335
btol=1e-8,
277336
show=0,
278337
)[0]
279-
assert_array_almost_equal(x, xlsqr, decimal=1)
338+
339+
assert y.dtype == par["dtype"]
340+
assert xadj.dtype == par["dtype"]
341+
assert_array_almost_equal(x, xlsqr, decimal=2 if dtype == np.float32 else 4)

0 commit comments

Comments
 (0)