Skip to content

Commit 621c3bc

Browse files
committed
fix: fixed Seislet to work with correct dtype (and added tests)
1 parent 476ae2d commit 621c3bc

2 files changed

Lines changed: 22 additions & 14 deletions

File tree

pylops/signalprocessing/seislet.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__(
451451
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
452452

453453
pad = [(0, ndimpow2 - self.dims[0])] + [(0, 0)] * (len(self.dims) - 1)
454-
self.pad = Pad(self.dims, pad)
454+
self.pad = Pad(self.dims, pad, dtype=self.dtype)
455455
self.nx, self.nt = self.dimsd
456456

457457
# define levels
@@ -473,7 +473,9 @@ def __init__(
473473
def _matvec(self, x: NDArray) -> NDArray:
474474
x = self.pad.matvec(x)
475475
x = np.reshape(x, self.dimsd)
476-
y = np.zeros((np.sum(self.levels_size) + self.levels_size[-1], self.nt))
476+
y = np.zeros(
477+
(np.sum(self.levels_size) + self.levels_size[-1], self.nt), dtype=self.dtype
478+
)
477479
for ilevel in range(self.level):
478480
odd = x[1::2]
479481
even = x[::2]
@@ -519,7 +521,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
519521
backward=False,
520522
adj=True,
521523
)
522-
y = np.zeros((2 * even.shape[0], self.nt))
524+
y = np.zeros((2 * even.shape[0], self.nt), dtype=self.dtype)
523525
y[1::2] = odd
524526
y[::2] = even
525527
y = self.pad.rmatvec(y.ravel())
@@ -542,7 +544,7 @@ def inverse(self, x: NDArray) -> NDArray:
542544
odd = res + self.predict(
543545
even, self.dt, self.dx, self.slopes, repeat=ilevel - 1, backward=False
544546
)
545-
y = np.zeros((2 * even.shape[0], self.nt))
547+
y = np.zeros((2 * even.shape[0], self.nt), dtype=self.dtype)
546548
y[1::2] = odd
547549
y[::2] = even
548550
y = self.pad.rmatvec(y.ravel())

pytests/test_seislet.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,20 @@
1515
"dx": 10,
1616
"dt": 0.004,
1717
"level": None,
18-
"dtype": "float32",
1918
} # nx power of 2, max level
2019
par2 = {
2120
"nx": 16,
2221
"nt": 30,
2322
"dx": 10,
2423
"dt": 0.004,
2524
"level": 2,
26-
"dtype": "float32",
2725
} # nx power of 2, smaller level
2826
par3 = {
2927
"nx": 13,
3028
"nt": 30,
3129
"dx": 10,
3230
"dt": 0.004,
3331
"level": 2,
34-
"dtype": "float32",
3532
} # nx not power of 2, max level
3633

3734
np.random.seed(10)
@@ -118,21 +115,30 @@ def _predict_reshape(
118115
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
119116
)
120117
@pytest.mark.parametrize("par", [(par1), (par2), (par3)])
121-
def test_Seislet(par):
122-
"""Dot-test and forward-inverse for Seislet"""
123-
slope = np.random.normal(0, 0.1, (par["nx"], par["nt"]))
118+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
119+
def test_Seislet(par, dtype):
120+
"""Dot-test and forward/adjoint/inverse for Seislet"""
121+
slope = np.random.normal(0, 0.1, (par["nx"], par["nt"])).astype(dtype)
124122

125123
for kind in ("haar", "linear"):
126124
Sop = Seislet(
127125
slope,
128126
sampling=(par["dx"], par["dt"]),
129127
level=par["level"],
130128
kind=kind,
131-
dtype=par["dtype"],
129+
dtype=dtype,
130+
)
131+
dottest(
132+
Sop,
133+
Sop.shape[0],
134+
par["nx"] * par["nt"],
135+
rtol=1e-3 if dtype == np.float32 else 1e-6,
132136
)
133-
dottest(Sop, Sop.shape[0], par["nx"] * par["nt"])
134137

135-
x = np.random.normal(0, 0.1, par["nx"] * par["nt"])
138+
x = np.random.normal(0, 0.1, par["nx"] * par["nt"]).astype(dtype)
136139
y = Sop * x
140+
xadj = Sop.H * y
137141
xinv = Sop.inverse(y)
138-
assert_array_almost_equal(x, xinv)
142+
assert y.dtype == dtype
143+
assert xadj.dtype == dtype
144+
assert_array_almost_equal(x, xinv, decimal=3 if dtype == np.float32 else 6)

0 commit comments

Comments
 (0)