Skip to content

Commit 1ac64bc

Browse files
committed
Update code and tests
1 parent 663406c commit 1ac64bc

5 files changed

Lines changed: 64 additions & 50 deletions

File tree

docs/source/installation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ Install it via ``pip`` with
356356
>> pip install devito
357357
358358
359-
FFTW
360-
----
359+
FFTW and MKL-FFT
360+
----------------
361361
Four different "engines" are provided by the :py:class:`pylops.signalprocessing.FFT` operator:
362362
``engine="numpy"`` (default), ``engine="scipy"``, ``engine="fftw"`` and ``engine="mkl_fft"``.
363363

pylops/signalprocessing/fft.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pyfftw
2323

2424
if mkl_fft_message is None:
25-
import mkl_fft.interfaces.numpy_fft as mkl_backend
25+
import mkl_fft.interfaces.scipy_fft as mkl_backend
26+
from mkl_fft.interfaces import _float_utils
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -437,8 +438,10 @@ def __init__(
437438

438439
@reshaped
439440
def _matvec(self, x: NDArray) -> NDArray:
441+
x = _float_utils._downcast_float128_array(x)
442+
x = _float_utils._upcast_float16_array(x)
440443
if self.ifftshift_before:
441-
x = mkl_backend.ifftshift(x, axes=self.axis)
444+
x = scipy.fft.ifftshift(x, axes=self.axis)
442445
if not self.clinear:
443446
x = np.real(x)
444447
if self.real:
@@ -451,13 +454,15 @@ def _matvec(self, x: NDArray) -> NDArray:
451454
if self.norm is _FFTNorms.ONE_OVER_N:
452455
y *= self._scale
453456
if self.fftshift_after:
454-
y = mkl_backend.fftshift(y, axes=self.axis)
457+
y = scipy.fft.fftshift(y, axes=self.axis)
455458
return y
456459

457460
@reshaped
458461
def _rmatvec(self, x: NDArray) -> NDArray:
462+
x = _float_utils._downcast_float128_array(x)
463+
x = _float_utils._upcast_float16_array(x)
459464
if self.fftshift_after:
460-
x = mkl_backend.ifftshift(x, axes=self.axis)
465+
x = scipy.fft.ifftshift(x, axes=self.axis)
461466
if self.real:
462467
x = x.copy()
463468
x = np.swapaxes(x, -1, self.axis)
@@ -477,7 +482,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
477482
if not self.clinear:
478483
y = np.real(y)
479484
if self.ifftshift_before:
480-
y = mkl_backend.fftshift(y, axes=self.axis)
485+
y = scipy.fft.fftshift(y, axes=self.axis)
481486
return y
482487

483488
def __truediv__(self, y):

pylops/signalprocessing/fft2d.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
mkl_fft_message = deps.mkl_fft_import("the mkl fft module")
1717

1818
if mkl_fft_message is None:
19-
import mkl_fft.interfaces.numpy_fft as mkl_backend
19+
import mkl_fft.interfaces.scipy_fft as mkl_backend
20+
from mkl_fft.interfaces import _float_utils
2021

2122

2223
class _FFT2D_numpy(_BaseFFTND):
@@ -290,8 +291,10 @@ def __init__(
290291

291292
@reshaped
292293
def _matvec(self, x):
294+
x = _float_utils._downcast_float128_array(x)
295+
x = _float_utils._upcast_float16_array(x)
293296
if self.ifftshift_before.any():
294-
x = mkl_backend.ifftshift(x, axes=self.axes[self.ifftshift_before])
297+
x = scipy.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
295298
if not self.clinear:
296299
x = np.real(x)
297300
if self.real:
@@ -309,13 +312,15 @@ def _matvec(self, x):
309312
y *= self._scale
310313
y = y.astype(self.cdtype)
311314
if self.fftshift_after.any():
312-
y = mkl_backend.fftshift(y, axes=self.axes[self.fftshift_after])
315+
y = scipy.fft.fftshift(y, axes=self.axes[self.fftshift_after])
313316
return y
314317

315318
@reshaped
316319
def _rmatvec(self, x):
320+
x = _float_utils._downcast_float128_array(x)
321+
x = _float_utils._upcast_float16_array(x)
317322
if self.fftshift_after.any():
318-
x = mkl_backend.ifftshift(x, axes=self.axes[self.fftshift_after])
323+
x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
319324
if self.real:
320325
x = x.copy()
321326
x = np.swapaxes(x, -1, self.axes[-1])
@@ -330,7 +335,6 @@ def _rmatvec(self, x):
330335
)
331336
if self.norm is _FFTNorms.NONE:
332337
y *= self._scale
333-
print(y.shape, self.dims[self.axes[0]])
334338
if self.nffts[0] > self.dims[self.axes[0]]:
335339
y = np.take(y, np.arange(self.dims[self.axes[0]]), axis=self.axes[0])
336340
if self.nffts[1] > self.dims[self.axes[1]]:
@@ -341,7 +345,7 @@ def _rmatvec(self, x):
341345
y = np.real(y)
342346
y = y.astype(self.rdtype)
343347
if self.ifftshift_before.any():
344-
y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
348+
y = scipy.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
345349
return y
346350

347351
def __truediv__(self, y):

pylops/signalprocessing/fftnd.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import numpy.typing as npt
8+
import scipy.fft
89

910
from pylops import LinearOperator
1011
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
@@ -16,7 +17,8 @@
1617
mkl_fft_message = deps.mkl_fft_import("the mkl fft module")
1718

1819
if mkl_fft_message is None:
19-
import mkl_fft.interfaces.numpy_fft as mkl_backend
20+
import mkl_fft.interfaces.scipy_fft as mkl_backend
21+
from mkl_fft.interfaces import _float_utils
2022

2123

2224
class _FFTND_numpy(_BaseFFTND):
@@ -260,8 +262,10 @@ def __init__(
260262

261263
@reshaped
262264
def _matvec(self, x: NDArray) -> NDArray:
265+
x = _float_utils._downcast_float128_array(x)
266+
x = _float_utils._upcast_float16_array(x)
263267
if self.ifftshift_before.any():
264-
x = mkl_backend.ifftshift(x, axes=self.axes[self.ifftshift_before])
268+
x = scipy.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
265269
if not self.clinear:
266270
x = np.real(x)
267271
if self.real:
@@ -279,13 +283,15 @@ def _matvec(self, x: NDArray) -> NDArray:
279283
if self.norm is _FFTNorms.ONE_OVER_N:
280284
y *= self._scale
281285
if self.fftshift_after.any():
282-
y = mkl_backend.fftshift(y, axes=self.axes[self.fftshift_after])
286+
y = scipy.fft.fftshift(y, axes=self.axes[self.fftshift_after])
283287
return y
284288

285289
@reshaped
286290
def _rmatvec(self, x: NDArray) -> NDArray:
291+
x = _float_utils._downcast_float128_array(x)
292+
x = _float_utils._upcast_float16_array(x)
287293
if self.fftshift_after.any():
288-
x = mkl_backend.ifftshift(x, axes=self.axes[self.fftshift_after])
294+
x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
289295
if self.real:
290296
# Apply scaling to obtain a correct adjoint for this operator
291297
x = x.copy()
@@ -309,7 +315,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
309315
if not self.clinear:
310316
y = np.real(y)
311317
if self.ifftshift_before.any():
312-
y = mkl_backend.fftshift(y, axes=self.axes[self.ifftshift_before])
318+
y = scipy.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
313319
return y
314320

315321
def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:

pytests/test_ffts.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22
import os
3-
import sys
43

54
if int(os.environ.get("TEST_CUPY_PYLOPS", 0)):
65
import cupy as np
@@ -17,7 +16,7 @@
1716

1817
from pylops.optimization.basic import lsqr
1918
from pylops.signalprocessing import FFT, FFT2D, FFTND
20-
from pylops.utils import dottest
19+
from pylops.utils import dottest, mkl_fft_enabled
2120

2221

2322
# Utility function
@@ -323,8 +322,8 @@ def test_unknown_engine(par):
323322

324323
@pytest.mark.parametrize("par", pars_fft_small_real)
325324
def test_FFT_small_real(par):
326-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
327-
pytest.skip("mkl_fft not supported on macOS")
325+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
326+
pytest.skip("mkl_fft is not installed")
328327
np.random.seed(5)
329328

330329
if backend == "numpy" or (backend == "cupy" and par["engine"] == "numpy"):
@@ -402,8 +401,8 @@ def test_FFT_small_real(par):
402401
)
403402
@pytest.mark.parametrize("par", pars_fft_random_real)
404403
def test_FFT_random_real(par):
405-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
406-
pytest.skip("mkl_fft not supported on macOS")
404+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
405+
pytest.skip("mkl_fft is not installed")
407406
np.random.seed(5)
408407

409408
shape = par["shape"]
@@ -460,8 +459,8 @@ def test_FFT_random_real(par):
460459

461460
@pytest.mark.parametrize("par", pars_fft_small_cpx)
462461
def test_FFT_small_complex(par):
463-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
464-
pytest.skip("mkl_fft not supported on macOS")
462+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
463+
pytest.skip("mkl_fft is not installed")
465464
np.random.seed(5)
466465
dtype, decimal = par["dtype_precision"]
467466
norm = par["norm"]
@@ -534,8 +533,8 @@ def test_FFT_small_complex(par):
534533

535534
@pytest.mark.parametrize("par", pars_fft_random_cpx)
536535
def test_FFT_random_complex(par):
537-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
538-
pytest.skip("mkl_fft not supported on macOS")
536+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
537+
pytest.skip("mkl_fft is not installed")
539538
np.random.seed(5)
540539
if backend == "numpy" or (backend == "cupy" and par["engine"] == "numpy"):
541540
shape = par["shape"]
@@ -623,8 +622,8 @@ def test_FFT_random_complex(par):
623622
)
624623
@pytest.mark.parametrize("par", pars_fft2d_random_real)
625624
def test_FFT2D_random_real(par):
626-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
627-
pytest.skip("mkl_fft not supported on macOS")
625+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
626+
pytest.skip("mkl_fft is not installed")
628627
np.random.seed(5)
629628
if backend == "numpy" or (backend == "cupy" and par["engine"] == "numpy"):
630629
shape = par["shape"]
@@ -686,8 +685,8 @@ def test_FFT2D_random_real(par):
686685

