@@ -44,11 +44,12 @@ def _check_norm(norm):
4444 )
4545
4646
47- def _check_shapes_for_direct (xs , shape , axes ):
47+ def _check_shapes_for_direct (xs , shape , axes , check_complimentary = False ):
4848 if len (axes ) > 7 : # Intel MKL supports up to 7D
4949 return False
50- if not (len (xs ) == len (shape )):
51- # full-dimensional transform
50+ if not (len (xs ) == len (shape )) and not check_complimentary :
51+ # full-dimensional transform is required for direct,
52+ # but less than full is OK for complimentary.
5253 return False
5354 if not (len (set (axes )) == len (axes )):
5455 # repeated axes
@@ -64,18 +65,6 @@ def _check_shapes_for_direct(xs, shape, axes):
6465 return True
6566
6667
67- def _check_shapes_equiv_s_none (s , shape , axes ):
68- for si , ai in zip (s , axes ):
69- try :
70- sh_ai = shape [ai ]
71- except IndexError :
72- raise ValueError ("Invalid axis (%d) specified" % ai )
73-
74- if si != sh_ai :
75- return False
76- return True
77-
78-
7968def _compute_fwd_scale (norm , n , shape ):
8069 _check_norm (norm )
8170 if norm in (None , "backward" ):
@@ -395,7 +384,7 @@ def _c2c_fftnd_impl(
395384 if direction not in [- 1 , + 1 ]:
396385 raise ValueError ("Direction of FFT should +1 or -1" )
397386
398- s_equiv_to_none = s is None
387+ _complementary = s is None
399388 valid_dtypes = [np .complex64 , np .complex128 , np .float32 , np .float64 ]
400389 # _direct_fftnd requires complex type, and full-dimensional transform
401390 if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
@@ -408,8 +397,8 @@ def _c2c_fftnd_impl(
408397 _direct = True
409398 # See if s matches the shape of x along the given axes.
410399 # If it does, we can use _iter_complementary rather than _iter_fftnd.
411- if _check_shapes_equiv_s_none (xs , x .shape , xa ):
412- s_equiv_to_none = True
400+ if _check_shapes_for_direct (xs , x .shape , xa , check_complimentary = True ):
401+ _complementary = True
413402 _direct = _direct and x .dtype in valid_dtypes
414403 else :
415404 _direct = False
@@ -422,7 +411,7 @@ def _c2c_fftnd_impl(
422411 out = out ,
423412 )
424413 else :
425- if s_equiv_to_none and x .dtype in valid_dtypes :
414+ if _complementary and x .dtype in valid_dtypes :
426415 x = np .asarray (x )
427416 if out is None :
428417 res = np .empty_like (x , dtype = _output_dtype (x .dtype ))
0 commit comments