Skip to content

Commit bca6dbc

Browse files
authored
Merge pull request #743 from MothNik/dev
Minor improvements for `signalprocessing.Interp` and `signalprocessing.InterpCubicSpline`
2 parents deb8240 + a9da70d commit bca6dbc

4 files changed

Lines changed: 106 additions & 10 deletions

File tree

pylops/signalprocessing/_interp_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _ensure_iava_is_unique(
2828
def _clip_iava_above_last_sample_index(
2929
iava: NDArray,
3030
sample_size: int,
31-
) -> None:
31+
) -> NDArray:
3232
"""
3333
Ensures that elements in ``iava`` do not exceed the last sample index.
3434
Elements above the penultimate sample are clipped to the next closest float value
@@ -47,8 +47,9 @@ def _clip_iava_above_last_sample_index(
4747
# NOTE: ``numpy.nextafter(x, -np.inf)`` gives the closest float-value that is
4848
# less than ``x``, i.e., this logic clips ``iava`` to the highest possible
4949
# value that is still below the last sample
50+
iava = iava.copy() # to avoid silent input mutation
5051
iava[np.where(outside)] = np.nextafter(last_sample_index, -np.inf)
5152

5253
_ensure_iava_is_unique(iava=iava)
5354

54-
return
55+
return iava

pylops/signalprocessing/interp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def _linearinterp(
4747

4848
# ensure that samples are not beyond the last sample, in that case set to
4949
# penultimate sample and raise a warning
50-
_clip_iava_above_last_sample_index(iava=iava, sample_size=sample_size)
50+
iava = _clip_iava_above_last_sample_index( # type: ignore
51+
iava=iava, # type: ignore
52+
sample_size=sample_size,
53+
)
5154

5255
# find indices and weights
5356
iva_l = ncp.floor(iava).astype(int)
@@ -133,7 +136,7 @@ def Interp(
133136
polynomial fitted between ``np.floor(iava)`` and ``np.floor(iava) + 1``.
134137
It offers an excellent tradeoff between accuracy and computational complexity
135138
and its results oscillate less than those obtained from sinc interpolation.
136-
It can also be accessed directly via :class:`pylops.singalprocessing.InterpCubicSpline`.
139+
It can also be accessed directly via :class:`pylops.signalprocessing.InterpCubicSpline`.
137140
138141
.. note:: The vector ``iava`` should contain unique values. If the same
139142
index is repeated twice an error will be raised. This also applies when
@@ -179,7 +182,7 @@ def Interp(
179182
ValueError
180183
If the vector ``iava`` contains repeated values.
181184
NotImplementedError
182-
If ``kind`` is not ``nearest``, ``linear`` or ``sinc``
185+
If ``kind`` is not ``nearest``, ``linear``, ``sinc``, or ``cubic_spline``
183186
184187
See Also
185188
--------

pylops/signalprocessing/interpspline.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class _BandedLUDecomposition:
210210
Represents the LU decomposition of a general banded matrix as performed by the
211211
LAPACK routines ``?gbtrf``.
212212
This class was implemented for spline interpolations between only 2 data points
213-
because the class :class:`_BandedLUDecomposition` uses the LAPACK routines
213+
because the class :class:`_TridiagonalLUDecomposition` uses the LAPACK routines
214214
``?gttrf`` that cannot handle 2 x 2 tridiagonal matrices.
215215
216216
"""
@@ -251,7 +251,11 @@ def from_tridiagonal_matrix(
251251
banded_representation[2, ::] = matrix.main_diagonal
252252
banded_representation[3, 0:-1] = matrix.sub_diagonal
253253

254-
(lu_banded, pivot_indices, info,) = lapack_factorizer(
254+
(
255+
lu_banded,
256+
pivot_indices,
257+
info,
258+
) = lapack_factorizer(
255259
ab=banded_representation,
256260
kl=1,
257261
ku=1,
@@ -389,7 +393,6 @@ def from_tridiagonal_matrix(
389393

390394
raise np.linalg.LinAlgError(
391395
f"Could not LU-factorize tridiagonal matrix! Got {info=}."
392-
f"Could not LU-factorize tridiagonal matrix! Got {info=}."
393396
)
394397

395398
def solve(
@@ -818,7 +821,10 @@ def __init__(
818821
)
819822

820823
iava = np.asarray(iava, dtype=np.float64)
821-
_clip_iava_above_last_sample_index(iava=iava, sample_size=num_cols)
824+
iava = _clip_iava_above_last_sample_index( # type: ignore
825+
iava=iava,
826+
sample_size=num_cols,
827+
)
822828

823829
if isinstance(bc_type, str) and bc_type.lower() in {"natural"}:
824830
self.bc_type = bc_type.lower()
@@ -950,7 +956,7 @@ def _rmatvec(self, x: InexactNDArray) -> InexactNDArray:
950956
x_mod[0 : self.num_cols]
951957
+ self._rmatmat_difference_method(
952958
self._lhs_B_matrix_transposed_lu.solve(
953-
rhs=x_mod[self.num_cols : x_mod.size],
959+
rhs=x_mod[self.num_cols :],
954960
lapack_solver=self._tridiag_lu_solve,
955961
)
956962
)

pytests/test_interpolation_spline.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from math import prod
12
from typing import Final, Tuple
23

34
import numpy as np
45
import pytest
56
from scipy.interpolate import CubicSpline
67

78
from pylops.signalprocessing import InterpCubicSpline
9+
from pylops.utils import dottest
810

911
TEST_ARRAY_SHAPE: Final[Tuple] = (
1012
20,
@@ -16,6 +18,90 @@
1618
MIN_NUM_TEST_SAMPLES: Final[int] = 1
1719

1820

21+
def test_cubic_spline_raises_on_not_supported_bc_type() -> None:
22+
"""
23+
Tests whether ``pylops.signalprocessing.InterpCubicSpline`` raises a
24+
``NotImplementedError`` for boundary conditions that are not supported.
25+
26+
"""
27+
28+
with pytest.raises(NotImplementedError):
29+
InterpCubicSpline(
30+
dims=(5, 2),
31+
iava=np.array([0.5, 2.3]),
32+
bc_type="erroneous", # type: ignore
33+
)
34+
35+
36+
@pytest.mark.parametrize(
37+
"with_complex",
38+
[
39+
pytest.param(False, id="real"),
40+
pytest.param(True, id="complex"),
41+
],
42+
)
43+
@pytest.mark.parametrize(
44+
"axis",
45+
[
46+
0,
47+
1,
48+
2,
49+
3,
50+
-1,
51+
-2,
52+
-3,
53+
],
54+
)
55+
@pytest.mark.parametrize(
56+
"subsample_fraction",
57+
[
58+
pytest.param(0.5, id="decimation"),
59+
pytest.param(5.0, id="upsampling"),
60+
],
61+
)
62+
def test_natural_cubic_spline_dottest(
63+
subsample_fraction: float,
64+
axis: int,
65+
with_complex: bool,
66+
) -> None:
67+
"""
68+
Tests ``pylops.signalprocessing.InterpCubicSpline`` with the ``dottest``.
69+
70+
"""
71+
72+
# Setup
73+
74+
num_samples = TEST_ARRAY_SHAPE[axis]
75+
x_eval_fractions = np.random.rand(
76+
max(
77+
round(num_samples * subsample_fraction),
78+
MIN_NUM_TEST_SAMPLES,
79+
)
80+
)
81+
x_eval_for_pylops = (num_samples - 1) * x_eval_fractions
82+
83+
shape_list = list(TEST_ARRAY_SHAPE)
84+
shape_list[axis] = x_eval_fractions.size # type: ignore
85+
num_rows = prod(shape_list, start=1)
86+
num_columns = prod(TEST_ARRAY_SHAPE, start=1)
87+
88+
# Test
89+
90+
splinop = InterpCubicSpline(
91+
dims=TEST_ARRAY_SHAPE,
92+
iava=x_eval_for_pylops,
93+
axis=axis,
94+
dtype="complex128" if with_complex else "float64",
95+
)
96+
97+
assert dottest(
98+
Op=splinop,
99+
nr=num_rows,
100+
nc=num_columns,
101+
complexflag=0 if not with_complex else 3,
102+
)
103+
104+
19105
@pytest.mark.parametrize(
20106
"with_complex",
21107
[

0 commit comments

Comments
 (0)