Skip to content

Commit 78e27dc

Browse files
committed
rename s_equiv_to_none and consolidate shape-checking logic
1 parent ac92374 commit 78e27dc

1 file changed

Lines changed: 8 additions & 19 deletions

File tree

mkl_fft/_fft_utils.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def _check_norm(norm):
4444
)
4545

4646

47-
def _check_shapes_for_direct(xs, shape, axes):
47+
def _check_shapes_for_direct(xs, shape, axes, check_complimentary=False):
4848
if len(axes) > 7: # Intel MKL supports up to 7D
4949
return False
50-
if not (len(xs) == len(shape)):
51-
# full-dimensional transform
50+
if not (len(xs) == len(shape)) and not check_complimentary:
51+
# full-dimensional transform is required for direct,
52+
# but less than full is OK for complimentary.
5253
return False
5354
if not (len(set(axes)) == len(axes)):
5455
# repeated axes
@@ -64,18 +65,6 @@ def _check_shapes_for_direct(xs, shape, axes):
6465
return True
6566

6667

67-
def _check_shapes_equiv_s_none(s, shape, axes):
68-
for si, ai in zip(s, axes):
69-
try:
70-
sh_ai = shape[ai]
71-
except IndexError:
72-
raise ValueError("Invalid axis (%d) specified" % ai)
73-
74-
if si != sh_ai:
75-
return False
76-
return True
77-
78-
7968
def _compute_fwd_scale(norm, n, shape):
8069
_check_norm(norm)
8170
if norm in (None, "backward"):
@@ -395,7 +384,7 @@ def _c2c_fftnd_impl(
395384
if direction not in [-1, +1]:
396385
raise ValueError("Direction of FFT should +1 or -1")
397386

398-
s_equiv_to_none = s is None
387+
_complementary = s is None
399388
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
400389
# _direct_fftnd requires complex type, and full-dimensional transform
401390
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
@@ -408,8 +397,8 @@ def _c2c_fftnd_impl(
408397
_direct = True
409398
# See if s matches the shape of x along the given axes.
410399
# If it does, we can use _iter_complementary rather than _iter_fftnd.
411-
if _check_shapes_equiv_s_none(xs, x.shape, xa):
412-
s_equiv_to_none = True
400+
if _check_shapes_for_direct(xs, x.shape, xa, check_complimentary=True):
401+
_complementary = True
413402
_direct = _direct and x.dtype in valid_dtypes
414403
else:
415404
_direct = False
@@ -422,7 +411,7 @@ def _c2c_fftnd_impl(
422411
out=out,
423412
)
424413
else:
425-
if s_equiv_to_none and x.dtype in valid_dtypes:
414+
if _complementary and x.dtype in valid_dtypes:
426415
x = np.asarray(x)
427416
if out is None:
428417
res = np.empty_like(x, dtype=_output_dtype(x.dtype))

0 commit comments

Comments
 (0)