687686
@pytest.mark.parametrize("par", pars_fft2d_random_cpx)
688687
def test_FFT2D_random_complex(par):
689-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
690-
pytest.skip("mkl_fft not supported on macOS")
688+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
689+
pytest.skip("mkl_fft is not installed")
691690
np.random.seed(5)
692691
if backend == "numpy" or (backend == "cupy" and par["engine"] == "numpy"):
693692
shape = par["shape"]
@@ -769,8 +768,8 @@ def test_FFT2D_random_complex(par):
769768

770769
@pytest.mark.parametrize("par", pars_fftnd_random_real)
771770
def test_FFTND_random_real(par):
772-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
773-
pytest.skip("mkl_fft not supported on macOS")
771+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
772+
pytest.skip("mkl_fft is not installed")
774773
np.random.seed(5)
775774
if backend == "numpy" or (backend == "cupy" and par["engine"] == "numpy"):
776775
shape = par["shape"]
@@ -832,8 +831,8 @@ def test_FFTND_random_real(par):
832831

833832
@pytest.mark.parametrize("par", pars_fftnd_random_cpx)
834833
def test_FFTND_random_complex(par):
835-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
836-
pytest.skip("mkl_fft not supported on macOS")
834+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
835+
pytest.skip("mkl_fft is not installed")
837836
np.random.seed(5)
838837
shape = par["shape"]
839838
dtype, decimal = par["dtype_precision"]
@@ -910,8 +909,8 @@ def test_FFTND_random_complex(par):
910909

