Skip to content

Commit 476ae2d

Browse files
committed
test: switch to cgls in test_patching for stability of fp32 tests
1 parent d5df25f commit 476ae2d

1 file changed

Lines changed: 28 additions & 27 deletions

File tree

pytests/test_patching.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414

1515
from pylops.basicoperators import MatrixMult
16+
from pylops.optimization.basic import cgls
1617
from pylops.signalprocessing import Patch2D, Patch3D
1718
from pylops.signalprocessing.patch2d import patch2d_design
1819
from pylops.signalprocessing.patch3d import patch3d_design
@@ -166,14 +167,14 @@ def test_Patch2D(par, dtype):
166167
rtol=1e-3 if dtype == np.float32 else 1e-6,
167168
backend=backend,
168169
)
169-
x = np.ones(par["ny"] * nwins[0] * par["nt"] * nwins[1], dtype=dtype)
170+
x = np.ones((nwins[0], nwins[1], par["ny"], par["nt"]), dtype=dtype)
170171
y = Pop * x.ravel()
171172
xadj = Pop.H * y
172-
xinv = Pop / y
173+
xinv = cgls(Pop, y, niter=50)[0]
173174

174175
assert y.dtype == dtype
175176
assert xadj.dtype == dtype
176-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
177+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
177178

178179

179180
@pytest.mark.parametrize("par", [(par1), (par4)])
@@ -210,14 +211,14 @@ def test_Patch2D_scalings(par, dtype):
210211
rtol=1e-3 if dtype == np.float32 else 1e-6,
211212
backend=backend,
212213
)
213-
x = np.ones(par["ny"] * nwins[0] * par["nt"] * nwins[1], dtype=dtype)
214+
x = np.ones((nwins[0], nwins[1], par["ny"], par["nt"]), dtype=dtype)
214215
y = Pop * x.ravel()
215216
xadj = Pop.H * y
216-
xinv = Pop / y
217+
xinv = cgls(Pop, y, niter=50)[0]
217218

218219
assert y.dtype == dtype
219220
assert xadj.dtype == dtype
220-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
221+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
221222

222223

223224
@pytest.mark.parametrize("par", [(par1), (par4)])
@@ -254,14 +255,14 @@ def test_Patch2D_singlepatch1(par, dtype):
254255
rtol=1e-3 if dtype == np.float32 else 1e-6,
255256
backend=backend,
256257
)
257-
x = np.ones(par["npy"] * nwins[0] * par["nt"] * nwins[1], dtype=dtype)
258+
x = np.ones((nwins[0], nwins[1], par["npy"], par["nt"]), dtype=dtype)
258259
y = Pop * x.ravel()
259260
xadj = Pop.H * y
260-
xinv = Pop / y
261+
xinv = cgls(Pop, y, niter=50)[0]
261262

262263
assert y.dtype == dtype
263264
assert xadj.dtype == dtype
264-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
265+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
265266

266267

267268
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -297,14 +298,14 @@ def test_Patch2D_singlepatch2(par, dtype):
297298
rtol=1e-3 if dtype == np.float32 else 1e-6,
298299
backend=backend,
299300
)
300-
x = np.ones(par["ny"] * nwins[0] * par["npt"] * nwins[1], dtype=dtype)
301+
x = np.ones((nwins[0], nwins[1], par["ny"], par["nt"]), dtype=dtype)
301302
y = Pop * x.ravel()
302303
xadj = Pop.H * y
303-
xinv = Pop / y
304+
xinv = cgls(Pop, y, niter=50)[0]
304305

305306
assert y.dtype == dtype
306307
assert xadj.dtype == dtype
307-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
308+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
308309

309310

310311
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -351,15 +352,15 @@ def test_Patch3D(par, dtype):
351352
backend=backend,
352353
)
353354
x = np.ones(
354-
(par["ny"] * nwins[0], par["nx"] * nwins[1], par["nt"] * nwins[2]), dtype=dtype
355+
(nwins[0], nwins[1], nwins[2], par["ny"], par["nx"], par["nt"]),
355356
)
356357
y = Pop * x.ravel()
357358
xadj = Pop.H * y
358-
xinv = Pop / y
359+
xinv = cgls(Pop, y, niter=50)[0]
359360

360361
assert y.dtype == dtype
361362
assert xadj.dtype == dtype
362-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
363+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
363364

364365

365366
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -403,15 +404,15 @@ def test_Patch3D_singlepatch1(par, dtype):
403404
backend=backend,
404405
)
405406
x = np.ones(
406-
(par["npy"] * nwins[0], par["nx"] * nwins[1], par["nt"] * nwins[2]), dtype=dtype
407+
(nwins[0], nwins[1], nwins[2], par["npy"], par["nx"], par["nt"]),
407408
)
408409
y = Pop * x.ravel()
409410
xadj = Pop.H * y
410-
xinv = Pop / y
411+
xinv = cgls(Pop, y, niter=50)[0]
411412

412413
assert y.dtype == dtype
413414
assert xadj.dtype == dtype
414-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
415+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
415416

416417

417418
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -455,15 +456,15 @@ def test_Patch3D_singlepatch2(par, dtype):
455456
backend=backend,
456457
)
457458
x = np.ones(
458-
(par["ny"] * nwins[0], par["npx"] * nwins[1], par["nt"] * nwins[2]), dtype=dtype
459+
(nwins[0], nwins[1], nwins[2], par["ny"], par["npx"], par["nt"]),
459460
)
460461
y = Pop * x.ravel()
461462
xadj = Pop.H * y
462-
xinv = Pop / y
463+
xinv = cgls(Pop, y, niter=50)[0]
463464

464465
assert y.dtype == dtype
465466
assert xadj.dtype == dtype
466-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
467+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
467468

468469

469470
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -509,16 +510,16 @@ def test_Patch3D_singlepatch12(par, dtype):
509510
backend=backend,
510511
)
511512
x = np.ones(
512-
(par["npy"] * nwins[0], par["npx"] * nwins[1], par["nt"] * nwins[2]),
513+
(nwins[0], nwins[1], nwins[2], par["npy"], par["npx"], par["nt"]),
513514
dtype=dtype,
514515
)
515516
y = Pop * x.ravel()
516517
xadj = Pop.H * y
517-
xinv = Pop / y
518+
xinv = cgls(Pop, y, niter=50)[0]
518519

519520
assert y.dtype == dtype
520521
assert xadj.dtype == dtype
521-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
522+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)
522523

523524

524525
@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
@@ -561,12 +562,12 @@ def test_Patch3D_singlepatch3(par, dtype):
561562
backend=backend,
562563
)
563564
x = np.ones(
564-
(par["ny"] * nwins[0], par["nx"] * nwins[1], par["npt"] * nwins[2]), dtype=dtype
565+
(nwins[0], nwins[1], nwins[2], par["ny"], par["nx"], par["npt"]),
565566
)
566567
y = Pop * x.ravel()
567568
xadj = Pop.H * y
568-
xinv = Pop / y
569+
xinv = cgls(Pop, y, niter=50)[0]
569570

570571
assert y.dtype == dtype
571572
assert xadj.dtype == dtype
572-
assert_array_almost_equal(x.ravel(), xinv, decimal=3 if dtype == np.float32 else 8)
573+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 8)

0 commit comments

Comments
 (0)