1616mkl_fft_message = deps .mkl_fft_import ("the mkl fft module" )
1717
1818if mkl_fft_message is None :
19- import mkl_fft .interfaces .scipy_fft as mkl_backend
19+ import mkl_fft .interfaces .numpy_fft as mkl_backend
2020
2121
2222class _FFTND_numpy (_BaseFFTND ):
@@ -238,21 +238,6 @@ def __init__(
238238 dtype : DTypeLike = "complex128" ,
239239 ** kwargs_fft ,
240240 ) -> None :
241- if np .dtype (dtype ) == np .float16 :
242- warnings .warn (
243- "mkl_fft backend is unavailable with float16 dtype. Will use float32."
244- )
245- dtype = np .float32
246- elif np .dtype (dtype ) == np .complex256 :
247- warnings .warn (
248- "mkl_fft backend is unavailable with complex256 dtype. Will use complex128."
249- )
250- dtype = np .complex128
251- elif np .dtype (dtype ) == np .float128 :
252- warnings .warn (
253- "mkl_fft backend is unavailable with float128 dtype. Will use float64."
254- )
255- dtype = np .float64
256241 super ().__init__ (
257242 dims = dims ,
258243 axes = axes ,
@@ -275,12 +260,6 @@ def __init__(
275260
276261 @reshaped
277262 def _matvec (self , x : NDArray ) -> NDArray :
278- if x .dtype == np .float16 :
279- x = x .astype (np .float32 )
280- elif x .dtype == np .float128 :
281- x = x .astype (np .float64 )
282- elif x .dtype == np .complex256 :
283- x = x .astype (np .complex128 )
284263 if self .ifftshift_before .any ():
285264 x = mkl_backend .ifftshift (x , axes = self .axes [self .ifftshift_before ])
286265 if not self .clinear :
@@ -305,12 +284,6 @@ def _matvec(self, x: NDArray) -> NDArray:
305284
306285 @reshaped
307286 def _rmatvec (self , x : NDArray ) -> NDArray :
308- if x .dtype == np .float16 :
309- x = x .astype (np .float32 )
310- elif x .dtype == np .float128 :
311- x = x .astype (np .float64 )
312- elif x .dtype == np .complex256 :
313- x = x .astype (np .complex128 )
314287 if self .fftshift_after .any ():
315288 x = mkl_backend .ifftshift (x , axes = self .axes [self .fftshift_after ])
316289 if self .real :
0 commit comments