911910
@pytest.mark.parametrize("par", pars_fft2dnd_small_cpx)
912911
def test_FFT2D_small_complex(par):
913-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
914-
pytest.skip("mkl_fft not supported on macOS")
912+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
913+
pytest.skip("mkl_fft is not installed")
915914
np.random.seed(5)
916915
dtype, decimal = par["dtype_precision"]
917916
norm = par["norm"]
@@ -962,8 +961,8 @@ def test_FFT2D_small_complex(par):
962961

963962
@pytest.mark.parametrize("par", pars_fft2dnd_small_cpx)
964963
def test_FFTND_small_complex(par):
965-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
966-
pytest.skip("mkl_fft not supported on macOS")
964+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
965+
pytest.skip("mkl_fft is not installed")
967966
np.random.seed(5)
968967
dtype, decimal = par["dtype_precision"]
969968
norm = par["norm"]
@@ -1039,8 +1038,8 @@ def test_FFTND_small_complex(par):
10391038
],
10401039
)
10411040
def test_FFT_1dsignal(par):
1042-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
1043-
pytest.skip("mkl_fft not supported on macOS")
1041+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
1042+
pytest.skip("mkl_fft is not installed")
10441043
np.random.seed(5)
10451044
"""Dot-test and inversion for FFT operator for 1d signal"""
10461045
decimal = 3 if np.real(np.ones(1, par["dtype"])).dtype == np.float32 else 8
@@ -1158,8 +1157,8 @@ def test_FFT_2dsignal(par):
11581157
"""Dot-test and inversion for fft operator for 2d signal
11591158
(fft on single dimension)
11601159
"""
1161-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
1162-
pytest.skip("mkl_fft not supported on macOS")
1160+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
1161+
pytest.skip("mkl_fft is not installed")
11631162
np.random.seed(5)
11641163
decimal = 3 if np.real(np.ones(1, par["dtype"])).dtype == np.float32 else 8
11651164

