Skip to content

Commit 9fea4e6

Browse files
authored
Merge branch 'master' into dask-interface
2 parents 0389820 + 1c0abc5 commit 9fea4e6

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
### Added
1010
* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224)
1111

12+
### Changed
13+
* In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283)
14+
1215
### Removed
1316
* Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)
1417

mkl_fft/_fft_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ def _check_norm(norm):
4343
)
4444

4545

46-
def _check_shapes_for_direct(xs, shape, axes):
46+
def _check_shapes_for_direct(xs, shape, axes, check_complementary=False):
4747
if len(axes) > 7: # Intel MKL supports up to 7D
4848
return False
49-
if not (len(xs) == len(shape)):
50-
# full-dimensional transform
49+
if not (len(xs) == len(shape)) and not check_complementary:
50+
# full-dimensional transform is required for direct,
51+
# but less than full is OK for complimentary.
5152
return False
5253
if not (len(set(axes)) == len(axes)):
5354
# repeated axes
@@ -382,6 +383,7 @@ def _c2c_fftnd_impl(
382383
if direction not in [-1, +1]:
383384
raise ValueError("Direction of FFT should +1 or -1")
384385

386+
_complementary = s is None
385387
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
386388
# _direct_fftnd requires complex type, and full-dimensional transform
387389
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
@@ -392,6 +394,12 @@ def _c2c_fftnd_impl(
392394
xs, xa = _cook_nd_args(x, s, axes)
393395
if _check_shapes_for_direct(xs, x.shape, xa):
394396
_direct = True
397+
# See if s matches the shape of x along the given axes.
398+
# If it does, we can use _iter_complementary rather than _iter_fftnd.
399+
if _check_shapes_for_direct(
400+
xs, x.shape, xa, check_complementary=True
401+
):
402+
_complementary = True
395403
_direct = _direct and x.dtype in valid_dtypes
396404
else:
397405
_direct = False
@@ -404,7 +412,7 @@ def _c2c_fftnd_impl(
404412
out=out,
405413
)
406414
else:
407-
if s is None and x.dtype in valid_dtypes:
415+
if _complementary and x.dtype in valid_dtypes:
408416
x = np.asarray(x)
409417
if out is None:
410418
res = np.empty_like(x, dtype=_output_dtype(x.dtype))

mkl_fft/tests/test_fftnd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,28 @@ def test_s_axes_out(dtype, s, axes, func):
264264
assert_allclose(r1, r2, rtol=rtol, atol=atol)
265265

266266

267+
@requires_numpy_2
268+
@pytest.mark.parametrize("dtype", [complex, float])
269+
@pytest.mark.parametrize("axes", [(1, 2, 3), (-1, -2, -3), [2, 1, 3]])
270+
@pytest.mark.parametrize("func", ["fftn", "ifftn"])
271+
def test_s_none_vs_s_full(dtype, axes, func):
272+
shape = (2, 30, 20, 10)
273+
if dtype is complex:
274+
x = np.random.random(shape) + 1j * np.random.random(shape)
275+
else:
276+
x = np.random.random(shape)
277+
278+
implied_s = [shape[ax] for ax in axes]
279+
280+
r1 = getattr(np.fft, func)(x, axes=axes)
281+
r2 = getattr(mkl_fft, func)(x, axes=axes)
282+
r3 = getattr(mkl_fft, func)(x, s=implied_s, axes=axes)
283+
284+
rtol, atol = _get_rtol_atol(x)
285+
assert_allclose(r1, r2, rtol=rtol, atol=atol)
286+
assert_allclose(r1, r3, rtol=rtol, atol=atol)
287+
288+
267289
@pytest.mark.parametrize("dtype", [complex, float])
268290
@pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)])
269291
@pytest.mark.parametrize("func", ["rfftn", "irfftn"])

0 commit comments

Comments
 (0)