Skip to content

Commit b11dd7a

Browse files
committed
test: added dtype checks for oneway
1 parent d97db06 commit b11dd7a

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

pytests/test_oneway.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"nx": 10,
3939
"nt": 20,
4040
"kind": "p",
41-
"dtype": "float32",
4241
"fftengine": "numpy",
4342
"kwargs_fft": {},
4443
} # even, p, numpy
@@ -47,7 +46,6 @@
4746
"nx": 11,
4847
"nt": 21,
4948
"kind": "p",
50-
"dtype": "float32",
5149
"fftengine": "numpy",
5250
"kwargs_fft": {},
5351
} # odd, p, numpy
@@ -56,7 +54,6 @@
5654
"nx": 10,
5755
"nt": 20,
5856
"kind": "p",
59-
"dtype": "float32",
6057
"fftengine": "scipy",
6158
"kwargs_fft": dict(workers=4),
6259
} # even, p, scipy
@@ -65,7 +62,6 @@
6562
"nx": 10,
6663
"nt": 20,
6764
"kind": "p",
68-
"dtype": "float32",
6965
"fftengine": "fft",
7066
"kwargs_fft": {},
7167
} # even, p, fftw
@@ -74,14 +70,12 @@
7470
"nx": 10,
7571
"nt": 20,
7672
"kind": "vz",
77-
"dtype": "float32",
7873
} # even, vz, numpy
7974
par2v = {
8075
"ny": 9,
8176
"nx": 11,
8277
"nt": 21,
8378
"kind": "vz",
84-
"dtype": "float32",
8579
} # odd, vz, numpy
8680

8781
# deghosting params
@@ -116,7 +110,8 @@ def create_data2D(datakind):
116110

117111

118112
@pytest.mark.parametrize("par", [(par1), (par2), (par1s), (par1w)])
119-
def test_PhaseShift_2dsignal(par):
113+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
114+
def test_PhaseShift_2dsignal(par, dtype):
120115
"""Dot-test for PhaseShift of 2d signal"""
121116
vel = 1500.0
122117
zprop = 200
@@ -131,32 +126,49 @@ def test_PhaseShift_2dsignal(par):
131126
freq,
132127
kx,
133128
fftengine=par["fftengine"] if backend == "numpy" else "numpy",
134-
dtype=par["dtype"],
129+
dtype=dtype,
135130
**kwargs_fft,
136131
)
137132
assert dottest(
138-
Pop, par["nt"] * par["nx"], par["nt"] * par["nx"], rtol=1e-3, backend=backend
133+
Pop,
134+
par["nt"] * par["nx"],
135+
par["nt"] * par["nx"],
136+
rtol=1e-4 if dtype == np.float32 else 1e-6,
137+
backend=backend,
139138
)
140139

140+
x = np.ones((par["nt"], par["nx"]), dtype=dtype)
141+
y = Pop * x.ravel()
142+
xadj = Pop.H * y
143+
assert y.dtype == dtype
144+
assert xadj.dtype == dtype
145+
141146

142147
@pytest.mark.parametrize("par", [(par1), (par2)])
143-
def test_PhaseShift_3dsignal(par):
148+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
149+
def test_PhaseShift_3dsignal(par, dtype):
144150
"""Dot-test for PhaseShift of 3d signal"""
145151
vel = 1500.0
146152
zprop = 200
147153
freq = np.fft.rfftfreq(par["nt"], 1.0)
148154
kx = np.fft.fftshift(np.fft.fftfreq(par["nx"], 1.0))
149155
ky = np.fft.fftshift(np.fft.fftfreq(par["ny"], 1.0))
150156

151-
Pop = PhaseShift(vel, zprop, par["nt"], freq, kx, ky, dtype=par["dtype"])
157+
Pop = PhaseShift(vel, zprop, par["nt"], freq, kx, ky, dtype=dtype)
152158
assert dottest(
153159
Pop,
154160
par["nt"] * par["nx"] * par["ny"],
155161
par["nt"] * par["nx"] * par["ny"],
156-
rtol=1e-3,
162+
rtol=1e-4 if dtype == np.float32 else 1e-6,
157163
backend=backend,
158164
)
159165

166+
x = np.ones((par["nt"], par["nx"], par["ny"]), dtype=dtype)
167+
y = Pop * x.ravel()
168+
xadj = Pop.H * y
169+
assert y.dtype == dtype
170+
assert xadj.dtype == dtype
171+
160172

161173
@pytest.mark.parametrize("par", [(par1), (par2), (par1v), (par2v)])
162174
def test_Deghosting_2dsignal(par):
@@ -176,7 +188,7 @@ def test_Deghosting_2dsignal(par):
176188
npad=0,
177189
ntaper=0,
178190
solver=lsqr,
179-
dtype=par["dtype"],
191+
dtype=np.float32,
180192
**dict(damp=1e-10, niter=60),
181193
)
182194

0 commit comments

Comments
 (0)