@@ -1375,8 +1374,8 @@ def test_FFT_3dsignal(par):
13751374
"""Dot-test and inversion for fft operator for 3d signal
13761375
(fft on single dimension)
13771376
"""
1378-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
1379-
pytest.skip("mkl_fft not supported on macOS")
1377+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
1378+
pytest.skip("mkl_fft is not installed")
13801379
np.random.seed(5)
13811380
decimal = 3 if np.real(np.ones(1, par["dtype"])).dtype == np.float32 else 8
13821381

@@ -1603,8 +1602,8 @@ def test_FFT_3dsignal(par):
16031602
)
16041603
def test_FFT2D(par):
16051604
"""Dot-test and inversion for FFT2D operator for 2d signal"""
1606-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
1607-
pytest.skip("mkl_fft not supported on macOS")
1605+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
1606+
pytest.skip("mkl_fft is not installed")
16081607
np.random.seed(5)
16091608
decimal = 3 if np.real(np.ones(1, par["dtype"])).dtype == np.float32 else 8
16101609

@@ -1740,8 +1739,8 @@ def test_FFT2D(par):
17401739
)
17411740
def test_FFT3D(par):
17421741
"""Dot-test and inversion for FFTND operator for 3d signal"""
1743-
if par["engine"] == "mkl_fft" and sys.platform == "darwin":
1744-
pytest.skip("mkl_fft not supported on macOS")
1742+
if par["engine"] == "mkl_fft" and not mkl_fft_enabled:
1743+
pytest.skip("mkl_fft is not installed")
17451744
np.random.seed(5)
17461745
decimal = 3 if np.real(np.ones(1, par["dtype"])).dtype == np.float32 else 8
17471746

0 commit comments

Comments
 (0)